In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
import scipy.stats as ss
import seaborn as sns
sc.settings.set_figure_params(dpi=100)
print(sc.__version__)

In [None]:
import blosum as bl
# perform encoding by direct, BCP, BLOSUM
vocab = ['A','C','D','E','F','G','H','I','K','L','M','N','P','Q','R','S','T','V','W','Y']
# direct encoding
map_direct = {x:[1 * (x == y) for y in vocab] for x in vocab}
# bcp encoding
aa_hydrophobicity = {
    'A': 1.8,  # Alanine
    'R': -4.5,  # Arginine
    'N': -3.5,  # Asparagine
    'D': -3.5,  # Aspartic Acid
    'C': 2.5,  # Cysteine
    'E': -3.5,  # Glutamic Acid
    'Q': -3.5,  # Glutamine
    'G': -0.4,  # Glycine
    'H': -3.2,  # Histidine
    'I': 4.5,  # Isoleucine
    'L': 3.8,  # Leucine
    'K': -3.9,  # Lysine
    'M': 1.9,  # Methionine
    'F': 2.8,  # Phenylalanine
    'P': -1.6,  # Proline
    'S': -0.8,  # Serine
    'T': -0.7,  # Threonine
    'W': -0.9,  # Tryptophan
    'Y': -1.3,  # Tyrosine
    'V': 4.2,  # Valine
}
# https://www.imgt.org/IMGTeducation/Aide-memoire/_UK/aminoacids/IMGTclasses.html
aa_volume = {
    'A': 88.6,   # Alanine
    'R': 173.4,  # Arginine
    'N': 114.1,  # Asparagine
    'D': 111.1,  # Aspartic Acid
    'C': 108.5,  # Cysteine
    'E': 138.4,  # Glutamic Acid
    'Q': 143.8,  # Glutamine
    'G': 60.1,   # Glycine
    'H': 153.2,  # Histidine
    'I': 166.7,  # Isoleucine
    'L': 166.7,  # Leucine
    'K': 168.6,  # Lysine
    'M': 162.9,  # Methionine
    'F': 189.9,  # Phenylalanine
    'P': 112.7,  # Proline
    'S': 89.0,   # Serine
    'T': 116.1,  # Threonine
    'W': 227.8,  # Tryptophan
    'Y': 193.6,  # Tyrosine
    'V': 140.0,  # Valine
}
# 1 = donor and acceptor, 0.5 = only donor or acceptor
aa_hbond = {
    'A': 0,    # Alanine
    'R': 0.5,  # Arginine
    'N': 1,    # Asparagine
    'D': 0.5,  # Aspartic Acid
    'C': 0,    # Cysteine
    'E': 0.5,  # Glutamic Acid
    'Q': 1,    # Glutamine
    'G': 0,    # Glycine
    'H': 1,    # Histidine
    'I': 0,    # Isoleucine
    'L': 0,    # Leucine
    'K': 0.5,  # Lysine
    'M': 0,    # Methionine
    'F': 0,    # Phenylalanine
    'P': 0,    # Proline
    'S': 1,    # Serine
    'T': 1,    # Threonine
    'W': 0.5,  # Tryptophan
    'Y': 1,    # Tyrosine
    'V': 0,    # Valine
}
has_sulfur = ['C','M']
is_aromatic = ['F','Y','W']
is_aliphatic = ['A','G','I','L','P','V']
is_basic = ['R','H','K']
is_acidic = ['D','E']
has_amide = ['N','Q']
vocab_bcp = ['hydrophobicity','volume','hbond','has_sulfur','is_aromatic',
             'is_aliphatic','is_basic','is_acidic','has_amide']
# > define a method to return the embedding for a given amino acid in BCP space
def bcp_translation(aa):
    embedding = []
    embedding.append(aa_hydrophobicity[aa])
    embedding.append(aa_volume[aa])
    embedding.append(aa_hbond[aa])
    embedding.append(1 * (aa in has_sulfur))
    embedding.append(1 * (aa in is_aromatic))
    embedding.append(1 * (aa in is_aliphatic))
    embedding.append(1 * (aa in is_basic))
    embedding.append(1 * (aa in is_acidic))
    embedding.append(1 * (aa in has_amide))
    return embedding
map_bcp = {x:bcp_translation(x) for x in vocab}
# blosum encoding
map_blosum = {x:[bl.BLOSUM(62)[x][y] for y in vocab] for x in vocab}

### Read in the Data

In [None]:
# read in the data
df = pd.read_csv('../outs/df.int.clean.csv', index_col=0)
a_trb = sc.read_h5ad('../outs/adata.trb.h5ad')

### kNN Model to Map Sampled Z Coordinates to UMAP

In [None]:
from sklearn.neighbors import KNeighborsRegressor
# rough mapping of UMAP via kNN
neigh = KNeighborsRegressor(n_neighbors=5)
neigh.fit(a_trb.X, a_trb.obsm['X_umap'])

In [None]:
from sklearn.metrics import root_mean_squared_error
# validate
np.random.seed(0)
idxs = np.random.choice(a_trb.obs.index, size=1000, replace=False)
fig, ax = plt.subplots(); ax.grid(False)
x, y = neigh.predict(a_trb[idxs].X)[:, 0], a_trb[idxs].obsm['X_umap'][:, 0]
ax.scatter(x, y, color='dodgerblue', s=5, alpha=0.25)
xlim, ylim = ax.get_xlim(), ax.get_ylim()
# model the lines
model = np.polynomial.Polynomial(0)
model = model.fit(x, y, 1)
xl, yl = model.linspace(domain=xlim)
ax.plot(xl, yl, color='k', linestyle='--')
ax.set_xlim(*xlim); ax.set_ylim(*ylim)
ax.set(xlabel='Predicted UMAP 1', ylabel='True UMAP 1')
ss.pearsonr(x, y), root_mean_squared_error(x, y)

In [None]:
# check the difference
ys = []
for seed in range(5):
    np.random.seed(seed)
    idxs = np.random.choice(a_trb.obs.index, size=1000, replace=False)
    x, y = neigh.predict(a_trb[idxs].X)[:, 0], a_trb[idxs].obsm['X_umap'][:, 0]
    ys.append(root_mean_squared_error(x, y))
fig, ax = plt.subplots(figsize=[1.5, 4]); ax.grid(False)
sns.boxplot(y=ys, linewidth=1.5, saturation=1, showfliers=False, linecolor='dodgerblue', color='skyblue')
sns.stripplot(y=ys, linewidth=1.5, s=6, alpha=0.5, color='skyblue', edgecolor='dodgerblue')
ax.set_xlim(-1, 1); ax.set_ylabel('RMSE UMAP 1')

In [None]:
# validate
np.random.seed(0)
idxs = np.random.choice(a_trb.obs.index, size=1000)
fig, ax = plt.subplots(); ax.grid(False)
x, y = neigh.predict(a_trb[idxs].X)[:, 1], a_trb[idxs].obsm['X_umap'][:, 1]
ax.scatter(x, y, color='dodgerblue', s=5, alpha=0.25)
xlim, ylim = ax.get_xlim(), ax.get_ylim()
# model the lines
model = np.polynomial.Polynomial(0)
model = model.fit(x, y, 1)
xl, yl = model.linspace(domain=xlim)
ax.plot(xl, yl, color='k', linestyle='--')
ax.set_xlim(*xlim); ax.set_ylim(*ylim)
ax.set(xlabel='Predicted UMAP 2', ylabel='True UMAP 2')
ss.pearsonr(x, y), root_mean_squared_error(x, y)

