In [1]:
import os
import sys
import argparse
# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from utils import *
from Model.DRPreter import DRPreter
from Model.Similarity import Similarity
from torch_scatter import scatter_add

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
parser = argparse.ArgumentParser()
parser.add_argument('-f')
parser.add_argument('--seed', type=int, default=42, help='seed')
parser.add_argument('--device', type=int, default=0, help='device')
parser.add_argument('--batch_size', type=int, default=128, help='batch size (default: 128)')
parser.add_argument('--lr', type=float, default=0.0001, help='learning rate (default: 0.0001)')
parser.add_argument('--layer', type=int, default=3, help='Number of cell layers')
parser.add_argument('--hidden_dim', type=int, default=8, help='hidden dim for cell')
parser.add_argument('--layer_drug', type=int, default=3, help='Number of drug layers')
parser.add_argument('--dim_drug', type=int, default=128, help='hidden dim for drug (default: 128)')
parser.add_argument('--dim_drug_cell', type=int, default=256, help='hidden dim for drug and cell (default: 256)')
parser.add_argument('--dropout_ratio', type=float, default=0.1, help='Dropout ratio (default: 0.1)')
parser.add_argument('--epochs', type=int, default=300, help='Maximum number of epochs (default: 300)')
parser.add_argument('--patience', type=int, default=100, help='patience for early stopping (default: 10)')
parser.add_argument('--mode', type=str, default='train', help='train, test')
parser.add_argument('--edge', type=str, default='STRING', help='STRING, BIOGRID') # BIOGRID: removed
# parser.add_argument('--string_edge', type=float, default=0.99, help='Threshold for edges of cell line graph')
parser.add_argument('--string_edge', type=int, default=990, help='Threshold for edges of cell line graph')
parser.add_argument('--dataset', type=str, default='2369disjoint', help='2369joint, 2369disjoint, COSMIC')
parser.add_argument('--trans', type=bool, default=True, help='Use Transformer or not')
parser.add_argument('--sim', type=bool, default=False, help='Construct homogeneous similarity networks or not')
args = parser.parse_args()
args

Namespace(f='/home/yingfei/.local/share/jupyter/runtime/kernel-0d24bc1c-7b1f-4309-ab1f-a3e9cf468b9b.json', seed=42, device=0, batch_size=128, lr=0.0001, layer=3, hidden_dim=8, layer_drug=3, dim_drug=128, dim_drug_cell=256, dropout_ratio=0.1, epochs=300, patience=100, mode='train', edge='STRING', string_edge=990, dataset='2369disjoint', trans=True, sim=False)

In [3]:
args.dataset = "disjoint"

In [6]:
args.device = 'cuda:{}'.format(args.device)
rpath = '../'
result_path = rpath + 'Result_new/Result_5/' ### Adjust folder name

print(f'seed: {args.seed}')
set_random_seed(args.seed)

edge_type = 'PPI_'+str(args.string_edge) if args.edge=='STRING' else args.edge
edge_index = np.load(rpath+f'/DRPreter/Data/Cell/edge_index_{edge_type}_{args.dataset}_cnv.npy') ### Adjust file name

# data = pd.read_csv(rpath+'Data/sorted_IC50_82833_580_170.csv')
data = pd.read_csv(rpath+'/DRPreter/Data/sorted_AUC.csv')

### Adjust file name
drug_dict = np.load(rpath+'/DRPreter/Data/Drug/drug_feature_graph.npy', allow_pickle=True).item() # pyg format of drug graph
cell_dict = np.load(rpath+f'/DRPreter/Data/Cell/cell_feature_std_{args.dataset}_cnv.npy', allow_pickle=True).item() 
# pyg data format of cell graph

example = cell_dict['ACH-000001']
args.num_feature = example.x.shape[1] # 1
args.num_genes = example.x.shape[0] # 4646

seed: 42


