# Mutation masking inference for BALM-unpaired

In [None]:
import copy
import os

import torch
import torch.nn as nn
import torch.nn.functional as F

from tqdm.notebook import tqdm

import numpy as np
import pandas as pd

from transformers import RobertaTokenizer, RobertaForMaskedLM

## load the model

In [None]:
# replace with actual model path
model_path = './BALM-unpaired/'

In [None]:
model = RobertaForMaskedLM.from_pretrained(model_path).to('cuda')

## tokenizer

In [None]:
tokenizer = RobertaTokenizer.from_pretrained('../tokenizer/')

## inference function

In [None]:
def infer(
    model, 
    tokenizer, 
    pair_ids, 
    inputs, 
    labels, 
    #germs,
    device='cuda'
):
    '''
    inputs and labels should already be tokenized
    
    labels should just be the 'input_ids' data, not the whole tokenized dict
    '''
    data = []
    
    with torch.no_grad():
        pbar = tqdm(list(zip(pair_ids, inputs, labels)))
        # pbar = tqdm(list(zip(pair_ids, inputs, labels, germs)))
        # for name, i, l, g in pbar:
        for name, i, l in pbar:
            mask_positions = (i.input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)
            labels_ = torch.where(i.input_ids == tokenizer.mask_token_id, l, -100)
            o = model(**i, labels=labels_)
            
            # loss
            loss = o.loss.item()
            
            # PPL
            perplexity = float(torch.exp(o.loss))
            
            # germlines
            # germs_ = torch.where(i.input_ids == tokenizer.mask_token_id, g, -100)
            # germ_tokens = [germs_[0, mask_pos] for mask_pos in mask_positions]
            # germ = [tokenizer.decode(germ_token) for germ_token in germ_tokens]
            # germ = "".join(germ)
            
            # ground truth
            actual_tokens = [labels_[0, mask_pos] for mask_pos in mask_positions]
            actual = [tokenizer.decode(actual_token) for actual_token in actual_tokens]
            actual = "".join(actual)
            
            # logits
            logits = [o.logits[0, mask_pos] for mask_pos in mask_positions][0]
            m = torch.nn.Softmax(dim=1)
            softmax = m(logits)
            
            # predictions
            pred_tokens = logits.argmax(axis=-1)
            preds = [tokenizer.decode(pred_token) for pred_token in pred_tokens]
            predictions = ''.join(preds)
            
            # format and append data
            for x in range(len(mask_positions[0])):
                d = {
                    "pair_id": name,
                    "perplexity": perplexity,
                    "loss": loss,
                    "mask_position": mask_positions[0][x].item(),
                    "prediction": predictions[x],
                    #"germline": germ[x],
                    "actual": actual[x],
                }
                for y in range(tokenizer.vocab_size):
                    _d = copy.deepcopy(d)
                    token = tokenizer.decode(y)
                    _d["token"] = token
                    _d["logit"] = logits[x, y].item()
                    _d["softmax"] = softmax[x, y].item()
                    data.append(_d)
    return data

## load labels & tokenize

In [None]:
# pair ids
with open('./data/pair_ids.txt') as f:
    pair_ids = [line.strip() for line in f]

In [None]:
# labels
with open('./data/heavy_labels.txt') as f:
    heavy_labels_txt = [line.strip() for line in f]
heavy_labels = [tokenizer(l, return_tensors='pt').to('cuda')['input_ids'] for l in heavy_labels_txt]

with open('./data/light_labels.txt') as f:
    light_labels_txt = [line.strip() for line in f]
light_labels = [tokenizer(l, return_tensors='pt').to('cuda')['input_ids'] for l in light_labels_txt]

## masked heavy chains

In [None]:
print('reading masked data...')
with open('./data/heavy-masked.txt') as f:
    hmasked_txt = [line.strip() for line in f]
hmasked = [tokenizer(l, return_tensors='pt').to("cuda") for l in hmasked_txt]

print('running inference:')
hmasked_data = infer(model,
                     tokenizer,
                     pair_ids,
                     hmasked,
                     heavy_labels,
                    )

print('writing output...')
hmasked_df = pd.DataFrame(hmasked_data)
hmasked_df.to_csv('./outputs/BALM-unpaired/heavy-masked.csv',
                           index=False)

## masked light chains

In [None]:
print('reading masked data...')
with open('./data/light-masked.txt') as f:
    lmasked_txt = [line.strip() for line in f]
lmasked = [tokenizer(l, return_tensors='pt').to("cuda") for l in lmasked_txt]

print('running inference:')
lmasked_data = infer(model,
                     tokenizer,
                     pair_ids,
                     lmasked,
                     light_labels,
                    )

print('writing output...')
lmasked_df = pd.DataFrame(lmasked_data)
lmasked_df.to_csv('./outputs/BALM-unpaired/light-masked.csv',
                           index=False)