In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import torch

import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import ConcatDataset
from torch.utils.data import DataLoader
from torch.optim import lr_scheduler

from sklearn.utils import class_weight
from sklearn.model_selection import KFold

import wandb

from math import log10, floor
import random
from datetime import datetime

from process_input.input_to_1D_feature_matrix import parse_matrices, FeatureDataset
from model_skeletons.fc_basemodel.base_fc_v3 import base_fc, model_version
from compute_accuracy import compute_accuracy
from class_weights_cv import get_class_weights

model_script=os.path.join("model_skeletons", "fc_basemodel", "{}.py".format(model_version)) #Used to read the file for logs
assert os.path.isfile(model_script), "Could not find file '{}'".format(model_script)
new_files_creation="create_new_data_files.py" #Used to read the file for logs
input_to_matrix=os.path.join("process_input", "input_to_1D_feature_matrix.py") #Used to read the file for logs
assert os.path.isfile(input_to_matrix), "Could not find file '{}'".format(input_to_matrix)
compute_accuracy_file = "compute_accuracy.py"
assert os.path.isfile(compute_accuracy_file), "Could not find file '{}'".format(compute_accuracy_file)
valid_batch_size=5
log_file=os.path.join("logs", "basemodel_fc", "{}.log".format(model_version))
cv_log_file=os.path.join("logs", "basemodel_fc", "cv_{}.log".format(model_version))
last_run_log_file=os.path.join("logs", "basemodel_fc", "current_run_log_{}.log".format(model_version))
last_run_log_file_final_runs=os.path.join("logs", "basemodel_fc", "{}_all_runs.log".format(model_version))
model_path=os.path.join("saved_models", "basemodel_fc", "{}.pth".format(model_version))
cv=False
train=False

In [2]:
#device = torch.device('cpu')
device = torch.device('cuda')
many_classes=True #ACGT order in onehot

In [3]:
def write_to_log_file(learning_rate, optimizer_type, weight_decay, momentum, is_scheduler, T_max, smoothing, 
                      train_batch_size, valid_batch_size, log_file, model_version, 
                      input_to_matrix, new_greatest_valid_acc, model_path, many_classes,
                      compute_accuracy_file, model_script, run_name, dt_string, acc_list=None):
    """Called, when new highest validation accuracy is found.
    
    Writes everything important information to a log file."""
    with open(log_file, 'w+') as fw:
        fw.write("Created: {}\nModel version: {}\nPath: {}\nRun name: {}\nAccuracy: {}\n\n".format(dt_string, model_version, 
                                                                                                   model_path, run_name, 
                                                                                                   new_greatest_valid_acc))
        if acc_list:
            for k, accuracy in enumerate(acc_list):
                fw.write("Fold {}: {}.\n".format(k+1, accuracy))
            fw.write('\n')
        fw.write("Hyperparameters:\nOptimizer: {}\n".format(optimizer_type))
        if optimizer_type=="SGD":
            fw.write("Momentum: {}\n".format(momentum))
        fw.write("Learning rate: {}\nWeight decay: {}\n".format(learning_rate, weight_decay))
        if is_scheduler:
            fw.write("Used CosineAnnealingLR scheduler with T_max {}\n".format(T_max))
        if many_classes:
            fw.write("Used CrossEntropyLoss with label smoothing {}\n".format(smoothing))
            fw.write("Balanced classes\n")
        else:
            fw.write("Used BCELoss without label smoothing\n")
        fw.write("Train batch size: {}\nValidation batch size: {}\n\n".format(train_batch_size, valid_batch_size))
        
        fw.write("\n--------------------------Script of the model can be seen below.---------------------------\n")
        with open(model_script, 'r') as fr:
            fw.write(fr.read())
        fw.write("\n-------------------------------------------------------------------------------------------")
        fw.write("\n\n\n")
    
        fw.write("\n------------------------------Created input matrices with script:------------------------------\n")
        with open(input_to_matrix, 'r') as fr:
            fw.write(fr.read())
        fw.write("\n-------------------------------------------------------------------------------------------")
        fw.write("\n\n\n")
        
        fw.write("\n------------------------------Computed accuracies with script:-----------------------------\n")
        with open(compute_accuracy_file, 'r') as fr:
            fw.write(fr.read())
        fw.write("\n-------------------------------------------------------------------------------------------\n")

