In [11]:
import pandas as pd 
import torch
import numpy as np
from matplotlib import pyplot as plt
#from torchmetrics import JaccardIndex
import seaborn as sbn
from sklearn.metrics import roc_auc_score, accuracy_score, r2_score
import time
import torchvision
import torchvision.transforms as transforms
import copy 
from scipy.stats import spearmanr
import h5py

import sys 
sys.path.append('./src/')
from data_loading import load_tabular_data, preprocess_data, corrupt_label
from DVGS import DVGS
from DVRL import DVRL
from utils import get_corruption_scores
from NN import NN
from AE import AE
import similarities 
import DShap
from LOO import LOO
from sklearn.utils.class_weight import compute_class_weight
from utils import load_data, get_filtered_scores
from LincsEmbNN import LincsEmbNN

from sklearn.linear_model import LogisticRegression# CV
from sklearn.model_selection import StratifiedKFold, cross_val_score
from sklearn.svm import SVC
from sklearn.preprocessing import StandardScaler

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [66]:
osig_idx = pd.read_csv('./data/processed/osig_indexed.tsv', sep='\t')
osig_idx.head()

Unnamed: 0,sig_id,pert_idx,cell_idx,log_conc,z_time
0,ABY001_A375_XH:BRD-A61304759:0.625:24,1425,4,-0.20412,0.472838
1,ABY001_A375_XH:BRD-A61304759:0.625:3,1425,4,-0.20412,-2.140817
2,ABY001_A375_XH:BRD-A61304759:10:24,1425,4,1.0,0.472838
3,ABY001_A375_XH:BRD-A61304759:10:3,1425,4,1.0,-2.140817
4,ABY001_A375_XH:BRD-A61304759:2.5:24,1425,4,0.39794,0.472838


In [67]:
_idxs = (~osig_idx.pert_idx.isna()).values.nonzero()[0]
_idxs

array([     0,      1,      2, ..., 509984, 509985, 509986])

In [68]:
f = h5py.File('./data/processed/lincs.h5')
y = torch.tensor(f['data'][_idxs, ...], dtype=torch.float32)
y.size()

torch.Size([509987, 978])

In [69]:
pert_idx = torch.tensor(osig_idx.pert_idx.values[_idxs], dtype=torch.long)
cell_idx = torch.tensor(osig_idx.cell_idx.values[_idxs], dtype=torch.long)
log_conc = torch.tensor(osig_idx.log_conc.values[_idxs], dtype=torch.float32)
z_time = torch.tensor(osig_idx.z_time.values[_idxs], dtype=torch.float32)

In [70]:
model = LincsEmbNN(cell_channels    = 64, 
                   num_lines        = osig_idx.cell_idx.max() + 1, 
                   pert_channels    = 64, 
                   num_perts        = osig_idx.pert_idx.max() + 1, 
                   out_channels     = 978, 
                   num_layers       = 2, 
                   hidden_channels  = 400, 
                   norm             = True, 
                   dropout          = 0.05, 
                   act              = torch.nn.Mish)

In [71]:
device = 'cuda'

model.to(device)

batch_size  = 1024
epochs      = 100
lr          = 1e-2 

optim = torch.optim.Adam(model.parameters(), lr=lr)
crit = torch.nn.MSELoss() 


for epoch in range(epochs): 
    _losses = []
    ii=0
    _r2s = []
    for batch_idx in torch.split(torch.randperm(y.size(0)), batch_size): 

        pert_batch = pert_idx[batch_idx].to(device)
        cell_batch = cell_idx[batch_idx].to(device)
        time_batch = z_time[batch_idx].to(device)
        conc_batch = log_conc[batch_idx].to(device)

        y_batch = y[batch_idx, :].to(device)

        yhat = model(pert_idx=pert_batch, cell_idx=cell_batch, z_time=time_batch, log_conc=conc_batch)

        optim.zero_grad()
        loss = crit(yhat, y_batch)
        loss.backward() 
        optim.step()
        
        _losses.append(loss.item())
        _r2s.append(r2_score(y_batch.detach().cpu().numpy(), yhat.detach().cpu().numpy(), multioutput='uniform_average'))
        print(f'[epoch progress: {ii}/{1+int(y.size(0)/batch_size)}]', end='\r')
        ii+=1

    print(f'epoch: {epoch} || avg. loss: {np.mean(_losses):.4f} || avg. R2: {np.mean(_r2s):.4f}')