In [None]:
# check the difference
ys = []
for seed in range(5):
    np.random.seed(seed)
    idxs = np.random.choice(a_trb.obs.index, size=1000, replace=False)
    x, y = neigh.predict(a_trb[idxs].X)[:, 1], a_trb[idxs].obsm['X_umap'][:, 1]
    ys.append(root_mean_squared_error(x, y))
fig, ax = plt.subplots(figsize=[1.5, 4]); ax.grid(False)
sns.boxplot(y=ys, linewidth=1.5, saturation=1, showfliers=False, linecolor='dodgerblue', color='skyblue')
sns.stripplot(y=ys, linewidth=1.5, s=6, alpha=0.5, color='skyblue', edgecolor='dodgerblue')
ax.set_xlim(-1, 1); ax.set_ylabel('RMSE UMAP 2')

### AgFlow Model Based on Normalizing Flows

In [None]:
# https://paperswithcode.com/paper/density-estimation-using-real-nvp
# https://colab.research.google.com/github/senya-ashukha/real-nvp-pytorch/blob/master/real-nvp-pytorch.ipynb#scrollTo=nKXQrDNFZG8D

In [None]:
import torch
from torch import nn
from torch import distributions
from torch.nn.parameter import Parameter

In [None]:
# define the realNVP module from literature
class RealNVP(nn.Module):
    def __init__(self, nets, nett, mask, prior):
        super(RealNVP, self).__init__()
        
        self.prior = prior
        self.mask = nn.Parameter(mask, requires_grad=False)
        self.t = torch.nn.ModuleList([nett() for _ in range(len(masks))])
        self.s = torch.nn.ModuleList([nets() for _ in range(len(masks))])
        
    def g(self, z):
        x = z
        for i in range(len(self.t)):
            x_ = x * self.mask[i]
            s = self.s[i](x_) * (1 - self.mask[i])
            t = self.t[i](x_) * (1 - self.mask[i])
            x = x_ + (1 - self.mask[i]) * (x * torch.exp(s) + t)
        return x

    def f(self, x):
        log_det_J, z = x.new_zeros(x.shape[0]), x
        for i in reversed(range(len(self.t))):
            z_ = self.mask[i] * z
            s = self.s[i](z_) * (1 - self.mask[i])
            t = self.t[i](z_) * (1 - self.mask[i])
            z = (1 - self.mask[i]) * (z - t) * torch.exp(-s) + z_
            log_det_J -= s.sum(dim=1)
        return z, log_det_J
    
    def log_prob(self,x):
        z, logp = self.f(x)
        return self.prior.log_prob(z) + logp
        
    def sample(self, batchSize): 
        z = self.prior.sample((batchSize, 1))
        logp = self.prior.log_prob(z)
        x = self.g(z)
        return x

In [None]:
# 256 stands for the quantization I believe, the mask is a checkerboard type
nets = lambda: nn.Sequential(nn.Linear(32, 256), nn.LeakyReLU(), nn.Linear(256, 256), nn.LeakyReLU(), nn.Linear(256, 32), nn.Tanh()).to('cuda')
nett = lambda: nn.Sequential(nn.Linear(32, 256), nn.LeakyReLU(), nn.Linear(256, 256), nn.LeakyReLU(), nn.Linear(256, 32)).to('cuda')
masks = torch.from_numpy(np.array([[0, 1]*16, [1, 0]*16] * 4).astype(np.float32)).to('cuda')
prior = distributions.MultivariateNormal(torch.zeros(32).to('cuda'), torch.eye(32).to('cuda'))

In [None]:
# define optimizer
torch.manual_seed(0); np.random.seed(0)
flow = RealNVP(nets, nett, masks, prior)
optimizer = torch.optim.Adam([p for p in flow.parameters() if p.requires_grad==True], lr=1e-5)
# derive the mapping dataset
trbs = df.loc[df['AG'] == 'YLQPRTFLL', 'TRB']
trbs = trbs[trbs.isin(a_trb.obs.index)]
X = a_trb[trbs].X.astype(np.float32).copy()
# track the loss values
losses = []
for t in range(10001):
    # derive the loss
    loss = -flow.log_prob(torch.from_numpy(X).to('cuda')).mean()
    # optimize accordingly
    optimizer.zero_grad()
    loss.backward(retain_graph=True)
    optimizer.step()
    losses.append(loss.item())
    if t % 1000 == 0:
        print('iter %s:' % t, 'loss = %.3f' % loss)

In [None]:
# plot the loss curve
fig, ax = plt.subplots(); ax.grid(False)
ax.plot(losses, color='dodgerblue')
ax.set(xlabel='Epochs', ylabel='Loss')

In [None]:
# sample from the dataset
torch.manual_seed(0); np.random.seed(0)
x = flow.sample(1000).clone().cpu().detach().numpy()
umap = neigh.predict(x[:, 0, :])

In [None]:
# derive the predicted
fig, ax = plt.subplots(); ax.grid(False)
ax.scatter(a_trb.obsm['X_umap'][:, 0], a_trb.obsm['X_umap'][:, 1], s=0.0001, color='lightgray')
z = ss.gaussian_kde(umap.T)(umap.T); idxs = np.argsort(z)
ax.scatter(umap[idxs, 0], umap[idxs, 1], c=z[idxs], alpha=0.5, s=4, cmap='Blues')
ax.set_title(r'$\hat{x} = g(z)$')

# derive the ground truth
# compare with the known
fig, ax = plt.subplots(); ax.grid(False)
ax.scatter(a_trb.obsm['X_umap'][:, 0], a_trb.obsm['X_umap'][:, 1], s=0.0001, color='lightgray')
z = ss.gaussian_kde(a_trb[trbs].obsm['X_umap'].T)(a_trb[trbs].obsm['X_umap'].T); idxs = np.argsort(z)
ax.scatter(a_trb[trbs].obsm['X_umap'][idxs, 0], a_trb[trbs].obsm['X_umap'][idxs, 1], c=z[idxs], alpha=0.5, s=4, cmap='Blues')
ax.set_title(r'$x$')

### Statistically Examine AgFlow

In [None]:
import statsmodels as sm
# look at the range of pvalues
pvals = []
for idx in range(a_trb.shape[1]):
    p = ss.ks_2samp(x[:, 0, idx], a_trb[trbs].X[:, idx], alternative='two-sided')[1]
    pvals.append(p)
pvals_truth = pvals
# compare against a random subset
np.random.seed(0)
trbs_rand = np.random.choice(a_trb.obs.index, size=len(trbs) * 3, replace=True)
pvals = []
for idx in range(a_trb.shape[1]):
    p = ss.ks_2samp(x[:, 0, idx], a_trb[trbs_rand].X[:, idx], alternative='two-sided')[1]
    pvals.append(p)
