In [None]:
!pip install fair-esm==1.0.2

Collecting fair-esm==1.0.2
  Downloading fair_esm-1.0.2-1-py3-none-any.whl.metadata (33 kB)
Downloading fair_esm-1.0.2-1-py3-none-any.whl (76 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/76.4 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m76.4/76.4 kB[0m [31m6.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: fair-esm
Successfully installed fair-esm-1.0.2


In [None]:
import torch
from torch.utils.data import DataLoader
from transformers import AutoModelForMaskedLM, AutoTokenizer, RobertaConfig
from glm_dataset import FastaBatchedContigDataset, get_collate_fn, contig_collate_fn, ContigDataset
from glm_model import gLM
from glm_utils import save_results

In [1]:
model_id = 'facebook/esm2_t33_650M_UR50D'
tokenizer = AutoTokenizer.from_pretrained(model_id)
fasta_file = '../data/example_data/inference_example/test.fa'

In [7]:
MAX_LEN = 12290

cds = FastaBatchedContigDataset.from_file(fasta_file)
cds.from_contig_file('../data/example_data/inference_example/contig_to_prots.tsv')
batches = cds.get_batch_indices(MAX_LEN, extra_toks_per_seq=1)

data_loader = DataLoader(
    cds,
    collate_fn=get_collate_fn(tokenizer),
    batch_sampler=batches
)

In [None]:
batch = next(iter(data_loader))
batch

In [None]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
esm_model = AutoModelForMaskedLM.from_pretrained(model_id, device_map=device)
esm_model.eval()

Some weights of the model checkpoint at facebook/esm2_t33_650M_UR50D were not used when initializing EsmForMaskedLM: ['esm.embeddings.position_embeddings.weight']
- This IS expected if you are initializing EsmForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing EsmForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


EsmForMaskedLM(
  (esm): EsmModel(
    (embeddings): EsmEmbeddings(
      (word_embeddings): Embedding(33, 1280, padding_idx=1)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): EsmEncoder(
      (layer): ModuleList(
        (0-32): 33 x EsmLayer(
          (attention): EsmAttention(
            (self): EsmSelfAttention(
              (query): Linear(in_features=1280, out_features=1280, bias=True)
              (key): Linear(in_features=1280, out_features=1280, bias=True)
              (value): Linear(in_features=1280, out_features=1280, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
              (rotary_embeddings): RotaryEmbedding()
            )
            (output): EsmSelfOutput(
              (dense): Linear(in_features=1280, out_features=1280, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (LayerNorm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
          )
          (intermediate): Esm

In [None]:
def get_plm_embs(esm_model, data_loader):
    plm_embs = []

    torch.cuda.empty_cache()

    with torch.no_grad():
        for batch in data_loader:
            lens = batch['attention_mask'].sum(dim=1)-2 # removes special tokens from count

            batch['input_ids'] = batch['input_ids'].to(device)
            batch['attention_mask'] = batch['attention_mask'].to(device)
            output = esm_model(input_ids=batch['input_ids'],
                              attention_mask=batch['attention_mask'],
                              output_hidden_states=True)
            states = output['hidden_states'][-1].detach().cpu()
            del output
            torch.cuda.empty_cache()

            for pid, ori, seq_size, state in zip(batch['prot_ids'], batch['prot_oris'], lens, states):
                truncate_len = min(MAX_LEN-2, seq_size)
                plm_embs.append((pid, ori, state[1:truncate_len+1].mean(0).detach().cpu()))
                
    return plm_embs

In [41]:
plm_embs = get_plm_embs(esm_model, data_loader)
nds = ContigDataset(cds, plm_embs, './preproc/norm_factors.pt', './preproc/pca_parms.pt')
ndl = DataLoader(nds, batch_size=1, collate_fn=contig_collate_fn)

In [None]:
batch = next(iter(ndl))
batch

In [22]:
!mkdir model
!wget https://zenodo.org/record/7855545/files/glm.bin --output-document=./model/glm.bin

--2025-06-23 18:37:50--  https://zenodo.org/record/7855545/files/glm.bin
Resolving zenodo.org (zenodo.org)... 188.185.48.194, 188.185.43.25, 188.185.45.92, ...
Connecting to zenodo.org (zenodo.org)|188.185.48.194|:443... connected.
HTTP request sent, awaiting response... 301 MOVED PERMANENTLY
Location: /records/7855545/files/glm.bin [following]
--2025-06-23 18:37:50--  https://zenodo.org/records/7855545/files/glm.bin
Reusing existing connection to zenodo.org:443.
HTTP request sent, awaiting response... 200 OK
Length: 3819298935 (3.6G) [application/octet-stream]
Saving to: ‘glm.bin’


2025-06-23 18:40:58 (19.4 MB/s) - ‘glm.bin’ saved [3819298935/3819298935]



In [14]:
HIDDEN_SIZE = 1280
HALF = True
EMB_DIM = 1281
NUM_PC_LABEL = 100

num_pred = 4
max_seq_length = 30
num_attention_heads = 10
num_hidden_layers= 19
pos_emb = "relative_key_query"
pred_probs = True

config = RobertaConfig(
    vocab_size = 30522,
    max_position_embedding = max_seq_length,
    hidden_size = HIDDEN_SIZE,
    num_attention_heads = num_attention_heads,
    type_vocab_size = 1,
    tie_word_embeddings = False,
    num_hidden_layers = num_hidden_layers,
    num_pc = NUM_PC_LABEL,
    num_pred = num_pred,
    predict_probs = pred_probs,
    emb_dim = EMB_DIM,
    output_attentions = True,
    output_hidden_states = True,
    position_embedding_type = pos_emb,
    attn_implementation = "eager"
)

model_path = './glm.bin'

model =  gLM(config)
model.load_state_dict(torch.load(model_path, map_location=device),strict=False)
model.eval()

gLM(
  (embeddings): RobertaEmbeddings(
    (word_embeddings): Embedding(30522, 1280, padding_idx=1)
    (position_embeddings): Embedding(512, 1280, padding_idx=1)
    (token_type_embeddings): Embedding(1, 1280)
    (LayerNorm): LayerNorm((1280,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): RobertaEncoder(
    (layer): ModuleList(
      (0-18): 19 x RobertaLayer(
        (attention): RobertaAttention(
          (self): RobertaSelfAttention(
            (query): Linear(in_features=1280, out_features=1280, bias=True)
            (key): Linear(in_features=1280, out_features=1280, bias=True)
            (value): Linear(in_features=1280, out_features=1280, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
            (distance_embedding): Embedding(1023, 128)
          )
          (output): RobertaSelfOutput(
            (dense): Linear(in_features=1280, out_features=1280, bias=True)
            (LayerNorm): LayerNorm((1280

In [6]:
def get_glm_embs(model, data_loader, half_precision=True):
    if not torch.cuda.is_available():
        half_precision = False

    torch.cuda.empty_cache()
    scaler = None

    with torch.no_grad():
        all_outputs = {}
        all_batches = {}
        for batch in data_loader:
            batch['inputs_embeds'] = batch['inputs_embeds'].to(device)
            batch['labels'] = batch['labels'].to(device)
            batch['attention_mask'] = batch['attention_mask'].to(device)

            if half_precision:
                scaler = torch.cuda.amp.GradScaler()
                with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
                    output = model(**batch)
            else:
                output = model(**batch)

            for k in ['logits_all_preds', 'probs', 'last_hidden_state', 'contacts']:
                current = all_outputs.get(k, None)
                if current is None:
                    all_outputs[k] = output[k].detach().cpu()
                else:
                    all_outputs[k] = torch.cat([all_outputs[k], output[k].detach().cpu()], dim=0)
            del output
            for k in batch.keys():
                current = all_batches.get(k, None)
                if current is None:
                    all_batches[k] = batch[k].detach().cpu()
                else:
                    all_batches[k] = torch.cat([all_batches[k], batch[k].detach().cpu()], dim=0)
            torch.cuda.empty_cache()
            
    return all_batches, all_outputs

In [17]:
all_batches, all_outputs = get_glm_embs(model, ndl, HALF)
save_results(all_batches, all_outputs, './results', 'test')

In [19]:
glm_embs = torch.load('test_glm_embs.pt', weights_only=False)
glm_embs[0]

array([[ 0.3616 ,  0.0715 , -0.3362 , ..., -0.2272 ,  0.02887,  0.06335],
       [ 0.379  ,  0.0717 , -0.3325 , ..., -0.2264 ,  0.03049,  0.04584],
       [ 0.374  ,  0.427  , -0.319  , ..., -0.5903 ,  0.4385 , -0.2443 ],
       ...,
       [-0.3157 , -0.0497 , -0.5293 , ..., -0.11896,  0.297  ,  0.4146 ],
       [-0.9995 , -0.4336 , -0.4648 , ..., -0.2656 ,  0.7754 ,  0.4424 ],
       [ 0.3442 ,  0.09143, -0.305  , ..., -0.1823 ,  0.02934,  0.0788 ]],
      dtype=float16)