In [1]:
from collections import Counter

import pandas as pd
import polars as pl

import abutils
import abstar

Import and Align SARS-CoV2 mAbs 

In [2]:
df = pd.read_csv('../TXG-20220218.csv')

In [3]:
txg_df = df[df["Identifier"].str.startswith("TXG")]
heavies = [abutils.Sequence(s, id=i) for s, i in zip(txg_df["vj_seq1"], txg_df["Identifier"])]
lights = [abutils.Sequence(s, id=i) for s, i in zip(txg_df["vj_seq2"], txg_df["Identifier"])]

In [6]:
seqs = abstar.run(heavies + lights, output_type="airr")
pairs = abutils.core.pair.assign_pairs(seqs)


Running abstar...
(1/1) ||||||||||||||||||||||||||||||||||||||||||||||||||||  100%

478 sequences contained an identifiable rearrangement
abstar completed in 10.88 seconds



In [7]:
counterh = Counter([p.heavy["sequence_aa"] for p in pairs])
counterl = Counter([p.light["sequence_aa"] for p in pairs])

seq_aa_l = []
germ_aa_l = []

h_alignments = []
l_alignments = []

for p in pairs:
    hgerm_aa = abutils.tl.translate(p.heavy["germline_alignment"])
    hseq_aa = abutils.tl.translate(p.heavy["sequence_alignment"])
    
    lgerm_aa = abutils.tl.translate(p.light["germline_alignment"])
    lseq_aa = abutils.tl.translate(p.light["sequence_alignment"])
    
    germ_aa = f'{hgerm_aa}<cls><cls>{lgerm_aa}'
    germ_aa_l.append(germ_aa)
    
    seq_aa = f'{hseq_aa}<cls><cls>{lseq_aa}'
    if seq_aa not in seq_aa_l:
        haln = abutils.tl.global_alignment(hseq_aa, hgerm_aa, gap_open=25)
        h_alignments.append(haln)
        
        laln = abutils.tl.global_alignment(lseq_aa, lgerm_aa, gap_open=25)
        l_alignments.append(laln)
        seq_aa_l.append(seq_aa)

Retrieve Scores from Balm

In [8]:
from datetime import date
import os
from tqdm.notebook import tqdm
import torch

from ..balm.config import BalmConfig, BalmMoEConfig
from ..balm.data import load_dataset, DataCollator
from ..balm.models import (
    BalmForMaskedLM,
    BalmModel,
    BalmMoEForMaskedLM,
)
from ..balm.tokenizer import Tokenizer

from Bio import SeqIO
import re

In [9]:
model = BalmMoEForMaskedLM.from_pretrained(
    "../training_runs/balmMoE_expertchoiceBig_1shared_altern_052924/model/"
).to('cuda')

All model checkpoint weights were used when initializing BalmMoEForMaskedLM.

All the weights of BalmMoEForMaskedLM were initialized from the model checkpoint at /home/jovyan/shared/simone/BALM_development_51324/training_runs/balmMoE_expertchoiceBig_1shared_altern_052924/balmMoE_expertchoiceBig_1shared_altern_052924/model/model.pt.
If your task is similar to the task the model of the checkpoint was trained on, you can already use BalmMoEForMaskedLM for predictions without further training.


In [10]:
seqs = germ_aa_l
seqs = [s.replace('*', '.') for s in seqs]
seq_names = [p.heavy.annotations['sequence_id'] for p in pairs]

In [11]:
tokenizer = Tokenizer(vocab="../balm/vocab.json")
vocab = tokenizer.vocab

In [12]:
masked_seqs = []
masked_seqs_tok = []
masked_seq_ids = []
for s in tqdm(range(len(seqs))):
    seq = seqs[s]
    seq_id = seq_names[s]
    for aa in range(len(seq)):
        if (seq[aa] == '<') or (seq[aa] == '>') or (seq[aa] == 'c') or (seq[aa] == 'l') or (seq[aa] == 's'):
            continue
        else:
            masked_seq = seq[:aa] + '<mask>' + seq[aa:]
            masked_seqs.append(masked_seq)
            masked_seqs_tok.append(tokenizer(masked_seq, return_tensors='pt', max_length = 320, padding = 'max_length')['input_ids'][0].to('cuda'))
            masked_seq_ids.append(seq_id)
    
