In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
import scipy.stats as ss
import seaborn as sns
sc.settings.set_figure_params(dpi=100)
print(sc.__version__)

### Read in the Data

In [None]:
import pickle as pkl
# read in the aggregated values
with open('../external_data/db.ags.pkl', 'rb') as f: ags = pkl.load(f)
with open('../external_data/db.tras.pkl', 'rb') as f: tras = pkl.load(f)
with open('../external_data/db.trbs.pkl', 'rb') as f: trbs = pkl.load(f)
with open('../external_data/db.paired_tcrs.pkl', 'rb') as f: paired_tcrs = pkl.load(f)
ags, tras, trbs = pd.Series(ags), pd.Series(tras), pd.Series(trbs)
print(len(ags)); print(len(tras)); print(len(trbs))

### Setup the Model Core Functions

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

In [None]:
def train(epoch, loss_func):
    model.train()
    train_loss = 0
    for batch_idx, data in enumerate(train_loader):
        data = data[0].to(device)
        optimizer.zero_grad()
        (recon_batch, recon_len), mu, logvar = model(data)
        if loss_func == 1:
            loss = loss_function1(recon_batch, recon_len, data, mu, logvar)
        elif loss_func == 2:
            loss = loss_function2(recon_batch, recon_len, data, mu, logvar)
        elif loss_func == 3:
            loss = loss_function3(recon_batch, recon_len, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(train_loader.dataset)))
    return train_loss / len(train_loader.dataset)
    
def test(epoch, loss_func):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i, data in enumerate(test_loader):
            data = data[0].to(device)
            (recon_batch, recon_len), mu, logvar = model(data)
            if loss_func == 1:
                test_loss += loss_function1(recon_batch, recon_len, data, mu, logvar).item()
            elif loss_func == 2:
                test_loss += loss_function2(recon_batch, recon_len, data, mu, logvar).item()
            elif loss_func == 3:
                test_loss += loss_function3(recon_batch, recon_len, data, mu, logvar).item()

    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))
    return test_loss

In [None]:
# define the key parameters
init_embed_size = 50-1
protein_len = 48
init_kernel_size = 3
init_cnn_filters = 256
init_kernel_stride = 1
init_kernel_padding = 1
secn_cnn_filters = 256
latent_dim = 32
vocab = ['A','C','D','E','F','G','H','I','K','L','M','N','P','Q','R','S','T','V','W','Y']
# we want the embedding output to be the vocab with the length to allow for reconstruction
out_embed_size = len(vocab)
n_nodes_len = 32

# define the convolutional variational autoencoder
class ConvVAE(nn.Module):
    def __init__(self):
        super(ConvVAE, self).__init__()

        # encoding
        self.fc1 = nn.Conv1d(
            in_channels=init_embed_size, out_channels=init_cnn_filters, kernel_size=init_kernel_size, 
            stride=init_kernel_stride, padding=init_kernel_padding,
        )
        self.fc2 = nn.Conv1d(
            in_channels=init_cnn_filters, out_channels=secn_cnn_filters, kernel_size=init_kernel_size, 
            stride=init_kernel_stride, padding=init_kernel_padding
        )
        # variational sampling
        self.fc31 = nn.Linear(secn_cnn_filters*protein_len, latent_dim)
        self.fc32 = nn.Linear(secn_cnn_filters*protein_len, latent_dim)
        self.fc4 = nn.Linear(latent_dim, secn_cnn_filters*protein_len)
        # decoding
        self.fc5 = nn.ConvTranspose1d(
            in_channels=secn_cnn_filters, out_channels=init_cnn_filters, kernel_size=init_kernel_size, 
            stride=init_kernel_stride, padding=init_kernel_padding
        )
        self.fc6 = nn.ConvTranspose1d(
            in_channels=init_cnn_filters, out_channels=out_embed_size, kernel_size=init_kernel_size, 
            stride=init_kernel_stride, padding=init_kernel_padding
        )
        self.fc7 = nn.Linear(init_cnn_filters*protein_len, n_nodes_len)
        self.fc8 = nn.Linear(n_nodes_len, 1)

    def encode(self, x):
        x1 = nn.LeakyReLU()(self.fc1(x[:, :-1, :]))
        x2 = nn.LeakyReLU()(self.fc2(x1))
        x2_ = nn.Flatten()(x2)
        return self.fc31(x2_), self.fc32(x2_)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        x4 = nn.LeakyReLU()(self.fc4(z))
        x4_ = x4.view(-1, secn_cnn_filters, protein_len)
        x5 = nn.LeakyReLU()(self.fc5(x4_))
        x5_ = nn.Flatten()(x5)
        x6 = nn.Sigmoid()(self.fc6(x5))
        return x6, self.fc8(nn.LeakyReLU()(self.fc7(x5_)))

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

