In [2]:
import os
import json 
from tqdm import tqdm
import numpy as np
import pickle

This notebook present the script to generate the ESMIN data from LLRs of ESM models.
Here, we take the minimum of LLRS for 11 models below.
The calculation of LLRs for each individual model can be performed with ```infer_llrs.py``. Please refer to the README.md for the running commands.

The individual llrs are stored in the ```base_llrs``` folder (Download the precomputed llrs [here](https://huggingface.co/datasets/ntranoslab/vesm_datasets/blob/main/base_llrs.zip)). One can generate different settings of ESMIN varying the set of models.

In [5]:
esm_name_dct = {
     'esm1b': 'esm1b_t33_650M_UR50S',
     'esm1v_1': 'esm1v_t33_650M_UR90S_1',
     'esm1v_2': 'esm1v_t33_650M_UR90S_2',
     'esm1v_3': 'esm1v_t33_650M_UR90S_3',
     'esm1v_4': 'esm1v_t33_650M_UR90S_4',
     'esm1v_5': 'esm1v_t33_650M_UR90S_5',
     'esm2_650m': 'esm2_t33_650M_UR50D',
     'esm2_3b': 'esm2_t36_3B_UR50D',
     'esm2_150m': 'esm2_t30_150M_UR50D',
     'esm2_8m': 'esm2_t6_8M_UR50D',
     'esm2_35m': 'esm2_t12_35M_UR50D',
}
models = list(esm_name_dct.keys())

####  Helper Functions

In [3]:
def read_json(fpath):
    with open(fpath, 'r') as file:
        data = json.load(file)
    return data

def read_pkl(fpath):
    with open(fpath, 'rb') as fp:
        dct = pickle.load(fp)
    return dct

def dict2json(json_pth, dct):
    with open(json_pth, "w") as outfile:
        json.dump(dct, outfile, indent=2)

def try_makedir(fdir):
    if not os.path.exists(fdir):
        os.makedirs(fdir)

def write_pkl(dct, fpath):
    with open(fpath, 'wb') as fp:
        pickle.dump(dct, fp)
        

In [4]:
# break long sequences into multiple segments
def break_long_sequence(seq_length, model_window=1022):
    half_window = model_window // 2
    if seq_length <= model_window:
        return [[0, seq_length]]
    else:
        lst = []
        s = 0; e = 0
        while e < seq_length:
            e = min(seq_length, s + model_window)
            lst.append([s, e])
            s += half_window
        return lst
    
def filter_segments(uid, sequence, scores, segments, max_len=1022):
    indices = break_long_sequence(len(sequence), model_window=max_len)
    data = {}
    for k, (s, e) in enumerate(indices):
        if s in segments:
            data[f"{uid}_segment_{s}"] = {
                "llrs": scores[s:e, :],
                "sequence": sequence[s:e]
            }
    return data

def get_stats(data_dct, key=None):
    n = 0; mu = 0; ss = 0
    min_e = 1000; max_e = -1000
    for _, arr in data_dct.items():
        if key == 'llrs':
            arr = arr['llrs']
        elif key == 'shorten':
            arr = arr[1:-1, 4:24]
        size_protein = arr.shape[0] * arr.shape[1]
        mu_protein = arr.mean()
        std_protein = arr.std()
        n += size_protein
        mu += mu_protein * size_protein
        ss_protein = (std_protein**2 + mu_protein**2) * size_protein
        ss += ss_protein
    mu = mu / n
    variance = ss/n - mu**2
    return mu

## Generating Data from base model llrs

In [6]:
data_dir = "." # data folder
llr_dir = f"{data_dir}/base_llrs" # folder contains llrs for all individual models

meta_dct = read_pkl(f"{data_dir}/train/UniProtKB_meta_data.pkl") 
segment_dct = meta_dct["segment_dct"]
seq_dct = meta_dct["seq_dct"]

In [None]:
# collect all scores 
esm_llr_dict = {m: read_pkl(f"{llr_dir}/llrs_{m}.pkl") for m in models}
esmin_dct = {}

# processing short sequences
for protein in  meta_dct["short_proteins"]:
    llrs_lst = [esm_llr_dict[m][protein] for m in models]
    esmin_dct[protein] = {
        "sequence": seq_dct[protein],
        "llrs": np.minimum.reduce(llrs_lst)
    }

# breaking long sequences
for protein in meta_dct["long_proteins"]:
    llrs_lst = [esm_llr_dict[m][protein] for m in models]
    scores = np.minimum.reduce(llrs_lst)
    segs = filter_segments(protein, seq_dct[protein], scores, segment_dct[protein], max_len=1022)
    esmin_dct.update(segs)

: 

Store the statitstics for adaptive mu

In [None]:
# Statistics 
print(f'"min": {get_stats(esmin_dct, "llrs")},')
for m in models:
    stats = get_stats(esm_llr_dict[m])
    print(f'"{m}": {stats},')
    
write_pkl(esmin_dct, f"{data_dir}/train/ESM11_UniProt_min.pkl") 

"min": -10.223638271060164,
"esm1b": -8.279172428875482,
"esm1v_1": -6.651073636222643,
"esm1v_2": -7.217647505916745,
"esm1v_3": -6.587440777276117,
"esm1v_4": -6.823804579465309,
"esm1v_5": -7.280105659452346,
"esm2_650m": -6.9738925928732,
"esm2_3b": -8.36833643554398,
"esm2_150m": -5.512521036469033,
"esm2_8m": -4.265973882034055,
"esm2_35m": -4.763005785707714,
