In [None]:
%load_ext blackcellmagic
%load_ext autoreload
%autoreload 2

import tqdm
from sklearn import metrics 
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import holoviews as hv 
hv.extension('bokeh')
import hvplot.pandas
import seaborn as sns

%matplotlib inline
%config InlineBackend.figure_format = 'svg'

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Linear
from torch.nn import BatchNorm1d
from torch.utils.data import Dataset

from torch_geometric.nn import global_add_pool, global_mean_pool

from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem.Draw import IPythonConsole
from rdkit.Chem import Draw
from torch_geometric.nn import GCNConv
#from torch_geometric.nn import ChebConv

import warnings 
warnings.filterwarnings('ignore')

from torch_geometric.data import Batch
from torch_geometric.data import DataLoader as geom_dataloader
from torch_geometric.data import Data as geom_data

from torchvision import transforms
from torch.utils.data import IterableDataset, DataLoader

import anndata as ad

In [None]:
seed = 78364
torch.manual_seed(seed)
np.random.seed(seed)

In [None]:
from magma import chemspace as mc
from magma import utils as mu
from magma import models as mm
mu.set_plotting_style_plt()

All right ! We're finally in the last step of the model building: using contrastive learning to construct an inner product space a.k.a. the joint embedding! Operationally, we will be using the same datasets as before, so we're good to go. 

### Load cell dataset

First-off, let's load the cell dataset and perform the same pipeline we've been using to perform the train-test split. 

In [None]:
# Write down the path to the cell dataset
path = '../../thomson_lab/data/drugbank/'

In [None]:
a = ad.read_h5ad(path + 'mult_cd3_100_train.h5ad')

In [None]:
%%time
n_samples = 500

sampling_ix = (
    a.obs.groupby(["sample_id"])
    .apply(
        lambda group_df: group_df.sample(
            group_df.shape[0] if group_df.shape[0] < n_samples else n_samples,
            replace = False)
    )
    .index.get_level_values(1) # Get the numerical index :) 
)

In [None]:
ada = a[sampling_ix].copy()

In [None]:
ada.obs.sample_id.value_counts(False).head()

In [None]:
ada.obs.sample_id.value_counts(False).tail()

In [None]:
codes, uniques = pd.factorize(ada.obs.drug_name)

In [None]:
ada.obs['sample_code'] = codes

Now, let's make a dictionary that takes in the numerical indices and returns the name of the sample. This will help us to retrieve the drugs for the corresponding cells during training.

In [None]:
code_to_name = dict(ada.obs[['sample_code', 'drug_name']].values)

In [None]:
code_to_name[1]

In [None]:
code_to_name[99]

In [None]:
uniques = uniques.to_list()

In [None]:
uniques[:5]

In [None]:
len(uniques)

### Load molecules dataset 

In [None]:
path

In [None]:
drugs = pd.read_csv(path + 'thomsonlab_drugs_smiles.csv')

We will make the RDKit mol objects as in the GCN example. 

In [None]:
drugs['mol'] = drugs.SMILES.apply(Chem.MolFromSmiles)

In [None]:
df_drugs = drugs[drugs.name.isin(uniques)]

In [None]:
df_drugs.head()

In [None]:
uniques[:5]

In [None]:
drugs['drug_name'] = drugs.name.str.lower()

In [None]:
uniques[:5]

In [None]:
df_drugs = drugs[drugs.drug_name.isin(uniques)]

In [None]:
df_drugs.shape

In [None]:
# Out-of-bag (OOB)
drugs_oob = drugs[~drugs.drug_name.isin(uniques)]

In [None]:
drugs_oob.shape

In [None]:
df_drugs.shape

Finally, we need a helper function that given a list of numerical indices it returns the drugs as molecular graphs in the `torch_geometric` data type.

In [None]:
def get_drug_batch(labels_batch:list)->list:
    "Returns a list of torch.geometric Data object given a list of sample codes."
    
    drug_graphs = [
        mc.mol2tensors(
            name_to_mol[code_to_name[x.item()]]
        ) for x in labels_batch
    ]

    return drug_graphs

In [None]:
code_to_name[1]

In [None]:
name_to_mol = dict(df_drugs[['drug_name', 'mol']].values)

In [None]:
# make a test
get_drug_batch(torch.LongTensor([1,2,3]))

In [None]:
from sklearn.model_selection import StratifiedShuffleSplit

