In [1]:
import os
import sys
import argparse
# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from utils import *
from Model.DRPreter import DRPreter
from Model.Similarity import Similarity
from torch_scatter import scatter_add

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
parser = argparse.ArgumentParser()
parser.add_argument('-f')
parser.add_argument('--seed', type=int, default=42, help='seed')
parser.add_argument('--device', type=int, default=0, help='device')
parser.add_argument('--batch_size', type=int, default=128, help='batch size (default: 128)')
parser.add_argument('--lr', type=float, default=0.0001, help='learning rate (default: 0.0001)')
parser.add_argument('--layer', type=int, default=3, help='Number of cell layers')
parser.add_argument('--hidden_dim', type=int, default=8, help='hidden dim for cell')
parser.add_argument('--layer_drug', type=int, default=3, help='Number of drug layers')
parser.add_argument('--dim_drug', type=int, default=128, help='hidden dim for drug (default: 128)')
parser.add_argument('--dim_drug_cell', type=int, default=256, help='hidden dim for drug and cell (default: 256)')
parser.add_argument('--dropout_ratio', type=float, default=0.1, help='Dropout ratio (default: 0.1)')
parser.add_argument('--epochs', type=int, default=300, help='Maximum number of epochs (default: 300)')
parser.add_argument('--patience', type=int, default=100, help='patience for early stopping (default: 10)')
parser.add_argument('--mode', type=str, default='train', help='train, test')
parser.add_argument('--edge', type=str, default='STRING', help='STRING, BIOGRID') # BIOGRID: removed
# parser.add_argument('--string_edge', type=float, default=0.99, help='Threshold for edges of cell line graph')
parser.add_argument('--string_edge', type=int, default=990, help='Threshold for edges of cell line graph')
parser.add_argument('--dataset', type=str, default='2369disjoint', help='2369joint, 2369disjoint, COSMIC')
parser.add_argument('--trans', type=bool, default=True, help='Use Transformer or not')
parser.add_argument('--sim', type=bool, default=False, help='Construct homogeneous similarity networks or not')
args = parser.parse_args()
args

Namespace(f='/home/yingfei/.local/share/jupyter/runtime/kernel-b9324725-e018-4094-9184-b2e6eacf0b92.json', seed=42, device=0, batch_size=128, lr=0.0001, layer=3, hidden_dim=8, layer_drug=3, dim_drug=128, dim_drug_cell=256, dropout_ratio=0.1, epochs=300, patience=100, mode='train', edge='STRING', string_edge=990, dataset='2369disjoint', trans=True, sim=False)

In [3]:
args.dataset = "disjoint"

In [4]:
args.device = 'cuda:{}'.format(args.device)
rpath = '../'
result_path = rpath + 'Result_0308/Result_2/' ### Adjust folder name

print(f'seed: {args.seed}')
set_random_seed(args.seed)

edge_type = 'PPI_'+str(args.string_edge) if args.edge=='STRING' else args.edge
edge_index = np.load(rpath+f'/DRPreter/Data/Cell/edge_index_{edge_type}_{args.dataset}_cnv.npy') ### Adjust file name

# data = pd.read_csv(rpath+'Data/sorted_IC50_82833_580_170.csv')
data = pd.read_csv(rpath+'/DRPreter/Data/sorted_AUC.csv')

### Adjust file name
drug_dict = np.load(rpath+'/DRPreter/Data/Drug/drug_feature_graph.npy', allow_pickle=True).item() # pyg format of drug graph
cell_dict = np.load(rpath+f'/DRPreter/Data/Cell/cell_feature_std_{args.dataset}_cnv.npy', allow_pickle=True).item() 
# pyg data format of cell graph

example = cell_dict['ACH-000001']
args.num_feature = example.x.shape[1] # 1
args.num_genes = example.x.shape[0] # 4646

seed: 42


In [5]:
### Added
# gene_list = scatter_add(torch.ones_like(example.x.squeeze()), example.x_mask.to(torch.int64).reshape(-1,2)).to(torch.int)
gene_list = scatter_add(torch.ones_like(example.x.squeeze()), example.x_mask.to(torch.int64)).to(torch.int)
args.max_gene = gene_list.max().item()
# gene_list = torch.div(torch.sum(gene_list,0),2) ### Added
# print(gene_list.shape)
args.cum_num_nodes = torch.cat([gene_list.new_zeros(1), gene_list.cumsum(dim=0)], dim=0)
args.n_pathways = gene_list.size(0)
print('num_genes:{}, num_edges:{}'.format(args.num_genes, len(edge_index[0])))
print('gene distribution: {}'.format(gene_list))
print('mean degree:{}'.format(len(edge_index[0]) / args.num_genes))

