# PyTorch: Prop3D with Language Models (ESM)

### Install prereqs: pytorch and huggingface transformers

Uncomment if you need to install. For PyTorch GPU installation, follow the instructions on https://pytorch.org/get-started/locally/

In [None]:
import sys

In [None]:
#!{sys.executable} -m pip install --user torch

In [None]:
#!{sys.executable} -m pip install --user tokenizers transformers

### Imports

In [None]:
import os
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
from transformers import AutoTokenizer, EsmForTokenClassification, DataCollatorForTokenClassification
from Prop3D.ml.datasets.DistributedDomainSequenceDataset import DistributedDomainSequenceDataset

torch.manual_seed(0)
device = "cuda" if torch.cuda.is_available() else "cpu"

### Define parameters

In [None]:
os.environ["HS_ENDPOINT"] = "http://prop3d-hsds.pods.uvarc.io"
os.environ["HS_USERNAME"] = "None"
os.environ["HS_PASSWORD"] = "None"

cath_file = "/CATH/Prop3D-20.h5"
cath_superfamily = "1/10/10/10" #Use / instead of .

#Could be charge, hydrophobicity, accessibility, 3 types of secondary structure, etc
predict_features = ["is_sheet", "is_helix", "Unk_SS"] 

### Set up ESM

In [None]:
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
model = EsmForTokenClassification.from_pretrained("facebook/esm2_t6_8M_UR50D", num_labels=len(predict_features))

In [None]:
data_collator = DataCollatorForTokenClassification(tokenizer)
def collate(x):
    sequences, labels = zip(*x)
    batch = []
    for s, l in x:
        s = tokenizer(s)
        s["labels"] = np.argmax(l, axis=1)
        batch.append(s)

    batch = data_collator(batch)
    batch["input_ids"].to(device)
    batch["attention_mask"].to(device)
    batch["labels"].to(device)
    
    return batch

### Set up Prop3D datasets and dataloaders

In [None]:
dataset_train = DistributedDomainSequenceDataset(
    cath_file, 
    cath_superfamily, 
    predict_features=predict_features, 
    cluster_level="S100")
training_loader = torch.utils.data.DataLoader(
    dataset_train, 
    batch_size=128, 
    collate_fn=collate,
    shuffle=True)
dataset_val = DistributedDomainSequenceDataset(
    cath_file, 
    cath_superfamily, 
    predict_features=predict_features, 
    cluster_level="S100",
    validation=True)
val_loader = torch.utils.data.DataLoader(
    dataset_val, 
    batch_size=128, 
    collate_fn=collate,
    shuffle=False)

### Start training

In [None]:
# Optimizers specified in the torch.optim package
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

In [None]:
for epoch in range(30):
    for loader, is_train in [(training_loader, True), (val_loader, False)]:
        running_loss = 0
        pbar = tqdm(loader)
        for batch in pbar: #enumerate(loader):
            # Every data instance is an input + label pair
            #inputs, labels = data
            #labels = labels.to(device)
            #inputs = tokenizer(inputs).to(device)
            
            # Zero your gradients for every batch!
            optimizer.zero_grad()

            if is_train:
                # Make predictions for this batch
                
                outputs = model(
                    input_ids=batch["input_ids"], 
                    attention_mask=batch["attention_mask"], 
                    labels=batch["labels"])
        
                # Compute the loss and its gradients
                loss = outputs.loss
                loss.backward()

                # Adjust learning weights
                optimizer.step()

                name = "TRAIN"

            else:
                # Make predictions for this batch
                outputs = model(inputs, labels=labels)
        
                # Compute the loss and its gradients
                loss = outputs.loss

                name = "VALIDATION"
                
            pbar.set_description(f"Epoch {epoch} {name} Loss {loss}")