In [4]:
class_types=['AA', 'AC', 'AG', 'AT', 'CA', 'CC', 'CG', 'CT', 'GA', 'GC', 'GG',
            'GT', 'TA', 'TC', 'TG', 'TT']
classes={'CA': 0, 'CC': 1, 'CG': 2, 'CT': 3, 'TA': 4, 'TC': 5, 'TG': 6, 'TT': 7}

In [8]:
train_file="data/20220214/sompred_crc9_clu1_pyri_mut_combined_train.matrix"
valid_file="data/20220214/sompred_crc9_clu1_pyri_mut_combined_valid.matrix"
test_file="data/20220214/sompred_crc9_clu1_pyri_mut_combined_test.matrix"

In [9]:
train_input, train_target = parse_matrices(train_file, classes)
valid_input, valid_target = parse_matrices(valid_file, classes)
test_input, test_target = parse_matrices(test_file, classes)

In [10]:
train_dataset=FeatureDataset(data=train_input, labels=train_target)
valid_dataset=FeatureDataset(data=valid_input, labels=valid_target)
test_dataset=FeatureDataset(data=test_input, labels=test_target)
combined_valid_train = ConcatDataset([train_dataset, valid_dataset]) #Combines validation and training datasets for cross validation

In [11]:
print(train_target.shape)
print(torch.sum(train_target!=1),torch.sum(train_target!=7))
print("Train neg, pos:",train_target.shape[0]-torch.sum(train_target==1)-torch.sum(train_target==7),
      torch.sum(train_target==1)+torch.sum(train_target==7))
print("Validation neg, pos:",valid_target.shape[0]-torch.sum(valid_target==1)-torch.sum(valid_target==7),
      torch.sum(valid_target==1)+torch.sum(valid_target==7))
print("Test neg, pos:",test_target.shape[0]-torch.sum(test_target==1)-torch.sum(test_target==7),
      torch.sum(test_target==1)+torch.sum(test_target==7))

torch.Size([9244])
tensor(6152) tensor(4943)
Train neg, pos: tensor(1851) tensor(7393)
Validation neg, pos: tensor(323) tensor(1303)
Test neg, pos: tensor(98) tensor(98)


In [12]:
def cv_train_network(net,criterion, valid_criterion,epochs,optimizer,trainloader, validloader,
                  is_scheduler,scheduler, wandb, early_stop=100):
    j=0
    greatest_acc=0
    min_tot_loss=float('inf')
    tot_loss=0
    tot_items=0
    for i in range(epochs):
        net.train()
        for sequences, labels in trainloader:
            sequences, labels = sequences.to(device), labels.to(device)
            optimizer.zero_grad()
            out=net(sequences)
            out=out.squeeze()
            loss=criterion(out,labels)
            tot_loss+=loss.item()
            tot_items+=len(labels)
            loss.backward()
            if torch.isnan(loss):
                raise RuntimeError("NAN!")
            optimizer.step()
        if is_scheduler:
            scheduler.step()
        tot_loss/=tot_items
        accuracy, tot_valid_loss = compute_accuracy(device, net, validloader, valid_criterion, 
                                                "VALID", verbose = False, cv=True)
        if wandb!=None:
            wandb.log({"Training loss": tot_loss,
                       "Validation loss": tot_valid_loss,
                       "Valid Accuracy": accuracy,
            #           "Test loss": test_loss,
            #           "Test Accuracy": test_accuracy,
            #           "Pooled test recall": fake_recall_test,
            #           "Pooled test precision": fake_precision_test,
            #           "Learning rate": optimizer.param_groups[0]['lr'],
            #           "Scheduler": is_scheduler,
                       "Epoch": i})
        if round(accuracy,3)<=round(greatest_acc,3):
            pass
        else:
            greatest_acc=accuracy
        if round(tot_valid_loss,3)>=round(min_tot_loss,3):
            j+=1
            if j>=early_stop and i>100:
                break
        else:
            j=0
            min_tot_loss=tot_valid_loss
        
    return greatest_acc

