In [1]:
from transformers import AutoTokenizer, AutoModelForMaskedLM, PreTrainedTokenizerFast, DataCollatorForLanguageModeling
import torch
import torch.nn as nn
import sys
import numpy as np
sys.path.append("../VAE_standard")
from models import DNADataset, ALPHABET, SEQ_LENGTH, LATENT_DIM, VAE

from matplotlib import pyplot as plt

sys.path.append("..")
import utils

import Bio.Data.CodonTable

from devinterp.utils import (
    EvaluateFn,
    EvalResults,
)

from BIF_sampler import (
    BIFEstimator,
    estimate_bif
)

import pandas as pd

In [2]:
MAX_TOKEN_LENGTH = 510
BATCH_SIZE=60
num_masks = 3

TEST_SEQ = 1
TRAIN_CUTOFF = 3000
TEST_TOKEN = 0

DEVICE = "cuda"

In [3]:
"""
tokenizer: 
input_ids - torch.LongTensor of shape (batch_size, sequence_length)
attention_mask - torch.Tensor of shape (batch_size, sequence_length), Mask values selected in {0,1}, where 0 := masked, 1 := not masked
"""

tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15, return_tensors="pt")

model = AutoModelForMaskedLM.from_pretrained("facebook/esm2_t6_8M_UR50D").to(DEVICE)



In [4]:
dataset = DNADataset(f"../data/training_spike.fasta")
sequences = [utils.get_genome(np.dot(x[0], np.arange(len(ALPHABET)))) for x in dataset]
names = [x[1] for x in dataset]
print("done extracting sequences!")

x = Bio.Data.CodonTable.standard_dna_table
str_seqs = ["".join(x).replace("-","") for x in sequences]
codons = [[x[num_masks * i:3 * (i+1)] for i in range(len(x) // 3)] for x in str_seqs]
aa_drop_na = [[s for s in "".join([x.forward_table.get(s,"") for s in seq][:MAX_TOKEN_LENGTH])] for seq in codons]
print("done extracting AAs!")

metadata = pd.read_csv(f"../data/all_data/all_metadata.tsv", sep="\t")
clade_labels = [metadata.loc[metadata.name == names[i], "clade_membership"].values[0] for i in range(len(names))]
print("done getting clade labels!")

done extracting sequences!
done extracting AAs!
done getting clade labels!


In [5]:
unique_aa_seqs = list(np.unique(["".join(x) for x in aa_drop_na]))

train_data = tokenizer(text=["".join(x) for x in aa_drop_na][:TRAIN_CUTOFF], return_tensors="pt", add_special_tokens=False, truncation=False, padding=True)["input_ids"]
bif_data = tokenizer(text=unique_aa_seqs, return_tensors="pt", add_special_tokens=False, truncation=False, padding=True)["input_ids"]

In [6]:
def collate_fn(batch, data_collator=data_collator, device=DEVICE):
    batch = torch.stack(batch, dim=0)
    inputs, labels = data_collator.torch_mask_tokens(batch)
    return inputs.to(device), labels.to(DEVICE)

sgld_dataloader = torch.utils.data.DataLoader(train_data, collate_fn=collate_fn, shuffle=False, batch_size=BATCH_SIZE, drop_last=False)
# bif_dataloader = 

In [7]:
print(next(iter(sgld_dataloader)))

(tensor([[20, 32,  7,  ..., 19, 10,  7],
        [32, 18,  7,  ...,  8, 18,  9],
        [20, 32,  7,  ...,  9,  6, 14],
        ...,
        [20, 18,  7,  ...,  4,  8, 18],
        [20, 18,  7,  ...,  4,  8, 18],
        [20, 18,  7,  ...,  4,  8, 18]], device='cuda:0'), tensor([[-100,   18, -100,  ..., -100, -100, -100],
        [  20, -100, -100,  ..., -100, -100, -100],
        [-100,   18, -100,  ...,   23, -100, -100],
        ...,
        [-100, -100, -100,  ..., -100, -100, -100],
        [-100, -100, -100,  ..., -100, -100, -100],
        [-100, -100, -100,  ..., -100, -100, -100]], device='cuda:0'))
