In [1]:
import os
import sys
import argparse
# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from utils import *
from Model.DRPreter import DRPreter
from Model.Similarity import Similarity
from torch_scatter import scatter_add

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
parser = argparse.ArgumentParser()
parser.add_argument('-f')
parser.add_argument('--seed', type=int, default=42, help='seed')
parser.add_argument('--device', type=int, default=0, help='device')
parser.add_argument('--batch_size', type=int, default=128, help='batch size (default: 128)')
parser.add_argument('--lr', type=float, default=0.0001, help='learning rate (default: 0.0001)')
parser.add_argument('--layer', type=int, default=3, help='Number of cell layers')
parser.add_argument('--hidden_dim', type=int, default=8, help='hidden dim for cell')
parser.add_argument('--layer_drug', type=int, default=3, help='Number of drug layers')
parser.add_argument('--dim_drug', type=int, default=128, help='hidden dim for drug (default: 128)')
parser.add_argument('--dim_drug_cell', type=int, default=256, help='hidden dim for drug and cell (default: 256)')
parser.add_argument('--dropout_ratio', type=float, default=0.1, help='Dropout ratio (default: 0.1)')
parser.add_argument('--epochs', type=int, default=300, help='Maximum number of epochs (default: 300)')
parser.add_argument('--patience', type=int, default=100, help='patience for early stopping (default: 10)')
parser.add_argument('--mode', type=str, default='train', help='train, test')
parser.add_argument('--edge', type=str, default='STRING', help='STRING, BIOGRID') # BIOGRID: removed
# parser.add_argument('--string_edge', type=float, default=0.99, help='Threshold for edges of cell line graph')
parser.add_argument('--string_edge', type=int, default=990, help='Threshold for edges of cell line graph')
parser.add_argument('--dataset', type=str, default='2369disjoint', help='2369joint, 2369disjoint, COSMIC')
parser.add_argument('--trans', type=bool, default=True, help='Use Transformer or not')
parser.add_argument('--sim', type=bool, default=False, help='Construct homogeneous similarity networks or not')
args = parser.parse_args()
args

Namespace(f='/home/yingfei/.local/share/jupyter/runtime/kernel-e533e146-4fb6-48f1-ae99-6fbff83c0776.json', seed=42, device=0, batch_size=128, lr=0.0001, layer=3, hidden_dim=8, layer_drug=3, dim_drug=128, dim_drug_cell=256, dropout_ratio=0.1, epochs=300, patience=100, mode='train', edge='STRING', string_edge=990, dataset='2369disjoint', trans=True, sim=False)

In [3]:
args.dataset = "disjoint"

In [4]:
args.device = 'cuda:{}'.format(args.device)
rpath = '../'
result_path = rpath + 'Result_0308/Result_4/' ### Adjust folder name

print(f'seed: {args.seed}')
set_random_seed(args.seed)

edge_type = 'PPI_'+str(args.string_edge) if args.edge=='STRING' else args.edge
edge_index = np.load(rpath+f'/DRPreter/Data/Cell/edge_index_{edge_type}_{args.dataset}_seg_cnv.npy') ### Adjust file name

# data = pd.read_csv(rpath+'Data/sorted_IC50_82833_580_170.csv')
data = pd.read_csv(rpath+'/DRPreter/Data/sorted_AUC.csv')

### Adjust file name
drug_dict = np.load(rpath+'/DRPreter/Data/Drug/drug_feature_graph.npy', allow_pickle=True).item() # pyg format of drug graph
cell_dict = np.load(rpath+f'/DRPreter/Data/Cell/cell_feature_std_{args.dataset}_seg_cnv.npy', allow_pickle=True).item() 
# pyg data format of cell graph

example = cell_dict['ACH-000001']
args.num_feature = example.x.shape[1] # 1
args.num_genes = example.x.shape[0] # 4646

seed: 42