In [13]:
def train_network(net,criterion, valid_criterion,epochs,optimizer,trainloader, validloader,
                  is_scheduler,scheduler, greatest_acc_overall, model_path, wandb, early_stop=50):
    j=0
    greatest_acc=0
    tot_loss=0
    tot_items=0
    for i in range(epochs):
        net.train()
        for sequences, labels in trainloader:
            sequences, labels = sequences.to(device), labels.to(device)
            optimizer.zero_grad()
            out=net(sequences)
            out=out.squeeze()
            loss=criterion(out,labels)
            tot_loss+=loss.item()
            tot_items+=len(labels)
            loss.backward()
            if torch.isnan(loss):
                raise RuntimeError("NAN!")
            optimizer.step()
        if is_scheduler:
            scheduler.step()
        tot_loss/=tot_items
        accuracy, tot_valid_loss = compute_accuracy(device, net, validloader, valid_criterion, "VALID", verbose = False, cv = True)
        
        if wandb!=None:
            wandb.log({"Training loss": tot_loss,
                       "Validation loss": tot_valid_loss,
                       "Valid Accuracy": accuracy,
            #           "Test loss": test_loss,
            #           "Test Accuracy": test_accuracy,
            #           "Pooled test recall": fake_recall_test,
            #           "Pooled test precision": fake_precision_test,
            #           "Learning rate": optimizer.param_groups[0]['lr'],
            #           "Scheduler": is_scheduler,
                       "Epoch": i})
            
        if round(accuracy,3)<=round(greatest_acc,3):
            if early_stop:
                j+=1
                if j>=early_stop and i>100:
                    print("Greates validation acc: {}".format(greatest_acc))
                    break
        else:
            if accuracy>greatest_acc_overall:
                torch.save(net.state_dict(), model_path)
                greatest_acc_overall=accuracy
            j=0
            greatest_acc=accuracy
    print("Greatest accuracy on run: {}".format(greatest_acc))
    return greatest_acc

In [14]:
def get_earlier_accuracy(log_file):
    with open(log_file, 'r') as fr:
        for line in fr:
            if "Accuracy:" in line:
                return float(line.strip().split(' ')[1]) #Accuracy is written as Accuracy: <acc>

In [15]:
greatest_avg_valid_acc = 0
if os.path.isfile(cv_log_file):
    greatest_avg_valid_acc = get_earlier_accuracy(cv_log_file)
print(greatest_avg_valid_acc)

0.2627335095525115


In [16]:
if not os.path.isfile(last_run_log_file):
    with open(last_run_log_file, 'w+') as fw:
        fw.write("Run log.\n\n")

