# PyTorch: Prop3D with Graphs (ProteinMPNN)

Here we should how to use Prop3D in a PyTorch model to predict the electrostatic protential using ProteinMPNN.

### Install preqrequites if needed
Uncomment to install

In [None]:
#!git clone https://github.com/dauparas/ProteinMPNN.git

### Define imports

In [None]:

import os
import sys
sys.path.append("ProteinMPNN/training")

import torch
from torch import nn
from Prop3D.ml.datasets.DistributedProteinMPNNDataset import DistributedProteinMPNNDataset

from model_utils import ProteinMPNN, featurize, get_std_opt, loss_nll

torch.manual_seed(0)
device = "cuda" if torch.cuda.is_available() else "cpu"

### Defined MPNN model
Instead of predicting 21 characters, only predict 3: is_electronegative, is_electropostive, is_neutral

### Define parameters

In [None]:
os.environ["HS_ENDPOINT"] = "http://prop3d-hsds.pods.uvarc.io"
os.environ["HS_USERNAME"] = "None"
os.environ["HS_PASSWORD"] = "None"

cath_file = "/CATH/Prop3D-20.h5"
cath_superfamily = "2/60/40/10" #Use / instead of .

#Could be charge, hydrophobicity, accessibility, 3 types of secondary structure, etc
predict_features = ["electrostatic_potential"] 

In [None]:
def collate(x):
    return x

In [None]:
dataset_train = DistributedProteinMPNNDataset(
    cath_file, 
    cath_superfamily, 
    predict_features=predict_features, 
    cluster_level="S100")
training_loader = torch.utils.data.DataLoader(
    dataset_train, 
    batch_size=16, 
    shuffle=True, 
    num_workers=64, 
    collate_fn=collate)
dataset_val = DistributedProteinMPNNDataset(
    cath_file, 
    cath_superfamily, 
    predict_features=predict_features,  
    cluster_level="S100", 
    validation=True)
val_loader = torch.utils.data.DataLoader(
    dataset_val,
    batch_size=16, 
    shuffle=False, 
    num_workers=64, 
    collate_fn=collate)

In [None]:
charge_to_idx = {(1,0,0):0, (0,1,0):1, (0,0,1):2}
def process_batch(batch):
    """Convert featuress into a new type of sequence with L=3
    """
    X, S, mask, lengths, chain_M, residue_idx, mask_self, chain_encoding_all = featurize(batch, device)
    for i, prot in enumerate(batch):
        for j, value in enumerate(prot["prop3d_features"]):
            try:
                S[i,j] = charge_to_idx[(value==0,value<0,value>0)]
            except KeyError:
                S[i,j] = 0.
    return X, S, mask, lengths, chain_M, residue_idx, mask_self, chain_encoding_all

In [None]:
def loss_smoothed(S, log_probs, mask, weight=0.1):
    """ Negative log probabilities """
    S_onehot = torch.nn.functional.one_hot(S, 3).float()

    # Label smoothing
    S_onehot = S_onehot + weight / float(S_onehot.size(-1))
    S_onehot = S_onehot / S_onehot.sum(-1, keepdim=True)

    loss = -(S_onehot * log_probs).sum(-1)
    loss_av = torch.sum(loss * mask) / 2000.0 #fixed 
    return loss, loss_av

In [None]:
model = ProteinMPNN(num_letters=3, vocab=3)
model = model.to(device)
optimizer = get_std_opt(model.parameters(), 128, 0)

In [None]:
for epoch in range(200):
    for loader, is_train in [(training_loader, True), (val_loader, False)]:
        if is_train:
            model.train()
        else:
            model.eval()
            
        pbar = tqdm(loader)
        for data in pbar:
            X, S, mask, lengths, chain_M, residue_idx, mask_self, chain_encoding_all = process_batch(data)

            if is_train:
                # Zero your gradients for every batch!
                optimizer.zero_grad()
                
            # Make predictions for this batch
            mask_for_loss = mask*chain_M

            log_probs = model(X, S, mask, chain_M, residue_idx, chain_encoding_all)
            _, loss_av_smoothed = loss_smoothed(S, log_probs, mask_for_loss)

            if is_train:
                loss_av_smoothed.backward()

                # Adjust learning weights
                optimizer.step()

                name = "TRAIN"

            else:
                name = "VALIDATION"
            
            loss, loss_av, true_false = loss_nll(S, log_probs, mask_for_loss)

            pbar.set_description(f"Epoch {epoch} {name} Loss {loss.mean()}")