In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from torch.utils.data import TensorDataset, DataLoader
import torch.nn as nn
import torch.optim as optim
from torch.nn.functional import one_hot

import numpy as np
import h5py
import pickle

from sklearn.model_selection import train_test_split
import pandas as pd

import sys
sys.path.append('../common')
from dataset import TMH
from models import CNN
from cait_models import cait_S24_224, cait_tmh

### Create dataloader

In [5]:
dataset_path = "/Users/fga/data/tmh"
dataset = TMH(embeddings_path=dataset_path+"/embeddings.h5",
                protein_hashes_path=dataset_path+"/seq_anno_hash.pickle",
                train_ids=dataset_path+"/data_splits/train_prot_id_labels.csv")
dataloader = DataLoader(dataset, batch_size=4)

test_dataset = TMH(embeddings_path=dataset_path+"/embeddings.h5",
                protein_hashes_path=dataset_path+"/seq_anno_hash.pickle",
                train_ids=dataset_path+"/data_splits/test_prot_id_labels.csv")
test_dataloader = DataLoader(test_dataset, batch_size=4)

In [6]:
# net = CNN()
# net = MLP()
# net = cait_S24_224()
net = cait_tmh()

# The function is indicating the performance of the model.
# During the training process this function should be minimized
criterion = nn.CrossEntropyLoss()

# The minimization is achieved through Stochastic Gradient Descent
optimizer = optim.Adagrad(net.parameters(), lr=0.01)

In [7]:
# single step
inputs, labels = iter(dataloader).__next__()

x = inputs.unsqueeze(1)

try:
    outputs = net(x)
except ValueError:
    import ipdb
    ipdb.post_mortem()

loss = criterion(outputs, labels.float())

print(f"{outputs} \n {labels} \n {loss}")

tensor([[ 0.4876, -0.5640, -0.7576, -0.3238],
        [ 0.4875, -0.5639, -0.7575, -0.3241],
        [ 0.4877, -0.5640, -0.7576, -0.3238],
        [ 0.4877, -0.5640, -0.7576, -0.3239]], grad_fn=<AddmmBackward0>) 
 tensor([[0, 1, 0, 0],
        [0, 0, 1, 0],
        [0, 1, 0, 0],
        [1, 0, 0, 0]]) 
 1.5701520442962646


In [8]:
for epoch in range(4):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(dataloader, 0):
        """
        shape of inputs: torch.Size([4, 3, 32, 32])
            Batchsize: 4
            Channels: 3 (Red, Green, Blue)
            Image size: 32 x 32

        labels: tensor([9, 3, 0, 3])
            9: class of image 0 in batch
            3: class of image 1 in batch
            ...
        """
        inputs, labels = data

        """ zero the parameter gradients after every batch
        This is necessary because the gradients (directions of how the weigths and biases
        will be updated) are accumulated in each backward pass.
        https://stackoverflow.com/questions/48001598/why-do-we-need-to-call-zero-grad-in-pytorch
        """
        optimizer.zero_grad()  # SGD

        # forward + backward + optimize
        # shape outputs: torch.Size([4, 10])
        # for every image a prediction
        # print(f"{inputs}")
        outputs = net(inputs.unsqueeze(1))

        #print(f"{outputs} \t {labels}")

        # the first iteration CrossEntropy: tensor(2.3100, grad_fn=<NllLossBackward0>)

        loss = criterion(outputs, labels.float())
        loss.backward()
        optimizer.step()

        # running loss after 3 iterations: 6.894119024276733
        # Why is the loss added?
        running_loss += loss.item()
        if i % 100 == 99:    # print every 2000 mini-batches
            print(f'[{epoch + 1}, {i + 1}] loss: {(running_loss / 99):.3f}')
            running_loss = 0.0

print('Finished Training')

[1, 100] loss: 1.853
[1, 200] loss: 1.129
[1, 300] loss: 1.098
[1, 400] loss: 1.143
[1, 500] loss: 1.186
[1, 600] loss: 1.182
[1, 700] loss: 1.116
[1, 800] loss: 1.134
[2, 100] loss: 1.111
[2, 200] loss: 1.059
[2, 300] loss: 1.072
[2, 400] loss: 1.115
[2, 500] loss: 1.168
[2, 600] loss: 1.166
[2, 700] loss: 1.104
[2, 800] loss: 1.124
[3, 100] loss: 1.099
[3, 200] loss: 0.915
[3, 300] loss: 0.886
[3, 400] loss: 1.098
[3, 500] loss: 1.093
[3, 600] loss: 1.096
[3, 700] loss: 1.059
[3, 800] loss: 0.745
[4, 100] loss: 0.511
[4, 200] loss: 0.435
[4, 300] loss: 0.475
[4, 400] loss: 0.479
[4, 500] loss: 0.504
[4, 600] loss: 0.501
[4, 700] loss: 0.417
[4, 800] loss: 0.441
Finished Training


In [9]:
torch.save(net.state_dict(), "model.pt")

### Evaluate the model on the test data
This could be done with TorchMetrics but we will do this manually here

In [11]:
correct = 0
total = 0
# since we're not training, we don't need to calculate the gradients for our outputs
with torch.no_grad():
    for data in test_dataloader:
        images, labels = data
        # calculate outputs by running images through the network
        outputs = net(images.unsqueeze(1))

        # the class with the highest energy is what we choose as prediction
        _, predicted = torch.max(outputs.data, 1)
        _, labels = torch.max(labels.data, 1)

        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the network on the {len(test_dataloader)} test embeddings: {(100 * correct / total):.2f}')

Accuracy of the network on the 91 test embeddings: 86.43


In [13]:
# prepare to count predictions for each class
mapp = {
            'G_SP': 0,
            'G': 1,
            'SP_TM': 2,
            'TM': 3
        }
classes = list(mapp.keys())
correct_pred = {classname: 0 for classname in classes}
total_pred = {classname: 0 for classname in classes}

# again no gradients needed
with torch.no_grad():
    for data in test_dataloader:
        images, labels = data
        outputs = net(images.unsqueeze(1))
        _, predictions = torch.max(outputs, 1)
        _, labels = torch.max(labels, 1)
        # collect the correct predictions for each class
        for label, prediction in zip(labels, predictions):
            if label == prediction:
                correct_pred[classes[label]] += 1
            total_pred[classes[label]] += 1


# print accuracy for each class
for classname, correct_count in correct_pred.items():
    accuracy = 100 * float(correct_count) / total_pred[classname]
    print("Accuracy for class {:5s} is: {:.1f} %".format(classname,
                                                         accuracy))

"""
Accuracy for class G_SP  is: 93.2 %
Accuracy for class G     is: 99.4 %
Accuracy for class SP_TM is: 87.6 %
Accuracy for class TM    is: 93.5 %
"""

Accuracy for class G_SP  is: 82.6 %
Accuracy for class G     is: 99.5 %
Accuracy for class SP_TM is: 72.7 %
Accuracy for class TM    is: 0.0 %


'\nAccuracy for class G_SP  is: 93.2 %\nAccuracy for class G     is: 99.4 %\nAccuracy for class SP_TM is: 87.6 %\nAccuracy for class TM    is: 93.5 %\n'

Create hashsum