In [1]:
import torch
import torch.nn as nn

from utils.setup import GetCustomProteinDatasetPadded, GetCVProteins

import numpy as np

import utils.metrics_utils as mu

# import matplotlib.pyplot as plt
# from sklearn.preprocessing import OneHotEncoder
# from sklearn.preprocessing import LabelEncoder
from torch.utils.data import Dataset

# from torchvision import datasets
# from torchvision.transforms import ToTensor
from sklearn import metrics
import torch.optim as optim

encode_length = 1500
print_error_type_pairs = False

# Create the dataset cunstructor (use encode_length to set the dimension of the encoded proteins)
CustomProteinDataset = GetCustomProteinDatasetPadded(encode_length)
CVProteins = GetCVProteins()


# Model
class Model(nn.Module):
    def __init__(self, num_classes):
        super(Model, self).__init__()
        self.num_classes = num_classes
        self.channels = 512
        self.length = 1500
        self.hidden1 = 128
        batchnorm = nn.BatchNorm1d

        activation_fn = nn.ReLU

        self.net = nn.Sequential(
            nn.Conv1d(self.channels, self.hidden1, 9, padding=4),
            batchnorm(self.hidden1),
            activation_fn(),
            nn.Conv1d(self.hidden1, self.num_classes, 7, padding=3),
        )

    def forward(self, x):
        return self.net(x)


# Accuracy
def accuracy(target, pred):
    return metrics.accuracy_score(
        target.detach().cpu().numpy(), pred.detach().cpu().numpy()
    )


# Given a CNN model, a training set and a validation set, return a trained model
def train_model(model, train_dataset, validation_dataset):
    # Loss function
    loss_fn = nn.CrossEntropyLoss(ignore_index=-1)

    # Optimizer
    optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-6)

    batch_size = 32

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=0,
        drop_last=False,
    )

    validation_loader = torch.utils.data.DataLoader(
        validation_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=0,
        drop_last=False,
    )

    num_epochs = 10
    validation_every_steps = 1

    step = 0
    model.train()

    train_accuracies = []
    valid_accuracies = []

    for epoch in range(num_epochs):
        train_accuracies_batches = []

        for batch in train_loader:
            # Inputs are [batch_size, 512, 1500]: 1500 long (padded) proteins, encoded using esm to get 512 latent variables
            inputs_train, targets_train = batch

            # Forward pass, compute gradients, perform one training step
            output_train = model(inputs_train[:, :-1, :])
            loss = loss_fn(output_train, targets_train)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Increment step counter
            step += 1

            # Compute accuracy
            predictions_train = output_train.max(1)[1]

            # Calculate accuracy for each protein in batch
            for idx in range(predictions_train.shape[0]):
                target_len = int(torch.sum(inputs_train[idx, -1, :]))
                train_accuracies_batches.append(
                    accuracy(
                        targets_train[idx][:target_len],
                        predictions_train[idx][:target_len],
                    )
                )

            if step % validation_every_steps == 0:
                # Append average training accuracy to list
                train_accuracies.append(np.mean(train_accuracies_batches))

                train_accuracies_batches = []

                # Compute accuracies on validation set
                validation_accuracies_batches = []

                prediction_labels_list = []
                target_labels_list = []

                with torch.no_grad():
                    model.eval()

                    for batch_val in validation_loader:
                        inputs_val, targets_val = batch_val
                        output_val = model(inputs_val[:, :-1, :])

                        predictions_val = output_val.max(1)[1]

                        for idx in range(predictions_val.shape[0]):
                            target_len = int(torch.sum(inputs_val[idx, -1, :]))

                            validation_accuracies_batches.append(
                                accuracy(
                                    targets_val[idx][0:target_len],
                                    predictions_val[idx][0:target_len],
                                )
                            )
                            prediction_labels_list += [
                                predictions_val[idx][0:target_len]
                            ]
                            target_labels_list += [targets_val[idx][0:target_len]]

                    model.train()

                valid_accuracies.append(
                    np.sum(validation_accuracies_batches) / len(validation_dataset)
                )

                print(f"Step {step:<5}")
                print(f"  training accuracy:    {train_accuracies[-1]}")
                print(f"  test accuracy:        {valid_accuracies[-1]}")

                # Extra accuracies
                (
                    error_type_pairs,
                    confusion_matrix,
                    type_accuracy,
                    detailed_type_accuracy,
                ) = mu.confusionMatrix(prediction_labels_list, target_labels_list)

                # type accuracy is average of per type, topology accuracy
                print(f"  type accuracy (test): {type_accuracy}")

                # detailed type accuracies
                for key in detailed_type_accuracy.keys():
                    print(f"  {key}")
                    for field in detailed_type_accuracy[key].keys():
                        print(f"    {field:<9}: {detailed_type_accuracy[key][field]}")

                print(confusion_matrix)

                if print_error_type_pairs:
                    for error_pair in error_type_pairs:
                        print("  Predicted topology:", error_pair["predicted topology"])
                        print("  Target topology:   ", error_pair["target topology"])

                print("")

    print("Done training a model.")

    return model