# Initialize stratified sampler
splitter = StratifiedShuffleSplit(n_splits = 1, test_size = 0.4, random_state = seed)

ixs = list(splitter.split(ada.X, ada.obs[['cell_type', 'sample_code']]))

train_ix, val_ix = ixs[0][0], ixs[0][1]

### Make cell dataset and dataloader

In [None]:
train_adata = ada[train_ix].copy()
test_adata = ada[val_ix].copy()

# Initialize torch dataset 
train_dataset = mu.adata_torch_dataset(
    train_adata, transform = transforms.ToTensor(), supervised = True, target_col = 'sample_code'
)

test_dataset = mu.adata_torch_dataset(
    test_adata, transform = transforms.ToTensor(), supervised = True, target_col = 'sample_code'
)

In [None]:
train_dataset[0]

In [None]:
batch_size = 32 #increase batch size because of large dataset

# Initialize DataLoader for minibatching 
train_loader = DataLoader(
    train_dataset, batch_size = batch_size, drop_last = True, shuffle = False, num_workers =4
)

val_loader = DataLoader(
    test_dataset, batch_size = batch_size, drop_last = True, shuffle = False, num_workers =4
)

### Initialize models with pre-trained weights 

Now, to make the joint embedding model, we initialize the cell and molecule encoders.

In [None]:
n_cats = len(uniques)
n_genes = a.n_vars
clf_dims = [n_genes, 512, 256, 64, n_cats]

In [None]:
cell_encoder = mm.supervised_model(clf_dims, model = 'multiclass', dropout = False)

In [None]:
!ls {path + 'models/' }

In [None]:
# the map_location is needed if the model was trained on a GPU 
# and this model is run on a gpu
cell_classifier_wts = torch.load(
    path + 'models/droog_100.pt',
    map_location=torch.device('cpu')
)

In [None]:
cell_encoder.load_state_dict(cell_classifier_wts)

In [None]:
cell_encoder

In [None]:
dims_conv = [18, 256, 128]
dims_linear = [128, 64, 37]
molecule_encoder = mm.GraphConvNetwork(dims_conv, dims_linear)

In [None]:
mol_encoder_wts = torch.load(path + 'models/gcn_drugbank_chem.pt')

In [None]:
molecule_encoder.load_state_dict(mol_encoder_wts)

### Initialize `JointEmbedding` model 

The joint embedding takes as arguments the molecule and cell encoder respectively !

In [None]:
joint_embedding = mm.JointEmbedding(molecule_encoder, cell_encoder)

In [None]:
joint_embedding

We can access the individual encoders because they're now attributes of this model. 

In [None]:
joint_embedding.molecule_encoder

We can also access subparts from the molecule encoder, and so on and so forth. 

In [None]:
joint_embedding.molecule_encoder.conv_encoder

### Train! 

The training procedure can be visualized in the following diagram. 

In [None]:
from IPython.display import Image

In [None]:
Image(
    url = 'https://github.com/manuflores/sandbox/blob/master/figs/diag.png?raw=true',
    format = 'png'
)

In math notation we have the following: 

$$
\mathbf{C} \in \mathbb{R}^{\text {cells} \times \text {genes} } \text { (count matrix)}\\[1.em]
\mathbf {\Psi} = f(\mathbf{C}) \in \mathbb{R}^{\text {cells} \times k } \text { (cell embeddings)} \\[1.em]
\mathbf{M} \in \mathbb{R}^{\text {molecules} \times \text {atom features} } \text { (molecule feature matrix)} \\[1.em]
\mathbf {\Phi} = g(\mathbf{M}) \in \mathbb{R}^{\text {molecules} \times k } \text { (molecule embeddings)} \\[1.em]
\mathbf {\Lambda} = \tilde{\mathbf {\Phi}} \tilde{\mathbf {\Psi}} ^\top \in \mathbb{R}^{\text {molecules} \times \text {cells}  } \\[1.em]
\mathbf {\Lambda}_{ij} = \mathrm{cos}(\theta)_ {\text {molecule}_i \text {cell}_j}
$$


$$
\text{where } \tilde{\mathbf {\Phi}} \text{ and } \tilde{\mathbf {\Psi}} \text{ are the rows of the cell embedding matrix , i.e.} \\[1em]
\tilde{\mathbf {\phi}}_{i} = \frac{\phi_i}{ \sqrt{ \phi_i^\top \phi_i }}
$$