pvals_rand = pvals

In [None]:
# compare the p-value for the Kolmogorov-Smirnov
xs = ['Truth']*len(pvals_truth)+['Rand']*len(pvals_rand)
ys = -np.log10(pvals_truth + pvals_rand)
fig, ax = plt.subplots(figsize=[2, 4]); ax.grid(False)
sns.boxplot(x=xs, y=ys, linewidth=1.5, saturation=1, showfliers=False, linecolor='dodgerblue', color='skyblue',
            order=['Rand','Truth'], palette=['lightgray','skyblue'])
np.random.seed(0)
sns.stripplot(x=xs, y=ys, jitter=0.4, palette=['dodgerblue'], order=['Truth'], alpha=0.4, s=4)
sns.stripplot(x=xs, y=ys, jitter=0.4, palette=['grey'], order=['Rand'], alpha=0.4, s=4)
ax.set_xlim(-1, 2); ax.set_ylabel('-log$_{10}$(p-value)\nKS test for CDF difference')
ax.get_children()[0].set_hatch('//')
ax.get_children()[0].set_edgecolor('grey')
for idx in range(1, 6):
    ax.get_children()[idx].set_color('grey')
ax.tick_params(axis='x', labelrotation=90)
ax.set_xticklabels(['Model\nvs. Rand','Model\nvs. Truth'])
ss.mannwhitneyu(pvals_truth, pvals_rand)

In [None]:
# compare with FDR
xs, ys, zs = [], [], []
xticklabels_mw, xticklabels_ks = [], []
for idx in range(a_trb.shape[1]):
    ys1, ys2, ys3 = x[:, 0, idx].tolist(), a_trb[trbs].X[:, idx].tolist(), a_trb[trbs_rand].X[:, idx].tolist()
    ys_iter = ys1 + ys2 + ys3; ys += ys_iter
    xs_iter = ['Model']*len(ys1) + ['Truth']*len(ys2) + ['Rand']*len(ys3); xs += xs_iter
    zs += [idx] * len(xs_iter)
    
# create the plot with mann-whitneyu
fig, ax = plt.subplots(figsize=[20, 4]); ax.grid(False)
sns.boxplot(x=zs, y=ys, hue=xs, linewidth=1.5, saturation=1, showfliers=False, linecolor='dodgerblue', color='skyblue',
            hue_order=['Rand','Truth','Model'], palette=['lightgray','lightsteelblue','skyblue'])
ax.set_xlim(-1, a_trb.shape[1])
ax.set_xticklabels([int(float(x.get_text()))+1 for x in ax.get_xticklabels()])
ax.set_yticks(np.arange(-3, 3+1, 1))
ax.set_yticklabels([x.get_text() for x in ax.get_yticklabels()], rotation=90)
ax.set_xlabel('Tarpon Latent Dimension', rotation=180)
ax.set_ylabel('Latent Dimension Value\n(Arbitary Units)')
shift = 5
for idx in range(a_trb.shape[1]):
    ax.get_children()[idx+idx*shift].set_hatch('//')
    ax.get_children()[idx+idx*shift].set_edgecolor('grey')
    for idx in range(idx+idx*shift+1, idx+idx*shift+6):
        ax.get_children()[idx].set_color('grey')
shift_ = 6*a_trb.shape[1]
for idx in range(a_trb.shape[1]):
    ax.get_children()[shift_+idx+idx*shift].set_hatch('/')
    ax.get_children()[shift_+idx+idx*shift].set_edgecolor('navy')
    for idx in range(shift_+idx+idx*shift+1, shift_+idx+idx*shift+6):
        ax.get_children()[idx].set_color('navy')
ax.tick_params(axis='x', labelrotation=90)
ax.legend('', frameon=False)

### Convert Z Coordinates to Sequence to Compare Overlap in Repertoires

In [None]:
# read back in the original TRB solver
# define the key parameters
init_embed_size = 50-1
protein_len = 48
init_kernel_size = 3
init_cnn_filters = 256
init_kernel_stride = 1
init_kernel_padding = 1
secn_cnn_filters = 256
latent_dim = 32
vocab = ['A','C','D','E','F','G','H','I','K','L','M','N','P','Q','R','S','T','V','W','Y']
# we want the embedding output to be the vocab with the length to allow for reconstruction
out_embed_size = len(vocab)
n_nodes_len = 32

# define the convolutional variational autoencoder
class ConvVAE(nn.Module):
    def __init__(self):
        super(ConvVAE, self).__init__()

        # encoding
        self.fc1 = nn.Conv1d(
            in_channels=init_embed_size, out_channels=init_cnn_filters, kernel_size=init_kernel_size, 
            stride=init_kernel_stride, padding=init_kernel_padding,
        )
        self.fc2 = nn.Conv1d(
            in_channels=init_cnn_filters, out_channels=secn_cnn_filters, kernel_size=init_kernel_size, 
            stride=init_kernel_stride, padding=init_kernel_padding
        )
        # variational sampling
        self.fc31 = nn.Linear(secn_cnn_filters*protein_len, latent_dim)
        self.fc32 = nn.Linear(secn_cnn_filters*protein_len, latent_dim)
        self.fc4 = nn.Linear(latent_dim, secn_cnn_filters*protein_len)
        # decoding
        self.fc5 = nn.ConvTranspose1d(
            in_channels=secn_cnn_filters, out_channels=init_cnn_filters, kernel_size=init_kernel_size, 
            stride=init_kernel_stride, padding=init_kernel_padding
        )
        self.fc6 = nn.ConvTranspose1d(
            in_channels=init_cnn_filters, out_channels=out_embed_size, kernel_size=init_kernel_size, 
            stride=init_kernel_stride, padding=init_kernel_padding
        )
        self.fc7 = nn.Linear(init_cnn_filters*protein_len, n_nodes_len)
        self.fc8 = nn.Linear(n_nodes_len, 1)

    def encode(self, x):
        x1 = nn.LeakyReLU()(self.fc1(x[:, :-1, :]))
        x2 = nn.LeakyReLU()(self.fc2(x1))
        x2_ = nn.Flatten()(x2)
        return self.fc31(x2_), self.fc32(x2_)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        x4 = nn.LeakyReLU()(self.fc4(z))
        x4_ = x4.view(-1, secn_cnn_filters, protein_len)
        x5 = nn.LeakyReLU()(self.fc5(x4_))
        x5_ = nn.Flatten()(x5)
        x6 = nn.Sigmoid()(self.fc6(x5))
        return x6, self.fc8(nn.LeakyReLU()(self.fc7(x5_)))

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar
# load the model
model = ConvVAE().to('cuda')
model.load_state_dict(torch.load('../models/model.convvae.trb.torch', weights_only=True))
model.eval()