# finalize inputs
inputs = list(zip(masked_seq_ids, masked_seqs, masked_seqs_tok))

  0%|          | 0/239 [00:00<?, ?it/s]

In [13]:
output = {}
with torch.no_grad():
    for name, seq, tokens in tqdm(inputs):
        o = model(tokens.unsqueeze(0), output_hidden_states=True, return_dict=True)
        softmax = torch.nn.Softmax(dim=1)
        logits = softmax(o["logits"]).to(device="cpu").numpy()[0] #probabilities
        mask_pos = seq.index("<mask>")
        name2 = f'{name}_pos{mask_pos}'
        output[name2] = logits[mask_pos+1]

  0%|          | 0/55229 [00:00<?, ?it/s]

In [14]:
output_df = pd.DataFrame.from_dict(output)
output_df.to_csv("./TXG_mAbs_maskedprobs_BALMMoE.csv")

Compare Mutation Predictions to Germline Scores

In [26]:
#import balmMoE scores
output_df1 = pd.read_csv("./TXG_mAbs_maskedprobs_BALMMoE.csv")
output_df1.index = output_df1.iloc[:, 0]
output_df1 = output_df1.iloc[:, 1:]
output_df1.index = list(vocab.keys())
output_df1= output_df1.T
max_prob_aas = output_df1.idxmax(axis=1)
output_df1["max_prob_aas"] = max_prob_aas

In [27]:
#clean and prep sequences
for s in range(len(seqs)):
    seqs[s] = re.sub(r'<cls><cls>', '_________', seqs[s])
    seqs[s] = re.sub(r'-', '_', seqs[s])
    seqs[s] = re.sub(r'\*', '.', seqs[s])

In [28]:
wt_dict = {}
wt_prob_d = {}
for row in tqdm(output_df1.index):
    for i, s in list(zip(seq_names, seqs)):
        for pos in range(len(s)):
            if f"{i}_pos{pos}" == str(row):
                wt = s[pos]
                wt_dict[f"{i}_pos{pos}"] = wt
                if (wt != '_'):
                    wt_prob = output_df1.loc[row, wt]
                else:
                    wt_prob = '0'
                wt_prob_d[f"{i}_pos{pos}"]= wt_prob

  0%|          | 0/55229 [00:00<?, ?it/s]

In [29]:
#merging and cleaning dataframes
wt_df = pd.DataFrame.from_dict(wt_dict, orient='index', columns=['wt'])
wt_df = wt_df.reset_index().rename(columns={'index': 'alias'})
output_df1 = output_df1.reset_index().rename(columns={'index': 'alias'})
output_df1 = pd.merge(output_df1, wt_df, on='alias', how='inner')
wt_prob_df = pd.DataFrame.from_dict(wt_prob_d, orient='index', columns=['wt_prob'])
wt_prob_df = wt_prob_df.reset_index().rename(columns={'index': 'alias'})
output_df1 = pd.merge(output_df1, wt_prob_df, on='alias', how='inner')

In [30]:
#creating final dataframe
import math
e = math.e
max_probs = output_df1.max(axis=1).to_list()
max_prob_aas = output_df1.max_prob_aas.to_list()
mAbs_pos = output_df1.alias.to_list()
wt_list = output_df1.wt.to_list()
wt_prob_l = output_df1.wt_prob.to_list()
df_max2wt = pd.DataFrame({'mAbs_pos' : mAbs_pos, 'wt' : wt_list,'max_prob_aa' : max_prob_aas, 'wt_prob': wt_prob_l, 'max_prob' : max_probs})
df_max2wt = df_max2wt[df_max2wt['wt'] != '_']
df_max2wt['max/wt_ratio'] = (df_max2wt['max_prob'] / df_max2wt['wt_prob'])
df_max2wt.to_csv("./TXG_mAbs_BALMMoE_maxmask2wt.csv")