In [1]:
import sys
from pathlib import Path
from sklearn.model_selection import train_test_split

In [2]:
import torch
from torch.utils.data import random_split
from torch.utils.data import DataLoader
import torch.nn as nn

In [5]:
# local libs
from dataset import AllVertices
from nn_model import AmberNN
from src.logger import Logger
from run_opts import config_runtime


In [6]:
_ = torch.manual_seed(142)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [7]:
n_cls = 3  # number of classes
conf_dev = config_runtime['device']
train_f = config_runtime['train_frac']
seed = config_runtime['seed']
learning_rate = config_runtime['learning_rate']
batch_size = config_runtime['batch_size']
hid_size = config_runtime['hidden_size']
log_step = config_runtime['log_step']
epochs = config_runtime['num_epochs']
run_name = config_runtime['run_name']

In [61]:
learning_rate = 0.01

In [10]:
# Read list of the proteins
proteins = []
with open("../../data/lists/train_chrg.txt", 'r') as iFile:
    for i in iFile:
        if i[0] != '#':
            proteins.append(i.strip())

In [15]:
train_prots, test_prots = train_test_split(proteins, train_size=train_f, random_state=412)

In [16]:
print("\n")
print("Data")
print("-------------------------")
print(f"Proteins: train {len(train_prots)}   test {len(test_prots)}")
print(test_prots)



Data
-------------------------
Proteins: train 54   test 7
['4g1q_B', '6TVP_AB', '1rzh_H', '2j8b_A', '1GP2_BG', '2bln_A', '6KL0_A']


In [17]:
train_dataset = AllVertices(train_prots)
test_dataset = AllVertices(test_prots)
n_features = train_dataset[0][0].shape[0]

In [18]:
print(f"Vertices: train {len(train_dataset)}   test {len(test_dataset)}")
print(f"Features: {n_features}")

Vertices: train 739661   test 109146
Features: 106


In [41]:
# Create data loaders.
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

In [53]:
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(n_features, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.Dropout(),
            nn.ReLU(),
            nn.Linear(32, 32),
            nn.ReLU(),
            nn.Linear(32, n_cls),
        )

    def forward(self, x):
        logits = self.linear_relu_stack(x)
        return logits

In [62]:
model = NeuralNetwork().to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

In [63]:
def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    for batch, (X,y) in enumerate(dataloader):
        # Compute prediction and loss
        X = X.to(device)
        y = y.to(device)
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 1000 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


def test_loop(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    with torch.no_grad():
        for X, y in dataloader:
            
            X = X.to(device)
            y = y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
        test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [64]:
epochs = 15
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(train_dataloader, model, loss_fn, optimizer)
    test_loop(test_dataloader, model, loss_fn)
print("Done!")

Epoch 1
-------------------------------
loss: 1.106787  [  128/739661]
loss: 1.047741  [128128/739661]
loss: 0.988061  [256128/739661]
loss: 0.940335  [384128/739661]
loss: 0.916842  [512128/739661]
loss: 1.007598  [640128/739661]
Test Error: 
 Accuracy: 51.5%, Avg loss: 0.967856 

Epoch 2
-------------------------------
loss: 1.008540  [  128/739661]
loss: 0.872582  [128128/739661]
loss: 0.888423  [256128/739661]
loss: 0.985062  [384128/739661]
loss: 0.885042  [512128/739661]
loss: 0.905500  [640128/739661]
Test Error: 
 Accuracy: 51.4%, Avg loss: 0.968624 

Epoch 3
-------------------------------
loss: 0.905877  [  128/739661]
loss: 0.976274  [128128/739661]
loss: 0.886835  [256128/739661]
loss: 0.895572  [384128/739661]
loss: 0.957664  [512128/739661]
loss: 0.947634  [640128/739661]
Test Error: 
 Accuracy: 51.5%, Avg loss: 0.962167 

Epoch 4
-------------------------------
loss: 0.900497  [  128/739661]
loss: 0.954361  [128128/739661]
loss: 0.968699  [256128/739661]
loss: 0.878765  