In [3]:
import os
import sys
import hashlib
import warnings
sys.dont_write_bytecode=True

import numpy as np
import pandas as pd
#import matplotlib.pyplot as plt

import torch

from tokenization_esm import EsmTokenizer
from modeling_esm import EsmForSequenceClassificationMHACustom


  from .autonotebook import tqdm as notebook_tqdm


## load data

In [4]:
kinases = 'YRIRGEIGSGNFSQVKLGIHSLTKEKVAIKILDKTKLDQKTQRLLSREISSMEKLHHPNIIRLYEVVETLSKLHLVMEYAGGGELFGKISTEGKLSEPESKLIFSQIVSAVKHMHENQIIHRDLKAENVFYTSNTCVKVGDFGFSTVSKKGEMLNTFCGSPPYAAPELFRDEHYIGIYVDIWALGVLLYFMVTGTMPFRAETVAKLKKSILEGTYSVPPHVSEPCHRLIRGVLQQIPTERYGIDCIMNDEWM'
peptides = 'SSLRRHGSMVSLVSG'


## load model

In [5]:
model_dir = 'model/'

tokenizer = EsmTokenizer.from_pretrained(model_dir)
model     = EsmForSequenceClassificationMHACustom.from_pretrained(model_dir, num_labels=2)


In [6]:
model

EsmForSequenceClassificationMHACustom(
  (esm): EsmModel(
    (embeddings): EsmEmbeddings(
      (word_embeddings): Embedding(33, 640, padding_idx=1)
      (dropout): Dropout(p=0.0, inplace=False)
      (position_embeddings): Embedding(1026, 640, padding_idx=1)
    )
    (encoder): EsmEncoder(
      (layer): ModuleList(
        (0): EsmLayer(
          (attention): EsmAttention(
            (self): EsmSelfAttention(
              (query): Linear(in_features=640, out_features=640, bias=True)
              (key): Linear(in_features=640, out_features=640, bias=True)
              (value): Linear(in_features=640, out_features=640, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
              (rotary_embeddings): RotaryEmbedding()
            )
            (output): EsmSelfOutput(
              (dense): Linear(in_features=640, out_features=640, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (LayerNorm): LayerNorm((640,), eps=1e-

## run model

In [7]:
def run_model(peptides, kinases, model=model, tokenizer=tokenizer, device='cpu', batch_size=50, output_hidden_states=True, output_attentions=True):
    torch.cuda.empty_cache()
    
    model.eval()
    model = model.to(device)

    ids = tokenizer(peptides, kinases, padding=True, return_tensors='pt')
    ids = ids.to(device)
    output = dict()
    with torch.no_grad():
        results, classifier_attn_outputs, classifier_attn_output_weights = model(ids['input_ids'], 
                        attention_mask=ids['attention_mask'], 
                        output_hidden_states=output_hidden_states, 
                        output_attentions=output_attentions)

        attention_mask = ids['attention_mask'].cpu().type(torch.bool)

        output['probability'] = results['logits'].softmax(1)[:,1].cpu().numpy()

        if output_hidden_states:
            last_embeddings = results['hidden_states'][-1].cpu().numpy()
            output['embedding'] = [i[m] for i, m in zip(last_embeddings, attention_mask)]

        if output_attentions:
            last_attentions = results['attentions'][-1].cpu().numpy()
            output['attention'] = [i[:,m,:][:,:,m] for i, m in zip(last_attentions, attention_mask)]

        classifier_attn_outputs = classifier_attn_outputs.cpu()
        output['classifier_attn_outputs'] = classifier_attn_outputs

        classifier_attn_output_weights = classifier_attn_output_weights.cpu()
        output['classifier_attn_output_weights'] = [i[:,m[16:]] for i, m in zip(classifier_attn_output_weights, attention_mask)]

    return output

In [8]:
output = run_model(peptides, kinases, 
    model=model, 
    tokenizer=tokenizer, 
    output_hidden_states=False,
    output_attentions=False,
    # batch_size=32, 
    batch_size=1, 
    )

# prediction score
output["probability"]

array([0.98041224], dtype=float32)