In [1]:
import os
from typing import Dict

import torch
import numpy as np
import pandas as pd
import hickle as hkl
from tqdm import tqdm
import deepchem as dc
import torch.nn as nn
from rdkit import Chem
import scipy.sparse as sp
import torch.nn.functional as F
from torch_geometric.data import Data as GraphData
from torch.utils.data import Dataset, DataLoader
from torch_geometric.utils import dense_to_sparse
from torch_geometric.nn import GCNConv, global_mean_pool, BatchNorm, Sequential as GraphSequential
from torch_geometric.loader import DataLoader as GraphDataLoader

No normalization for SPS. Feature removed!
No normalization for AvgIpc. Feature removed!
No normalization for NumAmideBonds. Feature removed!
No normalization for NumAtomStereoCenters. Feature removed!
No normalization for NumBridgeheadAtoms. Feature removed!
No normalization for NumHeterocycles. Feature removed!
No normalization for NumSpiroAtoms. Feature removed!
No normalization for NumUnspecifiedAtomStereoCenters. Feature removed!
No normalization for Phi. Feature removed!


Instructions for updating:
experimental_relax_shapes is deprecated, use reduce_retracing instead


Skipped loading modules with pytorch-geometric dependency, missing a dependency. No module named 'dgl'
Skipped loading modules with transformers dependency. No module named 'transformers'
cannot import name 'HuggingFaceModel' from 'deepchem.models.torch_models' (/Users/ericmonzon/mambaforge/envs/tensorflow/lib/python3.10/site-packages/deepchem/models/torch_models/__init__.py)
Skipped loading modules with pytorch-lightning dependency, missing a dependency. No module named 'lightning'
Skipped loading some Jax models, missing a dependency. No module named 'jax'


In [2]:
class MultiOmicsDataset(Dataset):
    def __init__(
        self,
        table_path: str,
        drug_dir: str,
        cell_line_dir: str,
        mode: str = "classification"
        ):
        super().__init__()
        valid_modes = ["classification", "regression"]
        assert mode in valid_modes, f"mode must be one of {valid_modes}"

        self.table = pd.read_csv(table_path)
        self.drug_dir = drug_dir
        self.cell_line_dir = cell_line_dir
        self.mode = mode

    def __len__(self):
        return self.table.shape[0]
    
    def __getitem__(self, idx):
        row = self.table.iloc[idx]
        drug_id = str(row["drug_id"])
        cell_line_id = row["cell_line_id"]
        target = row["label"] if self.mode == "classification" else row["ic50"]
        
        drug_dir = os.path.join(self.drug_dir, drug_id)
        cell_line_dir = os.path.join(self.cell_line_dir, cell_line_id)

        drug_feature_path = os.path.join(drug_dir, "drug-feature.npy")
        drug_edge_list_path = os.path.join(drug_dir, "drug-edge-list.npy")

        gene_expression_path = os.path.join(cell_line_dir, "gene-feature.npy")
        methylation_path = os.path.join(cell_line_dir, "methylation-feature.npy")
        mutation_path = os.path.join(cell_line_dir, "mutation-feature.npy")

        gene_expression = np.load(gene_expression_path)
        methylation = np.load(methylation_path)
        mutation = np.load(mutation_path)

        drug_dict = {
            "feature_path": drug_feature_path,
            "edge_list_path": drug_edge_list_path
        }

        cell_line_dict = {
            "gene_expression": gene_expression,
            "methylation": methylation,
            "mutation": mutation
        }

        return drug_dict, cell_line_dict, target

In [3]:
data_dir = os.path.join("..", "data", "cleaned")

table_path = os.path.join(data_dir, "train.csv")
drug_dir = os.path.join(data_dir, "drugs")
cell_line_dir = os.path.join(data_dir, "cell-line")

dataset = MultiOmicsDataset(table_path, drug_dir, cell_line_dir)

In [4]:
loader = DataLoader(dataset, batch_size=5)

drug_dict, cell_line_dict, target = next(iter(loader))

In [5]:
drug_dict

{'feature_path': ['../data/cleaned/drugs/11282283/drug-feature.npy',
  '../data/cleaned/drugs/216326/drug-feature.npy',
  '../data/cleaned/drugs/6918289/drug-feature.npy',
  '../data/cleaned/drugs/56965967/drug-feature.npy',
  '../data/cleaned/drugs/300471/drug-feature.npy'],
 'edge_list_path': ['../data/cleaned/drugs/11282283/drug-edge-list.npy',
  '../data/cleaned/drugs/216326/drug-edge-list.npy',
  '../data/cleaned/drugs/6918289/drug-edge-list.npy',
  '../data/cleaned/drugs/56965967/drug-edge-list.npy',
  '../data/cleaned/drugs/300471/drug-edge-list.npy']}

In [6]:
cell_line_dict

