In [1]:
from collections import Counter

import pandas as pd
import polars as pl

import abutils
import abstar

Import and Align SARS-CoV2 mAbs 

In [2]:
df = pd.read_csv('../TXG-20220218.csv')

In [3]:
txg_df = df[df["Identifier"].str.startswith("TXG")]
heavies = [abutils.Sequence(s, id=i) for s, i in zip(txg_df["vj_seq1"], txg_df["Identifier"])]
lights = [abutils.Sequence(s, id=i) for s, i in zip(txg_df["vj_seq2"], txg_df["Identifier"])]

In [4]:
seqs = abstar.run(heavies + lights, output_type="airr")
pairs = abutils.core.pair.assign_pairs(seqs)


Running abstar...
(1/1) ||||||||||||||||||||||||||||||||||||||||||||||||||||  100%

478 sequences contained an identifiable rearrangement
abstar completed in 10.86 seconds



In [5]:
counterh = Counter([p.heavy["sequence_aa"] for p in pairs])
counterl = Counter([p.light["sequence_aa"] for p in pairs])

seq_aa_l = []
germ_aa_l = []

h_alignments = []
l_alignments = []

for p in pairs:
    hgerm_aa = abutils.tl.translate(p.heavy["germline_alignment"])
    hseq_aa = abutils.tl.translate(p.heavy["sequence_alignment"])
    
    lgerm_aa = abutils.tl.translate(p.light["germline_alignment"])
    lseq_aa = abutils.tl.translate(p.light["sequence_alignment"])
    
    germ_aa = f'{hgerm_aa}<cls><cls>{lgerm_aa}'
    germ_aa_l.append(germ_aa)
    
    seq_aa = f'{hseq_aa}<cls><cls>{lseq_aa}'
    if seq_aa not in seq_aa_l:
        haln = abutils.tl.global_alignment(hseq_aa, hgerm_aa, gap_open=25)
        h_alignments.append(haln)
        
        laln = abutils.tl.global_alignment(lseq_aa, lgerm_aa, gap_open=25)
        l_alignments.append(laln)
        seq_aa_l.append(seq_aa)

Retrieve Scores from Balm

In [6]:
from datetime import date
import os
from tqdm.notebook import tqdm
import torch

from transformers import ( #for BALM
    RobertaConfig,
    RobertaTokenizer,
    RobertaForMaskedLM,
    DataCollatorForLanguageModeling,
    AutoTokenizer,
    AutoModelForMaskedLM
)

from Bio import SeqIO
import re

In [7]:
model = AutoModelForMaskedLM.from_pretrained(
    '/models/BALM-paired_LC-coherence_90-5-5-split_122222/'
).to('cuda')

In [8]:
seqs = germ_aa_l
seqs = [s.replace('*', '.') for s in seqs]
seq_names = [p.heavy.annotations['sequence_id'] for p in pairs]

In [9]:
tokenizer = RobertaTokenizer.from_pretrained("/pre-training/balm/tokenizer")
vocab = tokenizer.get_vocab()

In [10]:
masked_seqs = []
masked_seqs_tok = []
masked_seq_ids = []
for s in tqdm(range(len(seqs))):
    seq = seqs[s]
    seq_id = seq_names[s]
    for aa in range(len(seq)):
        #print(seq)
        if (seq[aa] != '<') or (seq[aa] != '>') or (seq[aa] != 'l') or (seq[aa] != 'c') or (seq[aa] != 's'):
            masked_seq = seq[:aa] + '<mask>' + seq[aa+1:]
            masked_seqs.append(masked_seq)
            masked_seqs_tok.append(tokenizer(masked_seq, return_tensors='pt', max_length = 512, padding = 'max_length').to('cuda'))
            masked_seq_ids.append(seq_id)
        else:
            print(seq[aa])
    
# finalize inputs
inputs = list(zip(masked_seq_ids, masked_seqs, masked_seqs_tok))

  0%|          | 0/239 [00:00<?, ?it/s]

In [11]:
output = {}
with torch.no_grad():
    for name, seq, tokens in tqdm(inputs):
        o = model(**tokens, output_hidden_states=True, return_dict=True,)
        softmax = torch.nn.Softmax(dim=1)
        logits = softmax(o["logits"]).to(device="cpu").numpy()[0] #probabilities
        mask_pos = seq.index("<mask>")
        name2 = f'{name}_pos{mask_pos}'
        output[name2] = logits[mask_pos+1]

  0%|          | 0/57619 [00:00<?, ?it/s]

In [12]:
output_df = pd.DataFrame.from_dict(output)
output_df.to_csv("./TXG_mAbs_maskedprobs_BALM.csv")

Compare Mutation Predictions to Germline Scores

In [32]:
#import balm scores
output_df1 = pd.read_csv("./TXG_mAbs_maskedprobs_BALM.csv")
output_df1.index = output_df1.iloc[:, 0]
output_df1 = output_df1.iloc[:, 1:]
output_df1.index = list(vocab.keys())
output_df1= output_df1.T
max_prob_aas = output_df1.idxmax(axis=1)
output_df1["max_prob_aas"] = max_prob_aas
output_df1