In [None]:
# define an engine to derive sequence
def engine(model, z):
    # derive the embedding and length
    tmp_out, tmp_len = model.decode(z)
    tmp_out = tmp_out.clone().detach().cpu().numpy()
    tmp_len = tmp_len.clone().detach().cpu().numpy()
    tmp_out = tmp_out[0]; tmp_len = round(tmp_len[0][0])
    
    # interpolate back to a sequence using the length
    curr_len = 48; targ_len = tmp_len
    # compute the x-coordinates of the original
    xp = np.arange(curr_len) / (curr_len - 1)
    x = np.arange(targ_len) / (targ_len - 1)
    # interpolate the results
    res = np.array([np.interp(x, xp, tmp_out[idx, :]) for idx in range(tmp_out.shape[0])])

    # derive the sequence
    data = pd.DataFrame(res.T, columns=vocab).T
    return ''.join(data.idxmax(0))

In [None]:
from tqdm import tqdm
# attempt to generate some peptides
seqs = []
for idx in tqdm(range(x.shape[0]), total=x.shape[0]):
    seqs.append(engine(model, torch.from_numpy(x[idx, 0, :].astype(np.float32)).to('cuda')))
seqs = pd.Series(seqs)

# hypergeometric where you have M animals, n are dogs, then you choose a random N from the M animals how many are dogs
# in our scenario this is where we have the entire TCRbeta-ome supplemented with generated samples
M  = a_trb.shape[0] + len(seqs[~seqs.isin(a_trb.obs.index)].unique())
n = len(trbs.unique())
N = len(seqs.unique())
k = len(set(trbs) & set(seqs))
# define the distribution
fig, ax = plt.subplots(figsize=[4, 2]); ax.grid(False)
xl = np.arange(0, 26, 1)
yl = ss.hypergeom(M, n, N).pmf(xl)
ax.plot(xl, yl*100, color='grey', linestyle='--')
ax.axvline(k, color='dodgerblue', lw=2)
ax.set(xlabel='# of Overlapping Unique TCRs', ylabel='Probability', title='Peptide-Specific TCRs')
# define the p-value
print(1 - ss.hypergeom(M, n, N).cdf(k))

# repeat for the random TCRs
M  = a_trb.shape[0] + len(seqs[~seqs.isin(a_trb.obs.index)].unique())
n = len(pd.Series(trbs_rand).unique())
N = len(seqs.unique())
k = len(set(trbs_rand) & set(seqs))
# define the distribution
fig, ax = plt.subplots(figsize=[4, 2]); ax.grid(False)
xl = np.arange(0, 26, 1)
yl = ss.hypergeom(M, n, N).pmf(xl)
ax.plot(xl, yl*100, color='grey', linestyle='--')
ax.axvline(k, color='k', lw=2)
ax.set(xlabel='# of Overlapping Unique TCRs', ylabel='Probability', title='Random TCRs')
# define the p-value
print(1 - ss.hypergeom(M, n, N).cdf(k))

In [None]:
# look at the percentage reported
fig, ax = plt.subplots(figsize=[8, 4]); ax.grid(False)
counts = seqs.value_counts()
ax.bar(counts.index[:25], counts[:25], edgecolor='dodgerblue', color='skyblue', lw=1.5)
ax.tick_params(axis='x', labelrotation=90)
ax.set(xlabel='Generated TCRs (Top 25)', ylabel='# of TCRs Generated')

### Retrain AgFlow Without CASSPDIEAFF

In [None]:
# define optimizer
torch.manual_seed(0); np.random.seed(0)
flow = RealNVP(nets, nett, masks, prior)
optimizer = torch.optim.Adam([p for p in flow.parameters() if p.requires_grad==True], lr=1e-5)
# derive the mapping dataset
trbs = df.loc[df['AG'] == 'YLQPRTFLL', 'TRB']
trbs = trbs[trbs.isin(a_trb.obs.index)]
# EXCLUSION
trbs = trbs[trbs != 'CASSPDIEAFF']

X = a_trb[trbs].X.astype(np.float32).copy()
# track the loss values
losses = []
for t in range(10001):
    # derive the loss
    loss = -flow.log_prob(torch.from_numpy(X).to('cuda')).mean()
    # optimize accordingly
    optimizer.zero_grad()
    loss.backward(retain_graph=True)
    optimizer.step()
    losses.append(loss.item())
    if t % 1000 == 0:
        print('iter %s:' % t, 'loss = %.3f' % loss)

In [None]:
# plot the loss curve
fig, ax = plt.subplots(); ax.grid(False)
ax.plot(losses, color='dodgerblue')
ax.set(xlabel='Epochs', ylabel='Loss')

In [None]:
# sample from the dataset and confirm similarity with previous
torch.manual_seed(0); np.random.seed(0)
x = flow.sample(1000).clone().cpu().detach().numpy()
umap = neigh.predict(x[:, 0, :])

# derive the predicted
fig, ax = plt.subplots(); ax.grid(False)
ax.scatter(a_trb.obsm['X_umap'][:, 0], a_trb.obsm['X_umap'][:, 1], s=0.0001, color='lightgray')
z = ss.gaussian_kde(umap.T)(umap.T); idxs = np.argsort(z)
ax.scatter(umap[idxs, 0], umap[idxs, 1], c=z[idxs], alpha=0.5, s=4, cmap='Blues')
ax.set_title(r'$\hat{x} = g(z)$')

# derive the ground truth
# compare with the known
fig, ax = plt.subplots(); ax.grid(False)
ax.scatter(a_trb.obsm['X_umap'][:, 0], a_trb.obsm['X_umap'][:, 1], s=0.0001, color='lightgray')
z = ss.gaussian_kde(a_trb[trbs].obsm['X_umap'].T)(a_trb[trbs].obsm['X_umap'].T); idxs = np.argsort(z)
ax.scatter(a_trb[trbs].obsm['X_umap'][idxs, 0], a_trb[trbs].obsm['X_umap'][idxs, 1], c=z[idxs], alpha=0.5, s=4, cmap='Blues')
ax.set_title(r'$x$')

In [None]:
from tqdm import tqdm
# attempt to generate some peptides
seqs = []
for idx in tqdm(range(x.shape[0]), total=x.shape[0]):
    seqs.append(engine(model, torch.from_numpy(x[idx, 0, :].astype(np.float32)).to('cuda')))
seqs = pd.Series(seqs)
# look at the percentage reported
fig, ax = plt.subplots(figsize=[8, 4]); ax.grid(False)
counts = seqs.value_counts()
ax.bar(counts.index[:25], counts[:25], edgecolor='dodgerblue', color='skyblue', lw=1.5)
ax.tick_params(axis='x', labelrotation=90)
ax.set(xlabel='Generated TCRs (Top 25)', ylabel='# of TCRs Generated')

### Repeat Exercise Robustly via Multiple Random Seeds

In [None]:
# derive the mapping dataset
trbs = df.loc[df['AG'] == 'YLQPRTFLL', 'TRB']
trbs = trbs[trbs.isin(a_trb.obs.index)]
trbs = trbs[trbs != 'CASSPDIEAFF']
len(trbs)

