**NOTE**: If, for some reason, you need to install `ipywidgets`, then you may have to restart the kernel in order to load IProgress

In [1]:
__author__ = "Matteo Pariset"

In [4]:
import torch, esm
import numpy as np
import pandas as pd
from tqdm.auto import tqdm

import os

import sys
sys.path.append("../../")

from utils import relative_path, get_dataset_dir

# ESM embeddings generator
version 1.1.4

In [None]:
# The name of the processed dataset
dataset_name = "developability-thera"
root_result = get_dataset_dir(dataset_name)
# The column containing Ab sequences
sequence_names = ["sequence"]
# The compression strategy
strategy = "bos"
batch_size = 16

In [None]:
# Change the default cache dir
torch.hub.set_dir("./cache")

In [None]:
pretr_model_name = "esm-1v-1"

if pretr_model_name == "esm-2":
    pretr_model = esm.pretrained.esm2_t33_650M_UR50D()
elif pretr_model_name == "esm-1v-1":
    pretr_model = esm.pretrained.esm1v_t33_650M_UR90S_1()
else:
    raise ValueError("Unsupported PLM")

def init_tokenizer(max_seq_length):
    _, alphabet = pretr_model
    batch_converter = alphabet.get_batch_converter()
    return batch_converter

def init_model(max_bpi_length):  
    model, _ = pretr_model
    model.to('cuda:0')
    return model
    
compression_strategies = {
    "uncompressed": lambda x: x,
    "mean": lambda x: x.mean(axis=1),
    "bos": lambda x: x[:,0]
}

In [None]:
def insert_in_mold(mold, content):
    assert len(mold.shape) == 2, "Only (batch, feats) molds are supported"
    # ATTENTION: The '<pad>' token in ESM1b has id == 1
    mold[:,:] = 1
    mold[:content.shape[0], :content.shape[1]] = content
    return mold

In [None]:
def compute_fast_embeddings(filename, sequence_names, batch_size, compression_func=None, max_bpi_len=None, show_progress=True, incremental_saving=None, saver_func=None):
    file_content = pd.read_csv(filename)

    def saving_hierarchy(idx_names, sequence_names, outs):
        return [(name, np.concatenate(outs[i], axis=0).astype('float32')) for i, name in zip(idx_names, sequence_names)]
    
    if show_progress:
        print("Executing ...")

    # Compute the number of batches (using first col of seqs)
    probe_seqs = file_content[sequence_names[0]].astype("string").to_numpy()
    print("INFO: Loaded %d rows" % len(probe_seqs))
    # Adjust batch size
    batch_size = min(batch_size, probe_seqs.shape[0])
    batch_num = len(probe_seqs)//batch_size

    # ESM uses SOS, EOS toks
    max_seq_len = max([len(x) for x in probe_seqs]) + 2

    input_encoder = init_tokenizer(max_seq_len)
    batches_encs = {}

    if max_bpi_len is None:
        max_bpi_len = 0

    for seq_idx, seq_name in enumerate(sequence_names):
        seqs = file_content[seq_name].astype("string").to_numpy()

        # ESM uses SOS, EOS toks
        seqs_len = max([len(x) for x in probe_seqs]) + 2
        if seqs_len != max_seq_len:
            raise NotImplementedError("ESM embeddings generator does not support sequences of different max length")

        batch_num = len(seqs)//batch_size
        in_batch = np.array_split(seqs, batch_num, axis=0)

        seq_batch_encs = [input_encoder([(("seq_%d" % i), seq) for i, seq in enumerate(batch_seqs)])[2] for batch_seqs in in_batch]

        max_bpi_len = max(max_bpi_len, *map(lambda x: x.shape[1], seq_batch_encs))

        batches_encs[seq_idx] = seq_batch_encs

    print("max_seq_len:", max_seq_len, "max_batch_len:", max_bpi_len)

    model = init_model(max_bpi_len)

    slice_num = 0

    names_idxs = list(range(len(sequence_names)))

    # Account for additional sample in batch from array_split
    mold_input = torch.ones((batch_size+1, max_bpi_len), dtype=torch.int, device='cuda:0')
    
    outs = {i: [] for i in names_idxs}

    with torch.no_grad():
        for batch_idx in tqdm(range(batch_num), disable=(not show_progress)):
            for seq_idx in names_idxs:
                outs[seq_idx].append(
                    compression_func(model(
                        insert_in_mold(mold_input, batches_encs[seq_idx][batch_idx])[:batches_encs[seq_idx][batch_idx].shape[0]],
                        repr_layers=[33], 
                        return_contacts=False
                        )['representations'][33].cpu()
                    )
                )
            
            if incremental_saving is not None:
                if batch_idx > 0 and batch_idx % incremental_saving == 0:
                    print(f"[INFO] Saving slice {slice_num} to file")
                    saver_func(saving_hierarchy(names_idxs, sequence_names, outs), slice_num)
                    slice_num += 1
                    outs = {i: [] for i in range(len(sequence_names))}

    if incremental_saving is None:
        return saving_hierarchy(names_idxs, sequence_names, outs)
    elif len(outs[0]) > 0:
        saver_func(saving_hierarchy(names_idxs, sequence_names, outs), slice_num)
        print("[INFO] Done")

In [None]:
def as_filename(dataset_name, strategy="", slice=None):
    fname = f"{dataset_name}_{pretr_model_name}"
    if strategy != "" and strategy in compression_strategies:
        fname += f"_{strategy}"
    if slice is not None:
        fname += f"_p{slice}"
    fname += ".npz"
    return fname

In [None]:
def incremental_saver(content, i):
    if len(content) > 1 or content[0][0] != "sequence":
        # Named saving
        print("[INFO] Using named hierarchy")
        np.savez(os.path.join(root_result, as_filename(dataset_name, strategy, slice=i)), **dict(content))
    else:
        # Anonymous saving
        print("[INFO] Using anonymous hierarchy")
        np.savez(os.path.join(root_result, as_filename(dataset_name, strategy, slice=i)), *[cnt for name, cnt in content])

##### Auto save multiple files

In [None]:
compute_fast_embeddings(os.path.join(get_dataset_dir(dataset_name), f"{dataset_name}.csv"), sequence_names, batch_size=batch_size, compression_func=compression_strategies[strategy], incremental_saving=256, saver_func=incremental_saver)