In [88]:
import os
import tqdm
import random
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset
from torch.utils.data.sampler import SubsetRandomSampler, WeightedRandomSampler
from sklearn.preprocessing import LabelEncoder, OneHotEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, roc_curve, auc
from scipy import interp
from itertools import cycle

In [78]:
def encode_seq(seq):
    encoding_list = []
    for i in range(len(seq)):
        if seq[i] == "*":
            encoding_list.append(np.zeros(5).reshape(1,5))
        elif seq[i] == "_":
            # print("Error: seqCDR contains '_'")
            # encoding_list.append(np.zeros(5).reshape(1,5))
            return np.nan
        else:
            encoding_list.append(af.loc[seq[i]].values.reshape(1,5))
    return np.array(encoding_list).reshape(1,-1)

af = pd.read_csv("~/data/project/pMHC-TCR/library/Atchley_factors.csv")
af.index = af["Amino acid"]
af.drop(columns=["Amino acid"], inplace=True)

In [86]:
class TCREncodeData(Dataset):
    def __init__(self, file_path):
        df = pd.read_csv(file_path, index_col=0)
        # just use the positive samples
        df = df[df["Class"] == "positive"]
        df = df.drop_duplicates(subset=["AseqCDR3", "BseqCDR3"], keep="first")

        df = df.loc[df["AseqCDR3"].str.len() < 40, :]
        df = df.loc[df["BseqCDR3"].str.len() < 40, :]

        len_map = {
            "AseqCDR3": df["AseqCDR3"].apply(lambda x: len(x)).max(),
            "BseqCDR3": df["BseqCDR3"].apply(lambda x: len(x)).max(),
        }
        print(len_map)
        
        for chain in ["AseqCDR3", "BseqCDR3"]:
            length = len_map[chain]
            df[chain] = df[chain].apply(
                lambda x: x + "*" * (length - len(x))
            )
            df[chain] = df[chain].apply(lambda x: encode_seq(x))

        df = df.dropna()
        print(df.shape)

        X = torch.zeros((len(df), 0))
        for seq in ["AseqCDR3", "BseqCDR3"]:
            X = torch.cat((X, torch.from_numpy(
                np.vstack(df[seq].values)
            )), dim=1)
        
        y = df["Class"].apply(lambda x: 1 if x == "positive" else 0).values

        self.X = X
        self.y = torch.from_numpy(y).float()
        self.Aseq_len = len_map["AseqCDR3"]
        self.Bseq_len = len_map["BseqCDR3"]

    def __len__(self):
        return len(self.y)
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

In [87]:
file_path = "~/data/project/data/seqData/20230228.csv"
TCRData = TCREncodeData(file_path)

{'AseqCDR3': 25, 'BseqCDR3': 21}
(787, 5)


# TCR autoencoder discrete