In [None]:
# retrieve the stats
df_stat = pd.DataFrame(columns=['seed1','seed2','size','pos_per_mil'])
for size in list(range(100, 700+100, 100))+[797]:
    for seed in range(5):
        print('-', end='')
        # define optimizer
        torch.manual_seed(seed); np.random.seed(seed)
        flow = RealNVP(nets, nett, masks, prior)
        optimizer = torch.optim.Adam([p for p in flow.parameters() if p.requires_grad==True], lr=1e-5)
        # derive the mapping dataset
        trbs = df.loc[df['AG'] == 'YLQPRTFLL', 'TRB']
        trbs = trbs[trbs.isin(a_trb.obs.index)]
        trbs = trbs[trbs != 'CASSPDIEAFF']

        trbs = np.random.choice(trbs, size=size, replace=False)

        X = a_trb[trbs].X.astype(np.float32).copy()
        # track the loss values
        losses = []
        for t in range(10001):
            # derive the loss
            loss = -flow.log_prob(torch.from_numpy(X).to('cuda')).mean()
            # optimize accordingly
            optimizer.zero_grad()
            loss.backward(retain_graph=True)
            optimizer.step()
            losses.append(loss.item())
        print('.', end='')
                                
        # sample from the dataset
        for seed2 in range(5):
            torch.manual_seed(seed2); np.random.seed(seed2)
            x = flow.sample(1000).clone().cpu().detach().numpy()
            # attempt to generate some peptides
            seqs = []
            for idx in range(x.shape[0]):
                seqs.append(engine(model, torch.from_numpy(x[idx, 0, :].astype(np.float32)).to('cuda')))
            seqs = pd.Series(seqs)
            # look at the percentage reported
            df_stat.loc[df_stat.shape[0]] = seed, seed2, size, sum(seqs == 'CASSPDIEAFF')
        print('>', end='')
    print()

In [None]:
# save the results
df_stat.to_csv('../outs/250420_AgFlow_YLQdetections.positive.csv')
df_stat = pd.read_csv('../outs/250420_AgFlow_YLQdetections.positive.csv', index_col=0)

In [None]:
# examine the distribution limiting to seed to seed
mask = (df_stat['seed1'] == df_stat['seed2']) & (df_stat['size'] != 797)
fig, ax = plt.subplots(figsize=[4, 4]); ax.grid(False)
sns.barplot(x='size', y='pos_per_mil', data=df_stat[mask], ax=ax,
            saturation=1, errcolor='dodgerblue', errwidth=1.5, capsize=0.3,
            edgecolor='dodgerblue', linewidth=1.5, color='skyblue')
mean = df_stat.loc[(df_stat['seed1'] == df_stat['seed2']) & (df_stat['size'] == 797)]['pos_per_mil'].mean()
ci95 = df_stat.loc[(df_stat['seed1'] == df_stat['seed2']) & (df_stat['size'] == 797)]['pos_per_mil'].std() / np.sqrt(5) * 1.96
ax.axhline(mean, color='grey', lw=2, linestyle='--', zorder=0,
           label='Mean Recovery When\nTrained on All (N=797)\nNon-CASSPDIEAFF TCRs')
ax.set_xlim(-1, 7)
xmin, xmax = ax.get_xlim()
ax.fill([xmin, xmax, xmax, xmin, xmin],
        [mean-ci95, mean-ci95, mean+ci95, mean+ci95, mean-ci95],
        color='k', lw=2, linestyle='--', alpha=0.25, zorder=0, 
        label='95% Confidence Interval When\nTrained on All (N=797)\nNon-CASSPDIEAFF TCRs')
ax.tick_params(axis='x', labelrotation=90)
ax.set(xlabel='# of Non-CASSPDIEAFF\nYLQ-specific TCRs for Training',
       ylabel='# of Embeddings\nMapped to CASSPDIEAFF')
ax.set_xlim(-0.75, 6.75)
ax.legend(bbox_to_anchor=(1, .5), bbox_transform=ax.transAxes, loc='center left', frameon=False)

In [None]:
# retrieve the stats for random background
df_stat = pd.DataFrame(columns=['seed1','seed2','size','pos_per_mil'])
for size in list(range(100, 700+100, 100))+[797]:
    for seed in range(5):
        print('-', end='')
        # define optimizer
        torch.manual_seed(seed); np.random.seed(seed)
        flow = RealNVP(nets, nett, masks, prior)
        optimizer = torch.optim.Adam([p for p in flow.parameters() if p.requires_grad==True], lr=1e-5)
        # derive the mapping dataset
        trbs = np.random.choice(a_trb.obs.index, size=size, replace=False)

        X = a_trb[trbs].X.astype(np.float32).copy()
        # track the loss values
        losses = []
        for t in range(10001):
            # derive the loss
            loss = -flow.log_prob(torch.from_numpy(X).to('cuda')).mean()
            # optimize accordingly
            optimizer.zero_grad()
            loss.backward(retain_graph=True)
            optimizer.step()
            losses.append(loss.item())
        print('.', end='')
                                
        # sample from the dataset
        for seed2 in range(5):
            torch.manual_seed(seed2); np.random.seed(seed2)
            x = flow.sample(1000).clone().cpu().detach().numpy()
            # attempt to generate some peptides
            seqs = []
            for idx in range(x.shape[0]):
                seqs.append(engine(model, torch.from_numpy(x[idx, 0, :].astype(np.float32)).to('cuda')))
            seqs = pd.Series(seqs)
            # look at the percentage reported
            df_stat.loc[df_stat.shape[0]] = seed, seed2, size, sum(seqs == 'CASSPDIEAFF')
        print('>', end='')
    print()

In [None]:
# save the results
df_stat.to_csv('../outs/250420_AgFlow_YLQdetections.random.csv')
df_stat = pd.read_csv('../outs/250420_AgFlow_YLQdetections.random.csv', index_col=0)

In [None]:
# examine the distribution limiting to seed to seed
mask = (df_stat['seed1'] == df_stat['seed2']) & (df_stat['size'] != 797)
fig, ax = plt.subplots(figsize=[4, 4]); ax.grid(False)
sns.barplot(x='size', y='pos_per_mil', data=df_stat[mask], ax=ax,
            saturation=1, errcolor='grey', errwidth=1.5, capsize=0.3,
            edgecolor='grey', linewidth=1.5, color='lightgray')
mean = df_stat.loc[(df_stat['seed1'] == df_stat['seed2']) & (df_stat['size'] == 797)]['pos_per_mil'].mean()
ci95 = df_stat.loc[(df_stat['seed1'] == df_stat['seed2']) & (df_stat['size'] == 797)]['pos_per_mil'].std() / np.sqrt(5) * 1.96
ax.axhline(mean, color='grey', lw=2, linestyle='--', zorder=0,
           label='Mean Recovery When\nTrained on All (N=797)\nRandom TCRs')
ax.set_xlim(-1, 7)
xmin, xmax = ax.get_xlim()
ax.fill([xmin, xmax, xmax, xmin, xmin],
        [mean-ci95, mean-ci95, mean+ci95, mean+ci95, mean-ci95],
        color='k', lw=2, linestyle='--', alpha=0.25, zorder=0, 
        label='95% Confidence Interval When\nTrained on All (N=797)\nRandom TCRs')
ax.tick_params(axis='x', labelrotation=90)
ax.set(xlabel='# of Random TCRs for Training',
       ylabel='# of Embeddings\nMapped to CASSPDIEAFF')
