# Mutation masking inference for BALM-paired

In [None]:
import copy
import os

import torch
import torch.nn as nn
import torch.nn.functional as F

from tqdm.notebook import tqdm

import numpy as np
import pandas as pd

from transformers import RobertaTokenizer, RobertaForMaskedLM

## load the model

In [None]:
# replace with actual model path
model_path = './BALM-paired/'

In [None]:
model = RobertaForMaskedLM.from_pretrained(model_path).to('cuda')

## tokenizer

In [None]:
tokenizer = RobertaTokenizer.from_pretrained('../tokenizer/')

## inference function

In [None]:
def infer(
    model, 
    tokenizer, 
    pair_ids, 
    inputs, 
    labels, 
    germs,
    device='cuda'
):
    '''
    inputs and labels should already be tokenized
    
    labels should just be the 'input_ids' data, not the whole tokenized dict
    '''
    data = []
    
    with torch.no_grad():
        # pbar = tqdm(list(zip(pair_ids, inputs, labels)))
        pbar = tqdm(list(zip(pair_ids, inputs, labels, germs)))
        for name, i, l, g in pbar:
            mask_positions = (i.input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)
            labels_ = torch.where(i.input_ids == tokenizer.mask_token_id, l, -100)
            o = model(**i, labels=labels_)
            
            # loss
            loss = o.loss.item()
            
            # PPL
            perplexity = float(torch.exp(o.loss))
            
            # germlines
            germs_ = torch.where(i.input_ids == tokenizer.mask_token_id, g, -100)
            germ_tokens = [germs_[0, mask_pos] for mask_pos in mask_positions]
            germ = [tokenizer.decode(germ_token) for germ_token in germ_tokens]
            germ = "".join(germ)
            
            # ground truth
            actual_tokens = [labels_[0, mask_pos] for mask_pos in mask_positions]
            actual = [tokenizer.decode(actual_token) for actual_token in actual_tokens]
            actual = "".join(actual)
            
            # logits
            logits = [o.logits[0, mask_pos] for mask_pos in mask_positions][0]
            m = torch.nn.Softmax(dim=1)
            softmax = m(logits)
            
            # predictions
            pred_tokens = logits.argmax(axis=-1)
            preds = [tokenizer.decode(pred_token) for pred_token in pred_tokens]
            predictions = ''.join(preds)
            
            # format and append data
            for x in range(len(mask_positions[0])):
                d = {
                    "pair_id": name,
                    "perplexity": perplexity,
                    "loss": loss,
                    "mask_position": mask_positions[0][x].item(),
                    "prediction": predictions[x],
                    "germline": germ[x],
                    "actual": actual[x],
                }
                for y in range(tokenizer.vocab_size):
                    _d = copy.deepcopy(d)
                    token = tokenizer.decode(y)
                    _d["token"] = token
                    _d["logit"] = logits[x, y].item()
                    _d["softmax"] = softmax[x, y].item()
                    data.append(_d)
    return data

## load labels & tokenize

In [None]:
# pair ids
with open('./data/pair_ids.txt') as f:
    pair_ids = [line.strip() for line in f]

In [None]:
# paired labels
with open('./data/paired_labels.txt') as f:
    paired_labels_txt = [line.strip() for line in f]

paired_labels = [tokenizer(l, return_tensors='pt').to('cuda')['input_ids'] for l in paired_labels_txt]

In [None]:
# heavy germline labels
with open('./data/light-masked_heavy-reverted.txt') as f:
    lmasked_hreverted_txt = [line.strip() for line in f]

hgerm_labels = [tokenizer(l, return_tensors='pt').to('cuda')['input_ids'] for l in lmasked_hreverted_txt]

In [None]:
# light germline labels
with open('./data/heavy-masked_light-reverted.txt') as f:
    hmasked_lreverted_txt = [line.strip() for line in f]

lgerm_labels = [tokenizer(l, return_tensors='pt').to('cuda')['input_ids'] for l in hmasked_lreverted_txt]

## masked heavy chains

#### mutated light chains

In [None]:
print('reading masked data...')
with open('./data/heavy-masked_light-mutated.txt') as f:
    hmasked_lmutated_txt = [line.strip() for line in f]
hmasked_lmutated = [tokenizer(l, return_tensors='pt').to("cuda") for l in hmasked_lmutated_txt]

print('running inference:')
hmasked_lmutated_data = infer(model,
                              tokenizer,
                              pair_ids,
                              hmasked_lmutated,
                              paired_labels,
                              hgerm_labels,
                             )

print('writing output...')
hmasked_lmutated_df = pd.DataFrame(hmasked_lmutated_data)
hmasked_lmutated_df.to_csv('./outputs/BALM-paired/heavy-masked_light-mutated.csv',
                           index=False)

#### germline reverted light chains

In [None]:
print('reading masked data...')
with open('./data/heavy-masked_light-reverted.txt') as f:
    hmasked_lreverted_txt = [line.strip() for line in f]
hmasked_lreverted = [tokenizer(l, return_tensors='pt').to("cuda") for l in hmasked_lreverted_txt]

print('running inference:')
hmasked_lreverted_data = infer(model,
                              tokenizer,
                              pair_ids,
                              hmasked_lreverted,
                              paired_labels,
                              hgerm_labels,
                              )

print('writing output...')
hmasked_lreverted_df = pd.DataFrame(hmasked_lreverted_data)
hmasked_lreverted_df.to_csv('./outputs/BALM-paired/heavy-masked_light-reverted.csv',
                           index=False)

## masked light chains

#### mutated heavy chains

In [None]:
print('reading masked data...')
with open('./data/light-masked_heavy-mutated.txt') as f:
    lmasked_hmutated_txt = [line.strip() for line in f]
lmasked_hmutated = [tokenizer(l, return_tensors='pt').to("cuda") for l in lmasked_hmutated_txt]

print('running inference:')
lmasked_hmutated_data = infer(model,
                              tokenizer,
                              pair_ids,
                              lmasked_hmutated,
                              paired_labels,
                              lgerm_labels,
                             )

print('writing output...')
lmasked_hmutated_df = pd.DataFrame(lmasked_hmutated_data)
lmasked_hmutated_df.to_csv('./outputs/BALM-paired/light-masked_heavy-mutated.csv',
                           index=False)

#### germline reverted heavy chains

In [None]:
print('reading masked data...')
with open('./data/light-masked_heavy-reverted.txt') as f:
    lmasked_hreverted_txt = [line.strip() for line in f]
lmasked_hreverted = [tokenizer(l, return_tensors='pt').to("cuda") for l in lmasked_hreverted_txt]

print('running inference:')
lmasked_hreverted_data = infer(model,
                              tokenizer,
                              pair_ids,
                              lmasked_hreverted,
                              paired_labels,
                              lgerm_labels
                              )

print('writing output...')
lmasked_hreverted_df = pd.DataFrame(lmasked_hreverted_data)
lmasked_hreverted_df.to_csv('./outputs/BALM-paired/light-masked_heavy-reverted.csv',
                           index=False)