{'gene_expression': tensor([[4.9593, 2.1110, 4.0179,  ..., 6.9226, 4.7301, 0.6690],
         [6.0851, 0.0000, 3.3730,  ..., 5.2079, 5.9896, 1.0496],
         [5.8585, 2.6622, 3.7015,  ..., 7.8377, 6.5159, 2.3132],
         [5.9741, 3.1554, 4.1383,  ..., 6.0906, 5.6496, 0.7312],
         [6.5563, 0.7049, 4.0339,  ..., 5.5110, 4.4296, 1.1309]]),
 'methylation': tensor([[0.0000, 0.0000, 0.1769,  ..., 0.0020, 0.0020, 0.0085],
         [0.0027, 0.0246, 0.0785,  ..., 0.0084, 0.0084, 0.0026],
         [0.0042, 0.3307, 0.0140,  ..., 0.0006, 0.0006, 0.0021],
         [0.0000, 0.0009, 0.0706,  ..., 0.0255, 0.0255, 0.0045],
         [0.0000, 0.0045, 0.5713,  ..., 0.0057, 0.0057, 0.0000]]),
 'mutation': tensor([[[0., 0., 0.,  ..., 0., 0., 0.]],
 
         [[0., 0., 0.,  ..., 0., 0., 0.]],
 
         [[0., 0., 0.,  ..., 0., 0., 0.]],
 
         [[0., 0., 0.,  ..., 0., 0., 0.]],
 
         [[0., 0., 0.,  ..., 0., 0., 0.]]])}

In [7]:
target.shape

torch.Size([5])

In [8]:
def extract_graph(drug_dict: Dict[str, str]):
    features = [np.load(f) for f in drug_dict["feature_path"]]
    edge_lists = [np.load(f) for f in drug_dict["edge_list_path"]]

    features = [torch.tensor(f, dtype=torch.float32) for f in features]
    edge_lists = [torch.tensor(e, dtype=torch.long) for e in edge_lists]

    graphs = [GraphData(x=feature, edge_index=edge_list) for feature, edge_list in zip(features, edge_lists)]

    return graphs

In [9]:
graphs = extract_graph(drug_dict)

graphs

[Data(x=[32, 75], edge_index=[2, 74]),
 Data(x=[19, 75], edge_index=[2, 42]),
 Data(x=[73, 75], edge_index=[2, 152]),
 Data(x=[38, 75], edge_index=[2, 84]),
 Data(x=[27, 75], edge_index=[2, 56])]

In [10]:
graph_loader = GraphDataLoader(graphs, batch_size=len(graphs), shuffle=False)

batch = next(iter(graph_loader))

In [11]:
batch.x.shape

torch.Size([189, 75])

In [12]:
class DrugGCN(nn.Module):
    def __init__(
        self,
        in_dim: int = 75,
        hidden_dim: int = 256,
        num_hidden: int = 3,
        output_dim: int = 100,
        dropout_prob: float = 0.1
        ):
        super().__init__()

        self.embedding_layer = GraphSequential("x, edge_index", [
            (GCNConv(in_dim, hidden_dim), "x, edge_index -> x"),
            nn.ReLU(),
            BatchNorm(hidden_dim),
            nn.Dropout(p=dropout_prob)
        ])

        hidden_layers = []

        for _ in range(num_hidden-2):
            hidden_layers.append((GCNConv(hidden_dim, hidden_dim), "x, edge_index -> x"))
            hidden_layers.append(nn.ReLU())
            hidden_layers.append(BatchNorm(hidden_dim))
            hidden_layers.append(nn.Dropout(p=dropout_prob))
        
        self.hidden_layer = GraphSequential("x, edge_index", hidden_layers)

        self.output_layer = GraphSequential("x, edge_index", [
            (GCNConv(hidden_dim, output_dim), "x, edge_index -> x"),
            nn.ReLU(),
            BatchNorm(output_dim),
            nn.Dropout(p=dropout_prob)
        ])
        
    def forward(self, x, edge_index, batch):
        x = self.embedding_layer(x, edge_index)
        # print(x.shape)

        x = self.hidden_layer(x, edge_index)
        # print(x.shape)

        x = self.output_layer(x, edge_index)
        # print(x.shape)

        embedding = global_mean_pool(x, batch)

        return embedding

In [13]:
graph_net = DrugGCN()

out = graph_net(batch.x, batch.edge_index, batch.batch)

In [14]:
out.shape

torch.Size([5, 100])

In [15]:
graph_net

DrugGCN(
  (embedding_layer): Sequential(
    (0) - GCNConv(75, 256): x, edge_index -> x
    (1) - ReLU(): x -> x
    (2) - BatchNorm(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True): x -> x
    (3) - Dropout(p=0.1, inplace=False): x -> x
  )
  (hidden_layer): Sequential(
    (0) - GCNConv(256, 256): x, edge_index -> x
    (1) - ReLU(): x -> x
    (2) - BatchNorm(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True): x -> x
    (3) - Dropout(p=0.1, inplace=False): x -> x
  )
  (output_layer): Sequential(
    (0) - GCNConv(256, 100): x, edge_index -> x
    (1) - ReLU(): x -> x
    (2) - BatchNorm(100, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True): x -> x
    (3) - Dropout(p=0.1, inplace=False): x -> x
  )
)

In [16]:
cell_line_dict["mutation"].shape

torch.Size([5, 1, 34673])

