In [1]:
import torch

# Load model

In [2]:
# Change depending on the model you've downloaded, consider renaming the bin file too to not confuse ourselves
# Notice that currently the all pytorch_model.bin files are around 600MB,
# but we can choose later to only publish the LoKr and LoHA adapters which would probably be around 10-20MB
checkpoint = torch.load("/Users/aryopg/Downloads/pytorch_model.bin", map_location=torch.device("mps"))

In [3]:
import argparse
import os
import sys

# This is to force the path to be on the same level as the dl_ba folder
sys.path.append("../..")

from dl_ba import common_utils
from dl_ba.configs import Configs
from dl_ba.model import BindingAffinityModel

# Choose the config path
config_filepath = "../../configs/random_seed_experiments/bindingdb_random/esm_lokr_chemberta_loha_cosinemse_1.yaml"
configs = Configs(**common_utils.load_yaml(config_filepath))

model = BindingAffinityModel(configs.model_configs)

  from .autonotebook import tqdm as notebook_tqdm
Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t30_150M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of RobertaModel were not initialized from the model checkpoint at DeepChem/ChemBERTa-77M-MTR and are newly initialized: ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


trainable params: 128,160 || all params: 148,923,641 || trainable%: 0.08605752527901195
trainable params: 221,184 || all params: 3,648,624 || trainable%: 6.062120952994882


In [4]:
model.load_state_dict(checkpoint)
model = model.eval()

# Analyses!

Use `with torch.no_grad():` when running the analysis

In order to visualise BERT attention, remember to merge the PEFT adapter to the original model first.

In [5]:
# Merge PEFT and base model

protein_model = model.protein_model
merged_protein_model = protein_model.merge_and_unload()

In [6]:
# Get sequence to encode (TYK2)

from tdc.multi_pred import DTI
data = DTI(name = 'DAVIS')
split = data.get_split()

target_sequence_ids = data.entity2_idx.loc[data.entity2_idx == "TYK2(JH2domain-pseudokinase)"]
target_sequence_id = target_sequence_ids.keys()[0]

target_sequence = data.entity2.iloc[target_sequence_id]
print(target_sequence)

Found local copy...
Loading...
Done!


MPLRHWGMARGSKPVGDGAQPMAAMGGLKVLLHWAGPGGGEPWVTFSESSLTAEEVCIHIAHKVGITPPCFNLFALFDAQAQVWLPPNHILEIPRDASLMLYFRIRFYFRNWHGMNPREPAVYRCGPPGTEASSDQTAQGMQLLDPASFEYLFEQGKHEFVNDVASLWELSTEEEIHHFKNESLGMAFLHLCHLALRHGIPLEEVAKKTSFKDCIPRSFRRHIRQHSALTRLRLRNVFRRFLRDFQPGRLSQQMVMVKYLATLERLAPRFGTERVPVCHLRLLAQAEGEPCYIRDSGVAPTDPGPESAAGPPTHEVLVTGTGGIQWWPVEEEVNKEEGSSGSSGRNPQASLFGKKAKAHKAVGQPADRPREPLWAYFCDFRDITHVVLKEHCVSIHRQDNKCLELSLPSRAAALSFVSLVDGYFRLTADSSHYLCHEVAPPRLVMSIRDGIHGPLLEPFVQAKLRPEDGLYLIHWSTSHPYRLILTVAQRSQAPDGMQSLRLRKFPIEQQDGAFVLEGWGRSFPSVRELGAALQGCLLRAGDDCFSLRRCCLPQPGETSNLIIMRGARASPRTLNLSQLSFHRVDQKEITQLSHLGQGTRTNVYEGRLRVEGSGDPEEGKMDDEDPLVPGRDRGQELRVVLKVLDPSHHDIALAFYETASLMSQVSHTHLAFVHGVCVRGPENIMVTEYVEHGPLDVWLRRERGHVPMAWKMVVAQQLASALSYLENKNLVHGNVCGRNILLARLGLAEGTSPFIKLSDPGVGLGALSREERVERIPWLAPECLPGGANSLSTAMDKWGFGATLLEICFDGEAPLQSRSPSEKEHFYQRQHRLPEPSCPQLATLTSQCLTYEPTQRPSFRTILRDLTRLQPHNLADVLTVNPDSPASDPTVFHKRYLKKIRDLGEGHFGKVSLYCYDPTNDGTGEMVAVKALKADCGPQHRSGWKQEIDILRTLYHEHIIKYKGCCEDQGEKSLQLVMEYVPLGSLRDYLPRHSIGLAQL

In [7]:
from transformers import AutoTokenizer

protein_tokenizer = AutoTokenizer.from_pretrained(
    configs.model_configs.protein_model_name_or_path
)

In [8]:
merged_protein_model

EsmModel(
  (embeddings): EsmEmbeddings(
    (word_embeddings): Embedding(33, 640, padding_idx=1)
    (dropout): Dropout(p=0.0, inplace=False)
    (position_embeddings): Embedding(1026, 640, padding_idx=1)
  )
  (encoder): EsmEncoder(
    (layer): ModuleList(
      (0-29): 30 x EsmLayer(
        (attention): EsmAttention(
          (self): EsmSelfAttention(
            (query): Linear(in_features=640, out_features=640, bias=True)
            (key): Linear(in_features=640, out_features=640, bias=True)
            (value): Linear(in_features=640, out_features=640, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (rotary_embeddings): RotaryEmbedding()
          )
          (output): EsmSelfOutput(
            (dense): Linear(in_features=640, out_features=640, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (LayerNorm): LayerNorm((640,), eps=1e-05, elementwise_affine=True)
        )
        (intermediate): EsmIntermediate(
  

In [None]:
from bertviz import head_view, model_view

inputs = protein_tokenizer(target_sequence, return_tensors='pt')
input_ids = inputs["input_ids"]
# token_type_ids = inputs['token_type_ids']

with torch.no_grad():
    attentions = merged_protein_model(input_ids, output_attentions=True)["attentions"]
    print(f"Layers: {len(attentions)}")
    print(f"Size: {attentions[0].size()}")

tokens = protein_tokenizer.convert_ids_to_tokens(input_ids[0].tolist())

# To heavy for me macbook
# model_view(attentions, tokens)

head_view(attentions, tokens, include_layers=[0, 15, 29])

Layers: 30
Size: torch.Size([1, 20, 1189, 1189])