In [17]:
#k_fold cross_validation for hyperparameters
if cv:
    k_folds = 5
    epochs=5000
    run_name="None"
    testloader=None

    kfold = KFold(n_splits=k_folds, shuffle=True) #batch size affects the size of datasets
    for i in range(150): #Test with 30 different hyperparameter combinations
        valid_accuracies = list()
        mom_text=""
        scheduler_text=""
        momentum=None
        T_max=None
        is_scheduler=False
        scheduler=None

        learning_rate=random.sample([0.01, 0.001, 0.0001, 0.00001], 1)[0]
        lr_text=str(learning_rate).replace(".","d")
        train_batch_size=random.sample([32, 64, 128], 1)[0]

        optimizer_type=random.sample(["Adam","SGD"], 1)[0]
        if optimizer_type=="SGD":
            momentum= random.sample([0, np.random.uniform()], 1)[0]
            if momentum!=0: momentum = round(momentum, -int(floor(log10(momentum))) + 2)
            mom_text="_mom"+str(round(momentum,2)).replace(".","d")
            is_scheduler=random.sample([True, False], 1)[0]
            if is_scheduler:
                T_max=random.sample([1, np.random.uniform(low=0.2)], 1)[0]
                T_max = round(T_max, -int(floor(log10(T_max))) + 2)
        weight_decay=random.sample([0, 0.00001, 0.0001, 0.001], 1)[0]
        if weight_decay!=0: weight_decay = round(weight_decay, -int(floor(log10(weight_decay))) + 2)
        decay_text="_wdecay"+str(weight_decay).replace(".","d")
        smoothing=random.sample([0, np.random.uniform(high=0.05)], 1)[0]
        if smoothing!=0: smoothing = round(smoothing, -int(floor(log10(smoothing))) + 2)

        for (train_ids, test_ids) in kfold.split(combined_valid_train):
            # Sample elements randomly from a given list of ids, no replacement.
            train_subsampler = torch.utils.data.SubsetRandomSampler(train_ids)
            test_subsampler = torch.utils.data.SubsetRandomSampler(test_ids)
            trainloader = torch.utils.data.DataLoader(
                          combined_valid_train, 
                          batch_size=train_batch_size, sampler=train_subsampler)
            validloader = torch.utils.data.DataLoader(
                              combined_valid_train,
                              batch_size=valid_batch_size, sampler=test_subsampler)
            net = base_fc().to(device)
            if optimizer_type=="Adam":
                optimizer=torch.optim.Adam(net.parameters(), lr=learning_rate, weight_decay=weight_decay)
            elif optimizer_type=="SGD":
                optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate, 
                                            momentum=momentum, weight_decay=weight_decay)
                if is_scheduler:
                    scheduler=lr_scheduler.CosineAnnealingLR(optimizer, T_max=T_max*epochs)
            else:
                raise RuntimeError("WRONG OPTIMIZER: {}".format(optimizer_type))


            train_class_weights, valid_class_weights = get_class_weights(device, trainloader, validloader)
            criterion=nn.CrossEntropyLoss(weight=train_class_weights, label_smoothing=smoothing)
            valid_criterion=nn.CrossEntropyLoss(weight=valid_class_weights)
            
            os.system("rm -rf masters_thesis/base_model_fc/wandb/run-*")
            run = wandb.init(project='MLP')
            run_name = "batch"+str(train_batch_size)+"_lr-"+lr_text \
            +"smooth"+str(round(smoothing,2)).replace('.','d')+"optim-"+optimizer_type+decay_text \
            +"_sch-"+str(is_scheduler)+scheduler_text+mom_text
            wandb.run.name = run_name
            config = wandb.config
            config.batch_size=train_batch_size
            #config.optimizer_type=optimizer_type
            #config.learning_rate = learning_rate
            #config.weight_decay=weight_decay
            #config.is_scheduler=is_scheduler

            valid_acc = cv_train_network(net,criterion,valid_criterion,epochs,optimizer,trainloader, validloader, 
                                      is_scheduler,scheduler, wandb)
            valid_accuracies.append(valid_acc)
        avg_valid_acc = sum(valid_accuracies) / len(valid_accuracies)

        print("Optimizer: {}\nLearning Rate: {}\nScheduler: {}".format(optimizer_type, learning_rate, is_scheduler))
        print("Weight decay: {}\nSmoothing: {}".format(weight_decay, smoothing))
        if optimizer_type=="SGD":
            print("Momentum: {}".format(momentum))
            if is_scheduler:
                print("T_max: {}\n".format(T_max))
        print("Average validation accuracy: {}".format(avg_valid_acc))
        for k, accuracy in enumerate(valid_accuracies):
            print("Fold {}: {}.".format(k+1, accuracy))

        with open(last_run_log_file, 'a+') as fw:
            fw.write("Time: {}\n".format(datetime.now().strftime("%d.%m.%Y %H:%M:%S")))
            fw.write("Average validation accuracy: {}\n".format(avg_valid_acc))
            fw.write("\n".join(["Fold {}: {}.".format(k+1, accuracy) for k, accuracy in enumerate(valid_accuracies)]))
            fw.write("\nOptimizer: {}\nLearning Rate: {}\nScheduler: {}\n".format(optimizer_type, learning_rate, is_scheduler))
            fw.write("Weight decay: {}\nSmoothing: {}\nBatch size: {}\n".format(weight_decay, smoothing, train_batch_size))
            if optimizer_type=="SGD":
                fw.write("Momentum: {}\n".format(momentum))
                if is_scheduler:
                    fw.write("T_max: {}\n".format(T_max))
            fw.write('\n\n')

        if avg_valid_acc>greatest_avg_valid_acc:
            now = datetime.now()
            dt_string = now.strftime("%d.%m.%Y %H:%M:%S")
            write_to_log_file(learning_rate, optimizer_type, weight_decay, momentum, is_scheduler, T_max,
                             smoothing, train_batch_size, valid_batch_size, cv_log_file, model_version, 
                              input_to_matrix, 
                              avg_valid_acc, model_path, many_classes, compute_accuracy_file,
                              model_script, run_name, dt_string, acc_list=valid_accuracies)
            greatest_avg_valid_acc=avg_valid_acc