In [17]:
class MutationConv1d(nn.Module):
    def __init__(
        self,
        in_channels: int = 1,
        output_dim: int = 100,
        dropout_prob: float = 0.1
        ):
        super().__init__()
        
        self.conv1 = nn.Sequential(
            nn.Conv1d(
                in_channels=in_channels,
                out_channels=50,
                kernel_size=700,
                stride=5
            ),
            nn.Tanh(),
            nn.MaxPool1d(kernel_size=5)
        )
        
        self.conv2 = nn.Sequential(
            nn.Conv1d(
                in_channels=50,
                out_channels=30,
                kernel_size=5,
                stride=2
            ),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=10),
            nn.Flatten()
        )

        self.fc = nn.Sequential(
            nn.Linear(2010, output_dim),
            nn.ReLU(),
            nn.Dropout(p=dropout_prob)
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        embedding = self.fc(x)

        return embedding

In [18]:
mutation_net = MutationConv1d()

mutation_net(cell_line_dict["mutation"]).shape

torch.Size([5, 100])

In [19]:
cell_line_dict["gene_expression"].shape

torch.Size([5, 697])

In [20]:
class MLP(nn.Module):
    def __init__(
        self,
        in_dim: int,
        embedding_dim: int = 256,
        output_dim: int = 100,
        dropout_prob: float = 0.1
        ):
        super().__init__()

        self.fc1 = nn.Sequential(
            nn.Linear(in_dim, embedding_dim),
            nn.Tanh(),
            nn.BatchNorm1d(embedding_dim),
            nn.Dropout(p=dropout_prob)
        )

        self.fc2 = nn.Sequential(
            nn.Linear(embedding_dim, output_dim),
            nn.ReLU()
        )

    def forward(self, x):
        x = self.fc1(x)
        embedding = self.fc2(x)

        return embedding

In [21]:
cell_line_dict["gene_expression"].shape

torch.Size([5, 697])

In [22]:
gene_net = MLP(in_dim=697)

gene_net(cell_line_dict["gene_expression"]).shape

torch.Size([5, 100])

In [23]:
cell_line_dict["methylation"].shape

torch.Size([5, 808])

In [24]:
methylation_net = MLP(in_dim=808)

methylation_net(cell_line_dict["methylation"]).shape

torch.Size([5, 100])

In [25]:
class DeepCDR(nn.Module):
    def __init__(
        self,
        mode: str = "classification",
        output_dim: int = 100,
        dropout_prob: float = 0.1,
        ):
        super().__init__()

        valid_modes = ["classification", "regression"]
        assert mode in valid_modes, f"mode must be one of {valid_modes}"

        self.drug_net = DrugGCN(
            output_dim=output_dim, 
            dropout_prob=dropout_prob
            )
        
        self.gene_net = MLP(
            in_dim=697, 
            output_dim=output_dim, 
            dropout_prob=dropout_prob
            )
        
        self.methylation_net = MLP(
            in_dim=808, 
            output_dim=output_dim, 
            dropout_prob=dropout_prob
            )
        
        self.mutation_net = MutationConv1d(
            output_dim=output_dim, 
            dropout_prob=dropout_prob
            )

        self.projection = nn.Sequential(
            nn.Linear(output_dim*4, 300),
            nn.Tanh(),
            nn.Dropout(p=dropout_prob),
        )

        self.conv = nn.Sequential(
            nn.Conv1d(in_channels=1, out_channels=30, kernel_size=150, stride=1),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2),
            nn.Conv1d(in_channels=30, out_channels=10, kernel_size=5, stride=1),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=3),
            nn.Conv1d(in_channels=10, out_channels=5, kernel_size=5, stride=1),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=3),
            nn.Dropout(p=dropout_prob),
            nn.Flatten(),
            nn.Dropout(p=0.2)
        )

        self.fc = nn.Linear(30, 2 if mode == "classification" else 1)

    def forward(
        self, 
        drug_graphs, 
        gene_expression_features, 
        methylation_features, 
        mutation_features
        ):

        drug_embedding = self.drug_net(
            drug_graphs.x, 
            drug_graphs.edge_index, 
            drug_graphs.batch
        )

        gene_expression_embedding = self.gene_net(
            gene_expression_features
        )

        methylation_embedding = self.methylation_net(
            methylation_features
        )

        mutation_embedding = self.mutation_net(
            mutation_features
        )

        combined_embedding = torch.cat([
            drug_embedding,
            gene_expression_embedding,
            methylation_embedding,
            mutation_embedding
        ], dim=-1)

        x = self.projection(combined_embedding)
        x = x.unsqueeze(1)
        x = self.conv(x)
        
        out = self.fc(x)

        return out


In [26]:
sample_embeds = [torch.randn(10, 100) for _ in range(4)]

torch.cat(sample_embeds, dim=-1).shape

torch.Size([10, 400])

In [27]:
multi_net = DeepCDR()

out = multi_net(batch, cell_line_dict["gene_expression"], cell_line_dict["methylation"], cell_line_dict["mutation"])

In [28]:
out.shape

torch.Size([5, 2])