In [None]:
# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function1(recon_x, recon_len, x, mu, logvar):
    # get the data
    BCE = nn.functional.binary_cross_entropy(recon_x, x[:, :len(vocab), :], reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD
def loss_function2(recon_x, recon_len, x, mu, logvar):
    # get the data
    TSE = (recon_len - x[:, -1, :]).pow(2).sum()
    return TSE
def loss_function3(recon_x, recon_len, x, mu, logvar):
    # get the data
    BCE = nn.functional.binary_cross_entropy(recon_x, x[:, :len(vocab), :], reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    TSE = (recon_len - x[:, -1, :]).pow(2).sum()
    return BCE + KLD + TSE

### TRB

In [None]:
import pickle as pkl
# embed all of our unique TRBs
with open('../outs/map.trb_to_embed.extended.pkl', 'rb') as f: trb_to_embed = pkl.load(f)
X_trbs = torch.stack([x.to(torch.float32) for x in trbs.map(trb_to_embed)])
# randomly split the data into train and test
torch.manual_seed(0); np.random.seed(0)
idxs_train = np.random.choice(range(len(X_trbs)), size=round(len(X_trbs)*0.75), replace=False)
idxs_test = np.array(range(len(X_trbs)))
idxs_test = idxs_test[~np.isin(idxs_test, idxs_train)]
X_trbs_train = X_trbs[idxs_train]
X_trbs_test = X_trbs[idxs_test]

In [None]:
from torch.utils.data import TensorDataset, DataLoader
# create a latent space for the trbs
batch_size = 2048
train_loader = DataLoader(dataset=TensorDataset(X_trbs_train), batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=TensorDataset(X_trbs_test), batch_size=batch_size, shuffle=False)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# initialize the model
model = ConvVAE().to(device)

In [None]:
# set the seed for training
torch.manual_seed(0); np.random.seed(0)
# set the learning parameters
lr = 0.0005; epochs = 20
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
# train the model with dual losses to balance between two objectives
train_losses = []; test_losses = []
for epoch in range(1, epochs + 1):
    train_losses_, test_losses_ = [], []
    for loss_func in [1, 2]:
        train_losses_.append(train(epoch, loss_func))
        test_losses_.append(test(epoch, loss_func))
    train_losses.append(train_losses_)
    test_losses.append(test_losses_)

In [None]:
# examine the KLD and BCE loss
fig, ax = plt.subplots(figsize=[4, 4]); ax.grid(False)
ax.plot([x[0] for x in train_losses], color='dodgerblue')
ax.plot([x[0] for x in test_losses], color='skyblue', linestyle='--')
ax.set(xlabel='Epochs', ylabel='KLD + BCE Loss')
# examine the length loss
fig, ax = plt.subplots(figsize=[4, 4]); ax.grid(False)
ax.plot([x[1] for x in train_losses], color='dodgerblue')
ax.plot([x[1] for x in test_losses], color='skyblue', linestyle='--')
ax.set(xlabel='Epochs', ylabel='TSE (LEN) Loss')

In [None]:
# set the seed for training
torch.manual_seed(0); np.random.seed(0)
# set the learning parameters
lr = 0.001; epochs = 40
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
# train the model with an integrated loss
train_losses = []; test_losses = []; epochs
for epoch in range(1, epochs + 1):
    train_losses_, test_losses_ = [], []
    for loss_func in [3]:
        train_losses_.append(train(epoch, loss_func))
        test_losses_.append(test(epoch, loss_func))
    train_losses.append(train_losses_)
    test_losses.append(test_losses_)

In [None]:
# examine the integrated loss from finetuning
fig, ax = plt.subplots(figsize=[4, 4]); ax.grid(False)
ax.plot([x[0] for x in train_losses], color='dodgerblue')
ax.plot([x[0] for x in test_losses], color='skyblue', linestyle='--')
ax.set(xlabel='Epochs', ylabel='KLD + BCE + TSE Loss')

In [None]:
# retrieve the predictions
recon_lens = []; recon_batchs = []
with torch.no_grad():
    for i, data in enumerate(test_loader):
        data = data[0].to(device)
        (recon_batch, recon_len), _, _ = model(data)
        recon_batchs.extend(recon_batch.clone().detach().cpu().numpy())
        recon_lens.extend(recon_len.clone().detach().cpu().tolist())
recon_lens = [x[0] for x in recon_lens]
# retrieve the indices
trbs_test = pd.Series(trbs.iloc[idxs_test])

In [None]:
# there is an ability to reconstruct proper length
fig, ax = plt.subplots(); ax.grid(False)
ax.scatter(recon_lens, trbs_test.apply(len), s=10, alpha=0.1, zorder=1, color='dodgerblue')
xlim, ylim = ax.get_xlim(), ax.get_ylim()
domain = [min(min(xlim), min(ylim)), max(max(xlim), max(ylim))]
ax.plot(domain, domain, color='skyblue', linestyle='--', zorder=0)
ax.set_xlim(*xlim); ax.set_ylim(*ylim)
ax.set(xlabel='Predicted length', ylabel='True length')
print('Non-Rounded', ss.pearsonr(recon_lens, trbs_test.apply(len)))
# with rounding the picture becomes more clear, there is a linear relationship but not an exact one
fig, ax = plt.subplots(); ax.grid(False)
ax.scatter(pd.Series(recon_lens).apply(round), trbs_test.apply(len), s=10, alpha=0.1, zorder=1, color='dodgerblue')
xlim, ylim = ax.get_xlim(), ax.get_ylim()
domain = [min(min(xlim), min(ylim)), max(max(xlim), max(ylim))]
ax.plot(domain, domain, color='skyblue', linestyle='--', zorder=0)
ax.set_xlim(*xlim); ax.set_ylim(*ylim)
ax.set(xlabel='Predicted length', ylabel='True length')
print('Rounded', ss.pearsonr(recon_lens, trbs_test.apply(len)))

In [None]:
# compute the average difference in sequence length
fig, ax = plt.subplots(figsize=[1, 3]); ax.grid(False)
sns.boxplot(y=pd.Series(recon_lens).apply(round) - trbs_test.apply(len), color='dodgerblue')
ax.set(ylabel='Length difference')
# examine the distribution to be sure
fig, ax = plt.subplots(figsize=[1, 3]); ax.grid(False)
sns.violinplot(y=pd.Series(recon_lens).apply(round) - trbs_test.apply(len), color='dodgerblue')
ax.set(ylabel='Length difference')
# provide statistics
(pd.Series(recon_lens).apply(round) - trbs_test.apply(len)).describe()

In [None]:
# save the model
torch.save(model.state_dict(), '../models/model.convvae.trb.torch')

In [None]:
# load the model
model.load_state_dict(torch.load('../models/model.convvae.trb.torch', weights_only=True))
model.eval()

In [None]:
from tqdm import tqdm
# get the encoded dimensions
torch.manual_seed(0); np.random.seed(0)
# create a complete loader or else there is an out of memory error
complete_loader = DataLoader(dataset=TensorDataset(X_trbs), batch_size=batch_size, shuffle=False)
# move through each subset in the complete loader
z_dims_per_batch = []
with torch.no_grad():
    for data in tqdm(complete_loader):
        data = data[0].to(device)
        enc_out = model.encode(data)
        # sampling centers around the mean so we just use mu
        z_dims_per_batch.append(enc_out[0].clone().detach().cpu().numpy())
z_dims = np.vstack(z_dims_per_batch)
del data, enc_out, z_dims_per_batch

In [None]:
from anndata import AnnData
from scipy.sparse import csr_matrix
adata = AnnData(z_dims, dtype=float)
sc.tl.pca(adata)
sc.pp.neighbors(adata, use_rep='X')
sc.tl.umap(adata)
sc.tl.leiden(adata, flavor='igraph', n_iterations=2)
adata.obs_names = trbs

In [None]:
# check length
adata.obs['LEN'] = adata.obs.index.to_series().apply(len)
sc.pl.umap(adata, color=['LEN'], vmin=12, vmax=20)

In [None]:
# save the current data
adata.write('../outs/adata.trb.h5ad')

##### Example sampling

In [None]:
# report statistics, averages
means = adata.X.mean(0); means

In [None]:
# report statistics, standard deviations
stds = adata.X.std(0); stds

In [None]:
# define a function to embed an amino acid with vocab only
map_direct = {x:[1 * (x == y) for y in vocab] for x in vocab}
def embed_aa(aa):
    return [x for x in map_direct[aa]]
# define a function to interpolate the protein
targ_len = 48
def stretch_pep(embedding, targ_len=targ_len):
    # get the current protein length
    orig_len, n_features = embedding.shape
    # derive the original and current lengths
    x = np.linspace(0, 1, targ_len)
    xp = np.linspace(0, 1, orig_len)
    # loop through each of the columns
    tensor = torch.Tensor(embedding.T.reshape(1, n_features, orig_len))
    res = torch.nn.functional.interpolate(tensor, size=(targ_len), mode='linear', align_corners=False)[0]
    return res

In [None]:
from Levenshtein import distance as levenshtein
# pick a random example to demonstrbte forward fluidity
torch.manual_seed(0); np.random.seed(0)
idxs = np.random.choice(range(len(trbs_test)), size=5, replace=False)
for idx in idxs:
    # retrieve the sequence of interest
    recon_batch = recon_batchs[idx]
    recon_len = recon_lens[idx]
    # retrieve the model resolved sequence and truth
    pred = pd.DataFrame(recon_batch, index=vocab)
    true_seq = trbs_test.iloc[idx]
    true = pd.DataFrame(stretch_pep(np.array([embed for embed in map(embed_aa, list(true_seq))])).numpy(), index=vocab)
    # plot the predicted vs. truth
    fig, ax = plt.subplots(figsize=[8, 4]); ax.grid(False)
    sns.heatmap(pred, xticklabels=0, yticklabels=1, cmap='PuBu')
    fig, ax = plt.subplots(figsize=[8, 4]); ax.grid(False)
    sns.heatmap(true, xticklabels=0, yticklabels=1, cmap='PuBu')
    # highlight any discrepancies
    fig, ax = plt.subplots(figsize=[8, 4]); ax.grid(False)
    sns.heatmap(pred-true, xticklabels=0, yticklabels=1, cmap='coolwarm', vmin=-1, vmax=1)
    # resolve the sequence
    curr_len = 48; targ_len = round(recon_len)
    xp = np.arange(curr_len) / (curr_len - 1)
    x = np.arange(targ_len) / (targ_len - 1)
    # interpolate the results
    res = np.array([np.interp(x, xp, recon_batch[idx, :]) for idx in range(recon_batch.shape[0])])
    pred_seq = ''.join(pd.DataFrame(res.T, columns=vocab).idxmax(1))
    print(true_seq, pred_seq, levenshtein(true_seq, pred_seq))

In [None]:
# as expected each are around 0 to 1 and thus we may sample similarly
torch.manual_seed(0); np.random.seed(0)
z = torch.Tensor([np.random.normal(loc=0, scale=1, size=1)[0] for idx in range(32)]).to('cuda')
tmp_out, tmp_len = model.decode(z)
tmp_out = tmp_out.clone().detach().cpu().numpy()
tmp_len = tmp_len.clone().detach().cpu().numpy()
tmp_out = tmp_out[0]; tmp_len = round(tmp_len[0][0])

In [None]:
# resolve the sequence with a deduction system
fig, ax = plt.subplots(figsize=[8, 4]); ax.grid(False)
data = pd.DataFrame(tmp_out.T, columns=vocab).T
sns.heatmap(data, xticklabels=0, yticklabels=1, cmap='PuBu')

In [None]:
# give examples if we had randomly chosen a "true sequence"
torch.manual_seed(0); np.random.seed(0)
idxs = np.random.choice(range(len(trbs_test)), size=5, replace=False)
for idx in idxs:
    # retrieve the sequence of interest
    recon_batch = recon_batchs[idx]
    recon_len = recon_lens[idx]
    # retrieve the model resolved sequence and truth
    pred = pd.DataFrame(recon_batch, index=vocab)
    true_seq = np.random.choice(trbs_test, size=1)[0]
    true = pd.DataFrame(stretch_pep(np.array([embed for embed in map(embed_aa, list(true_seq))])).numpy(), index=vocab)
    # plot the predicted vs. truth
    fig, ax = plt.subplots(figsize=[8, 4]); ax.grid(False)
    sns.heatmap(pred, xticklabels=0, yticklabels=1, cmap='PuBu')
    fig, ax = plt.subplots(figsize=[8, 4]); ax.grid(False)
    sns.heatmap(true, xticklabels=0, yticklabels=1, cmap='PuBu')
    # highlight any discrepancies
    fig, ax = plt.subplots(figsize=[8, 4]); ax.grid(False)
    sns.heatmap(pred-true, xticklabels=0, yticklabels=1, cmap='coolwarm', vmin=-1, vmax=1)
    # resolve the sequence
    curr_len = 48; targ_len = round(recon_len)
    xp = np.arange(curr_len) / (curr_len - 1)
    x = np.arange(targ_len) / (targ_len - 1)
    # interpolate the results
    res = np.array([np.interp(x, xp, recon_batch[idx, :]) for idx in range(recon_batch.shape[0])])
    pred_seq = ''.join(pd.DataFrame(res.T, columns=vocab).idxmax(1))
    print(true_seq, pred_seq, levenshtein(true_seq, pred_seq))

In [None]:
# give examples if we had randomly chosen a "true sequence"
torch.manual_seed(0); np.random.seed(0)
idxs = np.random.choice(range(len(trbs_test)), size=100, replace=False)
# trbck distances
df_dist = pd.DataFrame(columns=['dist_to_truth','dist_to_rand'])
for idx in idxs:
    # retrieve the sequence of interest
    recon_batch = recon_batchs[idx]
    recon_len = recon_lens[idx]
    # retrieve the model resolved sequence and truth
    pred = pd.DataFrame(recon_batch, index=vocab)
    true_seq = trbs_test.iloc[idx]
    true = pd.DataFrame(stretch_pep(np.array([embed for embed in map(embed_aa, list(true_seq))])).numpy(), index=vocab)
    # resolve the sequence
    curr_len = 48; targ_len = round(recon_len)
    xp = np.arange(curr_len) / (curr_len - 1)
    x = np.arange(targ_len) / (targ_len - 1)
    # interpolate the results
    res = np.array([np.interp(x, xp, recon_batch[idx, :]) for idx in range(recon_batch.shape[0])])
    pred_seq = ''.join(pd.DataFrame(res.T, columns=vocab).idxmax(1))
    dist_from_truth = levenshtein(true_seq, pred_seq); dist_from_rand = 0
    for _ in range(10):
        dist_from_rand += levenshtein(np.random.choice(trbs_test)[0], pred_seq)
    df_dist.loc[df_dist.shape[0]] = dist_from_truth, dist_from_rand / 10

In [None]:
# compare prediction accuracy statistically
fig, ax = plt.subplots(figsize=[2, 4]); ax.grid(False)
sns.boxplot(x='variable', y='value', data=df_dist.melt(), ax=ax, saturation=1,
            linewidth=1.5, linecolor='dodgerblue', order=['dist_to_truth'], color='skyblue')
sns.boxplot(x='variable', y='value', data=df_dist.melt(), ax=ax, saturation=1,
            linewidth=1.5, linecolor='grey', order=['dist_to_rand'], color='lightgray')
# create artificial lines
# for idx in df_dist.index: ax.plot([0, 1], df_dist.loc[idx], color='k', alpha=0.1, lw=0.5, zorder=0, linestyle='--')
ax.tick_params(axis='x', labelrotation=90); ax.set_xticklabels(['Truth','Random']); ax.set_xlim(-1, 2)
ax.set(xlabel='Prediction vs.', ylabel='Levenshtein distance')
ss.wilcoxon(df_dist['dist_to_truth'], df_dist['dist_to_rand']),\
ss.mannwhitneyu(df_dist['dist_to_truth'], df_dist['dist_to_rand'])

In [None]:
# interpolate back to a sequence using the length
curr_len = 48; targ_len = tmp_len
# compute the x-coordinates of the original
xp = np.arange(curr_len) / (curr_len - 1)
x = np.arange(targ_len) / (targ_len - 1)
# interpolate the results
res = np.array([np.interp(x, xp, tmp_out[idx, :]) for idx in range(tmp_out.shape[0])])

In [None]:
# resolve the sequence with a deduction system
fig, ax = plt.subplots(figsize=[8, 4]); ax.grid(False)
data = pd.DataFrame(res.T, columns=vocab).T
sns.heatmap(data, xticklabels=0, yticklabels=1, cmap='PuBu')

In [None]:
# reveal the sequence
''.join(data.idxmax(0))

In [None]:
# define an engine to derive sequence
def engine(model, z):
    # derive the embedding and length
    tmp_out, tmp_len = model.decode(z)
    tmp_out = tmp_out.clone().detach().cpu().numpy()
    tmp_len = tmp_len.clone().detach().cpu().numpy()
    tmp_out = tmp_out[0]; tmp_len = round(tmp_len[0][0])
    
    # interpolate back to a sequence using the length
    curr_len = 48; targ_len = tmp_len
    # compute the x-coordinates of the original
    xp = np.arange(curr_len) / (curr_len - 1)
    x = np.arange(targ_len) / (targ_len - 1)
    # interpolate the results
    res = np.array([np.interp(x, xp, tmp_out[idx, :]) for idx in range(tmp_out.shape[0])])

    # derive the sequence
    data = pd.DataFrame(res.T, columns=vocab).T
    return ''.join(data.idxmax(0))

In [None]:
# provide more examples
torch.manual_seed(0); np.random.seed(0)
for _ in range(100):
    z_in = [np.random.normal(loc=means[idx], scale=stds[idx], size=1)[0] for idx in range(32)]
    z = torch.Tensor(z_in).to('cuda')
    print(engine(model, z))

### TRA

In [None]:
import pickle as pkl
# embed all of our unique TRAs
with open('../outs/map.tra_to_embed.extended.pkl', 'rb') as f: tra_to_embed = pkl.load(f)
X_tras = torch.stack([x.to(torch.float32) for x in tras.map(tra_to_embed)])
# randomly split the data into train and test
torch.manual_seed(0); np.random.seed(0)
idxs_train = np.random.choice(range(len(X_tras)), size=round(len(X_tras)*0.75), replace=False)
idxs_test = np.array(range(len(X_tras)))
idxs_test = idxs_test[~np.isin(idxs_test, idxs_train)]
X_tras_train = X_tras[idxs_train]
X_tras_test = X_tras[idxs_test]

In [None]:
from torch.utils.data import TensorDataset, DataLoader
# create a latent space for the tras
batch_size = 2048
train_loader = DataLoader(dataset=TensorDataset(X_tras_train), batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=TensorDataset(X_tras_test), batch_size=batch_size, shuffle=False)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# initialize the model
model = ConvVAE().to(device)

In [None]:
# set the seed for training
torch.manual_seed(0); np.random.seed(0)
# set the learning parameters
lr = 0.0005; epochs = 20
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
# train the model with dual losses to balance between two objectives
train_losses = []; test_losses = []
for epoch in range(1, epochs + 1):
    train_losses_, test_losses_ = [], []
    for loss_func in [1, 2]:
        train_losses_.append(train(epoch, loss_func))
        test_losses_.append(test(epoch, loss_func))
    train_losses.append(train_losses_)
    test_losses.append(test_losses_)

In [None]:
# examine the KLD and BCE loss
fig, ax = plt.subplots(figsize=[4, 4]); ax.grid(False)
ax.plot([x[0] for x in train_losses], color='dodgerblue')
ax.plot([x[0] for x in test_losses], color='skyblue', linestyle='--')
ax.set(xlabel='Epochs', ylabel='KLD + BCE Loss')
# examine the length loss
fig, ax = plt.subplots(figsize=[4, 4]); ax.grid(False)
ax.plot([x[1] for x in train_losses], color='dodgerblue')
ax.plot([x[1] for x in test_losses], color='skyblue', linestyle='--')
ax.set(xlabel='Epochs', ylabel='TSE (LEN) Loss')

In [None]:
# set the seed for training
torch.manual_seed(0); np.random.seed(0)
# set the learning parameters
lr = 0.001; epochs = 40
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
# train the model with an integrated loss
train_losses = []; test_losses = []; epochs
for epoch in range(1, epochs + 1):
    train_losses_, test_losses_ = [], []
    for loss_func in [3]:
        train_losses_.append(train(epoch, loss_func))
        test_losses_.append(test(epoch, loss_func))
    train_losses.append(train_losses_)
    test_losses.append(test_losses_)

In [None]:
# examine the integrated loss from finetuning
fig, ax = plt.subplots(figsize=[4, 4]); ax.grid(False)
ax.plot([x[0] for x in train_losses], color='dodgerblue')
ax.plot([x[0] for x in test_losses], color='skyblue', linestyle='--')
ax.set(xlabel='Epochs', ylabel='KLD + BCE + TSE Loss')

In [None]:
# retrieve the predictions
recon_lens = []; recon_batchs = []
with torch.no_grad():
    for i, data in enumerate(test_loader):
        data = data[0].to(device)
        (recon_batch, recon_len), _, _ = model(data)
        recon_batchs.extend(recon_batch.clone().detach().cpu().numpy())
        recon_lens.extend(recon_len.clone().detach().cpu().tolist())
recon_lens = [x[0] for x in recon_lens]
# retrieve the indices
tras_test = pd.Series(tras.iloc[idxs_test])

In [None]:
# there is an ability to reconstruct proper length
fig, ax = plt.subplots(); ax.grid(False)
ax.scatter(recon_lens, tras_test.apply(len), s=10, alpha=0.1, zorder=1, color='dodgerblue')
xlim, ylim = ax.get_xlim(), ax.get_ylim()
domain = [min(min(xlim), min(ylim)), max(max(xlim), max(ylim))]
ax.plot(domain, domain, color='skyblue', linestyle='--', zorder=0)
ax.set_xlim(*xlim); ax.set_ylim(*ylim)
ax.set(xlabel='Predicted length', ylabel='True length')
print('Non-Rounded', ss.pearsonr(recon_lens, tras_test.apply(len)))
# with rounding the picture becomes more clear, there is a linear relationship but not an exact one
fig, ax = plt.subplots(); ax.grid(False)
ax.scatter(pd.Series(recon_lens).apply(round), tras_test.apply(len), s=10, alpha=0.1, zorder=1, color='dodgerblue')
xlim, ylim = ax.get_xlim(), ax.get_ylim()
domain = [min(min(xlim), min(ylim)), max(max(xlim), max(ylim))]
ax.plot(domain, domain, color='skyblue', linestyle='--', zorder=0)
ax.set_xlim(*xlim); ax.set_ylim(*ylim)
ax.set(xlabel='Predicted length', ylabel='True length')
print('Rounded', ss.pearsonr(recon_lens, tras_test.apply(len)))

In [None]:
# compute the average difference in sequence length
fig, ax = plt.subplots(figsize=[1, 3]); ax.grid(False)
sns.boxplot(y=pd.Series(recon_lens).apply(round) - tras_test.apply(len), color='dodgerblue')
ax.set(ylabel='Length difference')
# examine the distribution to be sure
fig, ax = plt.subplots(figsize=[1, 3]); ax.grid(False)
sns.violinplot(y=pd.Series(recon_lens).apply(round) - tras_test.apply(len), color='dodgerblue')
ax.set(ylabel='Length difference')
# provide statistics
(pd.Series(recon_lens).apply(round) - tras_test.apply(len)).describe()

In [None]:
# save the model
torch.save(model.state_dict(), '../models/model.convvae.tra.torch')

In [None]:
# load the model
model.load_state_dict(torch.load('../models/model.convvae.tra.torch', weights_only=True))
model.eval()

In [None]:
# get the encoded dimensions
torch.manual_seed(0); np.random.seed(0)
# create a complete loader or else there is an out of memory error
complete_loader = DataLoader(dataset=TensorDataset(X_tras), batch_size=batch_size, shuffle=False)
# move through each subset in the complete loader
z_dims_per_batch = []
with torch.no_grad():
    for data in tqdm(complete_loader):
        data = data[0].to(device)
        enc_out = model.encode(data)
        # sampling centers around the mean so we just use mu
        z_dims_per_batch.append(enc_out[0].clone().detach().cpu().numpy())
z_dims = np.vstack(z_dims_per_batch)
del data, enc_out, z_dims_per_batch

In [None]:
from anndata import AnnData
from scipy.sparse import csr_matrix
adata = AnnData(z_dims, dtype=float)
sc.tl.pca(adata)
sc.pp.neighbors(adata, use_rep='X')
sc.tl.umap(adata)
sc.tl.leiden(adata, flavor='igraph', n_iterations=2
adata.obs_names = tras

In [None]:
# check length
adata.obs['LEN'] = adata.obs.index.to_series().apply(len)
sc.pl.umap(adata, color=['LEN'], vmin=12, vmax=20)

In [None]:
# save the current data
adata.write('../outs/adata.tra.h5ad')

##### Example sampling

In [None]:
# report statistics, averages
means = adata.X.mean(0); means

In [None]:
# report statistics, standard deviations
stds = adata.X.std(0); stds

In [None]:
# define a function to embed an amino acid with vocab only
map_direct = {x:[1 * (x == y) for y in vocab] for x in vocab}
def embed_aa(aa):
    return [x for x in map_direct[aa]]
# define a function to interpolate the protein
targ_len = 48
def stretch_pep(embedding, targ_len=targ_len):
    # get the current protein length
    orig_len, n_features = embedding.shape
    # derive the original and current lengths
    x = np.linspace(0, 1, targ_len)
    xp = np.linspace(0, 1, orig_len)
    # loop through each of the columns
    tensor = torch.Tensor(embedding.T.reshape(1, n_features, orig_len))
    res = torch.nn.functional.interpolate(tensor, size=(targ_len), mode='linear', align_corners=False)[0]
    return res

In [None]:
from Levenshtein import distance as levenshtein
# pick a random example to demonstrate forward fluidity
torch.manual_seed(0); np.random.seed(0)
idxs = np.random.choice(range(len(tras_test)), size=5, replace=False)
for idx in idxs:
    # retrieve the sequence of interest
    recon_batch = recon_batchs[idx]
    recon_len = recon_lens[idx]
    # retrieve the model resolved sequence and truth
    pred = pd.DataFrame(recon_batch, index=vocab)
    true_seq = tras_test.iloc[idx]
    true = pd.DataFrame(stretch_pep(np.array([embed for embed in map(embed_aa, list(true_seq))])).numpy(), index=vocab)
    # plot the predicted vs. truth
    fig, ax = plt.subplots(figsize=[8, 4]); ax.grid(False)
    sns.heatmap(pred, xticklabels=0, yticklabels=1, cmap='PuBu')
    fig, ax = plt.subplots(figsize=[8, 4]); ax.grid(False)
    sns.heatmap(true, xticklabels=0, yticklabels=1, cmap='PuBu')
    # highlight any discrepancies
    fig, ax = plt.subplots(figsize=[8, 4]); ax.grid(False)
    sns.heatmap(pred-true, xticklabels=0, yticklabels=1, cmap='coolwarm', vmin=-1, vmax=1)
    # resolve the sequence
    curr_len = 48; targ_len = round(recon_len)
    xp = np.arange(curr_len) / (curr_len - 1)
    x = np.arange(targ_len) / (targ_len - 1)
    # interpolate the results
    res = np.array([np.interp(x, xp, recon_batch[idx, :]) for idx in range(recon_batch.shape[0])])
    pred_seq = ''.join(pd.DataFrame(res.T, columns=vocab).idxmax(1))
    print(true_seq, pred_seq, levenshtein(true_seq, pred_seq))

In [None]:
# as expected each are around 0 to 1 and thus we may sample similarly
torch.manual_seed(0); np.random.seed(0)
z = torch.Tensor([np.random.normal(loc=0, scale=1, size=1)[0] for idx in range(32)]).to('cuda')
tmp_out, tmp_len = model.decode(z)
tmp_out = tmp_out.clone().detach().cpu().numpy()
tmp_len = tmp_len.clone().detach().cpu().numpy()
tmp_out = tmp_out[0]; tmp_len = round(tmp_len[0][0])

In [None]:
# resolve the sequence with a deduction system
fig, ax = plt.subplots(figsize=[8, 4]); ax.grid(False)
data = pd.DataFrame(tmp_out.T, columns=vocab).T
sns.heatmap(data, xticklabels=0, yticklabels=1, cmap='PuBu')

In [None]:
# give examples if we had randomly chosen a "true sequence"
torch.manual_seed(0); np.random.seed(0)
idxs = np.random.choice(range(len(tras_test)), size=5, replace=False)
for idx in idxs:
    # retrieve the sequence of interest
    recon_batch = recon_batchs[idx]
    recon_len = recon_lens[idx]
    # retrieve the model resolved sequence and truth
    pred = pd.DataFrame(recon_batch, index=vocab)
    true_seq = np.random.choice(tras_test, size=1)[0]
    true = pd.DataFrame(stretch_pep(np.array([embed for embed in map(embed_aa, list(true_seq))])).numpy(), index=vocab)
    # plot the predicted vs. truth
    fig, ax = plt.subplots(figsize=[8, 4]); ax.grid(False)
    sns.heatmap(pred, xticklabels=0, yticklabels=1, cmap='PuBu')
    fig, ax = plt.subplots(figsize=[8, 4]); ax.grid(False)
    sns.heatmap(true, xticklabels=0, yticklabels=1, cmap='PuBu')
    # highlight any discrepancies
    fig, ax = plt.subplots(figsize=[8, 4]); ax.grid(False)
    sns.heatmap(pred-true, xticklabels=0, yticklabels=1, cmap='coolwarm', vmin=-1, vmax=1)
    # resolve the sequence
    curr_len = 48; targ_len = round(recon_len)
    xp = np.arange(curr_len) / (curr_len - 1)
    x = np.arange(targ_len) / (targ_len - 1)
    # interpolate the results
    res = np.array([np.interp(x, xp, recon_batch[idx, :]) for idx in range(recon_batch.shape[0])])
    pred_seq = ''.join(pd.DataFrame(res.T, columns=vocab).idxmax(1))
    print(true_seq, pred_seq, levenshtein(true_seq, pred_seq))

In [None]:
# give examples if we had randomly chosen a "true sequence"
torch.manual_seed(0); np.random.seed(0)
idxs = np.random.choice(range(len(tras_test)), size=100, replace=False)
# track distances
df_dist = pd.DataFrame(columns=['dist_to_truth','dist_to_rand'])
for idx in idxs:
    # retrieve the sequence of interest
    recon_batch = recon_batchs[idx]
    recon_len = recon_lens[idx]
    # retrieve the model resolved sequence and truth
    pred = pd.DataFrame(recon_batch, index=vocab)
    true_seq = tras_test.iloc[idx]
    true = pd.DataFrame(stretch_pep(np.array([embed for embed in map(embed_aa, list(true_seq))])).numpy(), index=vocab)
    # resolve the sequence
    curr_len = 48; targ_len = round(recon_len)
    xp = np.arange(curr_len) / (curr_len - 1)
    x = np.arange(targ_len) / (targ_len - 1)
    # interpolate the results
    res = np.array([np.interp(x, xp, recon_batch[idx, :]) for idx in range(recon_batch.shape[0])])
    pred_seq = ''.join(pd.DataFrame(res.T, columns=vocab).idxmax(1))
    dist_from_truth = levenshtein(true_seq, pred_seq); dist_from_rand = 0
    for _ in range(10):
        dist_from_rand += levenshtein(np.random.choice(tras_test)[0], pred_seq)
    df_dist.loc[df_dist.shape[0]] = dist_from_truth, dist_from_rand / 10

In [None]:
# compare prediction accuracy statistically
fig, ax = plt.subplots(figsize=[2, 4]); ax.grid(False)
sns.boxplot(x='variable', y='value', data=df_dist.melt(), ax=ax, saturation=1,
            linewidth=1.5, linecolor='dodgerblue', order=['dist_to_truth'], color='skyblue')
sns.boxplot(x='variable', y='value', data=df_dist.melt(), ax=ax, saturation=1,
            linewidth=1.5, linecolor='grey', order=['dist_to_rand'], color='lightgray')
# create artificial lines
# for idx in df_dist.index: ax.plot([0, 1], df_dist.loc[idx], color='k', alpha=0.1, lw=0.5, zorder=0, linestyle='--')
ax.tick_params(axis='x', labelrotation=90); ax.set_xticklabels(['Truth','Random']); ax.set_xlim(-1, 2)
ax.set(xlabel='Prediction vs.', ylabel='Levenshtein distance')
ss.wilcoxon(df_dist['dist_to_truth'], df_dist['dist_to_rand']),\
ss.mannwhitneyu(df_dist['dist_to_truth'], df_dist['dist_to_rand'])

In [None]:
# interpolate back to a sequence using the length
curr_len = 48; targ_len = tmp_len
# compute the x-coordinates of the original
xp = np.arange(curr_len) / (curr_len - 1)
x = np.arange(targ_len) / (targ_len - 1)
# interpolate the results
res = np.array([np.interp(x, xp, tmp_out[idx, :]) for idx in range(tmp_out.shape[0])])

In [None]:
# resolve the sequence with a deduction system
fig, ax = plt.subplots(figsize=[8, 4]); ax.grid(False)
data = pd.DataFrame(res.T, columns=vocab).T
sns.heatmap(data, xticklabels=0, yticklabels=1, cmap='PuBu')

In [None]:
# reveal the sequence
''.join(data.idxmax(0))

In [None]:
# provide more examples
torch.manual_seed(0); np.random.seed(0)
for _ in range(100):
    z_in = [np.random.normal(loc=means[idx], scale=stds[idx], size=1)[0] for idx in range(32)]
    z = torch.Tensor(z_in).to('cuda')
    print(engine(model, z))

### Time Inference

In [None]:
import pickle as pkl
# initialize the model
model = ConvVAE().to(device)
# load the model
model.load_state_dict(torch.load('../models/model.convvae.trb.torch', weights_only=True))
model.eval()

import time
# test for a batch size of 1000
sizes = [10, 20, 50, 100, 200, 500, 1000, 2000, 5000, 10000, 20000, 50000, 100000, 200000, 500000, int(1e6)]
batch_size = 5000
df_time = pd.DataFrame(columns=['size','batch_size','seed','duration'])
# get the time for different input sizes
for size in tqdm(sizes):
    for seed in range(5):
        if df_time.loc[(df_time['size'] == size) & (df_time['batch_size'] == batch_size) & (df_time['seed'] == seed)].shape[0] > 0:
            continue
        # randomly split the data into train and test
        torch.manual_seed(seed); np.random.seed(seed)
        idxs = np.random.choice(range(len(X_trbs)), size=size, replace=True)
        # loop through the number of TRBs
        loader = DataLoader(dataset=TensorDataset(X_trbs[idxs]), batch_size=batch_size, shuffle=False)
        # time the inference
        start = time.time()
        for data in loader:
            data = data[0].to(device)
            enc_out = model.encode(data)
        duration = time.time() - start
        df_time.loc[df_time.shape[0]] = size, batch_size, seed, duration
# save the dataframe temporarily in case of crash in the next size
df_time.to_csv('../outs/trb_speed.csv')

In [None]:
# map the duration in seconds per size
fig, ax = plt.subplots(figsize=[6, 4]); ax.grid(False)
data = df_time.loc[df_time['size'].isin(sizes[3:])]
sns.barplot(x='size', y='duration', data=data, ci=95, errwidth=1.5, capsize=0.3, saturation=1,
            errcolor='dodgerblue', edgecolor='dodgerblue', color='skyblue', linewidth=1.5)
_ = ax.set_xticklabels([int(float(x.get_text())) for x in ax.get_xticklabels()])
ax.tick_params(axis='x', labelrotation=90)
ax.set_yscale('log')
ax.set(xlabel='Number of TCRs', ylabel='Inference Time (Seconds)')

In [None]:
# map the millions of TCRs inference times
fig, ax = plt.subplots(); ax.grid(False)
ax.plot(xs / 1e6, ys['mean'] / 60, color='dodgerblue', lw=2, zorder=2)
ax.fill_between(xs / 1e6, (ys['mean'] - ys['mean_se']*1.96) / 60,
               (ys['mean'] + ys['mean_se']*1.96) / 60, color='skyblue', lw=1.5, zorder=1, alpha=0.5)
ax.set(xlabel='Millions of TCRs', ylabel='Inference Time (Minutes)')

### Compare with TCRdist

In [None]:
import pwseqdist as pw
from scipy.spatial.distance import cdist, pdist, squareform
import pickle as pkl
# load the pickled data
with open('../external_data/results.tcr.pkl', 'rb') as f:
    results_tcr = pkl.load(f)

# read in the data
df = pd.read_csv('../outs/df.int.clean.csv', index_col=0)
a_trb = sc.read_h5ad('../outs/adata.trb.h5ad')

In [None]:
# create tracking objects
datas = []
methods = ['TCRdist','Tarpon']
epitopes = ['YLQPRTFLL','NLVPMVATV','TPRVTGGGAM','GILGFVFTL','GLCTLVAML','YVLDHLIVV','ELAGIGILTV','EAAGIGILTV',
            'SLLMWITQC','KLGGALQAK','AVFDRKSDAK','RAKFKQLL','IVTDFSVIK','LLWNGPMAV','SPRWYFYYL','TTDPSFLGRY',
            'RLRAEAQVK','LLLDRLNQL','LTDEMIAQY','CINGVCWTV','KTFPPTEPK','QYIKWPWYI','VMTTVLATL','DATYQRTRALVR',
            'NQKLIANQF','FLCMKALLL']

In [None]:
from tqdm import tqdm
# loop through epitopes
for epitope in tqdm(epitopes):
    # derive the binding TCRs
    np.random.seed(0)
    # grab the binding TCRs
    trbs = df['TRB'][df['AG'] == epitope].value_counts()
    if len(trbs) == 0:
        print(epitope)
        continue
    trbs = trbs.loc[trbs.index.isin(a_trb.obs.index)]
    # trbs = trbs.loc[np.random.choice(trbs, size=min(100, len(trbs)), replace=False)]
    # grab the irrelevant non-binding TCRs
    trbs_rand = df['TRB'][df['AG'] != epitope].value_counts()
    trbs_rand = trbs_rand.loc[trbs_rand.index.isin(a_trb.obs.index)]
    trbs_rand = trbs_rand.loc[~trbs_rand.index.isin(trbs.index)]
    trbs_rand = trbs_rand.loc[np.random.choice(trbs_rand.index, size=trbs.shape[0], replace=False)]
    # confirm and aggregate
    assert trbs.index.unique().shape[0] == trbs.index.shape[0]
    assert trbs_rand.index.unique().shape[0] == trbs_rand.index.shape[0]
    assert trbs.index.union(trbs_rand.index).shape[0] == (trbs.shape[0]+trbs_rand.shape[0])
    trbs_agg = trbs.index.union(trbs_rand.index)
    
    # compute the pairwise distance via tcrdist-like metrics and tarpon
    m2d = {}
    m2d['TCRdist'] = pw.apply_pairwise_rect(seqs1 = trbs_agg.tolist(),
                                            metric = pw.metrics.nb_vector_tcrdist, 
                                            ncpus = 5, use_numba = True, uniqify = True)
    m2d['Tarpon'] = squareform(pdist(a_trb[trbs_agg].X))
    
    # loop through methods
    for method in methods:
        # copy over data for usage
        dmet = m2d[method].copy().astype(float)
        for idx in range(dmet.shape[0]):
            dmet[idx, idx] = np.nan

        # extract each pairwise comparison between groups
        ag2rand = dmet[:dmet.shape[0]//2, dmet.shape[0]//2:].flatten()
        ag2rand = ag2rand[~np.isnan(ag2rand)]
        ag2ag = dmet[:dmet.shape[0]//2, :dmet.shape[0]//2].flatten()
        ag2ag = ag2ag[~np.isnan(ag2ag)]
        rand2rand = dmet[dmet.shape[0]//2:, dmet.shape[0]//2:].flatten()
        rand2rand = rand2rand[~np.isnan(rand2rand)]
        # convert to dataframes
        ag2rand = pd.DataFrame(ag2rand, columns=['dist']); ag2rand['comp'] = 'ag2rand'
        # aggregate and normalize
        data = pd.concat([ag2rand], axis=0)
        data['dist'] -= data['dist'].mean()
        data['dist'] /= data['dist'].std()
        # derive metrics
        data['epitope'] = epitope
        data['method'] = method
        datas.append(data)

In [None]:
# derive averages for distance to antigen vs. random
from tqdm import tqdm
df_stat = pd.DataFrame(columns=['mean_a2r','epitope','method'])
for data in tqdm(datas):
    # retrieve relevant metrics
    epitope = data['epitope'].iloc[0]
    method = data['method'].iloc[0]
    mean_a2r = data.loc[data['comp'] == 'ag2rand', 'dist'].mean()
    # insert into dataframe
    df_stat.loc[df_stat.shape[0]] = mean_a2r, epitope, method
# save the data
df_stat.to_csv('../outs/250425_compare_with_diffalgos.csv')
with open('../outs/250425_compare_with_diffalgos.pkl', 'wb') as f:
    pkl.dump(datas, f)

In [None]:
# compare the two relevant populations
df_stat['enrich'] = df_stat['mean_a2r'].copy()
fig, ax = plt.subplots(figsize=[1.5, 4]); ax.grid(False)
sns.barplot(x='method', y='enrich', data=df_stat, ax=ax,
            saturation=1, errcolor='dodgerblue', errwidth=1.5, order=['TCRdist','Tarpon'],
            linewidth=1.5, capsize=0.3, edgecolor='dodgerblue', color='skyblue')
np.random.seed(0)
sns.stripplot(x='method', y='enrich', data=df_stat, ax=ax, alpha=0.5, order=['TCRdist','Tarpon'],
              linewidth=1.5, edgecolor='dodgerblue', color='skyblue', jitter=0.35)
ax.tick_params(axis='x', labelrotation=90)
ax.set_ylabel('Normalized Distance Between\nAg-specific and Random TCRs')
ss.wilcoxon(df_stat.loc[df_stat['method'] == 'Tarpon', 'enrich'],
            df_stat.loc[df_stat['method'] == 'TCRdist', 'enrich'])