In [212]:
import torch
import gzip
from pathlib import Path
from torch.utils.data import random_split
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
from src.data import ProteinsDataset, AllVertices, ProteinRecord
from src.nn_model import Amber_NN



In [84]:
_ = torch.manual_seed(142)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [173]:
protein = ["1FZW_B", "1GP2_BG", "2E7J_B", "2E89_A", "2EG5_E", "2NYZ_A",
           "2OOR_A", "3AAA_AB",
           "7FCT_A", "7MX9_A", "7N8G_A", "7QRR_A", "7WUG_145"] 
# "6BOY_BC", <- left it for testing
pd = AllVertices(protein)

n_features = pd[0][0].shape[0]

train_dataset, test_dataset = random_split(pd, [0.85, 0.15],
             generator=torch.Generator().manual_seed(32))

In [174]:
print(len(pd))
print(n_features)


230080
111


In [175]:
# test_dataset[2]

In [227]:
class NeuralNet(nn.Module):
    def __init__(self, input_size, num_classes, hidden_size, p = 0.1):
        super(NeuralNet, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, hidden_size)
        self.fc4 = nn.Linear(hidden_size, num_classes)
        self.dropout = nn.Dropout(p) 

    def forward(self, x):
        out = F.relu(self.fc1(x))
        out = F.relu(self.fc2(out))
        out = self.dropout(out)
        out = F.relu(self.fc3(out))
        out = self.dropout(out)
        out = F.relu(self.fc4(out))
        return out



class Amber_NN(nn.Module):

    def __init__(self, nfeatures, nclasses):
        super().__init__()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(nfeatures, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, nclasses)
        )

    def forward(self, x):
        # x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

In [228]:
learning_rate = 5e-3
batch_size = 128

# Create data loaders.
train_dataloader = DataLoader(train_dataset, batch_size=batch_size)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size)

model = NeuralNet(n_features, 3, 256).to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

In [229]:
X = torch.stack([train_dataset[333][0], train_dataset[235][0]], axis=0)
X = X.to(device)

In [230]:
logits = model(X)
pred_probab = nn.Softmax(dim=1)(logits)
y_pred = pred_probab.argmax(1)
print(f"Predicted class: {y_pred}")

Predicted class: tensor([0, 0], device='cuda:0')


In [231]:
def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    for batch, (X,y) in enumerate(dataloader):
        # Compute prediction and loss
        X = X.to(device)
        y = y.to(device)
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 1000 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


def test_loop(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    with torch.no_grad():
        for X, y in dataloader:
            
            X = X.to(device)
            y = y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
        test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [232]:
epochs = 50
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(train_dataloader, model, loss_fn, optimizer)
    test_loop(test_dataloader, model, loss_fn)
print("Done!")

Epoch 1
-------------------------------
loss: 1.077596  [  128/195568]
loss: 0.933999  [128128/195568]
Test Error: 
 Accuracy: 64.7%, Avg loss: 0.811236 

Epoch 2
-------------------------------
loss: 0.853248  [  128/195568]
loss: 0.927002  [128128/195568]
Test Error: 
 Accuracy: 65.0%, Avg loss: 0.791923 

Epoch 3
-------------------------------
loss: 0.839897  [  128/195568]
loss: 0.850629  [128128/195568]
Test Error: 
 Accuracy: 66.0%, Avg loss: 0.759220 

Epoch 4
-------------------------------
loss: 0.815682  [  128/195568]
loss: 0.838163  [128128/195568]
Test Error: 
 Accuracy: 66.4%, Avg loss: 0.751968 

Epoch 5
-------------------------------
loss: 0.799493  [  128/195568]
loss: 0.831783  [128128/195568]
Test Error: 
 Accuracy: 66.7%, Avg loss: 0.745294 

Epoch 6
-------------------------------
loss: 0.801085  [  128/195568]
loss: 0.832500  [128128/195568]
Test Error: 
 Accuracy: 66.9%, Avg loss: 0.742281 

Epoch 7
-------------------------------
loss: 0.791491  [  128/195568]

## Test the model

In [233]:
def save_for_paraview(filename, xyz_tensor, color_tensor, z_shift = 0):
    
    tosave = torch.cat([xyz_tensor, color_tensor[:,None]], dim=1)
    qc = tosave.cpu().detach().numpy()
    with open(filename, 'w') as iFile:
        iFile.write("x,y,z,c\n")
        for i in qc:
            x,y,z,c = i
            iFile.write(f'{x},{y},{z+z_shift},{c}\n')
    
    return 0

In [234]:
Trial_pdb = "6BOY_B"

In [235]:
data = ProteinRecord(Trial_pdb)

In [236]:
X = data.f.to(device)
p = data.p.to(device)
y = data.y_aux.to(device)

In [237]:
with torch.no_grad():
    logits = model(X)
    pred_probab = nn.Softmax(dim=1)(logits)
    y_pred = pred_probab.argmax(1)
    # print(f"Predicted class: {y_pred}")

In [238]:
save_for_paraview(f"predicted_{Trial_pdb}.csv", p, y_pred)
save_for_paraview(f"real_{Trial_pdb}.csv", p, y)
        

0

In [241]:
y[580:680]

tensor([1, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 1, 1, 1, 0, 0, 1, 1, 1, 2, 2,
        2, 2, 1, 1, 1, 1, 0, 0, 1, 1, 2, 2, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1,
        1, 1, 2, 2, 1, 1, 2, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 1,
        1, 1, 1, 1], device='cuda:0')