In [1]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import pickle
import os
from torch.utils.data import Dataset
from sklearn import metrics

In [6]:
def featurize(data):
    xtrain = data.iloc[:,1:]
    xtrain = xtrain.loc[:, xtrain.var() != 0]
    atchleys = xtrain.loc[:, xtrain.columns.str.contains('AF')]
    atchleys = atchleys.loc[:, ~atchleys.columns.str.contains('by')]


    atch1 = atchleys.loc[:, atchleys.columns.str.contains('AF1')]
    atch2 = atchleys.loc[:, atchleys.columns.str.contains('AF2')]
    atch3 = atchleys.loc[:, atchleys.columns.str.contains('AF3')]
    atch4 = atchleys.loc[:, atchleys.columns.str.contains('AF4')]
    atch5 = atchleys.loc[:, atchleys.columns.str.contains('AF5')]

    newatch = np.stack([atch1, atch2, atch3, atch4, atch5], axis=-1)
    newatch = [np.transpose(x) for x in newatch]
    newatch = np.array(newatch)

    return newatch

In [7]:
class finalnet(nn.Module):
    def __init__(self, n_nodes):
        super(finalnet,self).__init__()

        self.conv = nn.Conv1d(5, 10, 3)
        self.pool = nn.AvgPool1d(3, stride=1)
        self.flatten = nn.Flatten()
        self.linear1 = nn.Linear(540, n_nodes)
        self.linearout = nn.Linear(n_nodes+1, 1)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x, x2):
        x = self.relu(self.conv(x))
        x = self.pool(x)
        x = self.flatten(x)
        x = self.relu(torch.cat((x2, self.linear1(x)), 1))
        x = self.linearout(x)
        return x

In [8]:
class pairedDataset(Dataset):

    def __init__(self, features_file, scores_file, use_x2=True):
        feats = pd.read_csv(features_file)
        scores = pd.read_csv(scores_file)
        feats = feats[scores['target'].notna()]
        scores = scores[scores['target'].notna()]
        x = featurize(feats)
        self.x = torch.tensor(x, dtype=torch.float32)
        x2 = scores['score'].values
        if (use_x2):
            self.x2 = torch.tensor(x2.reshape(x2.shape[0],1), dtype=torch.float32)
        else:
            self.x2 = torch.zeros((x2.shape[0],1), dtype=torch.float32)
        y = np.array(scores['target']*1, dtype='float32')
        self.y = torch.tensor(y.reshape(y.shape[0],1), dtype=torch.float32)
        print(scores.columns)
        if 'cell' in scores.columns:
            self.barcode = scores['cell'].values
        else:
            self.barcode = scores['Unnamed: 0'].values

    def __len__(self):

        return len(self.x)

    def __getitem__(self, idx):

        return self.x[idx], self.x2[idx], self.y[idx], self.barcode[idx]

In [9]:
def score_tcrs (model, loader, batch_size):
    outputs = np.zeros(len(loader)*batch_size)
    testy = np.zeros(len(loader)*batch_size)
    cell = []
    i = 0
    for x, x2, y, barcode in loader:
        step = min(batch_size, y.shape[0])
        testy[i:(step+i)] = y.detach().numpy().flatten()
        out = model(x, x2)
        outputs[i:(step+i)] = out.detach().numpy().flatten()
        cell[i:(step+1)] = barcode
        i = i + step
    outputs = outputs[0:i]
    testy = testy[0:i]
    res = pd.DataFrame({'cell': cell, 'cnn_score':outputs, 'y':testy})
    return(res)

In [10]:
with open('data/mycnn091424_final/cnn091424_target1_batch64_lr0.0003_nnodes10.pkl', 'rb') as f:
    res = pickle.load(f)
    dict1 = pickle.load(f)

In [11]:
with open('data/mycnn091424_final/cnn091424_target2_batch64_lr0.0003_nnodes10.pkl', 'rb') as f:
    res = pickle.load(f)
    dict2 = pickle.load(f)

In [12]:
with open('data/mycnn091424_final/cnn091424_target3_batch64_lr0.0003_nnodes10.pkl', 'rb') as f:
    res = pickle.load(f)
    dict3 = pickle.load(f)

In [13]:
with open('data/mycnn091424_final/cnn091424_target4_batch64_lr0.0003_nnodes10.pkl', 'rb') as f:
    res = pickle.load(f)
    dict4 = pickle.load(f)

In [14]:
model1 = finalnet(n_nodes=10)
model1.load_state_dict(dict1)
model2 = finalnet(n_nodes=10)
model2.load_state_dict(dict2)
model3 = finalnet(n_nodes=10)
model3.load_state_dict(dict3)
model4 = finalnet(n_nodes=10)
model4.load_state_dict(dict4)

<All keys matched successfully>

In [22]:
scores_file = 'stephenson_x2files/target1_x2file.csv'

sdata1 = pairedDataset('data/stephenson_TCRs_ftzdM_scaled_nocex_061224.csv', scores_file, use_x2=True)

Index(['Unnamed: 0', 'cell', 'score', 'target'], dtype='object')


In [23]:
loader1 = torch.utils.data.DataLoader(sdata1,
                batch_size=256,
                shuffle=False)
steph_scores1 = score_tcrs(model1, loader1, 256)
metrics.roc_auc_score(steph_scores1['y'], steph_scores1['cnn_score'])


0.836487692980701

In [26]:
scores_file = 'stephenson_x2files/target2_x2file.csv'
sdata2 = pairedDataset('stephenson_TCRs_ftzdM_scaled_nocex_061224.csv', scores_file, use_x2=True)

Index(['Unnamed: 0', 'cell', 'score', 'target'], dtype='object')


In [27]:
loader2 = torch.utils.data.DataLoader(sdata2,
                batch_size=256,
                shuffle=False)
steph_scores2 = score_tcrs(model2, loader2, 256)
metrics.roc_auc_score(steph_scores2['y'], steph_scores2['cnn_score'])


0.7628802552810205

In [29]:
scores_file = 'stephenson_x2files/target3_x2file.csv'
sdata3 = pairedDataset('stephenson_TCRs_ftzdM_scaled_nocex_061224.csv', scores_file, use_x2=True)

Index(['Unnamed: 0', 'cell', 'score', 'target'], dtype='object')


In [30]:
loader3 = torch.utils.data.DataLoader(sdata3,
                batch_size=256,
                shuffle=False)
steph_scores3 = score_tcrs(model3, loader3, 256)
metrics.roc_auc_score(steph_scores3['y'], steph_scores3['cnn_score'])


0.5610634373947461

In [31]:
scores_file = 'stephenson_x2files/target4_x2file.csv'
sdata4 = pairedDataset('stephenson_TCRs_ftzdM_scaled_nocex_061224.csv', scores_file, use_x2=True)

Index(['Unnamed: 0', 'cell', 'score', 'target'], dtype='object')


In [32]:
loader4 = torch.utils.data.DataLoader(sdata4,
                batch_size=256,
                shuffle=False)
steph_scores4 = score_tcrs(model4, loader4, 256)
metrics.roc_auc_score(steph_scores4['y'], steph_scores4['cnn_score'])


0.5333481349746946

In [37]:
steph_scores1.to_csv("data/stephenson_CNNtarget1scores_091524.csv")
steph_scores2.to_csv("data/stephenson_CNNtarget2scores_091524.csv")
steph_scores3.to_csv("data/stephenson_CNNtarget3scores_091524.csv")
steph_scores4.to_csv("data/stephenson_CNNtarget4scores_091524.csv")