In [5]:
### Added
# gene_list = scatter_add(torch.ones_like(example.x.squeeze()), example.x_mask.to(torch.int64).reshape(-1,2)).to(torch.int)
gene_list = scatter_add(torch.ones_like(example.x.squeeze()), example.x_mask.to(torch.int64)).to(torch.int)
args.max_gene = gene_list.max().item()
# gene_list = torch.div(torch.sum(gene_list,0),2) ### Added
# print(gene_list.shape)
args.cum_num_nodes = torch.cat([gene_list.new_zeros(1), gene_list.cumsum(dim=0)], dim=0)
args.n_pathways = gene_list.size(0)
print('num_genes:{}, num_edges:{}'.format(args.num_genes, len(edge_index[0])))
print('gene distribution: {}'.format(gene_list))
print('mean degree:{}'.format(len(edge_index[0]) / args.num_genes))

num_genes:4787, num_edges:12244
gene distribution: tensor([142,  94, 137, 294,  74, 155, 218,  79, 146, 190, 136, 293, 167,  73,
        232,  59, 239, 109, 104,  85, 126, 101, 221,  59,  56,  71,  44, 351,
        162,  96, 102,  41,  88, 102, 141], dtype=torch.int32)
mean degree:2.557760601629413


In [6]:
args.device = "cuda:0" # 'cuda:{}'.format(args.device)

In [7]:
train_loader, val_loader, test_loader = load_data(data, drug_dict, cell_dict, torch.tensor(edge_index, dtype=torch.long), args)
print('total: {}, train: {}, val: {}, test: {}'.format(len(data), len(train_loader.dataset), len(val_loader.dataset), len(test_loader.dataset)))
model = DRPreter(args).to("cuda:0")

total: 227436, train: 184020, val: 20472, test: 22944


In [8]:
state_dict_name = f'weights_0308/weights_4/weight_sim_seed{args.seed}.pth' if args.sim==True else f'weights_0308/weights_4/weight_seed{args.seed}.pth'
state_dict_name

'weights_0308/weights_4/weight_seed42.pth'

In [9]:
model.load_state_dict(torch.load(state_dict_name, map_location="cuda:0")['model_state_dict'])

<All keys matched successfully>

In [10]:
embedding(model, "JW-7-24-1", "ACH-000001", drug_dict, cell_dict, edge_index, args) # drug embedding and cell line embedding

