# gLM Inference

This notebook was developed to allow for easy inference using the pretrained Genome Language Model (gLM) in Google Colab.

***
For more details about the model, please check the original paper [_Genomic language model predicts protein co-regulation and function_](https://www.biorxiv.org/content/10.1101/2023.04.07.536042v3) by Yunha Hwang, Andre L. Cornman, Elizabeth H. Kellogg, Sergey Ovchinnikov, Peter R. Girguis.
***

The inference process follows these steps:
- loading a FASTA of proteins
- loading a contig (TSV) file with contig-to-protein mappings
- embedding proteins (pLM embeddings) using pretrained ESM model
- preprocessing pLM embeddings (normalization for inputs, PCA for labels)
- embedding proteins contextually using pretrained gLM model

## Setup

Apart from Colab's default environment, we only need to install `fair-esm` package.

In [1]:
!pip install fair-esm==1.0.2

Collecting fair-esm==1.0.2
  Downloading fair_esm-1.0.2-1-py3-none-any.whl.metadata (33 kB)
Downloading fair_esm-1.0.2-1-py3-none-any.whl (76 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/76.4 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m76.4/76.4 kB[0m [31m5.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: fair-esm
Successfully installed fair-esm-1.0.2


Next, we download custom dataset and model classes.

In [2]:
!wget https://raw.githubusercontent.com/dvgodoy/gLM/refs/heads/main/inference/glm_dataset.py
!wget https://raw.githubusercontent.com/dvgodoy/gLM/refs/heads/main/inference/glm_model.py
!wget https://raw.githubusercontent.com/dvgodoy/gLM/refs/heads/main/inference/glm_utils.py

--2025-06-24 15:01:59--  https://raw.githubusercontent.com/dvgodoy/gLM/refs/heads/main/inference/glm_dataset.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 9940 (9.7K) [text/plain]
Saving to: ‘glm_dataset.py’


2025-06-24 15:01:59 (15.6 MB/s) - ‘glm_dataset.py’ saved [9940/9940]

--2025-06-24 15:02:00--  https://raw.githubusercontent.com/dvgodoy/gLM/refs/heads/main/inference/glm_model.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.110.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 13378 (13K) [text/plain]
Saving to: ‘glm_model.py’


2025-06-24 15:02:00 (18.3 MB/s) - ‘glm_m

## Imports

In [4]:
import torch
from torch.utils.data import DataLoader
from transformers import AutoModelForMaskedLM, AutoTokenizer, RobertaConfig
from glm_dataset import FastaBatchedContigDataset, get_collate_fn, contig_collate_fn, ContigDataset
from glm_model import gLM
from glm_utils import save_results

## Data

The first file we'll use is a typical FASTA file containing the protein sequences, such as:

```
>lcl|NC_000913.3_prot_NP_414542.1_1 [gene=thrL] [locus_tag=b0001] [db_xref=UniProtKB/Swiss-Prot:P0AD86] [protein=thr operon leader peptide] [protein_id=NP_414542.1] [location=190..255] [gbkey=CDS]
MKRISTTITTTITITTGNGAG
```

In the protein above, the we'll be using the header (up to the first space) to identify the protein, e.g. `lcl|NC_000913.3_prot_NP_414542.1_1`.

The second file, is a TSV (tab-separated values) file where each line represents a contig, such as:

```
contig_0\t+lcl|NC_000913.3_prot_NP_414542.1_1;+lcl|NC_000913.3_prot_NP_414543.1_2;+lcl|NC_000913.3_prot_NP_414544.1_3;+lcl|NC_000913.3_prot_NP_414545.1_4;+lcl|NC_000913.3_prot_NP_414546.1_5;-lcl|NC_000913.3_prot_NP_414547.1_6;-lcl|NC_000913.3_prot_NP_414548.1_7;+lcl|NC_000913.3_prot_NP_414549.1_8;+lcl|NC_000913.3_prot_NP_414550.1_9;-lcl|NC_000913.3_prot_NP_414551.1_10;-lcl|NC_000913.3_prot_NP_414552.1_11;+lcl|NC_000913.3_prot_YP_009518733.1_12;-lcl|NC_000913.3_prot_NP_414554.1_13;+lcl|NC_000913.3_prot_NP_414555.1_14;+lcl|NC_000913.3_prot_NP_414556.1_15;+lcl|NC_000913.3_prot_NP_414557.1_16;-lcl|NC_000913.3_prot_NP_414559.1_17;-lcl|NC_000913.3_prot_YP_025292.1_18;+lcl|NC_000913.3_prot_NP_414560.1_19;+lcl|NC_000913.3_prot_NP_414561.1_20;-lcl|NC_000913.3_prot_NP_414562.1_21;-lcl|NC_000913.3_prot_NP_414563.1_22;-lcl|NC_000913.3_prot_NP_414564.1_23;+lcl|NC_000913.3_prot_NP_414565.1_24;+lcl|NC_000913.3_prot_NP_414566.1_25;+lcl|NC_000913.3_prot_NP_414567.1_26;+lcl|NC_000913.3_prot_NP_414568.1_27;+lcl|NC_000913.3_prot_NP_414569.1_28;+lcl|NC_000913.3_prot_NP_414570.1_29;+lcl|NC_000913.3_prot_NP_414571.1_30
```

The contig name, `contig_0` in the example above, is the first field, and everything the follow the tab (`\t`) is a sequence of all proteins in the contig (notice that the name of the first protein matches the first entry of our FASTA file) separated by "`;`". Moreover, each protein's ID is preceded by their orientation, either `+` or `-`.

In [7]:
!wget https://raw.githubusercontent.com/dvgodoy/gLM/refs/heads/main/data/example_data/inference_example/test.fa
!wget https://raw.githubusercontent.com/dvgodoy/gLM/refs/heads/main/data/example_data/inference_example/contig_to_prots.tsv

--2025-06-24 17:00:25--  https://raw.githubusercontent.com/dvgodoy/gLM/refs/heads/main/data/example_data/inference_example/test.fa
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 2606:50c0:8001::154, 2606:50c0:8002::154, 2606:50c0:8000::154, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|2606:50c0:8001::154|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 33005 (32K) [text/plain]
Saving to: ‘test.fa’


2025-06-24 17:00:25 (11.5 MB/s) - ‘test.fa’ saved [33005/33005]

--2025-06-24 17:00:25--  https://raw.githubusercontent.com/dvgodoy/gLM/refs/heads/main/data/example_data/inference_example/contig_to_prots.tsv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 2606:50c0:8001::154, 2606:50c0:8002::154, 2606:50c0:8000::154, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|2606:50c0:8001::154|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2237 (2.2K) [text/pla

## Dataset for Protein Language Model (pLM) Embeddings

We'll be using a customizer version of ESM's `FastaBatchedDataset` to include contig information. Our class is named `FastaBatchedContigDataset` and it allows for loading the contigs file using the `from_contig_file()` method.

In [8]:
fasta_file = 'test.fa'
contigs_file = 'contig_to_prots.tsv'
MAX_LEN = 12290

# Loads FASTA file into customized FastaBatchedDataset
cds = FastaBatchedContigDataset.from_file(fasta_file)

# Imports contig to proteins mapping
# It will assign sequential IDs to proteins
# And then map the proteins in the contigs to their corresponding IDs, 
# while also updating origin (+/-) information for each protein
cds.from_contig_file(contigs_file)

### Protein Mappings

Our custom class has a few new attributes, such as `prot2id`, `id2prot`, and `prot_oris`:

In [26]:
list(cds.prot2id.keys())[:3]

['lcl|NC_000913.3_prot_NP_414542.1_1',
 'lcl|NC_000913.3_prot_NP_414543.1_2',
 'lcl|NC_000913.3_prot_NP_414544.1_3']

In [21]:
(cds.prot2id['lcl|NC_000913.3_prot_NP_414542.1_1'],
 cds.id2prot[0],
 cds.prot_oris['lcl|NC_000913.3_prot_NP_414542.1_1'])

(0, 'lcl|NC_000913.3_prot_NP_414542.1_1', '+')

### Contig Mappings

It also has a `contigs` attribute, a dictionary of all contigs loaded, where their corresponding values are a list of the sequential IDs of their corresponding proteins.

In [18]:
cds.contigs.keys()

dict_keys(['contig_0', 'contig_1'])

In [19]:
print(cds.contigs['contig_0'])

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29]


## Tokenizer

The ESM model doesn't work with raw sequences, but tokenized ones. The job of the tokenizer is to convert each symbol into its corresponding ID. The vocabulary is only 33 tokens long, including all amino acids and a few extra special tokens.

In [10]:
# Loads ESM tokenizer to convert proteins into IDs
# The tokenizer will pad the shortest sequences with 0s
model_id = 'facebook/esm2_t33_650M_UR50D'
tokenizer = AutoTokenizer.from_pretrained(model_id)

print(tokenizer.all_tokens)

['<cls>', '<pad>', '<eos>', '<unk>', 'L', 'A', 'G', 'V', 'S', 'E', 'R', 'T', 'I', 'D', 'P', 'K', 'Q', 'N', 'F', 'Y', 'M', 'H', 'W', 'C', 'X', 'B', 'U', 'Z', 'O', '.', '-', '<null_1>', '<mask>']


## Mini-Batches

In order to send data to the ESM model, we'll group proteins in mini-batches, so that they do not exceed the maximum length configured. Also, we're using a custom collate function that uses the tokenizer to convert the sequences into their IDs, pad them to match their lengths, and include additional information such as the proteins orientation and label.

In [8]:
# Determines which elements from FASTA file are contained in each batch
batches = cds.get_batch_indices(MAX_LEN, extra_toks_per_seq=1)

# Uses custom collate_fn to produce mini-batches containing
# input_ids, attention_masks, prot_ids, prot_oris, labels
data_loader = DataLoader(
    cds,
    collate_fn=get_collate_fn(tokenizer),
    batch_sampler=batches
)

In [9]:
batch = next(iter(data_loader))
batch

{'input_ids': tensor([[ 0, 20, 15,  ...,  1,  1,  1],
        [ 0, 20, 15,  ...,  1,  1,  1],
        [ 0, 20, 17,  ...,  1,  1,  1],
        ...,
        [ 0, 20, 17,  ...,  1,  1,  1],
        [ 0, 20, 16,  ..., 13,  2,  1],
        [ 0, 20, 11,  ...,  4,  4,  2]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 1, 1, 0],
        [1, 1, 1,  ..., 1, 1, 1]]), 'prot_ids': (0, 17, 56, 16, 23, 22, 21, 43, 4, 49, 33, 12, 27, 47, 11, 26, 20, 45, 9, 8, 34, 55, 57, 10, 40, 5, 35, 54, 30, 50, 48, 19, 29, 2, 24, 41, 28, 7), 'prot_oris': ('+', '-', '+', '-', '+', '-', '-', '+', '+', '-', '+', '-', '+', '+', '+', '+', '-', '+', '-', '+', '-', '+', '-', '-', '+', '-', '-', '+', '+', '-', '-', '+', '+', '+', '+', '+', '+', '+'), 'labels': ('lcl|NC_000913.3_prot_NP_414542.1_1 [gene=thrL] [locus_tag=b0001] [db_xref=UniProtKB/Swiss-Prot:P0AD86] [protein=thr op

## pLM Model

Next, we load the `facebook/esm2_t33_650M_UR50D` from the Hugging Face Hub and send it directly to the device. It takes roughly 2.5 Gb of RAM in the GPU.

In [10]:
# Loads ESM model from HF Hub
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
esm_model = AutoModelForMaskedLM.from_pretrained(model_id, device_map=device)
esm_model.eval()

config.json:   0%|          | 0.00/724 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.61G [00:00<?, ?B/s]

Some weights of the model checkpoint at facebook/esm2_t33_650M_UR50D were not used when initializing EsmForMaskedLM: ['esm.embeddings.position_embeddings.weight']
- This IS expected if you are initializing EsmForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing EsmForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


EsmForMaskedLM(
  (esm): EsmModel(
    (embeddings): EsmEmbeddings(
      (word_embeddings): Embedding(33, 1280, padding_idx=1)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): EsmEncoder(
      (layer): ModuleList(
        (0-32): 33 x EsmLayer(
          (attention): EsmAttention(
            (self): EsmSelfAttention(
              (query): Linear(in_features=1280, out_features=1280, bias=True)
              (key): Linear(in_features=1280, out_features=1280, bias=True)
              (value): Linear(in_features=1280, out_features=1280, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
              (rotary_embeddings): RotaryEmbedding()
            )
            (output): EsmSelfOutput(
              (dense): Linear(in_features=1280, out_features=1280, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (LayerNorm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
          )
          (intermediate): Esm

## pLM Embeddings

To streamline the process of sending data to the ESM model, aggregating its outputs, and manage GPU memory to avoid OOM (out of memory) errors, we can use the `get_plm_embs()` function below:

In [11]:
def get_plm_embs(esm_model, data_loader):
    """
    Extracts protein language model (PLM) embeddings from the final hidden layer of the model
    for each protein sequence in the data loader.

    Args:
        esm_model (nn.Module): A transformer-based model (e.g., ESM or similar) with output_hidden_states enabled.
        data_loader (DataLoader): PyTorch DataLoader providing batches of protein input dictionaries
                                  with keys: 'input_ids', 'attention_mask', 'prot_ids', 'prot_oris'.

    Returns:
        List[Tuple[str, str, torch.Tensor]]: A list of tuples where each tuple contains:
            - prot_id (str): Protein identifier
            - prot_ori (str): Original protein sequence string
            - emb (torch.Tensor): Mean-pooled embedding over token-level representations (shape: [hidden_dim])
    """
    plm_embs = []
    torch.cuda.empty_cache()

    with torch.no_grad():
        for batch in data_loader:
            # Compute sequence lengths by summing attention mask and subtracting 2 special tokens (e.g., CLS, SEP)
            lens = batch['attention_mask'].sum(dim=1) - 2

            # Move input tensors to GPU
            batch['input_ids'] = batch['input_ids'].to(device)
            batch['attention_mask'] = batch['attention_mask'].to(device)

            # Forward pass to get hidden states
            output = esm_model(
                input_ids=batch['input_ids'],
                attention_mask=batch['attention_mask'],
                output_hidden_states=True
            )

            # Get last hidden layer (shape: [batch_size, seq_len, hidden_dim])
            states = output['hidden_states'][-1].detach().cpu()

            del output
            torch.cuda.empty_cache()

            # Process each sequence in the batch
            for pid, ori, seq_size, state in zip(batch['prot_ids'], batch['prot_oris'], lens, states):
                truncate_len = min(MAX_LEN - 2, seq_size)  # Truncate to max allowed length
                # Exclude special tokens and mean-pool across the valid tokens
                pooled_emb = state[1:truncate_len + 1].mean(0).detach().cpu()
                plm_embs.append((pid, ori, pooled_emb))

    return plm_embs

In [12]:
plm_embs = get_plm_embs(esm_model, data_loader)

The function will return a list of tuples, one of each protein.

In [34]:
len(plm_embs)

60

Each tuple contains the sequential ID assigned to the protein, its orientation, and its corresponding pLM embeddings returned by the ESM model.

In [33]:
plm_embs[0]

(0,
 '+',
 tensor([-0.0149,  0.0149,  0.0323,  ..., -0.0172,  0.0257,  0.0504],
        dtype=torch.float16))

## Genome Language Modeling (gLM) Embeddings

We're almost ready to retrieve contextual protein embeddings using the gLM model. But, first, we need to preprocess the ESM embeddings we obtained in the previous step.

### Preprocessing Data

There are two main preprocessing steps:

- normalizing the embeddings (required for inference)
- applying PCA to the embeddings (not required for inference, only for training/evaluation)

The parameters for both transformations were already precomputed and are available for download. The custom dataset will load these parameters and apply them accordingly, as we'll see in the next section.

In [35]:
!wget https://github.com/dvgodoy/gLM/raw/refs/heads/main/inference/preproc/norm_factors.pt
!wget https://github.com/dvgodoy/gLM/raw/refs/heads/main/inference/preproc/pca_parms.pt

--2025-06-24 17:05:32--  https://github.com/dvgodoy/gLM/raw/refs/heads/main/inference/preproc/norm_factors.pt
Resolving github.com (github.com)... 140.82.121.3
Connecting to github.com (github.com)|140.82.121.3|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/dvgodoy/gLM/refs/heads/main/inference/preproc/norm_factors.pt [following]
--2025-06-24 17:05:33--  https://raw.githubusercontent.com/dvgodoy/gLM/refs/heads/main/inference/preproc/norm_factors.pt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 2606:50c0:8001::154, 2606:50c0:8002::154, 2606:50c0:8000::154, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|2606:50c0:8001::154|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 11702 (11K) [application/octet-stream]
Saving to: ‘norm_factors.pt’


2025-06-24 17:05:33 (10.5 MB/s) - ‘norm_factors.pt’ saved [11702/11702]

--2025-06-24 17:05:33--  https://github.co

### Dataset

We have implemented yet another dataset, `ContigDataset` that takes our former `FastaBatchedContigDataset`, the pLM embeddings, and the two files containing the preprocessing parameters. 

This a dataset of contigs, where each entry is a contig (like a sentence in a typical language model), and the contig is composed of a sequence of proteins (like a sentence is composed of words/tokens). 

The shape of the resulting mini-batch is `(N, L, D)` where N stands for the number of elements in the mini-batch (how many contigs), L stands for the sequence length (how many proteins in the contig), and D stands for the dimensionality of the pLM embeddings (1280).

Moreover, just like typical LMs, if one contig is shorter (has fewer proteins) than other, it will be padded (-1 as protein ID) and the `attention_mask` will reflect that accordingly (containing a zero for the padded protein).

The dataset will perform the following steps behind the scenes:

- computing `input_embeds`
    - normalize the pLM embeddings (1280 dimensions) using the provided normalization factors
    - append an additional column containing either `+0.5` or `-0.5` according to the protein orientation, `+` or `-`, respectively (resulting in 1281 dimensions)
- computing `labels` (we don't necessarily need them for inference)
    - make a copy of the normalized pLM embeddings and apply the PCA transformation to it, thus reducing it to 99 dimensions only
    - append an additional column containing either `+0.5` or `-0.5` according to the protein orientation, `+` or `-`, respectively (resulting in 100 dimensions)
- pad shorter contigs with `-1` entries for protein IDs (`prot_ids` key)
- update the `attention_mask` according to the padding

In [45]:
# Loads normalization parms to preprocess pLM embeddings
# Applies PCA to pLM embeddings to get corresponding labels
# Concatenates either +0.5 or -0.5 to each protein, both in
# embeddings and labels, according to their origin, + or -
nds = ContigDataset(cds, plm_embs, 'norm_factors.pt', 'pca_parms.pt')

In [51]:
# Uses custom collate function to include build mini-batches of N contigs
# each contig being a sequence of L proteins
# each protein represented by D dimensions (1280 for embeddings, 100 for PCA-transformed labels)
# (N, L, D) shapes
# If some contigs are shorter than others, the attention masks are updated to reflect padding
# and the padded protein IDs will be -1
ndl = DataLoader(nds, batch_size=2, collate_fn=contig_collate_fn)

batch = next(iter(ndl))
batch

{'inputs_embeds': tensor([[[-7.4164e-01,  6.7392e-01,  7.5514e-01,  ...,  9.5282e-01,
            1.4359e-01,  5.0000e-01],
          [-1.3501e-01, -4.8856e-01,  9.2102e-01,  ...,  5.2716e-02,
            4.1855e-01,  5.0000e-01],
          [-1.0075e+00, -1.3698e+00,  7.2967e-01,  ...,  2.7863e-01,
            8.1399e-01,  5.0000e-01],
          ...,
          [-7.0542e-01,  9.5284e-02,  1.6966e+00,  ..., -1.7093e-02,
            2.1791e+00,  5.0000e-01],
          [ 4.7144e-01, -5.7810e-01, -2.3907e-01,  ...,  7.3569e-01,
            3.2090e+00,  5.0000e-01],
          [ 6.7436e-01, -1.2447e+00,  1.9134e+00,  ..., -9.7436e-01,
           -4.9839e-01,  5.0000e-01]],
 
         [[ 1.6382e-01, -3.3953e-01,  1.0804e+00,  ...,  9.6821e-01,
           -1.1294e+00,  5.0000e-01],
          [ 6.9267e-04,  1.9632e-01, -2.1015e-01,  ...,  7.3679e-01,
            2.8885e-01,  5.0000e-01],
          [ 1.7924e-01,  1.6796e-01, -1.4275e+00,  ...,  7.8929e-01,
            2.0567e+00,  5.0000e-01],
  

### gLM Model

To load the gLM model, we must first build its architecture, which is based on the RoBERTa model.

In [None]:
HIDDEN_SIZE = 1280
HALF = True
EMB_DIM = 1281
NUM_PC_LABEL = 100

num_pred = 4
max_seq_length = 30
num_attention_heads = 10
num_hidden_layers= 19
pos_emb = "relative_key_query"
pred_probs = True

# Creates RoBERTa configuration to load gLM
config = RobertaConfig(
    vocab_size = 30522,
    max_position_embedding = max_seq_length,
    hidden_size = HIDDEN_SIZE,
    num_attention_heads = num_attention_heads,
    type_vocab_size = 1,
    tie_word_embeddings = False,
    num_hidden_layers = num_hidden_layers,
    num_pc = NUM_PC_LABEL,
    num_pred = num_pred,
    predict_probs = pred_probs,
    emb_dim = EMB_DIM,
    output_attentions = True,
    output_hidden_states = True,
    position_embedding_type = pos_emb,
    attn_implementation = "eager"
)

# Loads configuration to build architecture
model =  gLM(config)

Now, we need to download its pretrained weights:

In [29]:
# Downloads the pretrained gLM model
# This download can be painfully slow at times!
!mkdir model
!wget https://zenodo.org/records/7855545/files/glm.bin --output-document=./model/glm.bin

mkdir: cannot create directory ‘model’: File exists
--2025-06-24 15:12:10--  https://zenodo.org/records/7855545/files/glm.bin
Resolving zenodo.org (zenodo.org)... 188.185.43.25, 188.185.48.194, 188.185.45.92, ...
Connecting to zenodo.org (zenodo.org)|188.185.43.25|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 3819298935 (3.6G) [application/octet-stream]
Saving to: ‘./model/glm.bin’


2025-06-24 15:30:38 (3.29 MB/s) - ‘./model/glm.bin’ saved [3819298935/3819298935]



Finally, we load the weights into the model and set it to evaluation mode.

In [30]:
# Loads saved weights
model_path = './model/glm.bin'
model.load_state_dict(torch.load(model_path, map_location=device),strict=False)
model.to(device)
model.eval()

gLM(
  (embeddings): RobertaEmbeddings(
    (word_embeddings): Embedding(30522, 1280, padding_idx=1)
    (position_embeddings): Embedding(512, 1280, padding_idx=1)
    (token_type_embeddings): Embedding(1, 1280)
    (LayerNorm): LayerNorm((1280,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): RobertaEncoder(
    (layer): ModuleList(
      (0-18): 19 x RobertaLayer(
        (attention): RobertaAttention(
          (self): RobertaSelfAttention(
            (query): Linear(in_features=1280, out_features=1280, bias=True)
            (key): Linear(in_features=1280, out_features=1280, bias=True)
            (value): Linear(in_features=1280, out_features=1280, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
            (distance_embedding): Embedding(1023, 128)
          )
          (output): RobertaSelfOutput(
            (dense): Linear(in_features=1280, out_features=1280, bias=True)
            (LayerNorm): LayerNorm((1280

The model has two head, `lm_head` and `contact_head`, for predicting the label embeddings and contacts, respectively. Notice that the `lm_head` is an instance of the `gLMMultiHead`, which produces 400 output features (4 heads for producing 4 candidates, each having 100 dimensions from the PCA-transformed labels with an extra column), and 4 logits for the candidates' corresponding probabilities.

### gLM Embeddings

To streamline the process of sending data to our gLM model, aggregating its outputs, and manage GPU memory to avoid OOM (out of memory) errors, we can use the `get_glm_embs()` function below:

In [51]:
def get_glm_embs(model, data_loader, device, half_precision=True):
    """
    Extracts embeddings and prediction outputs from a gLM model over a given dataset.

    Args:
        model (nn.Module): A gLM model that returns a dictionary of outputs including
                           'logits_all_preds', 'probs', 'last_hidden_state', and 'contacts'.
        data_loader (DataLoader): PyTorch DataLoader yielding batches containing:
                                  'inputs_embeds', 'labels', and 'attention_mask'.
        device (torch.device): The CUDA or CPU device to run the model on.
        half_precision (bool, optional): Whether to use automatic mixed precision (AMP) inference on GPU.
                                         Defaults to True, but is disabled if CUDA is unavailable.

    Returns:
        Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
            - all_batches: Dictionary with concatenated input batches (e.g., inputs_embeds, labels, etc.).
            - all_outputs: Dictionary with concatenated model outputs across all batches.
    """
    if not torch.cuda.is_available():
        half_precision = False  # Disable AMP if CUDA isn't available

    torch.cuda.empty_cache()
    scaler = None  # AMP gradient scaling

    with torch.no_grad():
        all_outputs = {}  # Stores outputs across all batches
        all_batches = {}  # Stores inputs across all batches

        for batch in data_loader:
            # Move batch components to the target device
            batch['inputs_embeds'] = batch['inputs_embeds'].to(device)
            batch['labels'] = batch['labels'].to(device)
            batch['attention_mask'] = batch['attention_mask'].to(device)

            # Run inference with AMP if half_precision is enabled
            if half_precision:
                scaler = torch.cuda.amp.GradScaler()
                with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
                    output = model(
                        inputs_embeds=batch['inputs_embeds'],
                        labels=batch['labels'],
                        attention_mask=batch['attention_mask']
                    )
            else:
                output = model(
                    inputs_embeds=batch['inputs_embeds'],
                    labels=batch['labels'],
                    attention_mask=batch['attention_mask']
                )

            # Accumulate model outputs (detach and move to CPU to free GPU memory)
            for k in ['logits_all_preds', 'probs', 'last_hidden_state', 'contacts']:
                current = all_outputs.get(k, None)
                if current is None:
                    all_outputs[k] = output[k].detach().cpu()
                else:
                    all_outputs[k] = torch.cat([all_outputs[k], output[k].detach().cpu()], dim=0)

            del output

            # Accumulate input batches as well
            for k in batch.keys():
                current = all_batches.get(k, None)
                if current is None:
                    all_batches[k] = batch[k].detach().cpu()
                else:
                    all_batches[k] = torch.cat([all_batches[k], batch[k].detach().cpu()], dim=0)

            torch.cuda.empty_cache()  # Prevent GPU memory buildup

    return all_batches, all_outputs

The function above will return plenty of information:

- from the batches:
    - input embeddings (normalized pLM embeddings with an appended column)
    - label embeddings (PCA-transformed normalized pLM embeddings with an appended column)
    - protein IDs
- from the outputs:
    - contextual gLM embeddings (from the last hidden state in the gLM model)
    - contact prediction based on embeddings and attention scores
    - a set of 4 output embeddings (predicted label embeddings, `logits_all_preds`)
    - a set of 4 probabilities for their corresponding predictions (`probs`)

In [75]:
# Retrieve gLM contextual embeddings
all_batches, all_outputs = get_glm_embs(model, ndl, device, HALF)
all_outputs.keys()

dict_keys(['logits_all_preds', 'probs', 'last_hidden_state', 'contacts'])

### Saving

Finally, we can use the `save_results()` function to organize the outputs, stack them (removing the batch dimension), and save them into three files:
- one containing all results (_results.pt)
- one containing only gLM embeddings (_glm_embs.pt)
- one containing attention scores/contacts (_attention.pt)

In [None]:
# Saves results to three files:

save_results(all_batches, all_outputs, './results', 'test')

In [76]:
glm_embs = torch.load('./results/test_glm_embs.pt', weights_only=False)
glm_embs[0]

(0,
 array([ 0.3616 ,  0.0725 , -0.3357 , ..., -0.2272 ,  0.02904,  0.064  ],
       dtype=float16))

You should see the following result:

`(0, array([ 0.3616 ,  0.0725 , -0.3357 , ..., -0.2272 ,  0.02904,  0.064  ], dtype=float16))`