num_genes:4787, num_edges:12244
gene distribution: tensor([142,  94, 137, 294,  74, 155, 218,  79, 146, 190, 136, 293, 167,  73,
        232,  59, 239, 109, 104,  85, 126, 101, 221,  59,  56,  71,  44, 351,
        162,  96, 102,  41,  88, 102, 141], dtype=torch.int32)
mean degree:2.557760601629413


In [14]:
args.device = "cuda:0" # 'cuda:{}'.format(args.device)

In [15]:
train_loader, val_loader, test_loader = load_data(data, drug_dict, cell_dict, torch.tensor(edge_index, dtype=torch.long), args)
print('total: {}, train: {}, val: {}, test: {}'.format(len(data), len(train_loader.dataset), len(val_loader.dataset), len(test_loader.dataset)))
model = DRPreter(args).to("cuda:0")

total: 227436, train: 184020, val: 20472, test: 22944


In [16]:
state_dict_name = f'weights_0308/weights_2/weight_sim_seed{args.seed}.pth' if args.sim==True else f'weights_0308/weights_2/weight_seed{args.seed}.pth'
state_dict_name

'weights_0308/weights_2/weight_seed42.pth'

In [17]:
model.load_state_dict(torch.load(state_dict_name, map_location="cuda:0")['model_state_dict'])

<All keys matched successfully>

In [18]:
embedding(model, "JW-7-24-1", "ACH-000001", drug_dict, cell_dict, edge_index, args) # drug embedding and cell line embedding

(tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.4574, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 2.1913, 0.0000, 0.6364,
          0.0000, 0.0000, 0.4625, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9763,
          1.9686, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 2.8609, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 2.1265, 0.0000, 0.0000, 0.0000, 0.0000, 2.5022, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.4160, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.3042, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.7395,
          0.0000, 0.0000, 0.4691, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3186, 0.0000, 0.0000, 0.0000,
          0.0000

In [19]:
cell_line_embedding = embedding(model, "JW-7-24-1", "ACH-000001", drug_dict, cell_dict, edge_index, args)[1]
cell_line_embedding.shape

torch.Size([1, 35, 256])

In [20]:
cell_line_embedding[0][0] # pathway embedding

tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.4904, 0.0110, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2863,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.1068, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1471, 0.0966, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.1114, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6240, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 

In [21]:
embedding(model, "JW-7-24-1", "ACH-000001", drug_dict, cell_dict, edge_index, args)[0] # drug embedding

tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.4574, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 2.1913, 0.0000, 0.6364,
         0.0000, 0.0000, 0.4625, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9763,
         1.9686, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 2.8609, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 2.1265, 0.0000, 0.0000, 0.0000, 0.0000, 2.5022, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.4160, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.3042, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.7395,
         0.0000, 0.0000, 0.4691, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3186, 0.0000, 0.0000, 0.0000,
         0.0000, 0.4200, 0.0

In [22]:
embedding(model, "JW-7-24-1", "ACH-000002", drug_dict, cell_dict, edge_index, args)[0] # drug embedding

tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.4574, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 2.1913, 0.0000, 0.6364,
         0.0000, 0.0000, 0.4625, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.9763,
         1.9686, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 2.8609, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 2.1265, 0.0000, 0.0000, 0.0000, 0.0000, 2.5022, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.4160, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.3042, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.7395,
         0.0000, 0.0000, 0.4691, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3186, 0.0000, 0.0000, 0.0000,
         0.0000, 0.4200, 0.0

In [23]:
model.CellEncoder.__dict__

{'training': False,
 '_parameters': OrderedDict(),
 '_buffers': OrderedDict(),
 '_non_persistent_buffers_set': set(),
 '_backward_hooks': OrderedDict(),
 '_is_full_backward_hook': None,
 '_forward_hooks': OrderedDict(),
 '_forward_pre_hooks': OrderedDict(),
 '_state_dict_hooks': OrderedDict(),
 '_load_state_dict_pre_hooks': OrderedDict(),
 '_load_state_dict_post_hooks': OrderedDict(),
 '_modules': OrderedDict([('convs_cell',
               ModuleList(
                 (0): GATConv(1, 8, heads=1)
                 (1): GATConv(8, 8, heads=1)
                 (2): GATConv(8, 8, heads=1)
               ))]),
 'num_feature': 1,
 'layer_cell': 3,
 'dim_cell': 8,
 'final_node': 4787}

In [24]:
example.x

tensor([[ 0.6484],
        [-0.2951],
        [-2.0614],
        ...,
        [-0.3492],
        [ 0.9511],
        [ 0.5783]])

In [25]:
cell = Batch.from_data_list([cell_dict["ACH-000001"]]).to(args.device)

In [26]:
cell

DataBatch(x=[4787, 1], x_mask=[4787], edge_index=[2, 12244], batch=[4787], ptr=[2])

In [27]:
model.CellEncoder(cell).shape

torch.Size([1, 38296])

In [28]:
cell = Batch.from_data_list([cell_dict["ACH-000001"]]).to(args.device)
model.CellEncoder(cell)

tensor([[0.0309, 0.0000, 0.2006,  ..., 0.0000, 0.1764, 0.1584]],
       device='cuda:0', grad_fn=<ReshapeAliasBackward0>)

In [29]:
cell = Batch.from_data_list([cell_dict["ACH-000001"]]).to(args.device)
model.CellEncoder(cell)[0][:8]

tensor([0.0309, 0.0000, 0.2006, 0.0734, 0.1617, 0.0000, 0.0804, 0.0808],
       device='cuda:0', grad_fn=<SliceBackward0>)

In [30]:
cell = Batch.from_data_list([cell_dict["ACH-000002"]]).to(args.device)
model.CellEncoder(cell)[0][:8]

tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0123, 0.0000, 0.0000],
       device='cuda:0', grad_fn=<SliceBackward0>)

