In [None]:
import sys
sys.path.append('..')

In [None]:
import numpy as np
import matplotlib.pyplot as plt

In [None]:
from DomainPrediction import BaseProtein

In [None]:
protein = BaseProtein(file='../../Data/GxpS_ATC_AF.pdb')
T = [i for i in range(538,608)] ## 539-608

In [None]:
protein.get_residues(T) ## T domain

In [None]:
masked_query = ''.join(['<mask>' if i in T else protein.sequence[i] for i in range(len(protein.sequence))])

In [None]:
masked_query

In [None]:
import torch
import esm

class ESM2():
    def __init__(self, model_path, device='cpu') -> None:
        self.model, self.alphabet = esm.pretrained.load_model_and_alphabet(model_path)
        self.batch_converter = self.alphabet.get_batch_converter()
        self.model.eval()
        self.device = device

        if self.device == 'gpu':
            self.model.cuda()

        self.tok_to_idx = self.alphabet.tok_to_idx
        self.idx_to_tok = {v:k for k,v in self.tok_to_idx.items()}

    def get_res(self, sequence):
        data = [
            ("protein1", sequence)
        ]
        batch_labels, batch_strs, batch_tokens = self.batch_converter(data)
        batch_lens = (batch_tokens != self.alphabet.padding_idx).sum(1)

        if self.device == 'gpu':
            batch_tokens = batch_tokens.cuda()

        with torch.no_grad():
            results = self.model(batch_tokens, repr_layers=[33], return_contacts=True)

        return results

    def get_logits(self, sequence):

        results = self.get_res(sequence)
        return results['logits']

    def get_prob(self, sequence):
        logits = self.get_logits(sequence)
        prob = torch.nn.functional.softmax(logits, dim=-1)[0, 1:-1, :] # 1st and last are start and end tokens

        return prob.cpu().numpy()

In [None]:
model_path = '/data/users/kgeorge/workspace/esm2/checkpoints/esm2_t30_150M_UR50D.pt'
# model_path = '/data/users/kgeorge/workspace/esm2/checkpoints/esm2_t6_8M_UR50D.pt'
esm2 = ESM2(model_path = model_path, device='gpu')

In [None]:
prob = esm2.get_prob(sequence=masked_query)

In [None]:
from tqdm import tqdm
import numpy as np

def compute_perplexity(model, sequence, mask_token='<mask>'):
    '''
        pseudoperplexity(x) = exp( -1/L \sum_{i=1}_{L} [log( p(x_{i}|x_{j!=i}) )] )
    '''
    
    sum_log = 0
    for pos in tqdm(range(len(sequence))):
        masked_query = list(sequence)
        assert mask_token not in masked_query
        masked_query[pos] = mask_token
        masked_query = ''.join(masked_query)
        prob = model.get_prob(sequence=masked_query)

        assert prob.shape[0] == len(sequence)

        prob_pos = np.log(prob[pos, model.tok_to_idx[sequence[pos]]])
        
        sum_log += prob_pos

    return np.exp(-1*sum_log/len(sequence))


In [None]:
perplexity = compute_perplexity(esm2, protein.sequence)