In [83]:
class TCR_a_autoencoder(nn.Module):
    def __init__(self, kernel_size=3, stride=2, padding=1, batch_size=16):
        super(TCR_a_autoencoder, self).__init__()
        self.batch_size = batch_size
        self.encoder = nn.Sequential(
            # (W + 2P - K)/S + 1
            # (batch_size, 5, 25)
            nn.Conv1d(5, 7, kernel_size=kernel_size, stride=stride, padding=padding),
            # (batch_size, 7, 13)
            nn.LeakyReLU(),
            
            nn.Conv1d(7, 8, kernel_size=kernel_size, stride=stride, padding=padding),
            # (batch_size, 8, 7)
            nn.LeakyReLU(),

            nn.Conv1d(8, 9, kernel_size=5, stride=stride, padding=padding),
            # (batch_size, 9, 3)
            nn.LeakyReLU(),

            nn.Conv1d(9, 10, kernel_size=5, stride=stride, padding=padding),
            nn.LeakyReLU(),
            # (batch_size, 10, 1)
        )

        self.decoder = nn.Sequential(
            # (w-1)S-2P+F
            # (batch_size, 10, 1)
            nn.ConvTranspose1d(10, 9, kernel_size=5, stride=2, padding=1),
            # (batch_size, 9, 3)
            nn.LeakyReLU(),
            
            nn.ConvTranspose1d(9, 8, kernel_size=5, stride=2, padding=1),
            # (batch_size, 8, 7)
            nn.LeakyReLU(),

            nn.ConvTranspose1d(8, 7, kernel_size=3, stride=2, padding=1),
            # (batch_size, 7, 13)
            nn.LeakyReLU(),

            nn.ConvTranspose1d(7, 5, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(),
            # (batch_size, 5, 25)
        )
    
    def forward(self, input):
        x = input.float()
        encoded = self.encoder(x)
        encoded = encoded.float()
        output = self.decoder(encoded)
        return encoded, output

In [84]:
class TCR_b_autoencoder(nn.Module):
    def __init__(self, kernel_size=3, stride=2, padding=1, batch_size=16):
        super(TCR_b_autoencoder, self).__init__()
        self.batch_size = batch_size
        self.encoder = nn.Sequential(
            # (W + 2P - K)/S + 1
            # (batch_size, 5, 21)
            nn.Conv1d(5, 7, kernel_size=kernel_size, stride=stride, padding=padding),
            # (batch_size, 7, 11)
            nn.LeakyReLU(),

            nn.Conv1d(7, 8, kernel_size=kernel_size, stride=stride, padding=padding),
            # (batch_size, 8, 6)
            nn.LeakyReLU(),

            nn.Conv1d(8, 9, kernel_size=kernel_size, stride=stride, padding=padding),
            # (batch_size, 9, 3)
            nn.LeakyReLU(),

            nn.Conv1d(9, 10, kernel_size=5, stride=stride, padding=padding),
            # (batch_size, 10, 1)
            nn.LeakyReLU(),
        )

        self.decoder = nn.Sequential(
            # (W-1)S-2P+F
            # (batch_size, 10, 1)
            nn.ConvTranspose1d(10, 9, kernel_size=5, stride=stride, padding=padding),
            # (batch_size, 9, 3)
            nn.LeakyReLU(),

            nn.ConvTranspose1d(9, 8, kernel_size=3, stride=stride, padding=padding),
            # (batch_size, 8, 5)
            nn.LeakyReLU(),

            nn.ConvTranspose1d(8, 7, kernel_size=5, stride=stride, padding=padding),
            # (batch_size, 7, 11)
            nn.LeakyReLU(),

            nn.ConvTranspose1d(7, 5, kernel_size=3, stride=stride, padding=padding),
            # (batch_size, 5, 21)
            nn.LeakyReLU(),
        )

    def forward(self, input):
        x= input.float()
        encoded = self.encoder(x)
        encoded = encoded.float()
        output = self.decoder(encoded)
        return encoded, output

In [None]:
def train_ae(model, train_loader, optimizer, criterion, epoch, seq_len):
    model.train()
    batch_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.float()
        data = data.view(batch_size, 5, seq_len)
        optimizer.zero_grad()
        _, output = model(data)
        loss = criterion(output, data)
        batch_loss += loss.item()
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
            epoch, batch_idx * len(data), len(train_loader.dataset),
            100. * batch_idx / len(train_loader), loss.item()))
        
    return batch_loss / len(data)

epochs = 100
batch_size = 16
learning_rate = 1e-3
kernel_size = 3
stride = 2
padding = 1
seq_len_a = TCRData.Aseq_len
seq_len_b = TCRData.Bseq_len

# train the autoencoder
model = TCR_a_autoencoder(kernel_size=kernel_size, stride=stride, padding=padding, batch_size=batch_size)
criterion = nn.MSELoss()
train_data, test_data = torch.utils.data.random_split(TCRData, lengths=[0.8, 0.2])
train_data = Subset(TCRData, train_data.indices)
test_data = Subset(TCRData, test_data.indices)

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, drop_last=True)

# plot the loss 
fig, ax = plt.subplots(1, 1, figsize=(5,5))

TCR_encode_losses = []
TCR_accuracy = 0
for epoch in range(1, epochs+1):
    TCR_encode_loss = train_autoencoder