def test_model(model, test_dataset):
    # Loss function

    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=test_dataset.__len__(),
        shuffle=False,
        num_workers=0,
        drop_last=False,
    )

    # Compute accuracies on validation set
    test_accuracies = []
    prediction_labels_list = []
    target_labels_list = []

    with torch.no_grad():
        model.eval()
        for batch in test_loader:
            inputs, targets = batch
            output = model(inputs[:, :-1, :])

            predictions = output.max(1)[1]

            for idx in range(predictions.shape[0]):
                target_len = int(torch.sum(inputs[idx, -1, :]))
                test_accuracies.append(
                    accuracy(
                        targets[idx][0:target_len],
                        predictions[idx][0:target_len],
                    )
                )
                prediction_labels_list += [predictions[idx][0:target_len]]
                target_labels_list += [targets[idx][0:target_len]]

    print("  Test accuracy: " + str(np.sum(test_accuracies) / len(test_dataset)))

    # Extra accuracies
    (
        error_type_pairs,
        confusion_matrix,
        type_accuracy,
        detailed_type_accuracy,
    ) = mu.confusionMatrix(prediction_labels_list, target_labels_list)

    # type accuracy is average of per type, topology accuracy
    print(f"  type accuracy (test): {type_accuracy}")

    # detailed type accuracies
    for key in detailed_type_accuracy.keys():
        print(f"  {key}")
        for field in detailed_type_accuracy[key].keys():
            print(f"    {field:<9}: {detailed_type_accuracy[key][field]}")

    print(confusion_matrix)

    if print_error_type_pairs:
        for error_pair in error_type_pairs:
            print("  Predicted topology:", error_pair["predicted topology"])
            print("  Target topology:   ", error_pair["target topology"])


n_cv = CVProteins.keys().__len__()

# cv0Indices = CVProteins["cv0"]
# cv1Indices = CVProteins["cv1"]
# cv2Indices = CVProteins["cv2"]
# cv3Indices = CVProteins["cv3"]
# cv4Indices = CVProteins["cv4"]


# train_dataset_set = []

n_unique_labels = 7

# model1 = Model(n_unique_labels)
# model2 = Model(n_unique_labels)
# model3 = Model(n_unique_labels)
# model4 = Model(n_unique_labels)
# model5 = Model(n_unique_labels)

# train_datasets = []
# validation_datasets = []
# test_datasets = []

# for loop in range(CVProteins.keys().__len__()):
#     train_datasets += [
#         CustomProteinDataset(
#             CVProteins["cv" + str((loop + 0) % 5)][0:10]
#             + CVProteins["cv" + str((loop + 1) % 5)][0:10]
#             + CVProteins["cv" + str((loop + 2) % 5)][0:10]
#         )
#     ]
#     validation_datasets += [
#         CustomProteinDataset(CVProteins["cv" + str((loop + 3) % 5)][0:10])
#     ]
#     test_datasets += [
#         CustomProteinDataset(CVProteins["cv" + str((loop + 4) % 5)][0:10])
#     ]

