# Learning the difference between two protein conformational states using a feed-forward neural network

In a recent research project (published [here](https://www.biorxiv.org/content/biorxiv/early/2024/02/26/2024.02.22.581541.full.pdf)), I have discovered two distinct conformational states of the protein-protein complex of Trypsin and fluorinated BPTI variants. In the project, I have found the differences between these two states by applying chemical intuition and thorough MD analysis. In this Notebook, I retrospectively explore the approach to learn what makes these two states different, without knowledge of the states, by using a feed forward neural network with the amino acid backbone and sidechain dihedrals of the whol protein complex as features. In a next step, the predictive power of the neural net with features removed is tested, to get an intuition about the importance of that feature for the difference between the states. The Idea for this procedure comes from [this paper](https://pubs.acs.org/doi/10.1021/acs.jctc.1c00924).

The following steps are needed:

1. Run simulations in both states to gather data. The simulations are started in either of the complexes and it is assumed that the systems remains in this state throughout the whole simulation. The data consists of simulation snapshots, labeled 0 or 1 for the respective state.
2. The simulation snapshots are converted into feature vectors containing all amino acid dihedrals of the protein complex.
3. The data is split into test- and train data and the model is built and trained.
4. The accuracy is tested when the features of the protein amino acids are separately removed, one-by-one, to see if any of the features causes a significant drop in accuracy and is therefore especially important for the difference between the two states.

## Data Preparation

The simulations yielded 17000 snapshots per state. The dihedrals of all amino acids were calculated using the mdtraj package and amounted to 1003 features. To handle discontinuities in the periodic dihedrals, the features were each processed with a sine and cosine function, yielding 2006 features.

In [15]:
import pandas as pd
import numpy as np
import json

from sklearn.model_selection import train_test_split
import torch


def sin_cos_df(df_file, state):
    df = pd.read_csv(df_file)
    df.drop('Unnamed: 0', axis=1, inplace=True)

    columns = []
    sin_cos_data = []

    for x in df.columns:
        sin_cos_data.append(np.sin(df[x]))
        columns.append(f'{x}-sin')

        sin_cos_data.append(np.cos(df[x]))
        columns.append(f'{x}-cos')

    df_new = pd.concat(sin_cos_data, axis=1)
    df_new.columns = columns

    if state == 'fully_bound':
        df_new['State'] = np.zeros(len(df_new))
    elif state == 'pre_bound':
        df_new['State'] = np.ones(len(df_new))

    return df_new

The data was then split into train and test data by randomly shuffling the data and using a share of 0.25 of the data as test set. (doesn't run here, as original simulation data is only stored locally)

In [None]:
from sklearn.model_selection import train_test_split
import torch


X = df_full.drop('State', axis=1).values
y = df_full['State'].values

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, shuffle=True, random_state=2023)

Train and test data is available on Google Drive

In [16]:
X_train = np.load('drive/MyDrive/FNN_protein_complex_data/X_train.npy')
X_test = np.load('drive/MyDrive/FNN_protein_complex_data/X_test.npy')

y_train = np.load('drive/MyDrive/FNN_protein_complex_data/y_train.npy')
y_test = np.load('drive/MyDrive/FNN_protein_complex_data/y_test.npy')

In [17]:
X_train = torch.tensor(X_train, dtype=torch.float32)
X_test = torch.tensor(X_test, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.long)
y_test = torch.tensor(y_test, dtype=torch.long)

print(X_train.shape, y_train.shape)
print(X_test.shape, y_test.shape)

torch.Size([25525, 2006]) torch.Size([25525])
torch.Size([8509, 2006]) torch.Size([8509])


## Build and train the model

The model is a feed forward neural network with one hidden layer.

In [18]:
import torch.nn as nn
from torch import optim


class Net(nn.Module):

    def __init__(self):
        super().__init__()
        torch.manual_seed(2023)
        self.net = nn.Sequential(
            nn.Linear(2006, 128),
            nn.ReLU(),
            nn.Linear(128, 2),
            nn.Sigmoid()
        )

    def forward(self, X):
        return self.net(X)

    def predict(self, X):
        Y_pred = self.forward(X)
        return Y_pred

def fit(X, y, model, opt, loss_fn, n_epochs = 1000):

    for epoch in range(n_epochs):
        loss = loss_fn(model(X), y)
        loss.backward()
        opt.step()
        opt.zero_grad()

        if epoch % 10 == 0:
            print(f'Loss (Epoch {epoch}): {loss.item()}')

The model was initially trained for 100 epochs, but loss was already converged after 50 epochs, so the training process was then stopped after 50 epochs.

In [19]:
net = Net()
loss_fn = nn.functional.cross_entropy
opt = torch.optim.Adam(net.parameters(), lr=0.001)

fit(X_train, y_train, net, opt, loss_fn, n_epochs=51)

Loss (Epoch 0): 0.693171501159668
Loss (Epoch 10): 0.5294300317764282
Loss (Epoch 20): 0.4086780250072479
Loss (Epoch 30): 0.3545016646385193
Loss (Epoch 40): 0.33295848965644836
Loss (Epoch 50): 0.32471030950546265


Test set accuracy was tested and found to be very high.

In [20]:
def accuracy(y_hat, y):
    pred = torch.argmax(y_hat, dim=1)
    return (pred == y).float().mean()


accuracy(net.predict(X_test), y_test)

tensor(0.9999)

## Remove features

Now we remove the features of every amino acid separately, to see if the test set accuracy drops, as this would point to a specific importance of that amino acid for differentiating between the two states.

In [31]:
with open('drive/MyDrive/FNN_protein_complex_data/res_feat_idx.json', 'r') as f:
    res_feat_idx = json.load(f)

accuracies = {}

for x in res_feat_idx:
    idx = res_feat_idx[x]
    X = torch.clone(X_test)
    X[:, idx] =  X_test.mean(dim=0)[idx]
    X[:, idx] =  X_test.mean(dim=0)[idx]

    y_hat = net.predict(X)
    acc = accuracy(y_hat, y_test)

    accuracies[x] = acc


Now we look at the 10 amno acids with the lowest accuracy.

In [33]:
accuracies_sorted = dict(sorted(accuracies.items(), key=lambda x:x[1]))

for i, x in enumerate(accuracies_sorted):
    print(x, accuracies_sorted[x])
    if i == 10:
        break

GLU187 tensor(0.9994)
ARG1 tensor(0.9994)
ARG17 tensor(0.9994)
GLY56 tensor(0.9994)
GLN30 tensor(0.9995)
LEU33 tensor(0.9995)
TYR39 tensor(0.9995)
GLY43 tensor(0.9995)
TYR59 tensor(0.9995)
LYS60 tensor(0.9995)
SER61 tensor(0.9995)


We see that the accuracy does not drop very much at all. This is expected as the difference between the states is likely to involve multiple amino acids. Interestingly, ARG17 is in the list where the accuracy drops most (albeit very slightly), which is one of the main amino acids that will change its position during the transition between the states.