epoch: 0 || avg. loss: 0.9021 || avg. R2: -0.0104
epoch: 1 || avg. loss: 0.8926 || avg. R2: 0.0001
epoch: 2 || avg. loss: 0.8881 || avg. R2: 0.0053
epoch: 3 || avg. loss: 0.8824 || avg. R2: 0.0121
epoch: 4 || avg. loss: 0.8777 || avg. R2: 0.0182
epoch: 5 || avg. loss: 0.8728 || avg. R2: 0.0229
epoch: 6 || avg. loss: 0.8679 || avg. R2: 0.0279
epoch: 7 || avg. loss: 0.8641 || avg. R2: 0.0325
epoch: 8 || avg. loss: 0.8602 || avg. R2: 0.0358
epoch: 9 || avg. loss: 0.8566 || avg. R2: 0.0400
epoch: 10 || avg. loss: 0.8545 || avg. R2: 0.0423
epoch: 11 || avg. loss: 0.8508 || avg. R2: 0.0460
epoch: 12 || avg. loss: 0.8486 || avg. R2: 0.0492
epoch: 13 || avg. loss: 0.8461 || avg. R2: 0.0520
epoch: 14 || avg. loss: 0.8438 || avg. R2: 0.0547
epoch: 15 || avg. loss: 0.8409 || avg. R2: 0.0574
epoch: 16 || avg. loss: 0.8390 || avg. R2: 0.0596
epoch: 17 || avg. loss: 0.8369 || avg. R2: 0.0618
epoch: 18 || avg. loss: 0.8356 || avg. R2: 0.0643
epoch: 19 || avg. loss: 0.8339 || avg. R2: 0.0661
epoch: 20

In [72]:
pert2targ = pd.read_csv('./data/processed/pert2targets.tsv', sep='\t')
pert2targ.head()

Unnamed: 0,pert_id,target,pert_idx,DRD2,NR3C1,HTR2A,KDR,PTGS2,PTGS1,HRH1,...,SSTR5,ADIPOR2,SLC6A12,MAPK9,SSTR1,SLC29A1,MKNK1,MGMT,SPHK1,SLC52A2
0,BRD-A00077618,['PRKG1'],0,False,False,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False
1,BRD-A00100033,[nan],1,False,False,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False
2,BRD-A00147595,"['PPARG', 'PPARG']",2,False,False,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False
3,BRD-A00150179,[nan],3,False,False,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False
4,BRD-A00218260,[nan],4,False,False,False,False,False,False,False,...,False,False,False,False,False,False,False,False,False,False


In [73]:
targspace = pert2targ.columns[3:]
len(targspace)

606

In [74]:
targspace[0]

'DRD2'

In [75]:
def get_pert_idx_pairs_with_shared_targets(pert2targ): 

    pos_class_idx1 = []
    pos_class_idx2 = []

    targspace = pert2targ.columns[3:]

    for i,t in enumerate(targspace): 
        print(f'progress: {i}/{len(targspace)}', end='\r')
        perts_sharing_targ = pert2targ[t].values.nonzero()[0]
        for j,p1 in enumerate(perts_sharing_targ): 
            for p2 in perts_sharing_targ[(j+1):]: 
                pos_class_idx1.append(p1)
                pos_class_idx2.append(p2)

    pos_class_idx1 = torch.tensor(pos_class_idx1, dtype=torch.long)
    pos_class_idx2 = torch.tensor(pos_class_idx2, dtype=torch.long)

    return pos_class_idx1, pos_class_idx2

pos_class_idx1, pos_class_idx2 = get_pert_idx_pairs_with_shared_targets(pert2targ)

neg_class_idx1 = torch.randint(osig_idx.pert_idx.max() + 1, size=(1000000,))  # grand majority of combinations are negative class ; random pairings should estimate neg
neg_class_idx2 = torch.randint(osig_idx.pert_idx.max() + 1, size=(1000000,))

progress: 605/606

In [77]:

def batched_cosine_similarity(embedding, idx1, idx2, batch_size=1024, normalize=True): 
    ''''''

    if normalize: 
        embedding = copy.deepcopy(embedding)
        embedding.weight.data = embedding.weight.data - embedding.weight.data.mean(dim=0)
        embedding.weight.data = embedding.weight.data / embedding.weight.data.std(dim=0)

    cos_sim = torch.nn.CosineSimilarity(dim=1)
    s = []
    for batch_idx in torch.split(torch.arange(len(idx1)), batch_size):
        with torch.no_grad(): 
            z1 = embedding(idx1[batch_idx])
            z2 = embedding(idx2[batch_idx])
            s.append( cos_sim(z1, z2) )

    return torch.cat(s, dim=-1).mean()

pos_sim = batched_cosine_similarity(model.pert_embedding.cpu(), pos_class_idx1, pos_class_idx2)
neg_sim = batched_cosine_similarity(model.pert_embedding.cpu(), neg_class_idx1, neg_class_idx2)

print(pos_sim)
print(neg_sim)

tensor(0.0118)
tensor(0.0007)


In [59]:
model.pert_embedding.weight.data

tensor([[ 3.5895, -0.3356,  0.4950,  ..., -1.9853, -1.8826, -0.3900],
        [-0.2993, -1.4494,  6.4771,  ...,  4.3259,  3.4183, -3.4906],
        [ 3.9816,  0.4846,  0.8419,  ...,  6.7406,  4.3670, -5.4558],
        ...,
        [-4.4173,  1.0471, -0.6121,  ...,  3.1440,  3.1208,  3.6972],
        [-0.3667,  2.0897,  4.7883,  ...,  1.3937,  2.4289,  1.3639],
        [-4.7490, -1.3679, -1.9048,  ..., -1.7002, -1.8848,  0.3995]])

In [20]:
model.cpu()

x = model.pert_embedding(torch.tensor(pert2targ.pert_idx.values.astype(int), dtype=torch.long)).detach().cpu().numpy()

#scaler = StandardScaler()
#scaler.fit(x)
#x = scaler.transform(x)

print('x shape', x.shape)
out = {'target': [], 'auroc':[], 'null_auroc_95ci':[], 'null_auroc_05ci':[]}

for i,targ in enumerate(pert2targ.columns[3:]): 
    
    y = pert2targ[targ].values*1.

    #print('y shape', y.shape)
    #print('y pos class:', y.mean())

    cv = StratifiedKFold(n_splits=5, shuffle=True)
    dti_model = LogisticRegression(class_weight='balanced', max_iter=1000)#, penalty='l2', solver='newton-cholesky') 
    #dti_model = SVC(kernel='linear', class_weight='balanced', max_iter=1000) 
    scores = cross_val_score(dti_model, x, y, scoring='roc_auc', cv=cv, n_jobs=5)

    # This really doesn't need to be done every time
    # since our performance is relative, doesn't necessarily need to be done at all
    # nice as a check thought...
    rand_scores = []
    for j in range(10): 
        y_rand = np.random.permutation(y)
        rand_scores.append(cross_val_score(dti_model, x, y_rand, scoring='roc_auc', cv=cv, n_jobs=5).mean())
    rand_scores = np.array(rand_scores)

    out['target'].append(targ)
    out['auroc'].append(scores.mean())
    out['null_auroc_05ci'].append(np.quantile(rand_scores, q=0.05))
    out['null_auroc_95ci'].append(np.quantile(rand_scores, q=0.95))

    print(f'[{i}] targ: {targ} \t|| roc: {scores.mean():.4f} \t|| null model roc ci: {np.quantile(rand_scores, q=0.05):.4f}, {np.quantile(rand_scores, q=0.95):.4f}', end='\r')

print()

out = pd.DataFrame(out)
out.head()

x shape (30519, 64)
[49] targ: JAK2 	|| roc: 0.7517 	|| null model roc ci: 0.3743, 0.5592882


Unnamed: 0,target,auroc,null_auroc_95ci,null_auroc_05ci
0,DRD2,0.587785,0.55727,0.418583
1,NR3C1,0.611947,0.566616,0.433207
2,HTR2A,0.610038,0.54826,0.444339
3,KDR,0.585589,0.580446,0.452139
4,PTGS2,0.711895,0.563742,0.405501


In [26]:
res = out[['auroc', 'null_auroc_95ci']].mean().to_frame().T
res = res.assign(q=0)
res.head()

Unnamed: 0,auroc,null_auroc_95ci,q
0,0.595156,0.583986,0