for loop in range(5):
    print("---------------------------------------------------------------------------")
    print(
         f"-                           Loop {loop:<8}                                 -"
    )
    print(
        f"-  train sets: cv{str((loop + 0) % 5):<1}, cv{str((loop + 1) % 5):<1}, cv{str((loop + 2) % 5):<1}"
    )
    print(f"-  validation sets: cv{str((loop + 3) % 5):<1}")
    print(f"-  test sets: cv{str((loop + 4) % 5):<1}")
    print("---------------------------------------------------------------------------")

    train_dataset = CustomProteinDataset(
        CVProteins["cv" + str((loop + 0) % 5)][0:10]
        + CVProteins["cv" + str((loop + 1) % 5)][0:10]
        + CVProteins["cv" + str((loop + 2) % 5)][0:10]
    )
    validation_dataset = CustomProteinDataset(
        CVProteins["cv" + str((loop + 3) % 5)][0:10]
    )
    test_dataset = CustomProteinDataset(CVProteins["cv" + str((loop + 4) % 5)][0:10])

    model = Model(n_unique_labels)
    trained_model = train_model(model, train_dataset, validation_dataset)
    test_model(trained_model, test_dataset)

  from .autonotebook import tqdm as notebook_tqdm


encoding proteins


30it [00:00, 349.13it/s]


encoding proteins


10it [00:00, 430.05it/s]


encoding proteins


10it [00:00, 332.12it/s]


encoding proteins


30it [00:00, 395.62it/s]


encoding proteins


10it [00:00, 326.16it/s]

encoding proteins



10it [00:00, 424.89it/s]


encoding proteins


30it [00:00, 375.52it/s]

encoding proteins



10it [00:00, 396.25it/s]


encoding proteins


10it [00:00, 353.31it/s]


encoding proteins


30it [00:00, 403.27it/s]


encoding proteins


10it [00:00, 266.17it/s]


encoding proteins


10it [00:00, 468.41it/s]


encoding proteins


30it [00:00, 191.54it/s]


encoding proteins


10it [00:00, 174.58it/s]


encoding proteins


10it [00:00, 323.37it/s]


---------------------------------------------------------------------------
-                            Loop 0                                       -
-  train sets: cv0, cv1, cv2
-  validation sets: cv3
-  test sets: cv4
---------------------------------------------------------------------------
encoding proteins


30it [00:00, 336.76it/s]


encoding proteins


10it [00:00, 259.07it/s]


encoding proteins


10it [00:00, 466.28it/s]


Step 1    
  training accuracy:    0.06344225596407861
  test accuracy:        0.5515365410095044
  type accuracy (test): 0.5600000023841858
  tm
    type     : 0.0
    topology : 0.0
  sptm
    type     : 1
    topology : 1
  sp
    type     : 0.0
    topology : 0.0
  glob
    type     : 0.800000011920929
    topology : 0.800000011920929
  beta
    type     : 1
    topology : 1
tensor([[0, 0, 0, 2, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 2, 0, 0, 1],
        [0, 0, 0, 0, 0, 4, 1],
        [0, 0, 0, 0, 0, 0, 0]])

Step 2    
  training accuracy:    0.5497283729526447
  test accuracy:        0.5502740170144208
  type accuracy (test): 0.48000001907348633
  tm
    type     : 0.0
    topology : 0.0
  sptm
    type     : 1
    topology : 1
  sp
    type     : 0.0
    topology : 0.0
  glob
    type     : 0.4000000059604645
    topology : 0.4000000059604645
  beta
    type     : 1
    topology : 1