In [18]:
greatest_acc_overall=0
if os.path.isfile(log_file):
    greatest_acc_overall=get_earlier_accuracy(log_file)
print(greatest_acc_overall)
train_batch_size=64

0.2596905056830896


In [19]:
trainloader = DataLoader(train_dataset
    ,batch_size=train_batch_size
    ,shuffle=True
    ,drop_last=True
)
validloader = DataLoader(valid_dataset
    ,batch_size=5
    ,shuffle=False
)

In [20]:
if many_classes:
    class_weights=class_weight.compute_class_weight(
        class_weight='balanced',
        classes=np.unique(train_target),
        y=np.array(train_target))
    class_weights=torch.tensor(class_weights,dtype=torch.float)
    class_weights=class_weights.to(device)
    
    valid_class_weights=class_weight.compute_class_weight(
        class_weight='balanced',
        classes=np.unique(train_target),
        y=np.array(train_target))
    valid_class_weights=torch.tensor(valid_class_weights,dtype=torch.float)
    valid_class_weights=valid_class_weights.to(device)
    
else:
    print(np.array(train_target))
print(class_weights, valid_class_weights)

tensor([2.4275, 0.3737, 9.0984, 1.8027, 6.3840, 3.8135, 9.3943, 0.2687],
       device='cuda:0') tensor([2.4275, 0.3737, 9.0984, 1.8027, 6.3840, 3.8135, 9.3943, 0.2687],
       device='cuda:0')


In [21]:
if not os.path.isfile(last_run_log_file_final_runs):
    with open(last_run_log_file_final_runs, 'w+') as fw:
        fw.write("Run log.\n\n")

In [22]:
if train:
    learning_rate=0.001
    momentum=None
    T_max=None
    is_scheduler=False
    optimizer_type="Adam"
    weight_decay=0.0001
    smoothing=0.0371
    scheduler=None
    epochs=5000
    criterion=nn.CrossEntropyLoss(weight=class_weights, label_smoothing=smoothing)
    valid_criterion=nn.CrossEntropyLoss(weight=valid_class_weights)
    lr_text=str(learning_rate).replace(".","d")
    decay_text="_wdecay"+str(weight_decay).replace(".","d")
    mom_text=""
    scheduler_text=""
    for i in range(1000):
        net = base_fc().to(device)
        if optimizer_type=="Adam":
            optimizer=torch.optim.Adam(net.parameters(), lr=learning_rate, weight_decay=weight_decay)
        elif optimizer_type=="SGD":
            optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate, 
                                        momentum=momentum, weight_decay=weight_decay)
            mom_text="_mom"+str(round(momentum,2)).replace(".","d")
        else:
            raise RuntimeError("WRONG OPTIMIZER.")
        if is_scheduler:
            scheduler=lr_scheduler.CosineAnnealingLR(optimizer, T_max=T_max*epochs)
            scheduler_text="-cos"+str(T_max).replace(".","d")
        run = wandb.init(project='mlp_basemodel_final')
        run_name="run_{}".format(i)
        wandb.run.name = run_name
        config = wandb.config
        config.batch_size=train_batch_size
        config.optimizer_type=optimizer_type
        config.learning_rate = learning_rate
        config.weight_decay=weight_decay
        config.is_scheduler=is_scheduler
        if is_scheduler:
            config.T_max=T_max
            config.scheduler="Cosine"
        if optimizer_type=="SGD":
            config.momentum = momentum

        greatest_acc = train_network(net,criterion, valid_criterion,epochs,optimizer,trainloader, validloader,
                          is_scheduler,scheduler, greatest_acc_overall, model_path, wandb, early_stop=None)

        with open(last_run_log_file_final_runs, 'a+') as fw:
                fw.write("Time: {}\n".format(datetime.now().strftime("%d.%m.%Y %H:%M:%S")))
                fw.write("Accuracy: {}\n".format(greatest_acc))
                fw.write("\nOptimizer: {}\nLearning Rate: {}\nScheduler: {}\n".format(optimizer_type, learning_rate, is_scheduler))
                fw.write("Weight decay: {}\nSmoothing: {}\nBatch size: {}\n".format(weight_decay, smoothing, train_batch_size))
                if optimizer_type=="SGD":
                    fw.write("Momentum: {}\n".format(momentum))
                    if is_scheduler:
                        fw.write("T_max: {}\n".format(T_max))
                fw.write('\n\n')


        if greatest_acc>greatest_acc_overall:
            now = datetime.now()
            dt_string = now.strftime("%d.%m.%Y %H:%M:%S")
            write_to_log_file(learning_rate, optimizer_type, weight_decay, momentum, is_scheduler, T_max,
                             smoothing, train_batch_size, valid_batch_size, log_file, model_version, 
                              input_to_matrix, 
                              greatest_acc, model_path, many_classes, compute_accuracy_file,
                              model_script, run_name, dt_string)
            greatest_acc_overall=greatest_acc
                #raise Exception("LOL")
                        #run.finish()
    #learning_rate=0.001
    #momentum=0.9
    #gamma=2
    #is_scheduler=True
    #optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate, momentum=momentum)#, weight_decay=0.1-0.0001
    #scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=gamma)
    #scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=250)