$$\text{where } \phi_i \text{ is the row vector} \in \mathbb{R}^{1 \times k} \text{embedding of molecule } i \text{ in the original embedding matrix } \mathbf {\Phi }.$$

Here $f$ is the cell encoder and $g$ is the molecule encoder. The softmax function then turns the $\Lambda$ matrix into probability distributions across columns and rows. 

We will go through the whole training loop in the notebook to make it more explicit as this is important. First, let's set up our training parameters.

In [None]:
optimizer = torch.optim.Adam(joint_embedding.parameters(), lr = 1e-3, weight_decay = 0)

n_epochs = 3
train_prints_per_epoch = 4

# How many times print loss per epoch
print_every = np.floor(
    train_adata.n_obs / batch_size / train_prints_per_epoch
)

train_loss_vector = [] # to store training loss
val_loss_vector = np.empty(shape = n_epochs) # and validation loss

In the current implementation we don't use a one-hot vectors because of the signature call of the `NLLLoss()` function instead of the cross entropy loss, because it's more numerically stable.  Rather our ordering labels are just a list of integers $(1, 2, ..., \text{batch_size})$.

In [None]:
# The ordering labels will be the ranking indicators
ordering_labels = torch.arange(batch_size)
#ordering_labels_hot = torch.eye(batch_size)
#F.one_hot(torch.from_numpy(ordering_labels))

ordering_labels

In [None]:
criterion = nn.NLLLoss()

Now, we can finally make the training loop. 

In [None]:
for epoch in np.arange(n_epochs):
    # TRAINING LOOP 
    running_loss = 0
    
    joint_embedding.train()
    
    train_acc_vector = np.zeros(int(train_adata.n_obs / batch_size))
    
    # y_true are the drug numerical labels
    for ix, (cell_batch, y_true) in tqdm.tqdm(enumerate(train_loader)):

        joint_embedding.zero_grad()
        
        # Make batch of molecular graphs
        molecule_batch = Batch.from_data_list(get_drug_batch(y_true))
    
        # Get cosine similarities
        logits = joint_embedding(
            molecule_batch,
            cell_batch.view(batch_size, -1).float()
        )
        
        # Get classification predictions across axes 
        # returns tensor of shape (mols, cells)
        y_pred_mols = F.log_softmax(logits, dim = 1)
        y_pred_cells = F.log_softmax(logits, dim = 0)
        
        # Compute error and average
        loss_mols = criterion(y_pred_mols, ordering_labels)
        loss_cells = criterion(y_pred_cells, ordering_labels)
        
        loss = (loss_mols + loss_cells)/2
        
        # Backprop and update weights
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
        cell_acc = mu.accuracy(y_pred_cells.argmax(axis = 1), ordering_labels)
        mol_acc = mu.accuracy(y_pred_mols.argmax(axis =1), ordering_labels)

        train_acc = (cell_acc + mol_acc)/ 2
        train_acc_vector[ix]= train_acc
        
        
        if ix % print_every == print_every -1 :

            # Print average loss
            print('[%d, %5d] Loss: %.3f' %
                  (epoch + 1, ix+1, running_loss / print_every))

            train_loss_vector.append(running_loss / print_every)

            # Reinitialize loss
            running_loss = 0.0
    
    
    mean_acc = np.mean(train_acc_vector)
    print('Mean training accuracy : %.3f'%(mean_acc*100))
    
    # VALIDATION LOOP
    with torch.no_grad():
        validation_loss = []
        
        val_acc = np.zeros(int(test_adata.n_obs / batch_size))
        val_top_acc = np.zeros(int(test_adata.n_obs / batch_size))
        
        for ix, (cell_batch, y_true) in tqdm.tqdm(enumerate(val_loader)):

            # Make batch of molecular graphs
            molecule_batch = Batch.from_data_list(get_drug_batch(y_true))

            logits = joint_embedding(
                molecule_batch,
                cell_batch.view(batch_size, -1).float()
            )
            
            # Get classification predictions across axes 
            # returns tensor of shape (mols, cells)
            y_pred_mols = F.log_softmax(logits, dim = 1)
            y_pred_cells = F.log_softmax(logits, dim = 0)
            
            cell_acc = mu.accuracy(y_pred_cells.argmax(axis = 1), ordering_labels)
            mol_acc = mu.accuracy(y_pred_mols.argmax(axis =1), ordering_labels)
            
            validation_acc = (cell_acc + mol_acc)/ 2
            val_acc[ix]= validation_acc
            
            loss_mols = criterion(y_pred_mols, ordering_labels)
            loss_cells = criterion(y_pred_cells, ordering_labels)
            
            val_loss = (loss_cells + loss_mols)/2
            
            validation_loss.append(val_loss)
            
        mean_val_loss = torch.tensor(validation_loss).mean().item()
        print('Val. loss %.3f'% mean_val_loss)
        
        # Record the epoch's mean validation loss
        val_loss_vector[epoch] = mean_val_loss
        
        mean_acc = np.mean(val_acc)
        print('Val. accuracy : %.3f'%(mean_acc*100))