tensor([[0, 0, 0, 1, 0, 0, 1],
        [0, 0, 0, 0, 0, 0, 0],
        [0, 0, 

30it [00:00, 326.23it/s]


encoding proteins


10it [00:00, 299.32it/s]


encoding proteins


10it [00:00, 336.28it/s]


Step 1    
  training accuracy:    0.11458695622210174
  test accuracy:        0.5902783915156873
  type accuracy (test): 0.6000000238418579
  tm
    type     : 0.0
    topology : 0.0
  sptm
    type     : 1
    topology : 1
  sp
    type     : 0.0
    topology : 0.0
  glob
    type     : 1.0
    topology : 1.0
  beta
    type     : 1
    topology : 1
tensor([[0, 0, 0, 2, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 3, 0, 0, 0],
        [0, 0, 0, 0, 0, 5, 0],
        [0, 0, 0, 0, 0, 0, 0]])

Step 2    
  training accuracy:    0.5405344805730264
  test accuracy:        0.5921485128746264
  type accuracy (test): 0.5600000023841858
  tm
    type     : 0.0
    topology : 0.0
  sptm
    type     : 1
    topology : 1
  sp
    type     : 0.0
    topology : 0.0
  glob
    type     : 0.800000011920929
    topology : 0.800000011920929
  beta
    type     : 1
    topology : 1
tensor([[0, 0, 0, 2, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 2, 0, 0, 1],
        [0, 0, 

30it [00:00, 316.83it/s]


encoding proteins


10it [00:00, 605.54it/s]

encoding proteins



10it [00:00, 319.48it/s]


Step 1    
  training accuracy:    0.21470108479445596
  test accuracy:        0.4883313817330211
  type accuracy (test): 0.20000000298023224
  tm
    type     : 0.0
    topology : 0.0
  sptm
    type     : 0.0
    topology : 0.0
  sp
    type     : 0.0
    topology : 0.0
  glob
    type     : 1.0
    topology : 1.0
  beta
    type     : 0.0
    topology : 0.0
tensor([[0, 0, 0, 1, 0, 0, 0],
        [0, 0, 0, 1, 0, 0, 0],
        [0, 0, 0, 3, 0, 0, 0],
        [0, 0, 0, 0, 0, 4, 0],
        [0, 0, 0, 1, 0, 0, 0]])

Step 2    
  training accuracy:    0.6075213742751137
  test accuracy:        0.4883313817330211
  type accuracy (test): 0.20000000298023224
  tm
    type     : 0.0
    topology : 0.0
  sptm
    type     : 0.0
    topology : 0.0
  sp
    type     : 0.0
    topology : 0.0
  glob
    type     : 1.0
    topology : 1.0
  beta
    type     : 0.0
    topology : 0.0
tensor([[0, 0, 0, 1, 0, 0, 0],
        [0, 0, 0, 1, 0, 0, 0],
        [0, 0, 0, 3, 0, 0, 0],
        [0, 0, 0, 0, 0, 4

30it [00:00, 306.51it/s]


encoding proteins


10it [00:00, 265.76it/s]

encoding proteins



10it [00:00, 302.36it/s]


Step 1    
  training accuracy:    0.3001464175065683
  test accuracy:        0.44423377374410755
  type accuracy (test): 0.20000000298023224
  tm
    type     : 0.0
    topology : 0.0
  sptm
    type     : 0.0
    topology : 0.0
  sp
    type     : 0.0
    topology : 0.0
  glob
    type     : 1.0
    topology : 1.0
  beta
    type     : 0.0
    topology : 0.0
tensor([[0, 0, 0, 1, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 1],
        [0, 0, 0, 2, 0, 0, 1],
        [0, 0, 0, 0, 0, 4, 0],
        [0, 0, 0, 0, 0, 0, 1]])

Step 2    
  training accuracy:    0.5814505703422939
  test accuracy:        0.4526932916977421
  type accuracy (test): 0.20000000298023224
  tm
    type     : 0.0
    topology : 0.0
  sptm
    type     : 0.0
    topology : 0.0
  sp
    type     : 0.0
    topology : 0.0
  glob
    type     : 1.0
    topology : 1.0
  beta
    type     : 0.0
    topology : 0.0
tensor([[0, 0, 0, 1, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 1],
        [0, 0, 0, 2, 0, 0, 1],
        [0, 0, 0, 0, 0, 4

30it [00:00, 230.56it/s]


encoding proteins


10it [00:00, 683.81it/s]


encoding proteins


10it [00:00, 226.36it/s]


Step 1    
  training accuracy:    0.04628377883907646
  test accuracy:        0.6065485040578575
  type accuracy (test): 0.5600000023841858
  tm
    type     : 0.0
    topology : 0.0
  sptm
    type     : 1
    topology : 1
  sp
    type     : 0.0
    topology : 0.0
  glob
    type     : 0.800000011920929
    topology : 0.800000011920929
  beta
    type     : 1
    topology : 1
tensor([[0, 0, 0, 0, 0, 0, 2],
        [0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 3],
        [0, 0, 0, 0, 0, 4, 1],
        [0, 0, 0, 0, 0, 0, 0]])

Step 2    
  training accuracy:    0.6113271520702597
  test accuracy:        0.6118019649140727
  type accuracy (test): 0.5600000023841858
  tm
    type     : 0.0
    topology : 0.0
  sptm
    type     : 1
    topology : 1
  sp
    type     : 0.0
    topology : 0.0
  glob
    type     : 0.800000011920929
    topology : 0.800000011920929
  beta
    type     : 1
    topology : 1
tensor([[0, 0, 0, 0, 0, 0, 2],
        [0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 