In [23]:
testloader = DataLoader(test_dataset
    ,batch_size=5
    ,shuffle=False
)

In [24]:
class model_skeleton(nn.Module):
    def __init__(self):
        super(base_fc, self).__init__()
        # YOUR CODE HERE
        self.dropout=nn.Dropout(0.3) #changed from 0.2->0.3 1.6.2022
        self.fc1=nn.Linear(84,120)#150) 
        self.bn1 = nn.BatchNorm1d(120)
        self.fc2=nn.Linear(120,8)
        self.sigmoid=nn.Sigmoid()
        

    def forward(self, x):
        """
        Args:
          x of shape (batch_size, 84): Input sequences.
        
        Returns:
          y of shape (batch_size, 8): Outputs of the network.
        """
        
        
        y=self.fc1(x)
        y=self.dropout(y)
        y=F.relu(y)
        y=self.bn1(y)
        
        y=self.fc2(y)
        return y

In [25]:

#model = model_skeleton(many_classes)
model=base_fc()
#model.load_state_dict(torch.load(model_path))
model.load_state_dict(torch.load(os.path.join("saved_models", "basemodel_fc", "base_fc_v3.pth")))
model.eval()
model.to(device)

base_fc(
  (dropout): Dropout(p=0.3, inplace=False)
  (fc1): Linear(in_features=84, out_features=120, bias=True)
  (bn1): BatchNorm1d(120, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc2): Linear(in_features=120, out_features=8, bias=True)
  (sigmoid): Sigmoid()
)

In [26]:
criterion=nn.CrossEntropyLoss(weight=class_weights)
accuracy, loss = \
compute_accuracy(device, model, testloader, None, "TEST", verbose=True, cv=False)

print("Accuracy for the test data:",accuracy)
#print("Loss for test data:",loss)
# print("F1-score:",f1)
# print("F2-score:",f2)
# print("Precision:",precision)
# print("Recall:",recall)


 TEST
TP: 22 . FN: 76 TP/(TP+FN): 0.22448979591836735 TN: 17 FP: 81 TN/(TN+FP): 0.17346938775510204 Wrong positive class predicted: 64 Wrong negative class predicted: 15
Fake F1-score: 0.6880000000000001 . Fake F2-score: 0.7904411764705882
Fake TP/(TP+FN): 0.5308641975308642 Fake TN/(TN+FP) 0.2831858407079646
Fake precision: 0.5657894736842105 Fake recall: 0.8775510204081632
F1-score: 0.21890547263681592
F2-score: 0.22222222222222224
Precision: 0.21359223300970873
Recall: 0.22448979591836735
Fake accuracy: 0.6020408163265306
Accuracy for the test data: 0.22930155249234196


In [None]:
def countProbabilityDistributions(net, dataloader):
    distributions=dict()
    m = nn.Softmax(dim=1)
    predictions=None
    with torch.no_grad():
        for sequences, labels in dataloader:
            sequences, labels = sequences.to(device), labels.to(device)
            outputs = net(sequences)
            _, predicted = torch.max(outputs.data, 1)
            if predictions==None:
                predictions=predicted
                trueLabels=labels
            else:
                predictions=torch.cat((predictions, predicted), 0)
                trueLabels=torch.cat((trueLabels, labels), 0)
            for i in range(outputs.data.shape[1]):
                if not i in distributions:
                    distributions[i]=m(outputs).data[:,i]
                else:
                    distributions[i]=torch.cat((distributions[i], m(outputs).data[:,i]), 0)
                    #print(m(outputs).data[0,:])
                    #print(predicted)
                    #raise RuntimeError("LOL")
    return distributions, predictions, trueLabels
