In [106]:
import esm
import sys, os

import torch.nn as nn
from torch.nn.parallel import DataParallel
import importlib
from ephod import run, utils

importlib.reload(run)
importlib.reload(utils)

<module 'ephod.utils' from '/Users/jgado/Dropbox/research/projects/ephod/ephod_publish/EpHod/ephod/utils.py'>

In [8]:
run.models.ResidualLightAttention

ephod.training.nn_models.ResidualLightAttention

In [9]:
path = '/Users/jgado/Dropbox/research/projects/ephod/ephod_publish/zenodo'

In [10]:
import pandas as pd
import numpy as np
import joblib

In [95]:
ephodsvr = joblib.load(f"{path}/ESM1v-SVR.pkl")

In [96]:
ephodsvr[0]

In [32]:
from ephod.training import nn_models
importlib.reload(nn_models)
from ephod.training.nn_models import ResidualLightAttention

In [50]:
total = 0
for p in model.parameters():
    total += torch.sum(p)
total

tensor(12862.9707, grad_fn=<AddBackward0>)

In [56]:
model = ResidualLightAttention(dim=1280, kernel_size=7, dropout=0.1, res_blocks=4, activation='elu')
model = DataParallel(model)
url = 'https://zenodo.org/records/14252615/files/ESM1v-RLATtr.pt?download=1'
model_dict = torch.hub.load_state_dict_from_url(url, progress=False, map_location="cpu")
model.load_state_dict(model_dict)

<All keys matched successfully>

In [58]:
total = 0
for p in model.parameters():
    total += torch.sum(p)
total

tensor(-63.0517, grad_fn=<AddBackward0>)

In [107]:
class EpHodModel():
    
    def __init__(self):

        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        if self.device != 'cuda':
            print('WARNING: You are not using a GPU which will be slow.')
        self.esm1v_model, self.esm1v_batch_converter = self.load_ESM1v_model()
        self.svr_model, self.svr_stats = self.load_SVR_model()
        self.rlat_model = self.load_RLAT_model()
        _ = self.esm1v_model.eval()
        _ = self.rlat_model.eval()
        
    
    def load_ESM1v_model(self):
        '''Return pretrained ESM1v model weights and batch converter'''
        
        model, alphabet = esm.pretrained.esm1v_t33_650M_UR90S_1()
        model = model.to(self.device)
        batch_converter = alphabet.get_batch_converter()
        
        return model, batch_converter
    
    
    def get_ESM1v_embeddings(self, accs, seqs):
        '''Return per-residue embeddings (padded) for protein sequences from ESM1v model'''

        seqs = [utils.replace_noncanonical(seq, 'X') for seq in seqs]
        data = [(accs[i], seqs[i]) for i in range(len(accs))]
        batch_labels, batch_strs, batch_tokens = self.esm1v_batch_converter(data)
        batch_tokens = batch_tokens.to(device=self.device, non_blocking=True)
        emb = self.esm1v_model(batch_tokens, repr_layers=[33], return_contacts=False)
        emb = emb["representations"][33]
        emb = emb.transpose(2,1) # From (batch, seqlen, features) to (batch, features, seqlen)
        emb = emb.to(self.device)

        return emb
    
    
    def load_RLAT_model(self):
        '''Return residual light attention top model'''
        
        model = ResidualLightAttention(dim=1280, kernel_size=7, dropout=0.1, res_blocks=4, activation='elu')
        model = DataParallel(model)
        url = 'https://zenodo.org/records/14252615/files/ESM1v-RLATtr.pt?download=1'
        model_dict = torch.hub.load_state_dict_from_url(url, progress=False, map_location="cpu")
        model.load_state_dict(model_dict)
        model.to(self.device)

        return model

    
    def load_SVR_model(self):
        '''Return SVR top model'''
        
        this_dir, this_filename = os.path.split(__file__)
        path = os.path.join(this_dir, 'data', 'ESM1v-SVR.pkl')
        svr_model, svr_stats = joblib.load(path)

        return svr_model, svr_stats
        

    
    def batch_predict_rlat(self, accs, seqs):
        '''Predict pHopt with EpHod on a batch of sequences'''
        
        with torch.no_grad():
            
            # Get ESM1v embeddings and run RLATtr model
            emb_esm1v = self.get_ESM1v_embeddings(accs, seqs)
            maxlen = emb_esm1v.shape[-1]
            masks = [[1] * len(seqs[i]) + [0] * (maxlen - len(seqs[i])) \
                     for i in range(len(seqs))]
            masks = torch.tensor(masks, dtype=torch.int32)
            masks = masks.to(self.device)
            out = self.rlat_model(emb_esm1v, masks)
            rlat_pred, rlat_emb, rlat_attn = [item.cpu().numpy() for item in out]
        
            # Run SVR
            emb_pool = emb_esm1v.numpy().mean(axis=-1) # (batch, features, seqlen)
            emb_pool = (emb_pool - self.svr_stats[0]) / (self.svr_stats[1] + 1e-8) # Normalize
            svr_pred = self.svr_model.predict(emb_pool)
            ensemble_pred = (rlat_pred + svr_pred) / 2
        outdict = dict(rlat_pred=rlat_pred, rlat_emb=rlat_emb, rlat_attn=rlat_attn, 
                       svr_pred=svr_pred, ensemble_pred=ensemble_pred)

        return outdict
            
                
                

