In [None]:
from matplotlib.cm import get_cmap
from matplotlib.colors import to_hex
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import scanpy as sc
from tqdm import tqdm
sc.settings.set_figure_params(dpi=100)

### Read in the Data

In [None]:
import pickle as pkl
# read in the aggregated values
with open('../external_data/db.ags.pkl', 'rb') as f: ags = pkl.load(f)
with open('../external_data/db.tras.pkl', 'rb') as f: tras = pkl.load(f)
with open('../external_data/db.trbs.pkl', 'rb') as f: trbs = pkl.load(f)
with open('../external_data/db.paired_tcrs.pkl', 'rb') as f: paired_tcrs = pkl.load(f)
ags, tras, trbs = pd.Series(ags), pd.Series(tras), pd.Series(trbs)

# trim the sequences
# cut all of the trbs by 1-4AA on each side
trbs4 = [x[4:-4] for x in trbs if len(x[4:-4]) > 0]
trbs3 = [x[3:-3] for x in trbs if len(x[3:-3]) > 0]
trbs2 = [x[2:-2] for x in trbs if len(x[2:-2]) > 0]
trbs1 = [x[1:-1] for x in trbs if len(x[1:-1]) > 0]
len(trbs4), len(trbs3), len(trbs2), len(trbs1)

### Perform Embedding

In [None]:
import blosum as bl
# perform encoding by direct, BCP, BLOSUM
vocab = ['A','C','D','E','F','G','H','I','K','L','M','N','P','Q','R','S','T','V','W','Y']
# direct encoding
map_direct = {x:[1 * (x == y) for y in vocab] for x in vocab}
# bcp encoding
aa_hydrophobicity = {
    'A': 1.8,  # Alanine
    'R': -4.5,  # Arginine
    'N': -3.5,  # Asparagine
    'D': -3.5,  # Aspartic Acid
    'C': 2.5,  # Cysteine
    'E': -3.5,  # Glutamic Acid
    'Q': -3.5,  # Glutamine
    'G': -0.4,  # Glycine
    'H': -3.2,  # Histidine
    'I': 4.5,  # Isoleucine
    'L': 3.8,  # Leucine
    'K': -3.9,  # Lysine
    'M': 1.9,  # Methionine
    'F': 2.8,  # Phenylalanine
    'P': -1.6,  # Proline
    'S': -0.8,  # Serine
    'T': -0.7,  # Threonine
    'W': -0.9,  # Tryptophan
    'Y': -1.3,  # Tyrosine
    'V': 4.2,  # Valine
}
# https://www.imgt.org/IMGTeducation/Aide-memoire/_UK/aminoacids/IMGTclasses.html
aa_volume = {
    'A': 88.6,   # Alanine
    'R': 173.4,  # Arginine
    'N': 114.1,  # Asparagine
    'D': 111.1,  # Aspartic Acid
    'C': 108.5,  # Cysteine
    'E': 138.4,  # Glutamic Acid
    'Q': 143.8,  # Glutamine
    'G': 60.1,   # Glycine
    'H': 153.2,  # Histidine
    'I': 166.7,  # Isoleucine
    'L': 166.7,  # Leucine
    'K': 168.6,  # Lysine
    'M': 162.9,  # Methionine
    'F': 189.9,  # Phenylalanine
    'P': 112.7,  # Proline
    'S': 89.0,   # Serine
    'T': 116.1,  # Threonine
    'W': 227.8,  # Tryptophan
    'Y': 193.6,  # Tyrosine
    'V': 140.0,  # Valine
}
# 1 = donor and acceptor, 0.5 = only donor or acceptor
aa_hbond = {
    'A': 0,    # Alanine
    'R': 0.5,  # Arginine
    'N': 1,    # Asparagine
    'D': 0.5,  # Aspartic Acid
    'C': 0,    # Cysteine
    'E': 0.5,  # Glutamic Acid
    'Q': 1,    # Glutamine
    'G': 0,    # Glycine
    'H': 1,    # Histidine
    'I': 0,    # Isoleucine
    'L': 0,    # Leucine
    'K': 0.5,  # Lysine
    'M': 0,    # Methionine
    'F': 0,    # Phenylalanine
    'P': 0,    # Proline
    'S': 1,    # Serine
    'T': 1,    # Threonine
    'W': 0.5,  # Tryptophan
    'Y': 1,    # Tyrosine
    'V': 0,    # Valine
}
has_sulfur = ['C','M']
is_aromatic = ['F','Y','W']
is_aliphatic = ['A','G','I','L','P','V']
is_basic = ['R','H','K']
is_acidic = ['D','E']
has_amide = ['N','Q']
vocab_bcp = ['hydrophobicity','volume','hbond','has_sulfur','is_aromatic',
             'is_aliphatic','is_basic','is_acidic','has_amide']
# > normalize the data for both volume and charge
vmin, vmax = min(list(aa_volume.values())), max(list(aa_volume.values()))
aa_volume = {k:(v-vmin)/(vmax-vmin) for k,v in aa_volume.items()}
vmax = max(abs(np.array(list(aa_hydrophobicity.values()))))
aa_hydrophobicity = {k:v/vmax for k,v in aa_hydrophobicity.items()}
# > define a method to return the embedding for a given amino acid in BCP space
def bcp_translation(aa):
    embedding = []
    embedding.append(aa_hydrophobicity[aa])
    embedding.append(aa_volume[aa])
    embedding.append(aa_hbond[aa])
    embedding.append(1 * (aa in has_sulfur))
    embedding.append(1 * (aa in is_aromatic))
    embedding.append(1 * (aa in is_aliphatic))
    embedding.append(1 * (aa in is_basic))
    embedding.append(1 * (aa in is_acidic))
    embedding.append(1 * (aa in has_amide))
    return embedding
map_bcp = {x:bcp_translation(x) for x in vocab}
map_blosum = {x:[bl.BLOSUM(62)[x][y] / 5 for y in vocab] for x in vocab}

In [None]:
# define a function to embed an amino acid with direct, bcp, blosum, and length
def embed_aa(aa):
    embed = [x for x in map_direct[aa]]
    embed += map_bcp[aa]
    embed += map_blosum[aa]
    embed += [0]
    return embed

In [None]:
from tqdm import tqdm
import torch
# define the number of samples per case
n_samples = 100
# define the number of lengths to test
targ_lens = range(5, 101, 5); mses = []
for targ_len in targ_lens:
    # set seed for reproducibility
    np.random.seed(0)
    # track the MSEs
    mse = 0
    for sequence in np.random.choice(trbs4, size=n_samples, replace=False):
        # retrieve the original length
        orig_len = len(sequence)
        # retrieve the embedding
        embedding = np.array([embed for embed in map(embed_aa, list(sequence))])
        tensor = torch.Tensor(embedding.T.reshape(1, 50, orig_len))
        res = torch.nn.functional.interpolate(tensor, size=(targ_len), mode='linear', align_corners=False)[0].T
        res_p = torch.nn.functional.interpolate(res.T.view((1, 50, targ_len)), size=(orig_len), mode='linear', align_corners=False)[0].T
        mse += (res_p - embedding).pow(2).sum()
    mses.append(torch.sqrt(mse / n_samples))
# create the elbow like plot
fig, ax = plt.subplots(); ax.grid(False)
ax.scatter(targ_lens, mses, edgecolor='dodgerblue', color='skyblue', lw=1.5)
ax.set(xlabel='TRB (4AA trimming) stretch length', ylabel='Average reconstruction loss')
trb_mses = [x for x in mses]

In [None]:
# define the number of lengths to test
mses = []
for targ_len in targ_lens:
    # set seed for reproducibility
    np.random.seed(0)
    # track the MSEs
    mse = 0
    for sequence in np.random.choice(trbs3, size=n_samples, replace=False):
        # retrieve the original length
        orig_len = len(sequence)
        # retrieve the embedding
        embedding = np.array([embed for embed in map(embed_aa, list(sequence))])
        tensor = torch.Tensor(embedding.T.reshape(1, 50, orig_len))
        res = torch.nn.functional.interpolate(tensor, size=(targ_len), mode='linear', align_corners=False)[0].T
        res_p = torch.nn.functional.interpolate(res.T.view((1, 50, targ_len)), size=(orig_len), mode='linear', align_corners=False)[0].T
        mse += (res_p - embedding).pow(2).sum()
    mses.append(torch.sqrt(mse / n_samples))
# create the elbow like plot
fig, ax = plt.subplots(); ax.grid(False)
ax.scatter(targ_lens, mses, edgecolor='dodgerblue', color='skyblue', lw=1.5)
ax.set(xlabel='TRB (3AA trimming) stretch length', ylabel='Average reconstruction loss')
trb_mses = [x for x in mses]

In [None]:
# define the number of lengths to test
mses = []
for targ_len in targ_lens:
    # set seed for reproducibility
    np.random.seed(0)
    # track the MSEs
    mse = 0
    for sequence in np.random.choice(trbs2, size=n_samples, replace=False):
        # retrieve the original length
        orig_len = len(sequence)
        # retrieve the embedding
        embedding = np.array([embed for embed in map(embed_aa, list(sequence))])
        tensor = torch.Tensor(embedding.T.reshape(1, 50, orig_len))
        res = torch.nn.functional.interpolate(tensor, size=(targ_len), mode='linear', align_corners=False)[0].T
        res_p = torch.nn.functional.interpolate(res.T.view((1, 50, targ_len)), size=(orig_len), mode='linear', align_corners=False)[0].T
        mse += (res_p - embedding).pow(2).sum()
    mses.append(torch.sqrt(mse / n_samples))
# create the elbow like plot
fig, ax = plt.subplots(); ax.grid(False)
ax.scatter(targ_lens, mses, edgecolor='dodgerblue', color='skyblue', lw=1.5)
ax.set(xlabel='TRB (2AA trimming) stretch length', ylabel='Average reconstruction loss')
trb_mses = [x for x in mses]

In [None]:
# define the number of lengths to test
mses = []
for targ_len in targ_lens:
    # set seed for reproducibility
    np.random.seed(0)
    # track the MSEs
    mse = 0
    for sequence in np.random.choice(trbs1, size=n_samples, replace=False):
        # retrieve the original length
        orig_len = len(sequence)
        # retrieve the embedding
        embedding = np.array([embed for embed in map(embed_aa, list(sequence))])
        tensor = torch.Tensor(embedding.T.reshape(1, 50, orig_len))
        res = torch.nn.functional.interpolate(tensor, size=(targ_len), mode='linear', align_corners=False)[0].T
        res_p = torch.nn.functional.interpolate(res.T.view((1, 50, targ_len)), size=(orig_len), mode='linear', align_corners=False)[0].T
        mse += (res_p - embedding).pow(2).sum()
    mses.append(torch.sqrt(mse / n_samples))