ax.set_xlim(-0.75, 6.75)
ax.legend(bbox_to_anchor=(1, .5), bbox_transform=ax.transAxes, loc='center left', frameon=False)

### Train AgFlow on Multiple Epitopes

In [None]:
# read in the data
df = pd.read_csv('../outs/df.int.clean.csv', index_col=0)
# walk through each pack (epitope)
packs = []
epitopes = ['YLQPRTFLL','NLVPMVATV','TPRVTGGGAM','GILGFVFTL','GLCTLVAML','YVLDHLIVV',
            'ELAGIGILTV','EAAGIGILTV','SLLMWITQC','KLGGALQAK','AVFDRKSDAK','RAKFKQLL',
            'IVTDFSVIK','LLWNGPMAV','SPRWYFYYL','TTDPSFLGRY','RLRAEAQVK','LLLDRLNQL',
            'LTDEMIAQY','CINGVCWTV','KTFPPTEPK','QYIKWPWYI','VMTTVLATL','DATYQRTRALVR','NQKLIANQF','FLCMKALLL']
for epitope in tqdm(epitopes):
    # define optimizer
    torch.manual_seed(0); np.random.seed(0)
    flow = RealNVP(nets, nett, masks, prior)
    optimizer = torch.optim.Adam([p for p in flow.parameters() if p.requires_grad==True], lr=1e-5)
    # derive the mapping dataset
    trbs = df.loc[df['AG'] == epitope, 'TRB']
    trbs = trbs[trbs.isin(a_trb.obs.index)]
    X = a_trb[trbs].X.astype(np.float32).copy()
    # track the loss values
    losses = []
    for t in range(10001):
        # derive the loss
        loss = -flow.log_prob(torch.from_numpy(X).to('cuda')).mean()
        # optimize accordingly
        optimizer.zero_grad()
        loss.backward(retain_graph=True)
        optimizer.step()
        losses.append(loss.item())
    # sample from the dataset
    torch.manual_seed(0); np.random.seed(0)
    x = flow.sample(1000).clone().cpu().detach().numpy()
    umap = neigh.predict(x[:, 0, :])
    
    # look at the range of pvalues
    pvals = []
    for idx in range(a_trb.shape[1]):
        p = ss.ks_2samp(x[:, 0, idx], a_trb[trbs].X[:, idx], alternative='two-sided')[1]
        pvals.append(p)
    pvals_truth = pvals
    # compare against a random subset
    np.random.seed(0)
    trbs_rand = np.random.choice(a_trb.obs.index, size=len(trbs) * 3, replace=True)
    pvals = []
    for idx in range(a_trb.shape[1]):
        p = ss.ks_2samp(x[:, 0, idx], a_trb[trbs_rand].X[:, idx], alternative='two-sided')[1]
        pvals.append(p)
    pvals_rand = pvals
    pack1 = np.mean(-np.log10(pvals_truth)), np.mean(-np.log10(pvals_rand)), ss.mannwhitneyu(pvals_truth, pvals_rand)[1]

    # attempt to generate some peptides
    seqs = []
    for idx in range(x.shape[0]):
        seqs.append(engine(model, torch.from_numpy(x[idx, 0, :].astype(np.float32)).to('cuda')))
    seqs = pd.Series(seqs)

    # hypergeometric where you have M animals, n are dogs, then you choose a random N from the M animals how many are dogs
    # in our scenario this is where we have the entire TCRbeta-ome supplemented with generated samples
    M  = a_trb.shape[0] + len(seqs[~seqs.isin(a_trb.obs.index)].unique())
    n = len(trbs.unique())
    N = len(seqs.unique())
    k = len(set(trbs) & set(seqs))
    # define the p-value
    p1 = 1 - ss.hypergeom(M, n, N).cdf(k)

    # repeat for the random TCRs
    M  = a_trb.shape[0] + len(seqs[~seqs.isin(a_trb.obs.index)].unique())
    n = len(pd.Series(trbs_rand).unique())
    N = len(seqs.unique())
    k = len(set(trbs_rand) & set(seqs))
    # define the p-value
    p2 = 1 - ss.hypergeom(M, n, N).cdf(k)
    pack2 = p1, p2
    pack = pd.Series(pack1 + pack2, index=['mean_nlog10p_vstrue','mean_nlog10p_vsrand','pval_nlog10p','pval_vstruth','pval_vsrand'])
    pack.name = epitope
    print(epitope)
    packs.append(pack)

In [None]:
# save all results for retrieval later on
pack = pd.concat(packs, axis=1)
pack.to_csv('../outs/NVP.stats.csv')
# derive the melted version
pack_melt = pack.loc[['pval_vstruth','pval_vsrand']].T.reset_index().melt(id_vars='index')
pack_melt['value_melt'] = -np.log10(pack_melt['value']+1e-20)

In [None]:
# visualize across epitopes
fig, ax = plt.subplots(figsize=[16, 4]); ax.grid(False)
sns.barplot(x='index', y='value_melt', hue='variable', data=pack_melt, saturation=1,
            palette=['skyblue','lightgray'], edgecolor='dodgerblue', linewidth=1.5)
ax.tick_params(axis='x', labelrotation=90)
ax.legend(bbox_to_anchor=(1, .5), bbox_transform=ax.transAxes, frameon=False, title=None, loc='center left')
ax.set(ylabel='-log$_{10}$(p-value) for the\nProbability of Overlap')
ax.get_children()[-2].get_children()[0].get_children()[1].get_children()[0].get_children()[1]\
.get_children()[0].get_children()[0].set_edgecolor('grey')
ax.get_children()[-2].get_children()[0].get_children()[1].get_children()[0].get_children()[0]\
.get_children()[1].get_children()[0].set_text('Model vs.\nTruth')
ax.get_children()[-2].get_children()[0].get_children()[1].get_children()[0].get_children()[1]\
.get_children()[1].get_children()[0].set_text('Model vs.\nRandom')
for rect in ax.get_children()[52:78]:
    rect.set_edgecolor('grey')
ax.set_xlim(-1, len(epitopes))
ax.set(xlabel=None)

In [None]:
# plot the p-values
fig, ax = plt.subplots(figsize=[8, 4]); ax.grid(False)
ax.bar(pack.columns, -np.log10(pack.loc['pval_nlog10p']),
       lw=1.5, edgecolor='dodgerblue', color='skyblue')
ax.tick_params(axis='x', labelrotation=90)
ax.set(ylabel='-log$_{10}$(p-value) for KS \nModel vs. Random\nand Model vs. Truth')

In [None]:
# derive the ground truth for GILGFVFTL
np.random.seed(0)
trbs = np.random.choice(a_trb.obs.index, size=3000, replace=False)
# compare with the known
fig, ax = plt.subplots(); ax.grid(False)
ax.scatter(a_trb.obsm['X_umap'][:, 0], a_trb.obsm['X_umap'][:, 1], s=0.0001, color='lightgray')
z = ss.gaussian_kde(a_trb[trbs].obsm['X_umap'].T)(a_trb[trbs].obsm['X_umap'].T); idxs = np.argsort(z)
ax.scatter(a_trb[trbs].obsm['X_umap'][idxs, 0], a_trb[trbs].obsm['X_umap'][idxs, 1], c=z[idxs], alpha=0.5, s=4, cmap='Blues')
ax.set_title(r'$x$')