#distributions, predictions, trueLabels= countProbabilityDistributions(net, trainloader)
distributions, predictions, trueLabels= countProbabilityDistributions(net, trainloader)

In [None]:
print(classes)
plt.hist(predictions.to("cpu").numpy(),bins=[0,1,2,3,4,5,6,7,8,9]);

In [None]:
trueLabels[3]

In [None]:
classes

In [None]:
classNames=['CA', 'CC', 'CG', 'CT', 'TA', 'TC', 'TG', 'TT']
for i in range(8):
    plt.figure(figsize=(20,20))
    condition=(predictions==i) & (trueLabels==i)
    plt.hist(distributions[i][condition].to("cpu").numpy(),bins=100, color='r', alpha=0.5);#, density=True);
    condition=(predictions==i) & (trueLabels!=i)
    plt.hist(distributions[i][condition].to("cpu").numpy(),bins=100, color='b', alpha=0.5);#, density=True);
    plt.title(classNames[i]);
    plt.xlim([0,1])
    plt.savefig('probDistr_{}.pdf'.format(classNames[i]))

In [None]:
print(classes)
plt.figure(figsize=(20,20))
condition=(predictions==1)
plt.hist(distributions[1][condition].to("cpu").numpy(),bins=100, color='r', alpha=0.5, density=True);
other=torch.cat((distributions[0][predictions==0],distributions[2][predictions==2],distributions[3][predictions==3]),dim=0)
plt.hist(other.to("cpu").numpy(),bins=100, color='b', alpha=0.5, density=True);

In [None]:
fig,ax = plt.subplots(2,2,constrained_layout = True) # Instantiate figure and axes object
ax[0][0].hist(distributions[0][predictions==0].to("cpu").numpy(),bins=100, color='r', density=True);
ax[0][0].title.set_text('CA 475')
ax[0][1].hist(distributions[1][predictions==1].to("cpu").numpy(),bins=100, color='r', density=True);
ax[0][1].title.set_text('CC 3048')
ax[1][0].hist(distributions[2][predictions==2].to("cpu").numpy(),bins=100, color='r', density=True);
ax[1][0].title.set_text('CG 127')
ax[1][1].hist(distributions[3][predictions==3].to("cpu").numpy(),bins=100, color='r', density=True);
ax[1][1].title.set_text('CT 639')

In [None]:
fig,ax = plt.subplots(2,2,constrained_layout = True) # Instantiate figure and axes object
ax[0][0].hist(distributions[4][predictions==4].to("cpu").numpy(),bins=100, color='r', density=True);
ax[0][0].title.set_text('TA 181')
ax[0][1].hist(distributions[5][predictions==5].to("cpu").numpy(),bins=100, color='r', density=True);
ax[0][1].title.set_text('TC 303')
ax[1][0].hist(distributions[6][predictions==6].to("cpu").numpy(),bins=100, color='r', density=True);
ax[1][0].title.set_text('TG 122')
ax[1][1].hist(distributions[7][predictions==7].to("cpu").numpy(),bins=100, color='r', density=True);
ax[1][1].title.set_text('TT 4285')

In [None]:
fig,ax = plt.subplots(2,4) # Instantiate figure and axes object
for index in range(8):
    ax[int(index>3)][index-(index>3)*4].hist(distributions[index].to("cpu").numpy(), density=True, bins=20)#,
            #bins=[0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9], density=True, histtype="step")
    ax[int(index>3)][index-(index>3)*4].title.set_text(str(index))
plt.legend()
plt.show()

In [None]:
fig,axes = plt.subplots(2,1) # Instantiate figure and axes object
axes[0].hist(distributions[0].to("cpu").numpy(), label=str(0),
        bins=[0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9], density=True, histtype="step");
axes[1].hist(distributions[1].to("cpu").numpy(), label=str(1),
        bins=[0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9], density=True, histtype="step");