Unnamed: 0,<s>,</s>,<pad>,<unk>,<mask>,A,C,D,E,F,...,N,P,Q,R,S,T,V,W,Y,max_prob_aas
TXG-0001_pos0,0.000537,0.000090,0.000650,0.000326,0.001674,6.712171e-10,4.991637e-16,1.158105e-11,7.386458e-08,4.245692e-12,...,2.950762e-10,1.893537e-07,1.801331e-02,1.252880e-06,1.139298e-10,1.441798e-10,6.473079e-09,1.292268e-12,3.689247e-10,Q
TXG-0001_pos1,0.000639,0.000114,0.000845,0.000676,0.000499,9.738896e-06,9.418030e-15,5.186145e-10,1.659513e-05,1.576908e-09,...,1.614860e-10,4.455026e-09,4.809455e-10,5.650238e-10,1.477380e-10,3.649209e-09,4.525929e-02,3.678126e-12,3.777115e-11,V
TXG-0001_pos2,0.000522,0.000324,0.000489,0.000354,0.000884,8.813778e-10,2.201597e-15,1.762267e-10,1.580648e-07,1.601118e-11,...,2.591281e-09,6.180144e-08,2.800575e-04,8.414578e-08,2.852198e-10,1.557890e-10,3.811031e-09,1.883095e-12,3.328436e-09,<mask>
TXG-0001_pos3,0.002038,0.000266,0.001731,0.002107,0.002279,5.209519e-08,5.717766e-15,2.655331e-11,3.010905e-09,3.635894e-09,...,1.943725e-09,9.711414e-08,2.404738e-08,6.017458e-09,2.669460e-09,3.526159e-09,6.364827e-05,1.832541e-11,1.316293e-10,<mask>
TXG-0001_pos4,0.001183,0.000502,0.001328,0.001170,0.000808,5.389254e-06,3.751557e-15,1.608833e-10,6.926930e-06,3.513015e-10,...,7.081735e-11,2.015917e-09,4.360423e-10,4.123664e-10,1.460642e-10,2.483344e-09,7.729272e-02,1.291765e-11,3.321982e-11,V
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
TXG-0239_pos237,0.001718,0.000399,0.001758,0.001001,0.000497,1.742520e-09,2.227548e-13,3.580414e-13,9.805512e-09,1.601680e-11,...,2.695992e-06,3.435909e-10,1.637434e-09,2.537776e-06,1.220414e-09,1.994869e-09,2.787094e-09,6.018908e-12,3.388221e-09,<pad>
TXG-0239_pos238,0.009470,0.003074,0.004261,0.008495,0.008474,3.185831e-10,2.508774e-14,1.760521e-14,1.166252e-10,1.021620e-10,...,4.833691e-09,6.611445e-11,9.398131e-13,2.646390e-11,3.475650e-11,1.962583e-11,1.884103e-06,3.982638e-12,2.262775e-11,<s>
TXG-0239_pos239,0.002299,0.000053,0.003280,0.000717,0.000697,1.215524e-07,1.145997e-13,4.051752e-10,9.923290e-02,1.039535e-10,...,8.454901e-08,1.413235e-10,3.676097e-09,1.851649e-09,5.463743e-11,6.346846e-10,3.977817e-08,4.982131e-11,2.367515e-09,E
TXG-0239_pos240,0.002806,0.000020,0.001437,0.000612,0.000194,1.245993e-08,1.309502e-12,4.967889e-13,1.786975e-10,2.143679e-08,...,3.133481e-06,2.151430e-09,1.089380e-12,5.694285e-09,5.843372e-08,8.360015e-08,1.458314e-06,5.030363e-13,3.944383e-10,I


In [31]:
#clean and prep sequences
seqs = germ_aa_l
for s in range(len(seqs)):
    seqs[s] = re.sub(r'<cls><cls>', '__________', seqs[s])
    seqs[s] = re.sub(r'-', '_', seqs[s])
    seqs[s] = re.sub(r'\*', '_', seqs[s])
    seqs[s] = re.sub(r'X', '_', seqs[s])

In [27]:
wt_dict = {}
wt_prob_d = {}
for row in tqdm(output_df1.index):
    for i, s in list(zip(seq_names, seqs)):
        for pos in range(len(s)):
            if f"{i}_pos{pos}" == str(row):
                wt = s[pos]
                wt_dict[f"{i}_pos{pos}"] = wt
                if (wt != '_'):
                    wt_prob = output_df1.loc[row, wt]
                else:
                    wt_prob = '0'
                wt_prob_d[f"{i}_pos{pos}"]= wt_prob

  0%|          | 0/57619 [00:00<?, ?it/s]

In [33]:
#merging and cleaning dataframes
wt_df = pd.DataFrame.from_dict(wt_dict, orient='index', columns=['wt'])
wt_df = wt_df.reset_index().rename(columns={'index': 'alias'})
output_df1 = output_df1.reset_index().rename(columns={'index': 'alias'})
output_df1 = pd.merge(output_df1, wt_df, on='alias', how='inner')
wt_prob_df = pd.DataFrame.from_dict(wt_prob_d, orient='index', columns=['wt_prob'])
wt_prob_df = wt_prob_df.reset_index().rename(columns={'index': 'alias'})
output_df1 = pd.merge(output_df1, wt_prob_df, on='alias', how='inner')

In [36]:
#creating final dataframe
import math
e = math.e
max_probs = output_df1.max(axis=1).to_list()
max_prob_aas = output_df1.max_prob_aas.to_list()
mAbs_pos = output_df1.alias.to_list()
wt_list = output_df1.wt.to_list()
wt_prob_l = output_df1.wt_prob.to_list()
df_max2wt = pd.DataFrame({'mAbs_pos' : mAbs_pos, 'wt' : wt_list,'max_prob_aa' : max_prob_aas, 'wt_prob': wt_prob_l, 'max_prob' : max_probs})
df_max2wt = df_max2wt[df_max2wt['wt'] != '_']
df_max2wt['max/wt_ratio'] = (df_max2wt['max_prob'] / df_max2wt['wt_prob'])
df_max2wt.to_csv("./TXG_mAbs_BALM_maxmask2wt.csv")