In [1]:
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch
import sys
import numpy as np
sys.path.append("../VAE_standard")
from models import DNADataset, ALPHABET, SEQ_LENGTH, LATENT_DIM, VAE

sys.path.append("..")
import utils

import Bio.Data.CodonTable

In [2]:
dataset = DNADataset(f"../data/training_spike.fasta")
sequences = [utils.get_genome(np.dot(x[0], np.arange(len(ALPHABET)))) for x in dataset]

In [3]:
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
model = AutoModelForMaskedLM.from_pretrained("facebook/esm2_t6_8M_UR50D")

MAX_TOKEN_LENGTH = 1024



In [4]:
x = Bio.Data.CodonTable.standard_dna_table
print(x)

Table 1 Standard, SGC0

  |  T      |  C      |  A      |  G      |
--+---------+---------+---------+---------+--
T | TTT F   | TCT S   | TAT Y   | TGT C   | T
T | TTC F   | TCC S   | TAC Y   | TGC C   | C
T | TTA L   | TCA S   | TAA Stop| TGA Stop| A
T | TTG L(s)| TCG S   | TAG Stop| TGG W   | G
--+---------+---------+---------+---------+--
C | CTT L   | CCT P   | CAT H   | CGT R   | T
C | CTC L   | CCC P   | CAC H   | CGC R   | C
C | CTA L   | CCA P   | CAA Q   | CGA R   | A
C | CTG L(s)| CCG P   | CAG Q   | CGG R   | G
--+---------+---------+---------+---------+--
A | ATT I   | ACT T   | AAT N   | AGT S   | T
A | ATC I   | ACC T   | AAC N   | AGC S   | C
A | ATA I   | ACA T   | AAA K   | AGA R   | A
A | ATG M(s)| ACG T   | AAG K   | AGG R   | G
--+---------+---------+---------+---------+--
G | GTT V   | GCT A   | GAT D   | GGT G   | T
G | GTC V   | GCC A   | GAC D   | GGC G   | C
G | GTA V   | GCA A   | GAA E   | GGA G   | A
G | GTG V   | GCG A   | GAG E   | GGG G   | G
--+---------

In [5]:
seq = "".join(sequences[0])
codon_seq = "".join([x.forward_table.get(seq[3*i:3*i+3], "stop") for i in range(len(seq) // 3 - 1)])
codon_subseq = codon_seq[:MAX_TOKEN_LENGTH]

In [6]:
print(len(seq))
print(len(codon_seq))
print(len(codon_subseq))

3822
1273
1024


In [52]:
k = sorted(list(tokenizer.get_vocab().keys()))
print(k)

['-', '.', '<cls>', '<eos>', '<mask>', '<null_1>', '<pad>', '<unk>', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z']


In [18]:
print(dir(tokenizer))



In [16]:
test_input = tokenizer(codon_subseq, return_tensors="pt", add_special_tokens=False)
print(test_input["attention_mask"])

tensor([[1, 1, 1,  ..., 1, 1, 1]])


In [10]:
with torch.no_grad():
    test = model(**test_input)

In [11]:
print(test)

MaskedLMOutput(loss=None, logits=tensor([[[-11.1755, -21.5964, -12.7069,  ..., -16.2963, -16.2876, -21.5850],
         [-11.0088, -21.2710, -11.9138,  ..., -16.1666, -16.2621, -21.2622],
         [-10.3287, -22.3205, -11.6068,  ..., -16.4872, -16.5031, -22.3068],
         ...,
         [-10.8905, -20.7800,  -9.2887,  ..., -16.4550, -16.3657, -20.7865],
         [ -9.2913, -19.5699,  -8.0398,  ..., -16.0253, -16.2629, -19.5651],
         [ -8.5411, -18.0649,  -4.8223,  ..., -16.0127, -16.0276, -18.0695]]]), hidden_states=None, attentions=None)


In [54]:
print(model)

EsmForMaskedLM(
  (esm): EsmModel(
    (embeddings): EsmEmbeddings(
      (word_embeddings): Embedding(33, 320, padding_idx=1)
      (dropout): Dropout(p=0.0, inplace=False)
      (position_embeddings): Embedding(1026, 320, padding_idx=1)
    )
    (encoder): EsmEncoder(
      (layer): ModuleList(
        (0-5): 6 x EsmLayer(
          (attention): EsmAttention(
            (self): EsmSelfAttention(
              (query): Linear(in_features=320, out_features=320, bias=True)
              (key): Linear(in_features=320, out_features=320, bias=True)
              (value): Linear(in_features=320, out_features=320, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
              (rotary_embeddings): RotaryEmbedding()
            )
            (output): EsmSelfOutput(
              (dense): Linear(in_features=320, out_features=320, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (LayerNorm): LayerNorm((320,), eps=1e-05, elementwise_a