Save model. 

In [None]:
#torch.save(joint_embedding.state_dict(), path + 'models/joint_emb.pt')

### Visualize molecules in the new embedding space

In [None]:
#drugs_oob.shape
df_drugs.head(2)

In [None]:
all_data = [mc.mol2tensors(m) for m in tqdm.tqdm(df_drugs.mol.values)]
proj_loader = geom_dataloader(all_data, batch_size = 32, shuffle = False, drop_last = False)

In [None]:
type(proj_loader)

In [None]:
# Make projection to second to last layer
with torch.no_grad():
    joint_embedding.eval()
    projection_arr = np.array(
        list(
            joint_embedding.molecule_encoder.project_to_latent_space(
                proj_loader, dims_conv[0], dims_linear[-2]
            )
        )
    )

In [None]:
df_proj = pd.DataFrame(
    projection_arr, columns = ['dim_' + str(i) for i in range(1, dims_linear[-2] +1)]
)

df_viz = pd.concat(
    [df_drugs, df_proj.set_index(df_drugs.index)], axis = 1
)

In [None]:
df_viz.head(2)

In [None]:
from sklearn.decomposition import PCA 

pca_obj = PCA()

pca_obj.fit(projection_arr)

In [None]:
var_ratio = np.cumsum(pca_obj.explained_variance_ratio_)

plt.plot(var_ratio)
plt.xlabel('number of components')
plt.ylabel('fraction of explained variance')

In [None]:
pcs = PCA(2).fit_transform(projection_arr)

In [None]:
df_viz['pc1'], df_viz['pc2'] = pcs.T

In [None]:
df_viz.hvplot.scatter(
    x = 'pc1', 
    y = 'pc2',
    #y = ['dim_2', 'dim_3'],
    c= 'drug_class',
    hover_cols = ['name'],
    #subplots = True,
    width = 700,
    height = 480, 
    size = 60, 
    alpha =.9
)

In [None]:
df_viz.hvplot.scatter(
    x = 'dim_1', 
    y = 'dim_64',
    #y = ['dim_2', 'dim_3'],
    c= 'drug_class',
    hover_cols = ['name'],
    #subplots = True,
    width = 700,
    height = 480, 
    size = 60, 
    alpha =.9
)

### Nearest neighbor search

We can compare the chemical only embedding, with the joint embedding by nearest neighbor search : the idea is the nearest neighbors should be different in both embeddings if the change in the encoders is substantial. 

We will make nearest neighbor search of molecules within our dataset, but we could in theory encode all drugs from Chembl or Drugbank and compute nearest neighbor in this expanded chemical space. 

In [None]:
from sklearn.neighbors import NearestNeighbors

knn = NearestNeighbors(n_neighbors = 10).fit(projection_arr)

In [None]:
%%time
neighbors = knn.kneighbors()

In [None]:
neighbor_ixs = neighbors[1]

In [None]:
df_viz.reset_index(drop=True, inplace=True)

In [None]:
def get_neighbors(df, sample_name, sort = True):
    "Returns a dataframe with nearest neighbor information sorted by subclass."
    ix_query = df[df['name'] == sample_name].index[0]
    ix_neighbors = neighbor_ixs[ix_query]
    if sort:
        return df.iloc[ix_neighbors].sort_values(by = 'subclass')
    else: 
        return df.iloc[ix_neighbors]

In [None]:
get_neighbors(df_viz, 'Dasatinib', sort = False)

In [None]:
get_neighbors(df_viz, 'Nilotinib', sort = False)

In [None]:
get_neighbors(df_viz, 'Vorinostat', sort = False)

### Save embeddings

In [None]:
#df_viz.to_csv(path + 'joint_emb.csv', index = False)

### Reproducibility

In [None]:
%load_ext watermark
%watermark -m -v -p numpy,torch,anndata,sklearn,holoviews,hvplot,matplotlib,rdkit