### Compare AgFlow Generated TCRs with Experimentally Validated TCRs

In [None]:
# read in the data
df = pd.read_csv('../outs/df.int.clean.csv', index_col=0)

# 256 stands for the quantization I believe, the mask is a checkerboard type
nets = lambda: nn.Sequential(nn.Linear(32, 256), nn.LeakyReLU(), nn.Linear(256, 256), nn.LeakyReLU(), nn.Linear(256, 32), nn.Tanh()).to('cuda')
nett = lambda: nn.Sequential(nn.Linear(32, 256), nn.LeakyReLU(), nn.Linear(256, 256), nn.LeakyReLU(), nn.Linear(256, 32)).to('cuda')
masks = torch.from_numpy(np.array([[0, 1]*16, [1, 0]*16] * 4).astype(np.float32)).to('cuda')
prior = distributions.MultivariateNormal(torch.zeros(32).to('cuda'), torch.eye(32).to('cuda'))

# define optimizer
torch.manual_seed(0); np.random.seed(0)
flow = RealNVP(nets, nett, masks, prior)
optimizer = torch.optim.Adam([p for p in flow.parameters() if p.requires_grad==True], lr=1e-5)
# derive the mapping dataset
trbs = df.loc[df['AG'] == 'YLQPRTFLL', 'TRB']
trbs = trbs[trbs.isin(a_trb.obs.index)]
X = a_trb[trbs].X.astype(np.float32).copy()
# track the loss values
losses = []
for t in range(10001):
    # derive the loss
    loss = -flow.log_prob(torch.from_numpy(X).to('cuda')).mean()
    # optimize accordingly
    optimizer.zero_grad()
    loss.backward(retain_graph=True)
    optimizer.step()
    losses.append(loss.item())
    if t % 1000 == 0:
        print('iter %s:' % t, 'loss = %.3f' % loss)

In [None]:
# we have alr settled on a stretch length
targ_len = 48

# define a function to embed an amino acid with direct, bcp, blosum, and length
def embed_aa(aa):
    embed = [x for x in map_direct[aa]]
    embed += map_bcp[aa]
    embed += map_blosum[aa]
    embed += [0]
    return embed

# define a function to interpolate the protein
def stretch_pep(embedding, targ_len=targ_len):
    # get the current protein length
    orig_len, n_features = embedding.shape
    # derive the original and current lengths
    x = np.linspace(0, 1, targ_len)
    xp = np.linspace(0, 1, orig_len)
    # loop through each of the columns
    tensor = torch.Tensor(embedding.T.reshape(1, n_features, orig_len))
    res = torch.nn.functional.interpolate(tensor, size=(targ_len), mode='linear', align_corners=False)[0]
    # add an the extra length information
    res[-1, :] = orig_len
    return res

In [None]:
# read in experimentally validated data
df = pd.read_csv('../external_data/EXTERNAL_TONTCRVDB/01_05_2025_TCRvdb.csv', index_col=0)
df_ylq = df.loc[df['epitope_aa'] == 'YLQPRTFLL']
# read in the training data from public databases
df_clean = pd.read_csv('../outs/df.int.clean.csv', index_col=0)
df_clean_ylq = df_clean.loc[df_clean['AG'] == 'YLQPRTFLL'].copy()

# process the TRBs
trb_to_embed = {}
trbs = pd.Series(df_ylq['cdr3_beta_aa'].unique())
for sequence in tqdm(trbs):
    # retrieve the embedding
    embedding = np.array([embed for embed in map(embed_aa, list(sequence))])
    # stretch the embedding
    embedding = stretch_pep(embedding, targ_len=targ_len)
    # save the embedding
    trb_to_embed[sequence] = embedding

# embed all of our unique TRBs
X_trbs = torch.stack([x.to(torch.float32) for x in trbs.map(trb_to_embed)])
# create a latent space for the trbs
loader = DataLoader(dataset=TensorDataset(X_trbs), batch_size=batch_size, shuffle=False)

# get the encoded dimensions
torch.manual_seed(0); np.random.seed(0)
# move through each subset in the complete loader
z_dims_per_batch = []
with torch.no_grad():
    for data in tqdm(loader):
        data = data[0].to(device)
        enc_out = model.encode(data)
        # sampling centers around the mean so we just use mu
        z_dims_per_batch.append(enc_out[0].clone().detach().cpu().numpy())
z_dims = np.vstack(z_dims_per_batch)
del data, enc_out, z_dims_per_batch

# get prediction from tarpon ylq to the values
trbs = df_clean_ylq['TRB'].copy()
trbs = trbs[trbs.isin(a_trb.obs.index)]
# get the mapping
tarpon_ylq = pd.DataFrame(z_dims, index=trbs)

In [None]:
from sklearn.model_selection import ShuffleSplit
# split the data
mapping = df_ylq.groupby('cdr3_beta_aa').mean(numeric_only=True)['log2FoldChange']
y = mapping.dropna().copy()
X = tarpon_ylq.loc[y.index]
skf = ShuffleSplit(n_splits=10, random_state=0, test_size=1/4)
rmses, prhos, srhos, coefs, preds, pred_tests, true_tests = [], [], [], [], [], [], []
for idxs_train, idxs_test in skf.split(X, y):
    # instantiate the linear regression model
    clf = LinearRegression()
    clf = clf.fit(X.iloc[idxs_train], y.iloc[idxs_train])
    # derive the probabilities
    pred = clf.predict(X.iloc[idxs_test])
    true = y.iloc[idxs_test]
    pred_tests.append(pd.Series(pred, index=X.index[idxs_test]))
    true_tests.append(pd.Series(true, index=X.index[idxs_test]))
    rmses.append(np.sqrt(np.mean((pred - true) ** 2)))
    prhos.append(ss.pearsonr(pred, true)[0])
    srhos.append(ss.spearmanr(pred, true)[0])
    preds.append(clf.predict(aylq.X))
    coefs.append(clf.coef_)
# evaluate models
np.random.seed(0)
fig, axs = plt.subplots(1, 3, figsize=[5, 3])
for ax in axs: ax.grid(False)
sns.boxplot(rmses, ax=axs[0], saturation=1, linewidth=1.5, linecolor='dodgerblue', color='skyblue')
sns.stripplot(rmses, ax=axs[0], jitter=0.3, alpha=0.5, linewidth=1.5, edgecolor='dodgerblue', color='skyblue')
axs[0].set_ylabel('RMSE')
sns.boxplot(prhos, ax=axs[1], saturation=1, linewidth=1.5, linecolor='dodgerblue', color='skyblue')
sns.stripplot(prhos, ax=axs[1], jitter=0.3, alpha=0.5, linewidth=1.5, edgecolor='dodgerblue', color='skyblue')
axs[1].set_ylabel('Pearson Rho')
sns.boxplot(srhos, ax=axs[2], saturation=1, linewidth=1.5, linecolor='dodgerblue', color='skyblue')
sns.stripplot(srhos, ax=axs[2], jitter=0.3, alpha=0.5, linewidth=1.5, edgecolor='dodgerblue', color='skyblue')
axs[2].set_ylabel('Spearman Rho')
fig.tight_layout()
print(np.mean(prhos), np.std(prhos) / np.sqrt(10) * 1.96)