(tensor([[0.0000e+00, 1.8140e+00, 2.0742e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 5.5125e-01, 7.2253e-01, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 6.2656e-01, 2.2766e+00, 0.0000e+00, 2.7307e+00, 5.0410e-01,
          0.0000e+00, 0.0000e+00, 0.0000e+00, 2.6944e+00, 0.0000e+00, 1.5380e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.3025e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 4.0905e+00, 1.0375e+00,
          1.6871e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 5.1949e-01,
          0.0000e+00, 0.0000e+00, 7.3568e-01, 0.0000e+00, 5.3732e-02, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.1285e-01,
          1.8296e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          4.2197

In [11]:
cell_line_embedding = embedding(model, "JW-7-24-1", "ACH-000001", drug_dict, cell_dict, edge_index, args)[1]
cell_line_embedding.shape

torch.Size([1, 35, 256])

In [12]:
cell_line_embedding[0][0] # pathway embedding

tensor([0.0000, 0.0000, 0.0225, 0.0000, 0.0000, 0.0000, 0.0688, 0.0240, 0.0000,
        0.0000, 0.0121, 0.0101, 0.0324, 0.0000, 0.0000, 0.0000, 0.0000, 0.1403,
        0.0014, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0776, 0.0000, 0.0000, 0.0000, 0.0308, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0255, 0.0325, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0057, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0505, 0.0000, 0.0157, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0104, 0.0000, 0.0000, 0.0053, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0039, 0.0000, 0.0000, 0.2740, 0.0000, 0.0000,
        0.0000, 0.0146, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0700, 0.0000,
        0.0000, 0.0157, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0053,
        0.0000, 0.0000, 0.0000, 0.0000, 

In [13]:
embedding(model, "JW-7-24-1", "ACH-000001", drug_dict, cell_dict, edge_index, args)[0] # drug embedding

tensor([[0.0000e+00, 1.8140e+00, 2.0742e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 5.5125e-01, 7.2253e-01, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 6.2656e-01, 2.2766e+00, 0.0000e+00, 2.7307e+00, 5.0410e-01,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 2.6944e+00, 0.0000e+00, 1.5380e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.3025e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 4.0905e+00, 1.0375e+00,
         1.6871e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 5.1949e-01,
         0.0000e+00, 0.0000e+00, 7.3568e-01, 0.0000e+00, 5.3732e-02, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.1285e-01,
         1.8296e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         4.2197e+00, 1.8371e

In [14]:
embedding(model, "JW-7-24-1", "ACH-000002", drug_dict, cell_dict, edge_index, args)[0] # drug embedding

tensor([[0.0000e+00, 1.8140e+00, 2.0742e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 5.5125e-01, 7.2253e-01, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 6.2656e-01, 2.2766e+00, 0.0000e+00, 2.7307e+00, 5.0410e-01,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 2.6944e+00, 0.0000e+00, 1.5380e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.3025e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 4.0905e+00, 1.0375e+00,
         1.6871e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 5.1949e-01,
         0.0000e+00, 0.0000e+00, 7.3568e-01, 0.0000e+00, 5.3732e-02, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.1285e-01,
         1.8296e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         4.2197e+00, 1.8371e

In [15]:
model.CellEncoder.__dict__

{'training': False,
 '_parameters': OrderedDict(),
 '_buffers': OrderedDict(),
 '_non_persistent_buffers_set': set(),
 '_backward_hooks': OrderedDict(),
 '_is_full_backward_hook': None,
 '_forward_hooks': OrderedDict(),
 '_forward_pre_hooks': OrderedDict(),
 '_state_dict_hooks': OrderedDict(),
 '_load_state_dict_pre_hooks': OrderedDict(),
 '_load_state_dict_post_hooks': OrderedDict(),
 '_modules': OrderedDict([('convs_cell',
               ModuleList(
                 (0): GATConv(1, 8, heads=1)
                 (1): GATConv(8, 8, heads=1)
                 (2): GATConv(8, 8, heads=1)
               ))]),
 'num_feature': 1,
 'layer_cell': 3,
 'dim_cell': 8,
 'final_node': 4787}

In [16]:
example.x

tensor([[ 1.6609],
        [-0.0702],
        [-1.4185],
        ...,
        [-1.5888],
        [ 1.8395],
        [ 1.2880]])

In [17]:
cell = Batch.from_data_list([cell_dict["ACH-000001"]]).to(args.device)

In [18]:
cell

DataBatch(x=[4787, 1], x_mask=[4787], edge_index=[2, 12244], batch=[4787], ptr=[2])

In [19]:
model.CellEncoder(cell).shape

torch.Size([1, 38296])

In [20]:
cell = Batch.from_data_list([cell_dict["ACH-000001"]]).to(args.device)
model.CellEncoder(cell)

tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.4808, 0.4501]],
       device='cuda:0', grad_fn=<ReshapeAliasBackward0>)

In [21]:
cell = Batch.from_data_list([cell_dict["ACH-000001"]]).to(args.device)
model.CellEncoder(cell)[0][:8]

tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0439, 0.0000, 0.0091, 0.0000],
       device='cuda:0', grad_fn=<SliceBackward0>)

In [22]:
cell = Batch.from_data_list([cell_dict["ACH-000002"]]).to(args.device)
model.CellEncoder(cell)[0][:8]

tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0013, 0.0000, 0.0000],
       device='cuda:0', grad_fn=<SliceBackward0>)

In [23]:
gene_list.sum()

tensor(4787)

In [24]:
import pickle
with open('Data/Cell/34pathway_score990_isolated.pkl', 'rb') as file:
    kegg = pickle.load(file)

In [25]:
sum(list(map(lambda x: len(x),kegg.values())))

4787

In [26]:
embedding(model, "JW-7-24-1", "ACH-000001", drug_dict, cell_dict, edge_index, args)[1][0][0] # pathway embedding

tensor([0.0000, 0.0000, 0.0225, 0.0000, 0.0000, 0.0000, 0.0688, 0.0240, 0.0000,
        0.0000, 0.0121, 0.0101, 0.0324, 0.0000, 0.0000, 0.0000, 0.0000, 0.1403,
        0.0014, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0776, 0.0000, 0.0000, 0.0000, 0.0308, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0255, 0.0325, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0057, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0505, 0.0000, 0.0157, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0104, 0.0000, 0.0000, 0.0053, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0039, 0.0000, 0.0000, 0.2740, 0.0000, 0.0000,
        0.0000, 0.0146, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0700, 0.0000,
        0.0000, 0.0157, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0053,
        0.0000, 0.0000, 0.0000, 0.0000, 

In [27]:
embedding(model, "KIN001-260", "ACH-000001", drug_dict, cell_dict, edge_index, args)[1][0][0]

tensor([0.0000, 0.0000, 0.0225, 0.0000, 0.0000, 0.0000, 0.0688, 0.0240, 0.0000,
        0.0000, 0.0121, 0.0101, 0.0324, 0.0000, 0.0000, 0.0000, 0.0000, 0.1403,
        0.0014, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0776, 0.0000, 0.0000, 0.0000, 0.0308, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0255, 0.0325, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0057, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0505, 0.0000, 0.0157, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0104, 0.0000, 0.0000, 0.0053, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0039, 0.0000, 0.0000, 0.2740, 0.0000, 0.0000,
        0.0000, 0.0146, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0700, 0.0000,
        0.0000, 0.0157, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0053,
        0.0000, 0.0000, 0.0000, 0.0000, 

In [28]:
embedding(model, "JW-7-24-1", "ACH-000002", drug_dict, cell_dict, edge_index, args)[1][0][0]

tensor([0.0000, 0.0000, 0.0253, 0.0000, 0.0000, 0.0000, 0.0710, 0.0195, 0.0000,
        0.0000, 0.0230, 0.0177, 0.0339, 0.0000, 0.0000, 0.0000, 0.0000, 0.1240,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0792, 0.0008, 0.0000, 0.0000, 0.0340, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0387, 0.0252, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0567, 0.0000, 0.0094, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0162, 0.0000, 0.0000, 0.0142, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0121, 0.0000, 0.0000, 0.2872, 0.0000, 0.0064,
        0.0007, 0.0136, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0713, 0.0000,
        0.0000, 0.0216, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0086,
        0.0016, 0.0000, 0.0000, 0.0000, 

In [29]:
len(data.DepMap_ID.unique())

692

In [30]:
embedding(model, "JW-7-24-1", "ACH-000002", drug_dict, cell_dict, edge_index, args)[1].shape

torch.Size([1, 35, 256])

In [31]:
cell_line_embedding_dict = {}
for cell_line in list(data.DepMap_ID.unique()):
    cell_line_embedding_dict[cell_line] = embedding(model, "KIN001-260", cell_line, 
                                                    drug_dict, cell_dict, edge_index, args)[1]

In [32]:
with open('cell_line_embedding_dict_seg_cnv.pickle', 'wb') as file:
    pickle.dump(cell_line_embedding_dict, file)

In [33]:
with open('cell_line_embedding_dict_seg_cnv.pickle', 'rb') as file:
    cell_line_embedding_dict = pickle.load(file)

In [34]:
cell_line_embedding_dict['ACH-000001'].shape

torch.Size([1, 35, 256])

In [35]:
len(data["Drug name"].unique())

385

In [36]:
drug_embedding_dict = {}
for drug in list(data["Drug name"].unique()):
    drug_embedding_dict[drug] = embedding(model, drug, 'ACH-000001', 
                                                    drug_dict, cell_dict, edge_index, args)[0]

In [37]:
with open('drug_embedding_dict_seg_cnv.pickle', 'wb') as file:
    pickle.dump(drug_embedding_dict, file)

In [38]:
with open('drug_embedding_dict_seg_cnv.pickle', 'rb') as file:
    drug_embedding_dict = pickle.load(file)

In [39]:
drug_embedding_dict['JW-7-24-1'].shape

torch.Size([1, 256])