In [1]:
from dataclasses import dataclass
import gc
import os
import pickle
from typing import Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F

from tqdm.notebook import tqdm

import numpy as np
import pandas as pd

import umap.umap_ as umap
from sklearn.manifold import TSNE

import seaborn as sns
import matplotlib.pyplot as plt

#from transformers import RobertaTokenizer, RobertaForMaskedLM

from datasets import load_dataset

from ..balm.config import BalmConfig, BalmMoEConfig
from ..balm.data import load_dataset, DataCollator
from ..balm.models import (
    BalmForMaskedLM,
    BalmModel,
    BalmMoEForMaskedLM,
)
from ..balm.tokenizer import Tokenizer

## load the model

In [34]:
model = BalmMoEForMaskedLM.from_pretrained("../training_runs/balmMoE_expertchoice_1shared_altern_052924/model")
model = model.to('cuda')

## tokenizer

In [6]:
tokenizer = Tokenizer(vocab="../balm/vocab.json")

## load data

In [10]:
 df = pd.read_csv('./lc-coherence_test-unique_annotated.csv')
df['chain'] = ['heavy' if l == 'IGH' else 'light' for l in df['locus']]

In [11]:
seq_df = df.pivot(
    index='pair_id', 
    columns='chain',
    values='sequence_aa'
)

In [16]:
seqs = []
for h, l in zip(seq_df['heavy'], seq_df['light']):
    seqs.append("{}<cls><cls>{}".format(h, l))
    
seq_names = list(seq_df.index.values)

In [17]:
seqs[0]

'EVQLWESGGGLVQPGGSLRLSCAASGFIFSSYAMIWVRQAPGKGLEWVSGISSSGGSTYYADSVKGRFTISRDNSKNTVYLQMNSLRTEDTAVYYCAKTNGAGSGKGYYYYGMDVWGQGTTVTVSS<cls><cls>EIVLTQSPGTLSLSPGESATLSCRASQSVSSTYLVWYQQKPGQAPRLLIYGASSRATGIPDRFSGSGSGTDFTLTISRLEPEDFAVYYCQQYGPSPLYTFGQGTKLEIR'

## tokenize data

In [32]:
tokenized_data = []
for s in tqdm(seqs):
    tokenized_data.append(tokenizer(s, return_tensors='pt', padding=True, truncation=True, max_length=320)['input_ids'][0])

  0%|          | 0/64516 [00:00<?, ?it/s]

In [None]:
i = {'input_ids': [t['input_ids'][0] for t in tokenized_data],
     'attention_mask': [t['attention_mask'] for t in tokenized_data]}

## inference

In [24]:
@dataclass
class ModelOutput:
    '''
    
    '''
    name: str
    chain: str
    mean_final_layer_embedding: np.ndarray


In [29]:
#inputs

In [35]:
outputs = []

batch = 0
batchsize = 20000

x = slice(batch*batchsize, (batch+1)*batchsize)
inputs = list(zip(seq_names[x], seqs[x], tokenized_data[x]))

with torch.no_grad():
    for name, seq, i in tqdm(inputs):
        o = model(
            torch.tensor(i).unsqueeze(0).to('cuda'),
            output_hidden_states=True,
            return_dict=False,
        )
        final_layer_hidden_state = np.array(o[1][6][0].to('cpu'))
        h, l = seq.split('<cls><cls>')
        h_state = final_layer_hidden_state[:len(h)]
        l_state = final_layer_hidden_state[-len(l):]
        outputs.append(ModelOutput(name, 'heavy', h_state.mean(axis=0)))
        outputs.append(ModelOutput(name, 'light', l_state.mean(axis=0)))

with open('./balmMoE_outputs_20k.pkl', 'wb') as f:
    pickle.dump(outputs, f)


  0%|          | 0/20000 [00:00<?, ?it/s]

  torch.tensor(i).unsqueeze(0).to('cuda'),
