In [4]:
from transformers import EsmForMaskedLM

esm = EsmForMaskedLM.from_pretrained("facebook/esm2_t6_8M_UR50D")
esm_sd = esm.state_dict()

for k, v in esm_sd.items():
  print(k, v.shape)

  from .autonotebook import tqdm as notebook_tqdm


esm.embeddings.word_embeddings.weight torch.Size([33, 320])
esm.embeddings.position_embeddings.weight torch.Size([1026, 320])
esm.encoder.layer.0.attention.self.query.weight torch.Size([320, 320])
esm.encoder.layer.0.attention.self.query.bias torch.Size([320])
esm.encoder.layer.0.attention.self.key.weight torch.Size([320, 320])
esm.encoder.layer.0.attention.self.key.bias torch.Size([320])
esm.encoder.layer.0.attention.self.value.weight torch.Size([320, 320])
esm.encoder.layer.0.attention.self.value.bias torch.Size([320])
esm.encoder.layer.0.attention.self.rotary_embeddings.inv_freq torch.Size([8])
esm.encoder.layer.0.attention.output.dense.weight torch.Size([320, 320])
esm.encoder.layer.0.attention.output.dense.bias torch.Size([320])
esm.encoder.layer.0.attention.LayerNorm.weight torch.Size([320])
esm.encoder.layer.0.attention.LayerNorm.bias torch.Size([320])
esm.encoder.layer.0.intermediate.dense.weight torch.Size([1280, 320])
esm.encoder.layer.0.intermediate.dense.bias torch.Size([12

In [10]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import EsmTokenizer

from data import ShardedMLMDataset

In [6]:
T = 510
TOKENS_PER_BATCH = 4096

train_loader = DataLoader(ShardedMLMDataset(crop_len=T, tokens_per_batch=TOKENS_PER_BATCH, split='train'), batch_size=None)
item = next(iter(train_loader))
print(item[0].shape, item[1].shape, item[2].shape)
print(item[0].numel())

torch.Size([33, 122]) torch.Size([33, 122]) torch.Size([33, 122])
4026


In [8]:
tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
print(tokenizer.vocab_size)
print(tokenizer.get_vocab())
example_seq = tokenizer.decode(item[0][0])
example_label = tokenizer.decode(item[1][0])
example_mask = ','.join(map(str, item[2][0].tolist()))
print(example_seq)
print(example_label)
print(example_mask)

33
{'<cls>': 0, '<pad>': 1, '<eos>': 2, '<unk>': 3, 'L': 4, 'A': 5, 'G': 6, 'V': 7, 'S': 8, 'E': 9, 'R': 10, 'T': 11, 'I': 12, 'D': 13, 'P': 14, 'K': 15, 'Q': 16, 'N': 17, 'F': 18, 'Y': 19, 'M': 20, 'H': 21, 'W': 22, 'C': 23, 'X': 24, 'B': 25, 'U': 26, 'Z': 27, 'O': 28, '.': 29, '-': 30, '<null_1>': 31, '<mask>': 32}
<cls> M I Q Q T L L L Y G Y P F G T <unk> V L K D N L <mask> V G Q M R T <mask> R S I D Q <mask> <mask> S R E W M V D <mask> C T E W L A I V T F S P G V C R K E S Q T <null_1> F C S <mask> I G W C P K N E S <mask> <mask> I L A G G Q A K T T L E G S S L F P S I <mask> G L E N C R <mask> E F T E S T N L R E <eos>
<unk> <unk> <unk> <unk> <unk> <unk> L <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> D <unk> <unk> <unk> <unk> <unk> <unk> F <unk> <unk> <unk> <unk> <unk> <unk> Q <unk> <unk> <unk> <unk> <unk> Y E <unk> <unk> <unk> <unk> <unk> <unk> <unk> R <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> 

In [13]:
from model import ESM, ESMConfig
from transformers import EsmForMaskedLM
import esm

def generate_sequence(logits):
    pred_ids = torch.argmax(logits, dim=-1)
    return tokenizer.convert_ids_to_tokens(pred_ids)

device = "cpu"
if torch.cuda.is_available():
    device = "cuda"
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
    device = "mps"
print(f"using device: {device}")

esm_model = EsmForMaskedLM.from_pretrained("facebook/esm2_t12_35M_UR50D")
esm_model.to("cpu")
esm_model.eval()
for p in esm_model.parameters():
    p.requires_grad_(False)
fair_esm_model, _ = esm.pretrained.esm2_t12_35M_UR50D()
fair_esm_model.to("cpu")
fair_esm_model.eval()
for p in fair_esm_model.parameters():
    p.requires_grad_(False)

model = ESM(ESMConfig())
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
for i, batch in enumerate(train_loader):
    inputs, labels, mask = batch
    inputs = inputs.to(device)
    labels = labels.to(device)
    mask = mask.to(device)

    # optimizer.zero_grad()
    # logits, loss = model(inputs, mask, labels)
    # loss.backward()
    # optimizer.step()
    
    if i % 10 == 0:
        inputs = inputs.to("cpu")
        labels = labels.to("cpu")
        mask = mask.to("cpu")
        esm_output = esm_model(input_ids=inputs, attention_mask=mask, labels=labels)
        esm_logits = esm_output.logits
        esm_loss = esm_output.loss
        fair_esm_logits = fair_esm_model(inputs, repr_layers=[], return_contacts=False)["logits"]
        fair_esm_loss = F.cross_entropy(fair_esm_logits.view(-1, fair_esm_logits.size(-1)), labels.view(-1), ignore_index=-100)
        # print(f"step {i} | loss: {loss.detach().item()}")
        print(f"step {i} | esm_loss: {esm_loss.detach().item()}")
        print(f"step {i} | fair_esm_loss: {fair_esm_loss.detach().item()}")
        # print(f"step {i} | lab ==> {tokenizer.decode(labels[0][labels[0] != -100])}")
        # print(f"step {i} | gen ==> {' '.join(generate_sequence(logits[0][labels[0] != -100]))}")
        # print(f"step {i} | esm_gen ==> {' '.join(generate_sequence(esm_logits[0][labels[0] != -100]))}")
    
    if i >= 49:
        break

using device: mps
step 0 | esm_loss: 2.3216257095336914
step 0 | fair_esm_loss: 2.321367025375366
step 10 | esm_loss: 2.06491756439209
step 10 | fair_esm_loss: 2.0659947395324707
step 20 | esm_loss: 2.179452896118164
step 20 | fair_esm_loss: 2.1791834831237793
step 30 | esm_loss: 2.168743133544922
step 30 | fair_esm_loss: 2.1677305698394775
step 40 | esm_loss: 2.153233766555786
step 40 | fair_esm_loss: 2.15655255317688