In [31]:
gene_list.sum()

tensor(4787)

In [34]:
import pickle
with open('Data/Cell/34pathway_score990_isolated.pkl', 'rb') as file:
    kegg = pickle.load(file)

In [37]:
sum(list(map(lambda x: len(x),kegg.values())))

4787

In [62]:
embedding(model, "JW-7-24-1", "ACH-000001", drug_dict, cell_dict, edge_index, args)[1][0][0] # pathway embedding

tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.4904, 0.0110, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2863,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.1068, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1471, 0.0966, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.1114, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6240, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 

In [63]:
embedding(model, "KIN001-260", "ACH-000001", drug_dict, cell_dict, edge_index, args)[1][0][0]

tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.4904, 0.0110, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2863,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.1068, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1471, 0.0966, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.1114, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6240, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 

In [44]:
embedding(model, "JW-7-24-1", "ACH-000002", drug_dict, cell_dict, edge_index, args)[1][0][0]

tensor([0.0000e+00, 0.0000e+00, 1.3980e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        1.4408e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 2.1533e-02, 3.8876e-03,
        2.7836e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 3.9001e-02,
        9.4033e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 2.4949e-02, 1.7200e-02, 0.0000e+00,
        0.0000e+00, 2.9670e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        5.4809e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 2.7424e-02,
        2.7755e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.6844e-02,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 2.3852e-04, 0.0000e+00,
        0.0000e+00, 5.7737e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 4.5195e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        8.2854e-03, 3.2079e-02, 0.0000e+00, 9.1569e-03, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 1.0920e-

In [49]:
len(data.DepMap_ID.unique())

692

In [50]:
embedding(model, "JW-7-24-1", "ACH-000002", drug_dict, cell_dict, edge_index, args)[1].shape

torch.Size([1, 35, 256])

In [51]:
cell_line_embedding_dict = {}
for cell_line in list(data.DepMap_ID.unique()):
    cell_line_embedding_dict[cell_line] = embedding(model, "KIN001-260", cell_line, 
                                                    drug_dict, cell_dict, edge_index, args)[1]

In [52]:
with open('cell_line_embedding_dict.pickle', 'wb') as file:
    pickle.dump(cell_line_embedding_dict, file)

In [53]:
with open('cell_line_embedding_dict.pickle', 'rb') as file:
    cell_line_embedding_dict = pickle.load(file)

In [55]:
cell_line_embedding_dict['ACH-000001'].shape

torch.Size([1, 35, 256])

In [58]:
len(data["Drug name"].unique())

385

In [68]:
drug_embedding_dict = {}
for drug in list(data["Drug name"].unique()):
    drug_embedding_dict[drug] = embedding(model, drug, 'ACH-000001', 
                                                    drug_dict, cell_dict, edge_index, args)[0]

In [69]:
with open('drug_embedding_dict.pickle', 'wb') as file:
    pickle.dump(drug_embedding_dict, file)

In [70]:
with open('drug_embedding_dict.pickle', 'rb') as file:
    drug_embedding_dict = pickle.load(file)

In [72]:
drug_embedding_dict['JW-7-24-1'].shape

torch.Size([1, 256])