In [None]:
# examine the TCRs that are well validated by Ton's dataset
fig, ax = plt.subplots(figsize=[4, 4]); ax.grid(False)
pred, true = pd.concat(pred_tests, axis=0), pd.concat(true_tests, axis=0)
sns.kdeplot(x=pred, y=true, bw_adjust=0.8, thresh=0.2, alpha=0.8, levels=8, fill=True, cmap='Blues')
ax.scatter(pred, true, alpha=0.25, s=1, color='skyblue')
ax.set(xlabel='Predicted YLQ Specificity', ylabel='Experimental YLQ Specificity')
ax.set_xlim(-1.4, 2.7); ax.set_ylim(-1.7, 4.5)
# determine rhos
rhos = [ss.pearsonr(pred, true)[0] for pred, true in zip(pred_tests, true_tests)]
np.mean(rhos), np.std(rhos) / np.sqrt(10) * 1.96

In [None]:
# sample from the dataset 1M times
torch.manual_seed(0); np.random.seed(0)
xs = []
for _ in tqdm(range(100)):
    xs.append(flow.sample(10000).clone().cpu().detach().numpy())
x = np.vstack(xs)
# attempt to generate some peptides
seqs = []
for idx in tqdm(range(x.shape[0]), total=x.shape[0]):
    seqs.append(engine(model, torch.from_numpy(x[idx, 0, :].astype(np.float32)).to('cuda')))
seqs = pd.Series(seqs)
seqs_unique = pd.Series(seqs.unique())

In [None]:
# read in the data
df_int = pd.read_csv('../outs/df.int.clean.csv', index_col=0)
# derive the mapping dataset
trbs = df_int.loc[df_int['AG'] == 'YLQPRTFLL', 'TRB']
trbs = trbs[trbs.isin(a_trb.obs.index)]
trbs_training = trbs.unique()
trbs_generated = seqs_unique
trbs_tested = tarpon_ylq.index
# get the list of all TRBs
trbs = trbs_tested.union(trbs_training).union(trbs_generated)

In [None]:
# process all YLQ TRBs
trb_to_embed = {}
trbs = pd.Series(trbs)
for sequence in tqdm(trbs):
    # retrieve the embedding
    embedding = np.array([embed for embed in map(embed_aa, list(sequence))])
    # stretch the embedding
    embedding = stretch_pep(embedding, targ_len=targ_len)
    # save the embedding
    trb_to_embed[sequence] = embedding

# embed all of our unique TRBs
X_trbs = torch.stack([x.to(torch.float32) for x in trbs.map(trb_to_embed)])
loader = DataLoader(dataset=TensorDataset(X_trbs), batch_size=batch_size, shuffle=False)

# get the encoded dimensions
torch.manual_seed(0); np.random.seed(0)
# move through each subset in the complete loader
z_dims_per_batch = []
with torch.no_grad():
    for data in tqdm(loader):
        data = data[0].to(device)
        enc_out = model.encode(data)
        # sampling centers around the mean so we just use mu
        z_dims_per_batch.append(enc_out[0].clone().detach().cpu().numpy())
z_dims = np.vstack(z_dims_per_batch)
del data, enc_out, z_dims_per_batch

In [None]:
from anndata import AnnData
# assemble the anndata object and compute UMAP
aylq = AnnData(z_dims)
aylq.obs.index = trbs
sc.pp.neighbors(aylq, use_rep='X', random_state=0)
sc.tl.umap(aylq, random_state=0)
# derive where the TCR came from
aylq.obs['was_generated'] = 1 * aylq.obs.index.isin(trbs_generated)
aylq.obs['was_experimentalDB'] = 1 * aylq.obs.index.isin(trbs_tested)

In [None]:
# plot on the logFC from the experimental data
aylq.obs['enrich_generation'] = np.log10(aylq.obs.index.map(seqs.value_counts()))
aylq.obs['enrich_experimental'] = -aylq.obs.index.map(mapping)
mask = ~aylq.obs['enrich_experimental'].isna()
ax = sc.pl.umap(aylq[~mask], show=False, s=0.5)
sc.pl.umap(aylq[mask], color=['enrich_experimental'], cmap='Blues', ax=ax,
           s=40, vmin=-1, vmax=5, edgecolor='k', linewidth=0.5, alpha=0.9)
# write the data
aylq.write('250605_sampled_1M_YLQ_TCRs.h5ad')

In [None]:
# plot on the generated number of YLQs
ax = sc.pl.umap(aylq, show=False)
xlim, ylim = ax.get_xlim(), ax.get_ylim()
# randomly sample the data
mask = aylq.obs['was_generated'] == 1
np.random.seed(0)
counts = seqs.value_counts().loc[aylq.obs.index[mask]]
percs = counts / counts.sum()
idxs = np.random.choice(aylq.obs.index[mask], size=10000, replace=True, p=percs)
sns.kdeplot(x=aylq[idxs].obsm['X_umap'][:, 0], y=aylq[idxs].obsm['X_umap'][:, 1],
            cmap='Blues', thresh=0.2, fill=True, alpha=0.8, bw_adjust=0.8)
ax.set_xlim(*xlim); ax.set_ylim(*ylim)

In [None]:
# derive the average and 95% confidence interval of experimental by prediction
means, ci95s = [], []
xs = np.arange(0, 4+1, 1)
for idx, vmax in enumerate(xs[1:]):
    vmin = xs[idx]
    mask = (aylq.obs['enrich_generation'] >= vmin)
    if vmax != 10:
        mask = mask & (aylq.obs['enrich_generation'] < vmax)
    means.append(aylq.obs.loc[mask, 'enrich_experimental'].mean())
    ci95s.append(aylq.obs.loc[mask, 'enrich_experimental'].std() / np.sqrt(sum(mask)) * 1.96)
# compare the percentages
fig, ax = plt.subplots(figsize=[3, 4])
ax.grid(False)
ax.scatter(xs[1:], means, color='dodgerblue')
for idx, x in enumerate(xs[1:]):
    mean = means[idx]
    ci95 = ci95s[idx]
    ax.plot([x]*2, [mean-ci95, mean+ci95], color='dodgerblue')
    ax.plot([x-0.05, x+0.05], [mean-ci95]*2, color='dodgerblue')
    ax.plot([x-0.05, x+0.05], [mean+ci95]*2, color='dodgerblue')
ax.set(xlabel='log10(# of CDR3s)\nfrom YLQ AgFlow Model', ylabel='Experimental YLQ Specificity\nfrom Messemaker et al. 2025')
ax.set_xticks(np.arange(1, 4+1, 1))
ax.set_xticklabels([f'[10$^{vmax-1}$, 10$^{vmax}$)' if vmax != 4 else '[10$^3$, ≥10$^4$)' for vmax in np.arange(1, 4+1, 1)])
ax.tick_params(axis='x', labelrotation=90)
ss.pearsonr(xs[1:], means)