In [None]:
import torch
from transformers import EsmTokenizer, EsmModel

tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
model = EsmModel.from_pretrained("facebook/esm2_t6_8M_UR50D")

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.eval().to(device)
print(f"Putting Model in eval mode on {device}")

print(f"Tokenizer tokenizer.vocab_size {tokenizer.vocab_size}")
print(f"Tokenizer tokenizer.max_model_input_sizes {tokenizer.max_model_input_sizes}")

In [None]:
from Bio import SeqIO
from dataclasses import dataclass, field

@dataclass
class Protein:
    name: str
    filename: str
    link: str
    format: str
    seq: str = field(init=False)
    len: int = field(init=False)

    def __post_init__(self):
        self.seq = "".join([str(res.seq) for res in SeqIO.parse(self.filename, self.format)])
        self.len = len(self.seq)

# Proteins we will compare
proteins = [
    Protein("1JHF", "1jhf.pdb", "https://www.rcsb.org/structure/1JHF", "pdb-atom"),
    Protein("3LS4", "3ls4.pdb", "https://www.rcsb.org/structure/3LS4", "pdb-atom"),
    Protein("1K6F", "1k6f.cif", "https://www.rcsb.org/structure/1K6F", "cif-atom"),
    Protein("5XR8", "5xr8.pdb", "https://www.rcsb.org/structure/5XR8", "pdb-atom"),
]

In [None]:
# from collections import Counter

# input_protein = str(proteins[0][1])
# print(f"Input protein is {input_protein}")
# print(f"Input protein length {len(input_protein)}")
# print(f"Input protein count {Counter(input_protein)}")

In [None]:
# inputs = tokenizer(input_protein, return_tensors="pt")
# inputs['input_ids'] = inputs['input_ids'].to(device)
# inputs['attention_mask'] = inputs['attention_mask'].to(device)

# print(f"Inputs {inputs}")
# print(f"Inputs input_ids shape {inputs['input_ids'].shape}")
# print(f"Inputs input_ids device {inputs['input_ids'].device}")
# print(f"Inputs input_ids type {type(inputs['input_ids'])}")
# print(f"Inputs attention_mask shape {inputs['attention_mask'].shape}")
# print(f"Inputs attention_mask device {inputs['attention_mask'].device}")
# print(f"Inputs attention_mask type {type(inputs['attention_mask'])}")

In [None]:
inputs = tokenizer([protein.seq for protein in proteins], return_tensors="pt", padding=True)

inputs['input_ids'] = inputs['input_ids'].to(device)
inputs['attention_mask'] = inputs['attention_mask'].to(device)

# print(f"Inputs input_ids {inputs['input_ids']}")
print(f"Inputs input_ids shape {inputs['input_ids'].shape}")
print(f"Inputs input_ids device {inputs['input_ids'].device}")
print(f"Inputs input_ids type {type(inputs['input_ids'])}")

# print(f"Inputs attention_mask {inputs['attention_mask']}")
print(f"Inputs attention_mask shape {inputs['attention_mask'].shape}")
print(f"Inputs attention_mask device {inputs['attention_mask'].device}")
print(f"Inputs attention_mask type {type(inputs['attention_mask'])}")

In [None]:
outputs = model(**inputs)
# embedding = outputs.last_hidden_state
embedding = outputs.pooler_output

# print(f"Embedding {embedding}")
print(f"Embedding type {type(embedding)}")
print(f"Embedding shape {embedding.shape}")


In [None]:
# Distance between our embeddings

for iA, proteinA in enumerate(proteins):
    print("\n")
    for iB, proteinB in enumerate(proteins):
        
        _ = torch.nn.CosineSimilarity(dim = 0)(embedding[iA, :], embedding[iB, :])
        print(f"CosineSimilarity between {proteinA.name} and {proteinB.name} is {_}")
        
        _ = torch.nn.PairwiseDistance()(embedding[iA, :], embedding[iB, :])
        print(f"PairwiseDistance between {proteinA.name} and {proteinB.name} is {_}")