# create the elbow like plot
fig, ax = plt.subplots(); ax.grid(False)
ax.scatter(targ_lens, mses, edgecolor='dodgerblue', color='skyblue', lw=1.5)
ax.set(xlabel='TRB (1AA trimming) stretch length', ylabel='Average reconstruction loss')
trb_mses = [x for x in mses]

In [None]:
# we therefore settle on a stretch length
targ_len = 48

In [None]:
import torch
# define a function to interpolate the protein
def stretch_pep(embedding, targ_len=targ_len):
    # get the current protein length
    orig_len, n_features = embedding.shape
    # derive the original and current lengths
    x = np.linspace(0, 1, targ_len)
    xp = np.linspace(0, 1, orig_len)
    # loop through each of the columns
    tensor = torch.Tensor(embedding.T.reshape(1, n_features, orig_len))
    res = torch.nn.functional.interpolate(tensor, size=(targ_len), mode='linear', align_corners=False)[0]
    # add an the extra length information
    res[-1, :] = orig_len
    return res

In [None]:
from tqdm import tqdm
import pickle as pkl
# process the TRBs
trb_to_embed = {}
for sequence in tqdm(trbs4):
    # retrieve the embedding
    embedding = np.array([embed for embed in map(embed_aa, list(sequence))])
    # stretch the embedding
    embedding = stretch_pep(embedding, targ_len=targ_len)
    # save the embedding
    trb_to_embed[sequence] = embedding
# save the embedding maps
with open('../outs/map.trb_to_embed.extended_4aatrim.pkl', 'wb') as f: pkl.dump(trb_to_embed, f)

In [None]:
# process the TRBs
trb_to_embed = {}
for sequence in tqdm(trbs3):
    # retrieve the embedding
    embedding = np.array([embed for embed in map(embed_aa, list(sequence))])
    # stretch the embedding
    embedding = stretch_pep(embedding, targ_len=targ_len)
    # save the embedding
    trb_to_embed[sequence] = embedding
# save the embedding maps
with open('../outs/map.trb_to_embed.extended_3aatrim.pkl', 'wb') as f: pkl.dump(trb_to_embed, f)

In [None]:
# process the TRBs
trb_to_embed = {}
for sequence in tqdm(trbs2):
    # retrieve the embedding
    embedding = np.array([embed for embed in map(embed_aa, list(sequence))])
    # stretch the embedding
    embedding = stretch_pep(embedding, targ_len=targ_len)
    # save the embedding
    trb_to_embed[sequence] = embedding
# save the embedding maps
with open('../outs/map.trb_to_embed.extended_2aatrim.pkl', 'wb') as f: pkl.dump(trb_to_embed, f)

In [None]:
# process the TRBs
trb_to_embed = {}
for sequence in tqdm(trbs1):
    # retrieve the embedding
    embedding = np.array([embed for embed in map(embed_aa, list(sequence))])
    # stretch the embedding
    embedding = stretch_pep(embedding, targ_len=targ_len)
    # save the embedding
    trb_to_embed[sequence] = embedding
# save the embedding maps
with open('../outs/map.trb_to_embed.extended_1aatrim.pkl', 'wb') as f: pkl.dump(trb_to_embed, f)

In [None]:
# write the TRBs
with open('../outs/trbs.4aa.pkl', 'wb') as f: pkl.dump(trbs4, f)
with open('../outs/trbs.3aa.pkl', 'wb') as f: pkl.dump(trbs3, f)
with open('../outs/trbs.2aa.pkl', 'wb') as f: pkl.dump(trbs2, f)
with open('../outs/trbs.1aa.pkl', 'wb') as f: pkl.dump(trbs1, f)

### Model Running

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

In [None]:
def train(epoch, loss_func):
    model.train()
    train_loss = 0
    for batch_idx, data in enumerate(train_loader):
        data = data[0].to(device)
        optimizer.zero_grad()
        (recon_batch, recon_len), mu, logvar = model(data)
        if loss_func == 1:
            loss = loss_function1(recon_batch, recon_len, data, mu, logvar)
        elif loss_func == 2:
            loss = loss_function2(recon_batch, recon_len, data, mu, logvar)
        elif loss_func == 3:
            loss = loss_function3(recon_batch, recon_len, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(train_loader.dataset)))
    return train_loss / len(train_loader.dataset)
    
def test(epoch, loss_func):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i, data in enumerate(test_loader):
            data = data[0].to(device)
            (recon_batch, recon_len), mu, logvar = model(data)
            if loss_func == 1:
                test_loss += loss_function1(recon_batch, recon_len, data, mu, logvar).item()
            elif loss_func == 2:
                test_loss += loss_function2(recon_batch, recon_len, data, mu, logvar).item()
            elif loss_func == 3:
                test_loss += loss_function3(recon_batch, recon_len, data, mu, logvar).item()

    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))
    return test_loss

In [None]:
# define the key parameters
init_embed_size = 50-1
protein_len = 48
init_kernel_size = 3
init_cnn_filters = 256
init_kernel_stride = 1
init_kernel_padding = 1
secn_cnn_filters = 256
latent_dim = 32
vocab = ['A','C','D','E','F','G','H','I','K','L','M','N','P','Q','R','S','T','V','W','Y']
# we want the embedding output to be the vocab with the length to allow for reconstruction
out_embed_size = len(vocab)
n_nodes_len = 32

# define the convolutional variational autoencoder
class ConvVAE(nn.Module):
    def __init__(self):
        super(ConvVAE, self).__init__()

        # encoding
        self.fc1 = nn.Conv1d(
            in_channels=init_embed_size, out_channels=init_cnn_filters, kernel_size=init_kernel_size, 
            stride=init_kernel_stride, padding=init_kernel_padding,
        )
        self.fc2 = nn.Conv1d(
            in_channels=init_cnn_filters, out_channels=secn_cnn_filters, kernel_size=init_kernel_size, 
            stride=init_kernel_stride, padding=init_kernel_padding
        )
        # variational sampling
        self.fc31 = nn.Linear(secn_cnn_filters*protein_len, latent_dim)
        self.fc32 = nn.Linear(secn_cnn_filters*protein_len, latent_dim)
        self.fc4 = nn.Linear(latent_dim, secn_cnn_filters*protein_len)
        # decoding
        self.fc5 = nn.ConvTranspose1d(
            in_channels=secn_cnn_filters, out_channels=init_cnn_filters, kernel_size=init_kernel_size, 
            stride=init_kernel_stride, padding=init_kernel_padding
        )
        self.fc6 = nn.ConvTranspose1d(
            in_channels=init_cnn_filters, out_channels=out_embed_size, kernel_size=init_kernel_size, 
            stride=init_kernel_stride, padding=init_kernel_padding
        )
        self.fc7 = nn.Linear(init_cnn_filters*protein_len, n_nodes_len)
        self.fc8 = nn.Linear(n_nodes_len, 1)

    def encode(self, x):
        x1 = nn.LeakyReLU()(self.fc1(x[:, :-1, :]))
        x2 = nn.LeakyReLU()(self.fc2(x1))
        x2_ = nn.Flatten()(x2)
        return self.fc31(x2_), self.fc32(x2_)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        x4 = nn.LeakyReLU()(self.fc4(z))
        x4_ = x4.view(-1, secn_cnn_filters, protein_len)
        x5 = nn.LeakyReLU()(self.fc5(x4_))
        x5_ = nn.Flatten()(x5)
        x6 = nn.Sigmoid()(self.fc6(x5))
        return x6, self.fc8(nn.LeakyReLU()(self.fc7(x5_)))

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

