<a href="https://colab.research.google.com/github/kkim14172/pMHC_specificity_prediction/blob/main/siamese_SBERT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%%capture
!pip install sentence-transformers

In [2]:
%%capture
!pip install datasets

In [166]:
from torch.utils.data import DataLoader
import torch
import math
from zipfile import ZipFile

from sentence_transformers import SentenceTransformer, SentencesDataset, losses, models, util
from sentence_transformers.evaluation import TripletEvaluator, EmbeddingSimilarityEvaluator, BinaryClassificationEvaluator
from sentence_transformers.readers import STSBenchmarkDataReader, InputExample
from sentence_transformers.cross_encoder import CrossEncoder
from sentence_transformers.cross_encoder.evaluation import CECorrelationEvaluator

import logging
import random
from datetime import datetime
import sys
import os
import gzip
import csv
import pandas as pd
import numpy as np
from tqdm.notebook import tqdm
import seaborn as sns
import matplotlib.pyplot as plt

%matplotlib inline
%config InlineBackend.figure_format = 'svg'

## Prepare triplet training data

In [11]:
aa2res = {
    'A':'ALA', 'R':'ARG', 'N':'ASN', 'D':'ASP', 'C':'CYS',
    'E':'GLU', 'Q':'GLN', 'G':'GLY', 'H':'HIS', 'I':'ILE',
    'L':'LEU', 'K':'LYS', 'M':'MET', 'F':'PHE', 'P':'PRO',
    'S':'SER', 'T':'THR', 'W':'TRP', 'Y':'TYR', 'V':'VAL'
}
def marked_text(peptide:str):
    text = " ".join(aa2res.get(aa) for aa in peptide)
    #update = "[CLS] " + text + " [SEP]"
    return(text)

## Bi-encoder (sentence-transformers)

In [82]:
model_name = 'bert-base-uncased' #You can specify any pre-trained model here
word_embedding_model = models.Transformer(model_name)

# Apply mean pooling to get one fixed sized sentence vector
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(),
                               pooling_mode_mean_tokens=True,
                               pooling_mode_cls_token=False,
                               pooling_mode_max_tokens=False)

model = SentenceTransformer(modules=[word_embedding_model, pooling_model])

In [181]:
###### Configuration ######
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Run on: {}".format(device))
batch_size = 5
num_epochs = 5
output_path = (
    "output/finetune-batch-all-keskin-"
    + model_name
    + "-"
    + datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
)

Run on: cuda


In [141]:
data = pd.read_csv("siamese_training_data.csv")
anchors = data['anchor'].drop_duplicates()

In [142]:
random.seed(2307200650)
eval_anchors = np.random.choice(anchors, round(len(anchors)*0.1), replace=False)

In [146]:
train_data = data.drop(data[data['anchor'].isin(eval_anchors)].index)
eval_data = data[data['anchor'].isin(eval_anchors)]

In [147]:
print(train_data.shape)
print(eval_data.shape)

(546482, 8)
(69150, 8)


In [151]:
def intoDataset(data):
  anchors = data["anchor"].drop_duplicates()
  poss = data["true"].drop_duplicates()
  negs = data["false"].drop_duplicates()

  examples = []

  for anchor in anchors:
    examples.append(InputExample(texts=[marked_text(anchor)], label=0))
    #examples.append(InputExample(texts=[marked_text()], label=0))

  for neg in negs:
    examples.append(InputExample(texts=[marked_text(neg)], label=1))

  return(examples)


In [175]:
def intoTriplets(data):

  triplets = []

  for ind, row in data.iterrows():
    triplets.append(InputExample(texts=[marked_text(row['anchor']),
                                        marked_text(row['true']),
                                        marked_text(row['false'])]))
  return triplets

In [179]:
####  Train the SBERT model ####
# Define the dataloader
train_set = intoDataset(train_data)
train_dataset = SentencesDataset(train_set, model)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
warmup_steps = int(len(train_dataset) * num_epochs / batch_size * 0.1) # 10% of train data for warm-up
print("Warmup-steps: {}".format(warmup_steps))

Warmup-steps: 229


In [186]:
eval_triplets = intoTriplets(eval_data)
eval_evaluator = TripletEvaluator.from_input_examples(eval_triplets, name='keskin-eval')
#eval_evaluator(model)

In [187]:
# Define loss function

train_loss = losses.BatchAllTripletLoss(model=model)
#train_loss = losses.BatchHardTripletLoss(model=model)
#train_loss = losses.BatchHardSoftMarginTripletLoss(model=model)
#train_loss = losses.BatchSemiHardTripletLoss(model=model)

In [188]:
# Train the model
model.fit(
    train_objectives=[(train_dataloader, train_loss)],
    evaluator=eval_evaluator,
    epochs=num_epochs,
    evaluation_steps=1000,
    warmup_steps=warmup_steps,
    output_path=output_path,
)

Epoch:   0%|          | 0/5 [00:00<?, ?it/s]

Iteration:   0%|          | 0/460 [00:00<?, ?it/s]

Iteration:   0%|          | 0/460 [00:00<?, ?it/s]

Iteration:   0%|          | 0/460 [00:00<?, ?it/s]

Iteration:   0%|          | 0/460 [00:00<?, ?it/s]

Iteration:   0%|          | 0/460 [00:00<?, ?it/s]

In [195]:
from google.colab import files
#!zip -r /content/output.zip /content/output

files.download("/content/output.zip")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [188]:
# Load Metrics
metric = load_metric(metrics_name) # e.g. "f1"

# Create Metrics
def compute_metrics(eval_pred):

  predictions, labels = eval_pred
  predictions = np.argmax(predictions, axis=1)

  # 'micro', 'macro', etc. are for multi-label classification. If you are running a binary classification, leave it as default or specify "binary" for average
  return metric.compute(predictions=predictions, references=labels, average="micro")

In [197]:
test_set = pd.read_csv("iedb_A0101_withMHCinfo.csv")

In [200]:
sentences = [marked_text(pep) for pep in test_set['peptide']]

In [209]:
import pickle
#Sentences are encoded by calling model.encode()
def encode_sentences(model_name, sentences, output_path):
  model = SentenceTransformer(model_name)
  embeddings = model.encode(sentences, device='cuda')
  #Store sentences & embeddings on disc
  with open(f'{output_path}/embeddings_{os.path.basename(model_name)}.pkl', "wb") as fOut:
    pickle.dump({'sentences': sentences, 'embeddings': embeddings}, fOut, protocol=pickle.HIGHEST_PROTOCOL)

In [211]:
model_name='bert-base-uncased'
#encode_sentences(model_name, sentences, output_path)
files.download(f'{output_path}/embeddings_{os.path.basename(model_name)}.pkl')

model_name='/content/output/finetune-batch-all-keskin-bert-base-uncased-2023-07-20_12-33-20'
#encode_sentences(model_name, sentences, output_path)
files.download(f'{output_path}/embeddings_{os.path.basename(model_name)}.pkl')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>