# Examples on how to use Prop3D as training data to a language model

### Install prereqs: pytorch and huggingface transformers

Uncomment if you need to install. For PyTorch GPU installation, follow the instructions on https://pytorch.org/get-started/locally/

In [None]:
import sys

In [None]:
#!{sys.executable} -m pip install --user torch

In [None]:
#{sys.executable} -m pip install --user tokenizers transformers

### Imports

In [None]:
from transformers import AutoTokenizer, EsmForTokenClassification
import torch
from Prop3D.ml.datasets.DistributedDomainSequenceDataset import DistributedDomainSequenceDataset
device = "cuda" if torch.cuda.is_available() else "cpu"

### Define parameters

In [None]:
cath_file = "/projects/Prop3D/Prop3D-20.h5"
cath_superfamily = "1/10/10/10" #Use / instead of .

#Could be charge, hydrophobicity, accessibility, 3 types of secondary structure, etc
predict_features = ["is_sheet", "is_helix", "Unk_SS"] 

### Set up ESM

In [None]:
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
model = EsmForTokenClassification.from_pretrained("facebook/esm2_t6_8M_UR50D", num_labels=len(predict_features))

### Set up Prop3D datasets and dataloaders

In [None]:

dataset_train = DistributedDomainSequenceDataset(cath_file, cath_superfamily, predict_features=predict_features, split_level="S100")
training_loader = torch.utils.data.DataLoader(dataset_train, batch_size=128, shuffle=True)
dataset_val = DistributedDomainSequenceDataset(cath_file, cath_superfamily, predict_features=predict_features, split_level="S100", validation=True)
val_loader = torch.utils.data.DataLoader(dataset_val, batch_size=128, shuffle=False)

### Start training

In [None]:
# Optimizers specified in the torch.optim package
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

In [None]:
for epoch in range(30):
    for loader, is_train in [(training_loader, True), (val_loader, False)]:
        running_loss = 0
        for i, data in enumerate(loader):
            # Every data instance is an input + label pair
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)
            inputs = tokenizer(inputs)

            # Zero your gradients for every batch!
            optimizer.zero_grad()

            if is_train:
                # Make predictions for this batch
                outputs = model(inputs, labels=labels)
        
                # Compute the loss and its gradients
                loss = outputs.loss
                loss.backward()

                # Adjust learning weights
                optimizer.step()

                name = "TRAIN"

            else:
                # Make predictions for this batch
                outputs = model(inputs, labels=labels)
        
                # Compute the loss and its gradients
                loss = outputs.loss

                name = "VALIDATION"

            running_loss += loss
            if i%1000==0:
                last_loss = running_loss / 1000 # loss per batch
                print('  {} batch {} loss: {}'.format(name, i + 1, last_loss))
                running_loss = 0