In [None]:
# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function1(recon_x, recon_len, x, mu, logvar):
    # get the data
    BCE = nn.functional.binary_cross_entropy(recon_x, x[:, :len(vocab), :], reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD
def loss_function2(recon_x, recon_len, x, mu, logvar):
    # get the data
    TSE = (recon_len - x[:, -1, :]).pow(2).sum()
    return TSE
def loss_function3(recon_x, recon_len, x, mu, logvar):
    # get the data
    BCE = nn.functional.binary_cross_entropy(recon_x, x[:, :len(vocab), :], reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    TSE = (recon_len - x[:, -1, :]).pow(2).sum()
    return BCE + KLD + TSE

### TRB Trimming 4AAs

In [None]:
import pickle as pkl
# read in the aggregated values
with open('../outs/trbs.4aa.pkl', 'rb') as f: trbs = pkl.load(f)
trbs = pd.Series(list(set(trbs))); print(len(trbs))
# embed all of our unique TRBs
with open('../outs/map.trb_to_embed.extended_4aatrim.pkl', 'rb') as f: trb_to_embed = pkl.load(f)
X_trbs = torch.stack([x.to(torch.float32) for x in trbs.map(trb_to_embed)])
# randomly split the data into train and test
torch.manual_seed(0); np.random.seed(0)
idxs_train = np.random.choice(range(len(X_trbs)), size=round(len(X_trbs)*0.75), replace=False)
idxs_test = np.array(range(len(X_trbs)))
idxs_test = idxs_test[~np.isin(idxs_test, idxs_train)]
X_trbs_train = X_trbs[idxs_train]
X_trbs_test = X_trbs[idxs_test]

from torch.utils.data import TensorDataset, DataLoader
# create a latent space for the trbs
batch_size = 2048
train_loader = DataLoader(dataset=TensorDataset(X_trbs_train), batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=TensorDataset(X_trbs_test), batch_size=batch_size, shuffle=False)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# initialize the model
model = ConvVAE().to(device)

# set the seed for training
torch.manual_seed(0); np.random.seed(0)
# set the learning parameters
lr = 0.0005; epochs = 20
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
# train the model with dual losses to balance between two objectives
train_losses = []; test_losses = []
for epoch in range(1, epochs + 1):
    train_losses_, test_losses_ = [], []
    for loss_func in [1, 2]:
        train_losses_.append(train(epoch, loss_func))
        test_losses_.append(test(epoch, loss_func))
    train_losses.append(train_losses_)
    test_losses.append(test_losses_)
    
# examine the KLD and BCE loss
fig, ax = plt.subplots(figsize=[4, 4]); ax.grid(False)
ax.plot([x[0] for x in train_losses], color='dodgerblue')
ax.plot([x[0] for x in test_losses], color='skyblue', linestyle='--')
ax.set(xlabel='Epochs', ylabel='KLD + BCE Loss')
# examine the length loss
fig, ax = plt.subplots(figsize=[4, 4]); ax.grid(False)
ax.plot([x[1] for x in train_losses], color='dodgerblue')
ax.plot([x[1] for x in test_losses], color='skyblue', linestyle='--')
ax.set(xlabel='Epochs', ylabel='TSE (LEN) Loss')# set the seed for training
torch.manual_seed(0); np.random.seed(0)

# set the learning parameters
lr = 0.001; epochs = 40
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
# train the model with an integrated loss
train_losses = []; test_losses = []; epochs
for epoch in range(1, epochs + 1):
    train_losses_, test_losses_ = [], []
    for loss_func in [3]:
        train_losses_.append(train(epoch, loss_func))
        test_losses_.append(test(epoch, loss_func))
    train_losses.append(train_losses_)
    test_losses.append(test_losses_)

# examine the integrated loss from finetuning
fig, ax = plt.subplots(figsize=[4, 4]); ax.grid(False)
ax.plot([x[0] for x in train_losses], color='dodgerblue')
ax.plot([x[0] for x in test_losses], color='skyblue', linestyle='--')
ax.set(xlabel='Epochs', ylabel='KLD + BCE + TSE Loss')

# retrieve the predictions
recon_lens = []; recon_batchs = []
with torch.no_grad():
    for i, data in enumerate(test_loader):
        data = data[0].to(device)
        (recon_batch, recon_len), _, _ = model(data)
        recon_batchs.extend(recon_batch.clone().detach().cpu().numpy())
        recon_lens.extend(recon_len.clone().detach().cpu().tolist())
recon_lens = [x[0] for x in recon_lens]
# retrieve the indices
trbs_test = pd.Series(trbs.iloc[idxs_test])

# there is an ability to reconstruct proper length
fig, ax = plt.subplots(); ax.grid(False)
ax.scatter(recon_lens, trbs_test.apply(len), s=10, alpha=0.1, zorder=1, color='dodgerblue')
xlim, ylim = ax.get_xlim(), ax.get_ylim()
domain = [min(min(xlim), min(ylim)), max(max(xlim), max(ylim))]
ax.plot(domain, domain, color='skyblue', linestyle='--', zorder=0)
ax.set_xlim(*xlim); ax.set_ylim(*ylim)
ax.set(xlabel='Predicted length', ylabel='True length')
print('Non-Rounded', ss.pearsonr(recon_lens, trbs_test.apply(len)))
# with rounding the picture becomes more clear, there is a linear relationship but not an exact one
fig, ax = plt.subplots(); ax.grid(False)
ax.scatter(pd.Series(recon_lens).apply(round), trbs_test.apply(len), s=10, alpha=0.1, zorder=1, color='dodgerblue')
xlim, ylim = ax.get_xlim(), ax.get_ylim()
domain = [min(min(xlim), min(ylim)), max(max(xlim), max(ylim))]
ax.plot(domain, domain, color='skyblue', linestyle='--', zorder=0)
ax.set_xlim(*xlim); ax.set_ylim(*ylim)
ax.set(xlabel='Predicted length', ylabel='True length')
print('Rounded', ss.pearsonr(recon_lens, trbs_test.apply(len)))

# save the model
torch.save(model.state_dict(), '../models/model.convvae_4aa.trb.torch')
model.eval()

from tqdm import tqdm
# get the encoded dimensions
torch.manual_seed(0); np.random.seed(0)
# create a complete loader or else there is an out of memory error
complete_loader = DataLoader(dataset=TensorDataset(X_trbs), batch_size=batch_size, shuffle=False)
# move through each subset in the complete loader
z_dims_per_batch = []
with torch.no_grad():
    for data in tqdm(complete_loader):
        data = data[0].to(device)
        enc_out = model.encode(data)
        # sampling centers around the mean so we just use mu
        z_dims_per_batch.append(enc_out[0].clone().detach().cpu().numpy())
z_dims = np.vstack(z_dims_per_batch)
del data, enc_out, z_dims_per_batch

from anndata import AnnData
from scipy.sparse import csr_matrix
adata = AnnData(z_dims, dtype=float)
sc.tl.pca(adata)
sc.pp.neighbors(adata, use_rep='X')
sc.tl.umap(adata)
sc.tl.leiden(adata, flavor='igraph', n_iterations=2)
adata.obs_names = trbs
# check length
adata.obs['LEN'] = adata.obs.index.to_series().apply(len)
sc.pl.umap(adata, color=['LEN'], vmin=12, vmax=20)

# save the current data
adata.write('../outs/adata.trb_4aa.h5ad')

# define a function to embed an amino acid with vocab only
map_direct = {x:[1 * (x == y) for y in vocab] for x in vocab}
def embed_aa(aa):
    return [x for x in map_direct[aa]]
# define a function to interpolate the protein
targ_len = 48
def stretch_pep(embedding, targ_len=targ_len):
    # get the current protein length
    orig_len, n_features = embedding.shape
    # derive the original and current lengths
    x = np.linspace(0, 1, targ_len)
    xp = np.linspace(0, 1, orig_len)
    # loop through each of the columns
    tensor = torch.Tensor(embedding.T.reshape(1, n_features, orig_len))
    res = torch.nn.functional.interpolate(tensor, size=(targ_len), mode='linear', align_corners=False)[0]
    return res

from Levenshtein import distance as levenshtein
# give examples if we had randomly chosen a "true sequence"
torch.manual_seed(0); np.random.seed(0)
idxs = np.random.choice(range(len(trbs_test)), size=100, replace=False)
# trbck distances
df_dist = pd.DataFrame(columns=['dist_to_truth','dist_to_rand'])
for idx in idxs:
    # retrieve the sequence of interest
    recon_batch = recon_batchs[idx]
    recon_len = recon_lens[idx]
    # retrieve the model resolved sequence and truth
    pred = pd.DataFrame(recon_batch, index=vocab)
    true_seq = trbs_test.iloc[idx]
    true = pd.DataFrame(stretch_pep(np.array([embed for embed in map(embed_aa, list(true_seq))])).numpy(), index=vocab)
    # resolve the sequence
    curr_len = 48; targ_len = round(recon_len)
    xp = np.arange(curr_len) / (curr_len - 1)
    x = np.arange(targ_len) / (targ_len - 1)
    # interpolate the results
    res = np.array([np.interp(x, xp, recon_batch[idx, :]) for idx in range(recon_batch.shape[0])])
    pred_seq = ''.join(pd.DataFrame(res.T, columns=vocab).idxmax(1))
    dist_from_truth = levenshtein(true_seq, pred_seq); dist_from_rand = 0
    for _ in range(10):
        dist_from_rand += levenshtein(np.random.choice(trbs_test)[0], pred_seq)
    df_dist.loc[df_dist.shape[0]] = dist_from_truth, dist_from_rand / 10
    
# compare prediction accuracy statistically
fig, ax = plt.subplots(figsize=[2, 4]); ax.grid(False)
sns.boxplot(x='variable', y='value', data=df_dist.melt(), ax=ax, saturation=1,
            linewidth=1.5, linecolor='dodgerblue', order=['dist_to_truth'], color='skyblue')
sns.boxplot(x='variable', y='value', data=df_dist.melt(), ax=ax, saturation=1,
            linewidth=1.5, linecolor='grey', order=['dist_to_rand'], color='lightgray')
# create artificial lines
# for idx in df_dist.index: ax.plot([0, 1], df_dist.loc[idx], color='k', alpha=0.1, lw=0.5, zorder=0, linestyle='--')
ax.tick_params(axis='x', labelrotation=90); ax.set_xticklabels(['Truth','Random']); ax.set_xlim(-1, 2)
ax.set(xlabel='Prediction vs.', ylabel='Levenshtein distance')
print(ss.wilcoxon(df_dist['dist_to_truth'], df_dist['dist_to_rand']))
print(ss.mannwhitneyu(df_dist['dist_to_truth'], df_dist['dist_to_rand']))

### TRB Trimming 3AAs

In [None]:
import pickle as pkl
# read in the aggregated values
with open('../outs/trbs.3aa.pkl', 'rb') as f: trbs = pkl.load(f)
trbs = pd.Series(list(set(trbs))); print(len(trbs))
# embed all of our unique TRBs
with open('../outs/map.trb_to_embed.extended_3aatrim.pkl', 'rb') as f: trb_to_embed = pkl.load(f)
X_trbs = torch.stack([x.to(torch.float32) for x in trbs.map(trb_to_embed)])
# randomly split the data into train and test
torch.manual_seed(0); np.random.seed(0)
idxs_train = np.random.choice(range(len(X_trbs)), size=round(len(X_trbs)*0.75), replace=False)
idxs_test = np.array(range(len(X_trbs)))
idxs_test = idxs_test[~np.isin(idxs_test, idxs_train)]
X_trbs_train = X_trbs[idxs_train]
X_trbs_test = X_trbs[idxs_test]

from torch.utils.data import TensorDataset, DataLoader
# create a latent space for the trbs
batch_size = 2048
train_loader = DataLoader(dataset=TensorDataset(X_trbs_train), batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=TensorDataset(X_trbs_test), batch_size=batch_size, shuffle=False)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# initialize the model
model = ConvVAE().to(device)

# set the seed for training
torch.manual_seed(0); np.random.seed(0)
# set the learning parameters
lr = 0.0005; epochs = 20
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
# train the model with dual losses to balance between two objectives
train_losses = []; test_losses = []
for epoch in range(1, epochs + 1):
    train_losses_, test_losses_ = [], []
    for loss_func in [1, 2]:
        train_losses_.append(train(epoch, loss_func))
        test_losses_.append(test(epoch, loss_func))
    train_losses.append(train_losses_)
    test_losses.append(test_losses_)
    
# examine the KLD and BCE loss
fig, ax = plt.subplots(figsize=[4, 4]); ax.grid(False)
ax.plot([x[0] for x in train_losses], color='dodgerblue')
ax.plot([x[0] for x in test_losses], color='skyblue', linestyle='--')
ax.set(xlabel='Epochs', ylabel='KLD + BCE Loss')
# examine the length loss
fig, ax = plt.subplots(figsize=[4, 4]); ax.grid(False)
ax.plot([x[1] for x in train_losses], color='dodgerblue')
ax.plot([x[1] for x in test_losses], color='skyblue', linestyle='--')
ax.set(xlabel='Epochs', ylabel='TSE (LEN) Loss')# set the seed for training
torch.manual_seed(0); np.random.seed(0)

# set the learning parameters
lr = 0.001; epochs = 40
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
# train the model with an integrated loss
train_losses = []; test_losses = []; epochs
for epoch in range(1, epochs + 1):
    train_losses_, test_losses_ = [], []
    for loss_func in [3]:
        train_losses_.append(train(epoch, loss_func))
        test_losses_.append(test(epoch, loss_func))
    train_losses.append(train_losses_)
    test_losses.append(test_losses_)

# examine the integrated loss from finetuning
fig, ax = plt.subplots(figsize=[4, 4]); ax.grid(False)
ax.plot([x[0] for x in train_losses], color='dodgerblue')
ax.plot([x[0] for x in test_losses], color='skyblue', linestyle='--')
ax.set(xlabel='Epochs', ylabel='KLD + BCE + TSE Loss')

# retrieve the predictions
recon_lens = []; recon_batchs = []
with torch.no_grad():
    for i, data in enumerate(test_loader):
        data = data[0].to(device)
        (recon_batch, recon_len), _, _ = model(data)
        recon_batchs.extend(recon_batch.clone().detach().cpu().numpy())
        recon_lens.extend(recon_len.clone().detach().cpu().tolist())
recon_lens = [x[0] for x in recon_lens]
# retrieve the indices
trbs_test = pd.Series(trbs.iloc[idxs_test])

# there is an ability to reconstruct proper length
fig, ax = plt.subplots(); ax.grid(False)
ax.scatter(recon_lens, trbs_test.apply(len), s=10, alpha=0.1, zorder=1, color='dodgerblue')
xlim, ylim = ax.get_xlim(), ax.get_ylim()
domain = [min(min(xlim), min(ylim)), max(max(xlim), max(ylim))]
ax.plot(domain, domain, color='skyblue', linestyle='--', zorder=0)
ax.set_xlim(*xlim); ax.set_ylim(*ylim)
ax.set(xlabel='Predicted length', ylabel='True length')
print('Non-Rounded', ss.pearsonr(recon_lens, trbs_test.apply(len)))
# with rounding the picture becomes more clear, there is a linear relationship but not an exact one
fig, ax = plt.subplots(); ax.grid(False)
ax.scatter(pd.Series(recon_lens).apply(round), trbs_test.apply(len), s=10, alpha=0.1, zorder=1, color='dodgerblue')
xlim, ylim = ax.get_xlim(), ax.get_ylim()
domain = [min(min(xlim), min(ylim)), max(max(xlim), max(ylim))]
ax.plot(domain, domain, color='skyblue', linestyle='--', zorder=0)
ax.set_xlim(*xlim); ax.set_ylim(*ylim)
ax.set(xlabel='Predicted length', ylabel='True length')
print('Rounded', ss.pearsonr(recon_lens, trbs_test.apply(len)))

# save the model
torch.save(model.state_dict(), '../models/model.convvae_3aa.trb.torch')
model.eval()

from tqdm import tqdm
# get the encoded dimensions
torch.manual_seed(0); np.random.seed(0)
# create a complete loader or else there is an out of memory error
complete_loader = DataLoader(dataset=TensorDataset(X_trbs), batch_size=batch_size, shuffle=False)
# move through each subset in the complete loader
z_dims_per_batch = []
with torch.no_grad():
    for data in tqdm(complete_loader):
        data = data[0].to(device)
        enc_out = model.encode(data)
        # sampling centers around the mean so we just use mu
        z_dims_per_batch.append(enc_out[0].clone().detach().cpu().numpy())
z_dims = np.vstack(z_dims_per_batch)
del data, enc_out, z_dims_per_batch

from anndata import AnnData
from scipy.sparse import csr_matrix
adata = AnnData(z_dims, dtype=float)
sc.tl.pca(adata)
sc.pp.neighbors(adata, use_rep='X')
sc.tl.umap(adata)
sc.tl.leiden(adata, flavor='igraph', n_iterations=2)
adata.obs_names = trbs
# check length
adata.obs['LEN'] = adata.obs.index.to_series().apply(len)
sc.pl.umap(adata, color=['LEN'], vmin=12, vmax=20)

# save the current data
adata.write('../outs/adata.trb_3aa.h5ad')

# define a function to embed an amino acid with vocab only
map_direct = {x:[1 * (x == y) for y in vocab] for x in vocab}
def embed_aa(aa):
    return [x for x in map_direct[aa]]
# define a function to interpolate the protein
targ_len = 48
def stretch_pep(embedding, targ_len=targ_len):
    # get the current protein length
    orig_len, n_features = embedding.shape
    # derive the original and current lengths
    x = np.linspace(0, 1, targ_len)
    xp = np.linspace(0, 1, orig_len)
    # loop through each of the columns
    tensor = torch.Tensor(embedding.T.reshape(1, n_features, orig_len))
    res = torch.nn.functional.interpolate(tensor, size=(targ_len), mode='linear', align_corners=False)[0]
    return res

from Levenshtein import distance as levenshtein
# give examples if we had randomly chosen a "true sequence"
torch.manual_seed(0); np.random.seed(0)
idxs = np.random.choice(range(len(trbs_test)), size=100, replace=False)
# trbck distances
df_dist = pd.DataFrame(columns=['dist_to_truth','dist_to_rand'])
for idx in idxs:
    # retrieve the sequence of interest
    recon_batch = recon_batchs[idx]
    recon_len = recon_lens[idx]
    # retrieve the model resolved sequence and truth
    pred = pd.DataFrame(recon_batch, index=vocab)
    true_seq = trbs_test.iloc[idx]
    true = pd.DataFrame(stretch_pep(np.array([embed for embed in map(embed_aa, list(true_seq))])).numpy(), index=vocab)
    # resolve the sequence
    curr_len = 48; targ_len = round(recon_len)
    xp = np.arange(curr_len) / (curr_len - 1)
    x = np.arange(targ_len) / (targ_len - 1)
    # interpolate the results
    res = np.array([np.interp(x, xp, recon_batch[idx, :]) for idx in range(recon_batch.shape[0])])
    pred_seq = ''.join(pd.DataFrame(res.T, columns=vocab).idxmax(1))
    dist_from_truth = levenshtein(true_seq, pred_seq); dist_from_rand = 0
    for _ in range(10):
        dist_from_rand += levenshtein(np.random.choice(trbs_test)[0], pred_seq)
    df_dist.loc[df_dist.shape[0]] = dist_from_truth, dist_from_rand / 10
    
# compare prediction accuracy statistically
fig, ax = plt.subplots(figsize=[2, 4]); ax.grid(False)
sns.boxplot(x='variable', y='value', data=df_dist.melt(), ax=ax, saturation=1,
            linewidth=1.5, linecolor='dodgerblue', order=['dist_to_truth'], color='skyblue')
sns.boxplot(x='variable', y='value', data=df_dist.melt(), ax=ax, saturation=1,
            linewidth=1.5, linecolor='grey', order=['dist_to_rand'], color='lightgray')
# create artificial lines
# for idx in df_dist.index: ax.plot([0, 1], df_dist.loc[idx], color='k', alpha=0.1, lw=0.5, zorder=0, linestyle='--')
ax.tick_params(axis='x', labelrotation=90); ax.set_xticklabels(['Truth','Random']); ax.set_xlim(-1, 2)
ax.set(xlabel='Prediction vs.', ylabel='Levenshtein distance')
print(ss.wilcoxon(df_dist['dist_to_truth'], df_dist['dist_to_rand']))
print(ss.mannwhitneyu(df_dist['dist_to_truth'], df_dist['dist_to_rand']))

### TRB Trimming 2AAs

In [None]:
import pickle as pkl
# read in the aggregated values
with open('../outs/trbs.2aa.pkl', 'rb') as f: trbs = pkl.load(f)
trbs = pd.Series(list(set(trbs))); print(len(trbs))
# embed all of our unique TRBs
with open('../outs/map.trb_to_embed.extended_2aatrim.pkl', 'rb') as f: trb_to_embed = pkl.load(f)
X_trbs = torch.stack([x.to(torch.float32) for x in trbs.map(trb_to_embed)])
# randomly split the data into train and test
torch.manual_seed(0); np.random.seed(0)
idxs_train = np.random.choice(range(len(X_trbs)), size=round(len(X_trbs)*0.75), replace=False)
idxs_test = np.array(range(len(X_trbs)))
idxs_test = idxs_test[~np.isin(idxs_test, idxs_train)]
X_trbs_train = X_trbs[idxs_train]
X_trbs_test = X_trbs[idxs_test]

from torch.utils.data import TensorDataset, DataLoader
# create a latent space for the trbs
batch_size = 2048
train_loader = DataLoader(dataset=TensorDataset(X_trbs_train), batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=TensorDataset(X_trbs_test), batch_size=batch_size, shuffle=False)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# initialize the model
model = ConvVAE().to(device)

# set the seed for training
torch.manual_seed(0); np.random.seed(0)
# set the learning parameters
lr = 0.0005; epochs = 20
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
# train the model with dual losses to balance between two objectives
train_losses = []; test_losses = []
for epoch in range(1, epochs + 1):
    train_losses_, test_losses_ = [], []
    for loss_func in [1, 2]:
        train_losses_.append(train(epoch, loss_func))
        test_losses_.append(test(epoch, loss_func))
    train_losses.append(train_losses_)
    test_losses.append(test_losses_)
    
# examine the KLD and BCE loss
fig, ax = plt.subplots(figsize=[4, 4]); ax.grid(False)
ax.plot([x[0] for x in train_losses], color='dodgerblue')
ax.plot([x[0] for x in test_losses], color='skyblue', linestyle='--')
ax.set(xlabel='Epochs', ylabel='KLD + BCE Loss')
# examine the length loss
fig, ax = plt.subplots(figsize=[4, 4]); ax.grid(False)
ax.plot([x[1] for x in train_losses], color='dodgerblue')
ax.plot([x[1] for x in test_losses], color='skyblue', linestyle='--')
ax.set(xlabel='Epochs', ylabel='TSE (LEN) Loss')# set the seed for training
torch.manual_seed(0); np.random.seed(0)

# set the learning parameters
lr = 0.001; epochs = 40
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
# train the model with an integrated loss
train_losses = []; test_losses = []; epochs
for epoch in range(1, epochs + 1):
    train_losses_, test_losses_ = [], []
    for loss_func in [3]:
        train_losses_.append(train(epoch, loss_func))
        test_losses_.append(test(epoch, loss_func))
    train_losses.append(train_losses_)
    test_losses.append(test_losses_)

# examine the integrated loss from finetuning
fig, ax = plt.subplots(figsize=[4, 4]); ax.grid(False)
ax.plot([x[0] for x in train_losses], color='dodgerblue')
ax.plot([x[0] for x in test_losses], color='skyblue', linestyle='--')
ax.set(xlabel='Epochs', ylabel='KLD + BCE + TSE Loss')

# retrieve the predictions
recon_lens = []; recon_batchs = []
with torch.no_grad():
    for i, data in enumerate(test_loader):
        data = data[0].to(device)
        (recon_batch, recon_len), _, _ = model(data)
        recon_batchs.extend(recon_batch.clone().detach().cpu().numpy())
        recon_lens.extend(recon_len.clone().detach().cpu().tolist())
recon_lens = [x[0] for x in recon_lens]
# retrieve the indices
trbs_test = pd.Series(trbs.iloc[idxs_test])

# there is an ability to reconstruct proper length
fig, ax = plt.subplots(); ax.grid(False)
ax.scatter(recon_lens, trbs_test.apply(len), s=10, alpha=0.1, zorder=1, color='dodgerblue')
xlim, ylim = ax.get_xlim(), ax.get_ylim()
domain = [min(min(xlim), min(ylim)), max(max(xlim), max(ylim))]
ax.plot(domain, domain, color='skyblue', linestyle='--', zorder=0)
ax.set_xlim(*xlim); ax.set_ylim(*ylim)
ax.set(xlabel='Predicted length', ylabel='True length')
print('Non-Rounded', ss.pearsonr(recon_lens, trbs_test.apply(len)))
# with rounding the picture becomes more clear, there is a linear relationship but not an exact one
fig, ax = plt.subplots(); ax.grid(False)
ax.scatter(pd.Series(recon_lens).apply(round), trbs_test.apply(len), s=10, alpha=0.1, zorder=1, color='dodgerblue')
xlim, ylim = ax.get_xlim(), ax.get_ylim()
domain = [min(min(xlim), min(ylim)), max(max(xlim), max(ylim))]
ax.plot(domain, domain, color='skyblue', linestyle='--', zorder=0)
ax.set_xlim(*xlim); ax.set_ylim(*ylim)
ax.set(xlabel='Predicted length', ylabel='True length')
print('Rounded', ss.pearsonr(recon_lens, trbs_test.apply(len)))

# save the model
torch.save(model.state_dict(), '../models/model.convvae_2aa.trb.torch')
model.eval()

from tqdm import tqdm
# get the encoded dimensions
torch.manual_seed(0); np.random.seed(0)
# create a complete loader or else there is an out of memory error
complete_loader = DataLoader(dataset=TensorDataset(X_trbs), batch_size=batch_size, shuffle=False)
# move through each subset in the complete loader
z_dims_per_batch = []
with torch.no_grad():
    for data in tqdm(complete_loader):
        data = data[0].to(device)
        enc_out = model.encode(data)
        # sampling centers around the mean so we just use mu
        z_dims_per_batch.append(enc_out[0].clone().detach().cpu().numpy())
z_dims = np.vstack(z_dims_per_batch)
del data, enc_out, z_dims_per_batch

from anndata import AnnData
from scipy.sparse import csr_matrix
adata = AnnData(z_dims, dtype=float)
sc.tl.pca(adata)
sc.pp.neighbors(adata, use_rep='X')
sc.tl.umap(adata)
sc.tl.leiden(adata, flavor='igraph', n_iterations=2)
adata.obs_names = trbs
# check length
adata.obs['LEN'] = adata.obs.index.to_series().apply(len)
sc.pl.umap(adata, color=['LEN'], vmin=12, vmax=20)

# save the current data
adata.write('../outs/adata.trb_2aa.h5ad')

# define a function to embed an amino acid with vocab only
map_direct = {x:[1 * (x == y) for y in vocab] for x in vocab}
def embed_aa(aa):
    return [x for x in map_direct[aa]]
# define a function to interpolate the protein
targ_len = 48
def stretch_pep(embedding, targ_len=targ_len):
    # get the current protein length
    orig_len, n_features = embedding.shape
    # derive the original and current lengths
    x = np.linspace(0, 1, targ_len)
    xp = np.linspace(0, 1, orig_len)
    # loop through each of the columns
    tensor = torch.Tensor(embedding.T.reshape(1, n_features, orig_len))
    res = torch.nn.functional.interpolate(tensor, size=(targ_len), mode='linear', align_corners=False)[0]
    return res

from Levenshtein import distance as levenshtein
# give examples if we had randomly chosen a "true sequence"
torch.manual_seed(0); np.random.seed(0)
idxs = np.random.choice(range(len(trbs_test)), size=100, replace=False)
# trbck distances
df_dist = pd.DataFrame(columns=['dist_to_truth','dist_to_rand'])
for idx in idxs:
    # retrieve the sequence of interest
    recon_batch = recon_batchs[idx]
    recon_len = recon_lens[idx]
    # retrieve the model resolved sequence and truth
    pred = pd.DataFrame(recon_batch, index=vocab)
    true_seq = trbs_test.iloc[idx]
    true = pd.DataFrame(stretch_pep(np.array([embed for embed in map(embed_aa, list(true_seq))])).numpy(), index=vocab)
    # resolve the sequence
    curr_len = 48; targ_len = round(recon_len)
    xp = np.arange(curr_len) / (curr_len - 1)
    x = np.arange(targ_len) / (targ_len - 1)
    # interpolate the results
    res = np.array([np.interp(x, xp, recon_batch[idx, :]) for idx in range(recon_batch.shape[0])])
    pred_seq = ''.join(pd.DataFrame(res.T, columns=vocab).idxmax(1))
    dist_from_truth = levenshtein(true_seq, pred_seq); dist_from_rand = 0
    for _ in range(10):
        dist_from_rand += levenshtein(np.random.choice(trbs_test)[0], pred_seq)
    df_dist.loc[df_dist.shape[0]] = dist_from_truth, dist_from_rand / 10
    
# compare prediction accuracy statistically
fig, ax = plt.subplots(figsize=[2, 4]); ax.grid(False)
sns.boxplot(x='variable', y='value', data=df_dist.melt(), ax=ax, saturation=1,
            linewidth=1.5, linecolor='dodgerblue', order=['dist_to_truth'], color='skyblue')
sns.boxplot(x='variable', y='value', data=df_dist.melt(), ax=ax, saturation=1,
            linewidth=1.5, linecolor='grey', order=['dist_to_rand'], color='lightgray')
# create artificial lines
# for idx in df_dist.index: ax.plot([0, 1], df_dist.loc[idx], color='k', alpha=0.1, lw=0.5, zorder=0, linestyle='--')
ax.tick_params(axis='x', labelrotation=90); ax.set_xticklabels(['Truth','Random']); ax.set_xlim(-1, 2)
ax.set(xlabel='Prediction vs.', ylabel='Levenshtein distance')
print(ss.wilcoxon(df_dist['dist_to_truth'], df_dist['dist_to_rand']))
print(ss.mannwhitneyu(df_dist['dist_to_truth'], df_dist['dist_to_rand']))

### TRB Trimming 1AAs

In [None]:
import pickle as pkl
# read in the aggregated values
with open('../outs/trbs.1aa.pkl', 'rb') as f: trbs = pkl.load(f)
trbs = pd.Series(list(set(trbs))); print(len(trbs))
# embed all of our unique TRBs
with open('../outs/map.trb_to_embed.extended_1aatrim.pkl', 'rb') as f: trb_to_embed = pkl.load(f)
X_trbs = torch.stack([x.to(torch.float32) for x in trbs.map(trb_to_embed)])
# randomly split the data into train and test
torch.manual_seed(0); np.random.seed(0)
idxs_train = np.random.choice(range(len(X_trbs)), size=round(len(X_trbs)*0.75), replace=False)
idxs_test = np.array(range(len(X_trbs)))
idxs_test = idxs_test[~np.isin(idxs_test, idxs_train)]
X_trbs_train = X_trbs[idxs_train]
X_trbs_test = X_trbs[idxs_test]

from torch.utils.data import TensorDataset, DataLoader
# create a latent space for the trbs
batch_size = 2048
train_loader = DataLoader(dataset=TensorDataset(X_trbs_train), batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=TensorDataset(X_trbs_test), batch_size=batch_size, shuffle=False)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# initialize the model
model = ConvVAE().to(device)

# set the seed for training
torch.manual_seed(0); np.random.seed(0)
# set the learning parameters
lr = 0.0005; epochs = 20
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
# train the model with dual losses to balance between two objectives
train_losses = []; test_losses = []
for epoch in range(1, epochs + 1):
    train_losses_, test_losses_ = [], []
    for loss_func in [1, 2]:
        train_losses_.append(train(epoch, loss_func))
        test_losses_.append(test(epoch, loss_func))
    train_losses.append(train_losses_)
    test_losses.append(test_losses_)
    
# examine the KLD and BCE loss
fig, ax = plt.subplots(figsize=[4, 4]); ax.grid(False)
ax.plot([x[0] for x in train_losses], color='dodgerblue')
ax.plot([x[0] for x in test_losses], color='skyblue', linestyle='--')
ax.set(xlabel='Epochs', ylabel='KLD + BCE Loss')
# examine the length loss
fig, ax = plt.subplots(figsize=[4, 4]); ax.grid(False)
ax.plot([x[1] for x in train_losses], color='dodgerblue')
ax.plot([x[1] for x in test_losses], color='skyblue', linestyle='--')
ax.set(xlabel='Epochs', ylabel='TSE (LEN) Loss')# set the seed for training
torch.manual_seed(0); np.random.seed(0)

# set the learning parameters
lr = 0.001; epochs = 40
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
# train the model with an integrated loss
train_losses = []; test_losses = []; epochs
for epoch in range(1, epochs + 1):
    train_losses_, test_losses_ = [], []
    for loss_func in [3]:
        train_losses_.append(train(epoch, loss_func))
        test_losses_.append(test(epoch, loss_func))
    train_losses.append(train_losses_)
    test_losses.append(test_losses_)

# examine the integrated loss from finetuning
fig, ax = plt.subplots(figsize=[4, 4]); ax.grid(False)
ax.plot([x[0] for x in train_losses], color='dodgerblue')
ax.plot([x[0] for x in test_losses], color='skyblue', linestyle='--')
ax.set(xlabel='Epochs', ylabel='KLD + BCE + TSE Loss')

# retrieve the predictions
recon_lens = []; recon_batchs = []
with torch.no_grad():
    for i, data in enumerate(test_loader):
        data = data[0].to(device)
        (recon_batch, recon_len), _, _ = model(data)
        recon_batchs.extend(recon_batch.clone().detach().cpu().numpy())
        recon_lens.extend(recon_len.clone().detach().cpu().tolist())
recon_lens = [x[0] for x in recon_lens]
# retrieve the indices
trbs_test = pd.Series(trbs.iloc[idxs_test])

# there is an ability to reconstruct proper length
fig, ax = plt.subplots(); ax.grid(False)
ax.scatter(recon_lens, trbs_test.apply(len), s=10, alpha=0.1, zorder=1, color='dodgerblue')
xlim, ylim = ax.get_xlim(), ax.get_ylim()
domain = [min(min(xlim), min(ylim)), max(max(xlim), max(ylim))]
ax.plot(domain, domain, color='skyblue', linestyle='--', zorder=0)
ax.set_xlim(*xlim); ax.set_ylim(*ylim)
ax.set(xlabel='Predicted length', ylabel='True length')
print('Non-Rounded', ss.pearsonr(recon_lens, trbs_test.apply(len)))
# with rounding the picture becomes more clear, there is a linear relationship but not an exact one
fig, ax = plt.subplots(); ax.grid(False)
ax.scatter(pd.Series(recon_lens).apply(round), trbs_test.apply(len), s=10, alpha=0.1, zorder=1, color='dodgerblue')
xlim, ylim = ax.get_xlim(), ax.get_ylim()
domain = [min(min(xlim), min(ylim)), max(max(xlim), max(ylim))]
ax.plot(domain, domain, color='skyblue', linestyle='--', zorder=0)
ax.set_xlim(*xlim); ax.set_ylim(*ylim)
ax.set(xlabel='Predicted length', ylabel='True length')
print('Rounded', ss.pearsonr(recon_lens, trbs_test.apply(len)))

# save the model
torch.save(model.state_dict(), '../models/model.convvae_1aa.trb.torch')
model.eval()

from tqdm import tqdm
# get the encoded dimensions
torch.manual_seed(0); np.random.seed(0)
# create a complete loader or else there is an out of memory error
complete_loader = DataLoader(dataset=TensorDataset(X_trbs), batch_size=batch_size, shuffle=False)
# move through each subset in the complete loader
z_dims_per_batch = []
with torch.no_grad():
    for data in tqdm(complete_loader):
        data = data[0].to(device)
        enc_out = model.encode(data)
        # sampling centers around the mean so we just use mu
        z_dims_per_batch.append(enc_out[0].clone().detach().cpu().numpy())
z_dims = np.vstack(z_dims_per_batch)
del data, enc_out, z_dims_per_batch

from anndata import AnnData
from scipy.sparse import csr_matrix
adata = AnnData(z_dims, dtype=float)
sc.tl.pca(adata)
sc.pp.neighbors(adata, use_rep='X')
sc.tl.umap(adata)
sc.tl.leiden(adata, flavor='igraph', n_iterations=2)
adata.obs_names = trbs
# check length
adata.obs['LEN'] = adata.obs.index.to_series().apply(len)
sc.pl.umap(adata, color=['LEN'], vmin=12, vmax=20)

# save the current data
adata.write('../outs/adata.trb_1aa.h5ad')

# define a function to embed an amino acid with vocab only
map_direct = {x:[1 * (x == y) for y in vocab] for x in vocab}
def embed_aa(aa):
    return [x for x in map_direct[aa]]
# define a function to interpolate the protein
targ_len = 48
def stretch_pep(embedding, targ_len=targ_len):
    # get the current protein length
    orig_len, n_features = embedding.shape
    # derive the original and current lengths
    x = np.linspace(0, 1, targ_len)
    xp = np.linspace(0, 1, orig_len)
    # loop through each of the columns
    tensor = torch.Tensor(embedding.T.reshape(1, n_features, orig_len))
    res = torch.nn.functional.interpolate(tensor, size=(targ_len), mode='linear', align_corners=False)[0]
    return res

from Levenshtein import distance as levenshtein
# give examples if we had randomly chosen a "true sequence"
torch.manual_seed(0); np.random.seed(0)
idxs = np.random.choice(range(len(trbs_test)), size=100, replace=False)
# trbck distances
df_dist = pd.DataFrame(columns=['dist_to_truth','dist_to_rand'])
for idx in idxs:
    # retrieve the sequence of interest
    recon_batch = recon_batchs[idx]
    recon_len = recon_lens[idx]
    # retrieve the model resolved sequence and truth
    pred = pd.DataFrame(recon_batch, index=vocab)
    true_seq = trbs_test.iloc[idx]
    true = pd.DataFrame(stretch_pep(np.array([embed for embed in map(embed_aa, list(true_seq))])).numpy(), index=vocab)
    # resolve the sequence
    curr_len = 48; targ_len = round(recon_len)
    xp = np.arange(curr_len) / (curr_len - 1)
    x = np.arange(targ_len) / (targ_len - 1)
    # interpolate the results
    res = np.array([np.interp(x, xp, recon_batch[idx, :]) for idx in range(recon_batch.shape[0])])
    pred_seq = ''.join(pd.DataFrame(res.T, columns=vocab).idxmax(1))
    dist_from_truth = levenshtein(true_seq, pred_seq); dist_from_rand = 0
    for _ in range(10):
        dist_from_rand += levenshtein(np.random.choice(trbs_test)[0], pred_seq)
    df_dist.loc[df_dist.shape[0]] = dist_from_truth, dist_from_rand / 10
    
# compare prediction accuracy statistically
fig, ax = plt.subplots(figsize=[2, 4]); ax.grid(False)
sns.boxplot(x='variable', y='value', data=df_dist.melt(), ax=ax, saturation=1,
            linewidth=1.5, linecolor='dodgerblue', order=['dist_to_truth'], color='skyblue')
sns.boxplot(x='variable', y='value', data=df_dist.melt(), ax=ax, saturation=1,
            linewidth=1.5, linecolor='grey', order=['dist_to_rand'], color='lightgray')
# create artificial lines
# for idx in df_dist.index: ax.plot([0, 1], df_dist.loc[idx], color='k', alpha=0.1, lw=0.5, zorder=0, linestyle='--')
ax.tick_params(axis='x', labelrotation=90); ax.set_xticklabels(['Truth','Random']); ax.set_xlim(-1, 2)
ax.set(xlabel='Prediction vs.', ylabel='Levenshtein distance')
print(ss.wilcoxon(df_dist['dist_to_truth'], df_dist['dist_to_rand']))
print(ss.mannwhitneyu(df_dist['dist_to_truth'], df_dist['dist_to_rand']))

### COVID-19 Mapping

In [None]:
import pickle as pkl
# read in the aggregated values
with open('../external_data/dbv2.trbs.pkl', 'rb') as f: trbs = pkl.load(f)
trbs = pd.Series(trbs)
# trim the sequences
trbs4 = [x[4:-4] for x in trbs if len(x[4:-4]) > 0]
trbs3 = [x[3:-3] for x in trbs if len(x[3:-3]) > 0]
trbs2 = [x[2:-2] for x in trbs if len(x[2:-2]) > 0]
trbs1 = [x[1:-1] for x in trbs if len(x[1:-1]) > 0]

In [None]:
# process the TRBs
trb_to_embed = {}
for sequence in tqdm(trbs4):
    embedding = np.array([embed for embed in map(embed_aa, list(sequence))])
    embedding = stretch_pep(embedding, targ_len=targ_len)
    trb_to_embed[sequence] = embedding
# save the embedding maps
with open('../outs/mapv2.trb_to_embed.extended_4aatrim.pkl', 'wb') as f: pkl.dump(trb_to_embed, f)

# process the TRBs
trb_to_embed = {}
for sequence in tqdm(trbs3):
    embedding = np.array([embed for embed in map(embed_aa, list(sequence))])
    embedding = stretch_pep(embedding, targ_len=targ_len)
    trb_to_embed[sequence] = embedding
# save the embedding maps
with open('../outs/mapv2.trb_to_embed.extended_3aatrim.pkl', 'wb') as f: pkl.dump(trb_to_embed, f)

# process the TRBs
trb_to_embed = {}
for sequence in tqdm(trbs2):
    embedding = np.array([embed for embed in map(embed_aa, list(sequence))])
    embedding = stretch_pep(embedding, targ_len=targ_len)
    trb_to_embed[sequence] = embedding
# save the embedding maps
with open('../outs/mapv2.trb_to_embed.extended_2aatrim.pkl', 'wb') as f: pkl.dump(trb_to_embed, f)

# process the TRBs
trb_to_embed = {}
for sequence in tqdm(trbs1):
    embedding = np.array([embed for embed in map(embed_aa, list(sequence))])
    embedding = stretch_pep(embedding, targ_len=targ_len)
    trb_to_embed[sequence] = embedding
# save the embedding maps
with open('../outs/mapv2.trb_to_embed.extended_1aatrim.pkl', 'wb') as f: pkl.dump(trb_to_embed, f)

In [None]:
# write the TRBs
with open('../outs/trbsv2.4aa.pkl', 'wb') as f: pkl.dump(trbs4, f)
with open('../outs/trbsv2.3aa.pkl', 'wb') as f: pkl.dump(trbs3, f)
with open('../outs/trbsv2.2aa.pkl', 'wb') as f: pkl.dump(trbs2, f)
with open('../outs/trbsv2.1aa.pkl', 'wb') as f: pkl.dump(trbs1, f)

In [None]:
# read in the aggregated values
with open('../outs/trbsv2.1aa.pkl', 'rb') as f: trbs = pkl.load(f)
trbs = pd.Series(list(set(trbs)))
with open('../outs/mapv2.trb_to_embed.extended_1aatrim.pkl', 'rb') as f: trb_to_embed = pkl.load(f)
X_trbs = torch.stack([x.to(torch.float32) for x in trbs.map(trb_to_embed)])
# initialize the model
model = ConvVAE().to(device)
model.load_state_dict(torch.load('../models/model.convvae_1aa.trb.torch', weights_only=True))
model.eval()

# get the encoded dimensions
torch.manual_seed(0); np.random.seed(0)
complete_loader = DataLoader(dataset=TensorDataset(X_trbs), batch_size=batch_size, shuffle=False)
z_dims_per_batch = []
with torch.no_grad():
    for data in tqdm(complete_loader):
        data = data[0].to(device)
        enc_out = model.encode(data)
        # sampling centers around the mean so we just use mu
        z_dims_per_batch.append(enc_out[0].clone().detach().cpu().numpy())
z_dims = np.vstack(z_dims_per_batch)
del data, enc_out, z_dims_per_batch
pd.DataFrame(z_dims, index=trbs).to_csv('../outs/su22.trb_1aatrim.csv')

In [None]:
# read in the aggregated values
with open('../outs/trbsv2.2aa.pkl', 'rb') as f: trbs = pkl.load(f)
trbs = pd.Series(list(set(trbs)))
with open('../outs/mapv2.trb_to_embed.extended_2aatrim.pkl', 'rb') as f: trb_to_embed = pkl.load(f)
X_trbs = torch.stack([x.to(torch.float32) for x in trbs.map(trb_to_embed)])
# initialize the model
model = ConvVAE().to(device)
model.load_state_dict(torch.load('../models/model.convvae_2aa.trb.torch', weights_only=True))
model.eval()

# get the encoded dimensions
torch.manual_seed(0); np.random.seed(0)
complete_loader = DataLoader(dataset=TensorDataset(X_trbs), batch_size=batch_size, shuffle=False)
z_dims_per_batch = []
with torch.no_grad():
    for data in tqdm(complete_loader):
        data = data[0].to(device)
        enc_out = model.encode(data)
        # sampling centers around the mean so we just use mu
        z_dims_per_batch.append(enc_out[0].clone().detach().cpu().numpy())
z_dims = np.vstack(z_dims_per_batch)
del data, enc_out, z_dims_per_batch
pd.DataFrame(z_dims, index=trbs).to_csv('../outs/su22.trb_2aatrim.csv')

In [None]:
# read in the aggregated values
with open('../outs/trbsv2.3aa.pkl', 'rb') as f: trbs = pkl.load(f)
trbs = pd.Series(list(set(trbs)))
with open('../outs/mapv2.trb_to_embed.extended_3aatrim.pkl', 'rb') as f: trb_to_embed = pkl.load(f)
X_trbs = torch.stack([x.to(torch.float32) for x in trbs.map(trb_to_embed)])
# initialize the model
model = ConvVAE().to(device)
model.load_state_dict(torch.load('../models/model.convvae_3aa.trb.torch', weights_only=True))
model.eval()

# get the encoded dimensions
torch.manual_seed(0); np.random.seed(0)
complete_loader = DataLoader(dataset=TensorDataset(X_trbs), batch_size=batch_size, shuffle=False)
z_dims_per_batch = []
with torch.no_grad():
    for data in tqdm(complete_loader):
        data = data[0].to(device)
        enc_out = model.encode(data)
        # sampling centers around the mean so we just use mu
        z_dims_per_batch.append(enc_out[0].clone().detach().cpu().numpy())
z_dims = np.vstack(z_dims_per_batch)
del data, enc_out, z_dims_per_batch
pd.DataFrame(z_dims, index=trbs).to_csv('../outs/su22.trb_3aatrim.csv')

In [None]:
# read in the aggregated values
with open('../outs/trbsv2.4aa.pkl', 'rb') as f: trbs = pkl.load(f)
trbs = pd.Series(list(set(trbs)))
with open('../outs/mapv2.trb_to_embed.extended_4aatrim.pkl', 'rb') as f: trb_to_embed = pkl.load(f)
X_trbs = torch.stack([x.to(torch.float32) for x in trbs.map(trb_to_embed)])
# initialize the model
model = ConvVAE().to(device)
model.load_state_dict(torch.load('../models/model.convvae_4aa.trb.torch', weights_only=True))
model.eval()

# get the encoded dimensions
torch.manual_seed(0); np.random.seed(0)
complete_loader = DataLoader(dataset=TensorDataset(X_trbs), batch_size=batch_size, shuffle=False)
z_dims_per_batch = []
with torch.no_grad():
    for data in tqdm(complete_loader):
        data = data[0].to(device)
        enc_out = model.encode(data)
        # sampling centers around the mean so we just use mu
        z_dims_per_batch.append(enc_out[0].clone().detach().cpu().numpy())
z_dims = np.vstack(z_dims_per_batch)
del data, enc_out, z_dims_per_batch
pd.DataFrame(z_dims, index=trbs).to_csv('../outs/su22.trb_4aatrim.csv')

### Gathering Trimmed Data for Modeling

In [None]:
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import auc, roc_curve, precision_recall_curve, f1_score, balanced_accuracy_score, accuracy_score
import warnings
warnings.filterwarnings('ignore')

# read in the pickled data
with open('../external_data/results.tcr.pkl', 'rb') as f:
    results_tcr = pkl.load(f)
# define the minimum number of cells
min_cells = 2
# derive annotations
results_tcr['SU_CELL2022_COVID19'][['batch','subbatch','sample']] = \
results_tcr['SU_CELL2022_COVID19']['batch_info'].str.split(':', expand=True)

In [None]:
from sklearn.linear_model import LogisticRegression
# define a function to interrogate the data
def interrogate_with_globals():
    # create statistics tracking dataframe
    df_stat = pd.DataFrame(columns=['auroc','auprc','f1_score','balacc'])
    # create tracking variables for downstream visualization and statistics
    probas, probas_bin, truths = [], [], []
    fprs, tprs, pres, recs = [], [], [], []
    # train utilizing random forest models in a stratified shuffled manner
    skf = StratifiedShuffleSplit(n_splits=10, random_state=0, test_size=1/4)
    for idxs_train, idxs_test in skf.split(X1, y1):
        # instantiate the random forest model
        clf = LogisticRegression()
        # fit the random forest model using Dataset #1
        clf = clf.fit(X1.iloc[idxs_train], y1.iloc[idxs_train])

        # predict on Dataset #2 correcting to all indices if requested
        if pred_on_all:
            idxs_test = range(X2.shape[0])
        # derive the probabilities
        proba = clf.predict_proba(X2.iloc[idxs_test])[:, clf.classes_ == 1]
        probas.append(pd.Series(proba[:, 0], index=X2.index[idxs_test]))
        # binarize into categorical predictions
        proba_bin = 1 * (proba >= 0.50)
        probas_bin.append(pd.Series(proba_bin[:, 0], index=X2.index[idxs_test]))
        # retrieve the associated ground truth
        truth = y2.iloc[idxs_test]
        truths.append(truth.copy())

        # compute subsequent AUROC and AUPRC related metrics
        fpr, tpr, _ = roc_curve(truth, proba)
        pre, rec, _ = precision_recall_curve(truth, proba)
        fprs.append(fpr); tprs.append(tpr); pres.append(pre); recs.append(rec)
        # save the relevant statistics
        df_stat.loc[df_stat.shape[0]] = auc(fpr, tpr), auc(rec, pre), \
                                        f1_score(truth, proba_bin, average='binary'), \
                                        balanced_accuracy_score(truth, proba_bin)
    return df_stat

In [None]:
# loop through each trim length
for trim_len in range(1, 4+1, 1):
    # read in the data
    a_trb = sc.read_h5ad(f'../outs/adata.trb_{trim_len}aa.h5ad')
    trb_covid = pd.read_csv(f'../outs/su22.trb_{trim_len}aatrim.csv', index_col=0)
    # compile data
    trbs_atlas = pd.DataFrame(a_trb.X, index=a_trb.obs.index)
    trbs_atlas = trbs_atlas.loc[~trbs_atlas.index.isin(trb_covid.index)]
    trb_covid.columns = trbs_atlas.columns
    trbs_X = pd.concat([trbs_atlas, trb_covid], axis=0)
    assert trbs_X.index.is_unique
    
    # get the tag, keeping only pairs that have at least min_cells cells
    clusters = ['CD8+T','TREG','CD4+T']
    mask = results_tcr['SUO_SCIENCE2022_FETAL']['celltype_annotation'].isin(clusters)
    data = results_tcr['SUO_SCIENCE2022_FETAL'].loc[mask, ['donor','celltype_annotation','TRB']].astype(str).copy()
    data['tag'] = data[['donor','celltype_annotation']].astype(str).agg(':'.join, axis=1)
    data['TRB'] = data['TRB'].str.slice(trim_len, -trim_len)
    # filter the data more harshly because less assured of quality
    data['TRB'][~data['TRB'].isin(a_trb.obs.index)] = np.nan
    data = data.dropna(subset=['TRB'])
    counts = data['tag'].value_counts(); tags = counts.index[counts >= min_cells]
    # compile the Xs
    Xs = []
    for tag in tqdm(tags):
        trbs = data.loc[data['tag'] == tag, 'TRB']
        mask = trbs[trbs.isin(a_trb.obs.index)]
        X_ = pd.Series(a_trb[mask].X.mean(0), name=tag)
        Xs.append(X_)
    og_trb_suo2022_X = pd.concat(Xs, axis=1).T
    
    # get the tag, keeping only pairs that have at least min_cells cells
    data = results_tcr['SU_CELL2022_COVID19'][['sample','TcellType','TRB']].astype(str).copy()
    data['tag'] = data[['sample','TcellType']].astype(str).agg(':'.join, axis=1)
    data['TRB'] = data['TRB'].str.slice(trim_len, -trim_len)
    # filter the data more harshly because less assured of quality
    data['TRB'][~data['TRB'].isin(trbs_X.index)] = np.nan
    data = data.dropna(subset=['TRB'])
    counts = data['tag'].value_counts(); tags = counts.index[counts >= min_cells]
    # compile the Xs
    Xs = []
    for tag in tqdm(tags):
        trbs = data.loc[data['tag'] == tag, 'TRB']
        mask = trbs[trbs.isin(trbs_X.index)]
        X_ = pd.Series(trbs_X.loc[mask].mean(0), name=tag)
        Xs.append(X_)
    og_trb_su2022_X = pd.concat(Xs, axis=1).T
    
    # get the tag, keeping only pairs that have at least min_cells cells
    data = results_tcr['ZHENG_SCIENCE2021_PANCAN'][['patient','TcellType','TRB']].astype(str).copy()
    data['tag'] = data[['patient','TcellType']].astype(str).agg(':'.join, axis=1)
    data['TRB'] = data['TRB'].str.slice(trim_len, -trim_len)
    # filter the data more harshly because less assured of quality
    data['TRB'][~data['TRB'].isin(a_trb.obs.index)] = np.nan
    data = data.dropna(subset=['TRB'])
    counts = data['tag'].value_counts(); tags = counts.index[counts >= min_cells]
    # compile the Xs
    Xs = []
    for tag in tqdm(tags):
        trbs = data.loc[data['tag'] == tag, 'TRB']
        mask = trbs[trbs.isin(a_trb.obs.index)]
        X_ = pd.Series(a_trb[mask].X.mean(0), name=tag)
        Xs.append(X_)
    og_trb_zheng2021_X = pd.concat(Xs, axis=1).T
    
    # SELF PREDICTION
    # define whether we are to predict on the complete data
    pred_on_all = False
    
    # define the data to train on
    X1 = og_trb_suo2022_X.copy()
    # setup a mask for CD8+ cells
    y1 = pd.Series(X1.index.str.contains(':CD8'), index=X1.index)
    # define the data to predict on
    X2 = og_trb_suo2022_X.copy()
    # setup a mask for CD8+ cells
    y2 = pd.Series(X2.index.str.contains(':CD8'), index=X2.index)
    # perform predictions with all
    df_stat_fetal = interrogate_with_globals()
    
    # define the data to train on
    X1 = og_trb_su2022_X.copy()
    # setup a mask for CD8+ cells
    y1 = pd.Series(X1.index.str.contains(':CD8'), index=X1.index)
    # define the data to predict on
    X2 = og_trb_su2022_X.copy()
    # setup a mask for CD8+ cells
    y2 = pd.Series(X2.index.str.contains(':CD8'), index=X2.index)
    # perform predictions with all
    df_stat_covid = interrogate_with_globals()
    
    # define the data to train on
    X1 = og_trb_zheng2021_X.copy()
    # setup a mask for CD8+ cells
    y1 = pd.Series(X1.index.str.contains(':CD8'), index=X1.index)
    # define the data to predict on
    X2 = og_trb_zheng2021_X.copy()
    # setup a mask for CD8+ cells
    y2 = pd.Series(X2.index.str.contains(':CD8'), index=X2.index)
    # perform predictions with all
    df_stat_tumor = interrogate_with_globals()
    
    # CROSS PREDICTION
    # define whether we are to predict on the complete data
    pred_on_all = True
    
    # define the data to train on
    X1 = og_trb_suo2022_X.copy()
    # setup a mask for CD8+ cells
    y1 = pd.Series(X1.index.str.contains(':CD8'), index=X1.index)
    # define the data to predict on
    X2 = og_trb_zheng2021_X.copy()
    # setup a mask for CD8+ cells
    y2 = pd.Series(X2.index.str.contains(':CD8'), index=X2.index)
    # perform predictions with all
    df_stat_fetal2tumor = interrogate_with_globals()
    
    # define the data to train on
    X1 = og_trb_suo2022_X.copy()
    # setup a mask for CD8+ cells
    y1 = pd.Series(X1.index.str.contains(':CD8'), index=X1.index)
    # define the data to predict on
    X2 = og_trb_su2022_X.copy()
    # setup a mask for CD8+ cells
    y2 = pd.Series(X2.index.str.contains(':CD8'), index=X2.index)
    # perform predictions with all
    df_stat_fetal2covid = interrogate_with_globals()
    
    # define the data to train on
    X1 = og_trb_suo2022_X.copy()
    # setup a mask for CD8+ cells
    y1 = pd.Series(X1.index.str.contains(':CD8'), index=X1.index)
    # define the data to predict on
    # > covid-19 and healthy donors
    X2A = og_trb_su2022_X.copy()
    # setup a mask for CD8+ cells
    y2A = pd.Series(X2A.index.str.contains(':CD8'), index=X2A.index)
    # > pan-cancer types
    X2B = og_trb_zheng2021_X.copy()
    # setup a mask for CD8+ cells
    y2B = pd.Series(X2B.index.str.contains(':CD8'), index=X2B.index)
    # > concatenate the two datasets
    X2 = pd.concat([X2A, X2B], axis=0)
    y2 = pd.concat([y2A, y2B], axis=0)
    # perform predictions with all
    df_stat_fetal2adult = interrogate_with_globals()
    
    # define the data to train on
    X1 = og_trb_suo2022_X.copy()
    # setup a mask for CD8+ cells
    y1 = pd.Series(X1.index.str.contains(':CD8'), index=X1.index)
    # define the data to predict on
    X2 = og_trb_zheng2021_X.copy()
    # setup a mask for CD8+ cells
    y2 = pd.Series(X2.index.str.contains(':CD8'), index=X2.index)
    # reverse the comparison
    X1, y1, X2, y2 = X2, y2, X1, y1
    # perform predictions with all
    df_stat_tumor2fetal = interrogate_with_globals()
    
    # define the data to train on
    X1 = og_trb_suo2022_X.copy()
    # setup a mask for CD8+ cells
    y1 = pd.Series(X1.index.str.contains(':CD8'), index=X1.index)
    # define the data to predict on
    X2 = og_trb_su2022_X.copy()
    # setup a mask for CD8+ cells
    y2 = pd.Series(X2.index.str.contains(':CD8'), index=X2.index)
    # reverse the comparison
    X1, y1, X2, y2 = X2, y2, X1, y1
    # perform predictions with all
    df_stat_covid2fetal = interrogate_with_globals()
    
    # define the data to train on
    X1 = og_trb_suo2022_X.copy()
    # setup a mask for CD8+ cells
    y1 = pd.Series(X1.index.str.contains(':CD8'), index=X1.index)
    # define the data to predict on
    # > covid-19 and healthy donors
    X2A = og_trb_su2022_X.copy()
    # setup a mask for CD8+ cells
    y2A = pd.Series(X2A.index.str.contains(':CD8'), index=X2A.index)
    # > pan-cancer types
    X2B = og_trb_zheng2021_X.copy()
    # setup a mask for CD8+ cells
    y2B = pd.Series(X2B.index.str.contains(':CD8'), index=X2B.index)
    # > concatenate the two datasets
    X2 = pd.concat([X2A, X2B], axis=0)
    y2 = pd.concat([y2A, y2B], axis=0)
    # reverse the comparison
    X1, y1, X2, y2 = X2, y2, X1, y1
    # perform predictions with all
    df_stat_adult2fetal = interrogate_with_globals()
    
    # define the data to train on
    X1 = og_trb_su2022_X.copy()
    # setup a mask for CD8+ cells
    y1 = pd.Series(X1.index.str.contains(':CD8'), index=X2A.index)
    # define the data to predict on
    X2 = og_trb_zheng2021_X.copy()
    # setup a mask for CD8+ cells
    y2 = pd.Series(X2.index.str.contains(':CD8'), index=X2.index)
    # perform predictions with all
    df_stat_covid2tumor = interrogate_with_globals()
    
    # define the data to train on
    X1 = og_trb_su2022_X.copy()
    # setup a mask for CD8+ cells
    y1 = pd.Series(X1.index.str.contains(':CD8'), index=X2A.index)
    # define the data to predict on
    X2 = og_trb_zheng2021_X.copy()
    # setup a mask for CD8+ cells
    y2 = pd.Series(X2.index.str.contains(':CD8'), index=X2.index)
    # reverse the comparison
    X1, y1, X2, y2 = X2, y2, X1, y1
    # perform predictions with all
    df_stat_tumor2covid = interrogate_with_globals()
    
    import pickle as pkl
    # save all of the values
    df_stats = {'adult2fetal':df_stat_adult2fetal,
                'covid2fetal':df_stat_covid2fetal, 'covid2tumor':df_stat_covid2tumor, 'covid':df_stat_covid,
                'tumor2covid':df_stat_tumor2covid, 'tumor':df_stat_tumor, 'tumor2fetal':df_stat_tumor2fetal,
                'fetal2adult':df_stat_fetal2adult,
                'fetal2covid':df_stat_fetal2covid, 'fetal2tumor':df_stat_fetal2tumor, 'fetal':df_stat_fetal,}
    with open(f'../outs/250429v2_cd4vscd8_logisticregression.{trim_len}aa_stripped.pkl', 'wb') as f:
        pkl.dump(df_stats, f)

#### Visualize the Results

In [None]:
# read in the stripped data, to compare with predictions of self
h2d = {}
with open('../outs/250429v2_cd4vscd8_logisticregression.4aa_stripped.pkl', 'rb') as f:
    h2d['Trimmed by 4AA\nOn Each Side'] = pkl.load(f)
with open('../outs/250429v2_cd4vscd8_logisticregression.3aa_stripped.pkl', 'rb') as f:
    h2d['Trimmed by 3AA\nOn Each Side'] = pkl.load(f)
with open('../outs/250429v2_cd4vscd8_logisticregression.2aa_stripped.pkl', 'rb') as f:
    h2d['Trimmed by 2AA\nOn Each Side'] = pkl.load(f)
with open('../outs/250429v2_cd4vscd8_logisticregression.1aa_stripped.pkl', 'rb') as f:
    h2d['Trimmed by 1AA\nOn Each Side'] = pkl.load(f)
# read in the non-stripped data, to compare with predictions of self
with open('../outs/250421v2_cd4vscd8_logisticregression.pkl', 'rb') as f:
    h2d['Untrimmed'] = pkl.load(f)

In [None]:
# define a plotting function
def visualize_on_globals():
    # assemble the plotting dataframe
    df_plot = pd.DataFrame(columns=['x','y','hue'])
    for method, df_stats in h2d.items():
        for k, vs in df_stats.items():
            if '2' in k: continue
            for v in vs[key]:
                df_plot.loc[df_plot.shape[0]] = method, v, k

    # create the ordered box plots
    fig, ax = plt.subplots(figsize=[5, 4]); ax.grid(False)
    order = list(h2d.keys())
    sns.boxplot(x='x', y='y', hue='hue', data=df_plot, linewidth=1.5, linecolor='dodgerblue',
                color='skyblue', saturation=1, showfliers=False, order=order)
    sns.stripplot(x='x', y='y', hue='hue', data=df_plot, linewidth=1.5, edgecolor='dodgerblue',
                  color='skyblue', jitter=0.25, order=order, s=5, alpha=0.5, dodge=True)
    ax.tick_params(axis='x', labelrotation=90)
    ax.axhline(0.5, color='grey', linestyle='--')
    ax.legend(title=None, bbox_to_anchor=(1, .5), loc='center left', bbox_transform=ax.transAxes, frameon=False)
    ax.set(xlabel='Datasets Utilized and Method', ylabel=key_label)

In [None]:
key = 'auroc'; key_label = 'AUROC'
visualize_on_globals()

In [None]:
key = 'auprc'; key_label = 'AUPRC'
visualize_on_globals()