In [7]:
### Added
# gene_list = scatter_add(torch.ones_like(example.x.squeeze()), example.x_mask.to(torch.int64).reshape(-1,2)).to(torch.int)
gene_list = scatter_add(torch.ones_like(example.x.squeeze()), example.x_mask.to(torch.int64)).to(torch.int)
args.max_gene = gene_list.max().item()
# gene_list = torch.div(torch.sum(gene_list,0),2) ### Added
# print(gene_list.shape)
args.cum_num_nodes = torch.cat([gene_list.new_zeros(1), gene_list.cumsum(dim=0)], dim=0)
args.n_pathways = gene_list.size(0)
print('num_genes:{}, num_edges:{}'.format(args.num_genes, len(edge_index[0])))
print('gene distribution: {}'.format(gene_list))
print('mean degree:{}'.format(len(edge_index[0]) / args.num_genes))

num_genes:4645, num_edges:12128
gene distribution: tensor([141,  94, 137, 294,  74, 155, 218,  79, 146, 190, 136, 293, 167,  73,
        232,  59, 239, 109, 104,  85, 126, 101, 221,  59,  56,  71,  44, 351,
        162,  96, 102,  41,  88, 102], dtype=torch.int32)
mean degree:2.610979547900969


In [11]:
args.device = 'cuda:{}'.format(args.device)

In [12]:
train_loader, val_loader, test_loader = load_data(data, drug_dict, cell_dict, torch.tensor(edge_index, dtype=torch.long), args)
print('total: {}, train: {}, val: {}, test: {}'.format(len(data), len(train_loader.dataset), len(val_loader.dataset), len(test_loader.dataset)))
model = DRPreter(args).to(args.device)

total: 227436, train: 184020, val: 20472, test: 22944


In [22]:
state_dict_name = f'weights_new/weights_3/weight_sim_seed{args.seed}.pth' if args.sim==True else f'weights_new/weights_3/weight_seed{args.seed}.pth'
state_dict_name

'weights_new/weights_3/weight_seed42.pth'

In [24]:
model.load_state_dict(torch.load(state_dict_name, map_location=args.device)['model_state_dict'])

<All keys matched successfully>

In [28]:
embedding(model, "JW-7-24-1", "ACH-000001", drug_dict, cell_dict, edge_index, args) # drug embedding and cell line embedding

(tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 2.1086, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          1.5615, 0.0000, 0.0000, 0.0000, 0.2116, 0.0000, 0.0000, 0.2023, 0.0000,
          3.1867, 0.0000, 0.0000, 1.0397, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.4473, 0.0000, 0.0000,
          0.0000, 4.9132, 2.4176, 0.0000, 0.0000, 0.0000, 0.1065, 0.0000, 0.0821,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 1.1596, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.1694, 0.2488, 0.0000, 0.0000, 0.2249, 2.5764,
          0.0000, 0.0000, 0.0000, 0.0000, 1.3447, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0607, 0.0000, 0.1639, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 2.7443, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000

In [39]:
cell_line_embedding = embedding(model, "JW-7-24-1", "ACH-000001", drug_dict, cell_dict, edge_index, args)[1]
cell_line_embedding.shape

torch.Size([1, 34, 256])

In [43]:
cell_line_embedding[0][0] # pathway embedding

tensor([0.0000e+00, 0.0000e+00, 5.6932e-03, 0.0000e+00, 1.6928e-02, 0.0000e+00,
        6.4673e-02, 1.9251e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.6711e-02,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 4.2571e-02,
        2.4963e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 2.5433e-02, 3.8864e-03, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 1.6343e-03, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        3.0692e-02, 5.2445e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 5.6996e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 2.6161e-02, 0.0000e+00, 2.4850e-02, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 2.0478e-

In [70]:
embedding(model, "JW-7-24-1", "ACH-000001", drug_dict, cell_dict, edge_index, args)[0]

tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 2.1086, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         1.5615, 0.0000, 0.0000, 0.0000, 0.2116, 0.0000, 0.0000, 0.2023, 0.0000,
         3.1867, 0.0000, 0.0000, 1.0397, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.4473, 0.0000, 0.0000,
         0.0000, 4.9132, 2.4176, 0.0000, 0.0000, 0.0000, 0.1065, 0.0000, 0.0821,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 1.1596, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.1694, 0.2488, 0.0000, 0.0000, 0.2249, 2.5764,
         0.0000, 0.0000, 0.0000, 0.0000, 1.3447, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0607, 0.0000, 0.1639, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 2.7443, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0

In [81]:
embedding(model, "JW-7-24-1", "ACH-000002", drug_dict, cell_dict, edge_index, args)[0]

tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 2.1086, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         1.5615, 0.0000, 0.0000, 0.0000, 0.2116, 0.0000, 0.0000, 0.2023, 0.0000,
         3.1867, 0.0000, 0.0000, 1.0397, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.4473, 0.0000, 0.0000,
         0.0000, 4.9132, 2.4176, 0.0000, 0.0000, 0.0000, 0.1065, 0.0000, 0.0821,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 1.1596, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.1694, 0.2488, 0.0000, 0.0000, 0.2249, 2.5764,
         0.0000, 0.0000, 0.0000, 0.0000, 1.3447, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0607, 0.0000, 0.1639, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 2.7443, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0

In [46]:
model.CellEncoder.__dict__

{'training': False,
 '_parameters': OrderedDict(),
 '_buffers': OrderedDict(),
 '_non_persistent_buffers_set': set(),
 '_backward_hooks': OrderedDict(),
 '_is_full_backward_hook': None,
 '_forward_hooks': OrderedDict(),
 '_forward_pre_hooks': OrderedDict(),
 '_state_dict_hooks': OrderedDict(),
 '_load_state_dict_pre_hooks': OrderedDict(),
 '_load_state_dict_post_hooks': OrderedDict(),
 '_modules': OrderedDict([('convs_cell',
               ModuleList(
                 (0): GATConv(1, 8, heads=1)
                 (1): GATConv(8, 8, heads=1)
                 (2): GATConv(8, 8, heads=1)
               ))]),
 'num_feature': 1,
 'layer_cell': 3,
 'dim_cell': 8,
 'final_node': 4645}

In [50]:
example.x

tensor([[ 0.6484],
        [-0.2951],
        [-2.0614],
        ...,
        [ 0.8302],
        [ 1.0838],
        [-0.2596]])

In [61]:
cell = Batch.from_data_list([cell_dict["ACH-000001"]]).to(args.device)

In [62]:
cell

DataBatch(x=[4645, 1], x_mask=[4645], edge_index=[2, 12128], batch=[4645], ptr=[2])

In [63]:
model.CellEncoder(cell).shape

torch.Size([1, 37160])

In [66]:
cell = Batch.from_data_list([cell_dict["ACH-000001"]]).to(args.device)
model.CellEncoder(cell)

tensor([[0.0000, 0.0000, 0.0602,  ..., 0.0000, 0.0000, 0.0000]],
       device='cuda:0', grad_fn=<ReshapeAliasBackward0>)

In [73]:
cell = Batch.from_data_list([cell_dict["ACH-000001"]]).to(args.device)
model.CellEncoder(cell)[0][:8]

tensor([0.0000, 0.0000, 0.0602, 0.0141, 0.0614, 0.0000, 0.0085, 0.0210],
       device='cuda:0', grad_fn=<SliceBackward0>)

In [84]:
cell = Batch.from_data_list([cell_dict["ACH-000002"]]).to(args.device)
model.CellEncoder(cell)[0][:8]

tensor([0., 0., 0., 0., 0., 0., 0., 0.], device='cuda:0',
       grad_fn=<SliceBackward0>)

In [87]:
gene_list.sum()

tensor(4645)

In [94]:
import pickle
with open('Data/Cell/34pathway_score990.pkl', 'rb') as file:
    kegg = pickle.load(file)

In [99]:
sum(list(map(lambda x: len(x),kegg.values())))

4646