In [None]:
ephod_model = EpHodModel()

In [78]:
accs = ['1', '2']
seqs = ['MNTDVRIEKDFLGEKEIPKDAYYGVQTIRATENFPITGYRIHPELIKSLGIVKKSAALANMEVGLLDKEVGQYIVKAADEVIEGKWNDQFIVDPIQGGAGTSINMNANEVIANRALELMGEEKGNYSKISPNSHVNMSQSTNDAFPTATHIAVLSLLNQLIETTKYMQQEFMKKADEFAGVIKMGRTHLQDAVPILLGQEFEAYARVIARDIERIANTRNNLYDINMGATAVGTGLNADPEYISIVTEHLAKFSGHPLRSAQHLVDATQNTDCYTEVSSALKVCMINMSKIANDLRLMASGPRAGLSEIVLPARQPGSSIMPGKVNPVMPEVMNQVAFQVFGNDLTITSASEAGQFELNVMEPVLFFNLIQSISIMTNVFKSFTENCLKGIKANEERMKEYVEKSIGIITAINPHVGYETAAKLAREAYLTGESIRELCIKYGVLTEEQLNEILNPYEMTHPGIAGRK', 
        'MTAIIDIVGREILDSRGNPTVEVDVVLEDGSFGRAAVPSGASTGAHEAVELRDGGSRYLGKGVEKAVEVVNGKIFDAIAGMDAESQLLIDQTLIDLDGSANKGNLGANAILGVSLAVAKAAAQASGLPLYRYVGGTNAHVLPVPMMNIINGGAHADNPIDFQEFMILPVGATSIREAVRYGSEVFHTLKKRLKDAGHNTNVGDEGGFAPNLKNAQAALDFIMESIEKAGFKPGEDIALGLDCAATEFFKDGNYVYEGERKTRDPKAQAKYLAKLASDYPIVTIEDGMAEDDWEGWKYLTDLIGNKCQLVGDDLFVTNSARLRDGIRLGVANSILVKVNQIGSLSETLDAVETAHKAGYTAVMSHRSGETEDSTIADLAVATNCGQIKTGSLARSDRTAKYNQLIRIEEELGKQARYAGRSALKLL'
       ]
        

In [79]:
out = ephod_model.batch_predict(accs, seqs)
out[0]

tensor([7.8687, 7.2128])

In [98]:
out[2].shape

torch.Size([2, 1280, 470])

In [115]:
from ephod import run, utils
importlib.reload(utils)
importlib.reload(run)


<module 'ephod.run' from '/Users/jgado/Dropbox/research/projects/ephod/ephod_publish/EpHod/ephod/run.py'>

In [116]:
ephod_model = run.EpHodModel()



In [None]:
accs = ['1', '2']
seqs = ['MNTDVRIEKDFLGEKEIPKDAYYGVQTIRATENFPITGYRIHPELIKSLGIVKKSAALANMEVGLLDKEVGQYIVKAADEVIEGKWNDQFIVDPIQGGAGTSINMNANEVIANRALELMGEEKGNYSKISPNSHVNMSQSTNDAFPTATHIAVLSLLNQLIETTKYMQQEFMKKADEFAGVIKMGRTHLQDAVPILLGQEFEAYARVIARDIERIANTRNNLYDINMGATAVGTGLNADPEYISIVTEHLAKFSGHPLRSAQHLVDATQNTDCYTEVSSALKVCMINMSKIANDLRLMASGPRAGLSEIVLPARQPGSSIMPGKVNPVMPEVMNQVAFQVFGNDLTITSASEAGQFELNVMEPVLFFNLIQSISIMTNVFKSFTENCLKGIKANEERMKEYVEKSIGIITAINPHVGYETAAKLAREAYLTGESIRELCIKYGVLTEEQLNEILNPYEMTHPGIAGRK', 
        'MTAIIDIVGREILDSRGNPTVEVDVVLEDGSFGRAAVPSGASTGAHEAVELRDGGSRYLGKGVEKAVEVVNGKIFDAIAGMDAESQLLIDQTLIDLDGSANKGNLGANAILGVSLAVAKAAAQASGLPLYRYVGGTNAHVLPVPMMNIINGGAHADNPIDFQEFMILPVGATSIREAVRYGSEVFHTLKKRLKDAGHNTNVGDEGGFAPNLKNAQAALDFIMESIEKAGFKPGEDIALGLDCAATEFFKDGNYVYEGERKTRDPKAQAKYLAKLASDYPIVTIEDGMAEDDWEGWKYLTDLIGNKCQLVGDDLFVTNSARLRDGIRLGVANSILVKVNQIGSLSETLDAVETAHKAGYTAVMSHRSGETEDSTIADLAVATNCGQIKTGSLARSDRTAKYNQLIRIEEELGKQARYAGRSALKLL'
       ]
out = ephod_model.batch_predict(accs, seqs)