plt.legend()
plt.show()

In [None]:
print(distributions[0][0:11].data)

In [None]:
accuracy, loss, f1, f2, precision, recall, f1_fake, f2_fake, fake_precision, fake_recall = \
compute_accuracy(net, testloader, criterion, "TEST", True)

print("Accuracy for the test data:",accuracy)
print("Loss for test data:",loss)
print("F1-score:",f1)
print("F2-score:",f2)
print("Precision:",precision)
print("Recall:",recall)
config.test_acc=accuracy
config.test_loss=loss
config.test_f1=f1
config.test_f2=f2
config.test_precis=precision
config.test_recall=recall

In [None]:
for sequences, labels in testloader:
    sequences, labels = sequences.to(device), labels.to(device)
    outputs = net(sequences)
    _, predicted = torch.max(outputs.data, 1)
    for sequence, label, prediction, output in zip(sequences, labels, predicted, outputs):
        if label in [0,5,10,15]:
            if label==prediction:
                continue
            else:
                if prediction in [0,5,10,15]:
                    print(class_types[prediction], class_types[label],sequence, output)

In [None]:
for sequences, labels in testloader:
    sequences, labels = sequences.to(device), labels.to(device)
    outputs = net(sequences)
    _, predicted = torch.max(outputs.data, 1)
    for sequence, label, prediction, output in zip(sequences, labels, predicted, outputs):
        if prediction==label:
            print(sequence,'\n\n',output,'\n\n', label,'\n\n', prediction, class_types[label.item()], 
                  class_types[prediction.item()])
            #print("TRUE:",class_types[prediction.item()])
            raise UserWarning('Exit Early')
        else:
            print("FALSE",sequence,'\n\n',output,'\n\n', label,'\n\n', prediction, class_types[label.item()],
                 class_types[prediction.item()])
            raise UserWarning('Exit Early')

In [None]:
for sequences, labels in trainloader:
    for label in labels:
        if class_types[label.item()]=="TT":
            print("TT")

In [None]:
test_tensor=torch.FloatTensor(np.array([[0,0,1,0,0,0,1,0,0,0,0,1,0,0,1,0,0,0,0,1,0,0,0,0,0,0,0,1,0,0,0,0,0],
                                [0,1,0,1,0,0,0,1,1,1,1,0,0,1,0,0,0,1,1,0,1,1,1,0,1,1,1,0,1,0,1,1,0],
                                [0,0,0,0,0,1,0,0,0,0,0,0,1,0,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,1,0,0,0],
                                [1,0,0,0,1,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1]]))
test_target=torch.FloatTensor(8) #G->A=8
test_tensor=test_tensor.reshape(1,1,4,33)
out_test=net(test_tensor)
print(out_test)

In [None]:
predicted = torch.argmax(out_test, 1)
total =1
correct = (predicted == torch.argmax(test_target)).sum().item()
print(correct)

In [None]:
test=np.zeros(6)
for luku in test:
    print(luku)

In [None]:
print(torch.FloatTensor(np.array([[0,0,1,0,0,0,1,0,0,0,0,1,0,0,1,0,0,0,0,1,0,0,0,0,0,0,0,1,0,0,0,0,0],
                                [0,1,0,1,0,0,0,1,1,1,1,0,0,1,0,0,0,1,1,0,1,1,1,0,1,1,1,0,1,0,1,1,0],
                                [0,0,0,0,0,1,0,0,0,0,0,0,1,0,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,1,0,0,0],
                                [1,0,0,0,1,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1]])))

In [None]:
net.eval()
with torch.no_grad():
    images, labels = iter(testloader).next()
    tests.plot_images(images[:5], n_rows=1)
    
    # Compute predictions
    images = images.to(device)
    y = net(images)

print('Ground truth labels: ', ' '.join('%10s' % classes[labels[j]] for j in range(5)))
print('Predictions:         ', ' '.join('%10s' % classes[j] for j in y.argmax(dim=1)))

In [None]:
# Compute the accuracy on the test set
accuracy = compute_accuracy(net, testloader)
print('Accuracy of the network on the test images: %.3f' % accuracy)
assert accuracy > 0.85, "Poor accuracy {:.3f}".format(accuracy)
print('Success')