In [2]:
from transformers import EsmForMaskedLM

esm = EsmForMaskedLM.from_pretrained("facebook/esm2_t6_8M_UR50D")
esm_sd = esm.state_dict()

for k, v in esm_sd.items():
  print(k, v.shape)

esm.embeddings.word_embeddings.weight torch.Size([33, 320])
esm.embeddings.position_embeddings.weight torch.Size([1026, 320])
esm.encoder.layer.0.attention.self.query.weight torch.Size([320, 320])
esm.encoder.layer.0.attention.self.query.bias torch.Size([320])
esm.encoder.layer.0.attention.self.key.weight torch.Size([320, 320])
esm.encoder.layer.0.attention.self.key.bias torch.Size([320])
esm.encoder.layer.0.attention.self.value.weight torch.Size([320, 320])
esm.encoder.layer.0.attention.self.value.bias torch.Size([320])
esm.encoder.layer.0.attention.self.rotary_embeddings.inv_freq torch.Size([8])
esm.encoder.layer.0.attention.output.dense.weight torch.Size([320, 320])
esm.encoder.layer.0.attention.output.dense.bias torch.Size([320])
esm.encoder.layer.0.attention.LayerNorm.weight torch.Size([320])
esm.encoder.layer.0.attention.LayerNorm.bias torch.Size([320])
esm.encoder.layer.0.intermediate.dense.weight torch.Size([1280, 320])
esm.encoder.layer.0.intermediate.dense.bias torch.Size([12

In [1]:
import torch
from torch.utils.data import DataLoader
from transformers import EsmTokenizer

from data import ShardedMLMDataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
T = 510
TOKENS_PER_BATCH = 4096

train_loader = DataLoader(ShardedMLMDataset(crop_len=T, tokens_per_batch=TOKENS_PER_BATCH, split='train'), batch_size=None)
item = next(iter(train_loader))
print(item[0].shape, item[1].shape, item[2].shape)
print(item[0].numel())

torch.Size([47, 87]) torch.Size([47, 87]) torch.Size([47, 87])
4089


In [3]:
tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
print(tokenizer.vocab_size)
print(tokenizer.get_vocab())
example_seq = tokenizer.decode(item[0][0])
example_label = tokenizer.decode(item[1][0])
example_mask = ','.join(map(str, item[2][0].tolist()))
print(example_seq)
print(example_label)
print(example_mask)

33
{'<cls>': 0, '<pad>': 1, '<eos>': 2, '<unk>': 3, 'L': 4, 'A': 5, 'G': 6, 'V': 7, 'S': 8, 'E': 9, 'R': 10, 'T': 11, 'I': 12, 'D': 13, 'P': 14, 'K': 15, 'Q': 16, 'N': 17, 'F': 18, 'Y': 19, 'M': 20, 'H': 21, 'W': 22, 'C': 23, 'X': 24, 'B': 25, 'U': 26, 'Z': 27, 'O': 28, '.': 29, '-': 30, '<null_1>': 31, '<mask>': 32}
<cls> M R E <mask> Q L <mask> <mask> Y I K I H K L F L L K K <mask> E V V S <mask> E H E R F Y C Y F D <mask> <mask> <mask> T T I K G T T S K U V D <mask> K <mask> I W V E A R P E I R K <mask> D I D <mask> N P S G F I N A G E L <mask> P E <mask> <eos>
<unk> <unk> <unk> <unk> K <unk> <unk> M E <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> M <unk> <unk> <unk> <unk> D <unk> <unk> <unk> S <unk> <unk> <unk> <unk> <unk> <unk> I N Y <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> I <unk> <unk> N <unk> V <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> K <unk> <unk> <unk> S <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk>

In [6]:
from model import ESM, ESMConfig
from transformers import EsmForMaskedLM

def generate_sequence(logits):
    pred_ids = torch.argmax(logits, dim=-1)
    return tokenizer.convert_ids_to_tokens(pred_ids)

device = "cpu"
if torch.cuda.is_available():
    device = "cuda"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
    device = "mps"
print(f"using device: {device}")

esm_model = EsmForMaskedLM.from_pretrained("facebook/esm2_t12_35M_UR50D")
esm_model.to("cpu")
esm_model.eval()
for p in esm_model.parameters():
    p.requires_grad_(False)
model = ESM(ESMConfig())
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
for i, batch in enumerate(train_loader):
    inputs, labels, mask = batch
    inputs = inputs.to(device)
    labels = labels.to(device)
    mask = mask.to(device)

    optimizer.zero_grad()
    logits, loss = model(inputs, mask, labels)
    loss.backward()
    optimizer.step()
    
    if i % 10 == 0:
        inputs = inputs.to("cpu")
        labels = labels.to("cpu")
        mask = mask.to("cpu")
        esm_output = esm_model(input_ids=inputs, attention_mask=mask, labels=labels)
        esm_logits = esm_output.logits
        esm_loss = esm_output.loss
        print(f"step {i} | loss: {loss.detach().item()}")
        print(f"step {i} | esm_loss: {esm_loss.detach().item()}")
        print(f"step {i} | lab ==> {tokenizer.decode(labels[0][labels[0] != -100])}")
        print(f"step {i} | gen ==> {' '.join(generate_sequence(logits[0][labels[0] != -100]))}")
        print(f"step {i} | esm_gen ==> {' '.join(generate_sequence(esm_logits[0][labels[0] != -100]))}")
    
    if i >= 49:
        break

using device: mps
step 0 | loss: 3.5479190349578857
step 0 | esm_loss: 2.491105318069458
step 0 | lab ==> L E V D T S R K A
step 0 | gen ==> Z Z Z Z Z Z Z I Z
step 0 | esm_gen ==> D K E K K I I K L
step 10 | loss: 2.9885172843933105
step 10 | esm_loss: 2.318218231201172
step 10 | lab ==> F S L S L F E W G R
step 10 | gen ==> L L L L L L L L L L
step 10 | esm_gen ==> F L L L L L L L K K
step 20 | loss: 2.9062085151672363
step 20 | esm_loss: 2.2791340351104736
step 20 | lab ==> L L L Y H V I S A R H
step 20 | gen ==> L L L L L L L L L L L
step 20 | esm_gen ==> L L P L H P I L L R R
step 30 | loss: 2.9326579570770264
step 30 | esm_loss: 2.427870512008667
step 30 | lab ==> I G Y S G R I K W K
step 30 | gen ==> L L L L L L L L L L
step 30 | esm_gen ==> I G G G L K A Y K K
step 40 | loss: 2.944211483001709
step 40 | esm_loss: 2.4567716121673584
step 40 | lab ==> A I P I R S P S G
step 40 | gen ==> A A A A A A A A A
step 40 | esm_gen ==> A A A I A S A S S
