In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import torch

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import ConcatDataset
from torch.optim import lr_scheduler

from sklearn.utils import class_weight
from sklearn.model_selection import KFold

from tqdm import tqdm

import wandb

import collections
from math import log10, floor
import random
from datetime import datetime

from process_input.input_to_2Dmatrix_transformer import parse_matrices, FeatureDataset, create_class_weights
from model_skeletons.transformer.transformer_v12 import Transformer, model_version, WeightedMSELoss
from compute_accuracy_transformer import compute_accuracy

norm=True
norm_type=None # "sum", "max"
target_mode="whole_matrix" #["whole_matrix", "whole_DNA_seq", "only_target_base", "target_with_landscape"]
if target_mode in ["whole_matrix", "whole_DNA_seq"]:
    correct_label_index=16
else:
    correct_label_index=0
model_script=os.path.join("model_skeletons", "transformer", "{}.py".format(model_version)) #Used to read the file for logs
assert os.path.isfile(model_script), "Could not find file '{}'".format(model_script)
new_files_creation="create_new_data_files.py" #Used to read the file for logs
input_to_matrix=os.path.join("process_input","input_to_2Dmatrix_transformer.py") #Used to read the file for logs
assert os.path.isfile(input_to_matrix), "Could not find file '{}'".format(input_to_matrix)
compute_accuracy_file = "compute_accuracy_transformer.py"
assert os.path.isfile(compute_accuracy_file), "Could not find file '{}'".format(compute_accuracy_file)
valid_batch_size=5
log_file=os.path.join("logs", "transformer", "{}.log".format(model_version))
cv_log_file=os.path.join("logs", "transformer", "cv_{}.log".format(model_version))
last_run_log_file=os.path.join("logs", "transformer", "{}_cv_all_runs.log".format(model_version))
last_run_log_file_final_runs=os.path.join("logs", "transformer", "{}_all_runs.log".format(model_version))
model_path=os.path.join("saved_models", "transformer", "{}.pth".format(model_version))
cv=False
train=False
indices_of_interest=[80,81,82,83]


In [2]:
#device = torch.device('cpu')
device = torch.device('cuda')
many_classes=True #ACGT order in onehot

In [3]:
def write_to_log_file(learning_rate, optimizer_type, weight_decay, 
                      train_batch_size, valid_batch_size, log_file, model_version, 
                      input_to_matrix, new_greatest_valid_acc, model_path, many_classes,
                      compute_accuracy_file, model_script, run_name, dt_string, target_mode,
                      nhead, num_encoder_layers, num_decoder_layers, norm, norm_type, acc_list=None):
    """Called, when new highest validation accuracy is found.
    
    Writes everything important information to a log file."""
    with open(log_file, 'w+') as fw:
        fw.write("Created: {}\nModel version: {}\nPath: {}\nRun name: {}\nAccuracy: {}\n\n".format(dt_string, model_version, 
                                                                                                   model_path, run_name, 
                                                                                                   new_greatest_valid_acc))
        if acc_list:
            for k, accuracy in enumerate(acc_list):
                fw.write("Fold {}: {}.\n".format(k+1, accuracy))
            fw.write('\n')
        fw.write("Hyperparameters:\nOptimizer: {}\n".format(optimizer_type))
        fw.write("Learning rate: {}\nWeight decay: {}\n".format(learning_rate, weight_decay))
        fw.write("head: {}\nnum_encoder_layers: {}\nnum_decoder_layers: {}\n".format(nhead, num_encoder_layers, num_decoder_layers))
        fw.write("Used MSELoss\n")
        fw.write("Balanced classes\n")
        fw.write("Data normalized: {}\n".format(norm))
        fw.write("Norm type: {}\n".format(norm_type))
        fw.write("Target mode: {}\n".format(target_mode))
        fw.write("Train batch size: {}\nValidation batch size: {}\n\n".format(train_batch_size, valid_batch_size))
        
        fw.write("\n--------------------------Script of the model can be seen below.---------------------------\n")
        with open(model_script, 'r') as fr:
            fw.write(fr.read())
        fw.write("\n-------------------------------------------------------------------------------------------")
        fw.write("\n\n\n")
    
        fw.write("\n------------------------------Created input matrices with script:------------------------------\n")
        with open(input_to_matrix, 'r') as fr:
            fw.write(fr.read())
        fw.write("\n-------------------------------------------------------------------------------------------")
        fw.write("\n\n\n")
        
        fw.write("\n------------------------------Computed accuracies with script:-----------------------------\n")
        with open(compute_accuracy_file, 'r') as fr:
            fw.write(fr.read())
        fw.write("\n-------------------------------------------------------------------------------------------\n")

In [4]:
"""In input file, 33 base, 17th is altering, 1st A, 2nd C, 3rd G, 4th T"""
class_types=['AA', 'AC', 'AG', 'AT', 'CA', 'CC', 'CG', 'CT', 'GA', 'GC', 'GG',
            'GT', 'TA', 'TC', 'TG', 'TT']
classes={'CA': 0, 'CC': 1, 'CG': 2, 'CT': 3, 'TA': 4, 'TC': 5, 'TG': 6, 'TT': 7}

In [5]:
train_file="data/20220214/sompred_crc9_clu1_pyri_mut_combined_train.matrix"
valid_file="data/20220214/sompred_crc9_clu1_pyri_mut_combined_valid.matrix"
test_file="data/20220214/sompred_crc9_clu1_pyri_mut_combined_test.matrix"

In [6]:
trainset_input, trainset_target, class_weights_train_whole = parse_matrices(train_file, norm, target_mode, norm_type)
valid_input, valid_target, class_weights_valid_whole = parse_matrices(valid_file, norm, target_mode, norm_type)
test_input, test_target, class_weights_test = parse_matrices(test_file, norm, target_mode, norm_type)

In [7]:
sos_token = torch.zeros((1, 1, trainset_input.shape[2]))
eos_token = torch.ones((1, 1, trainset_input.shape[2]))

trainset_input_with_tokens = torch.cat([trainset_input, eos_token.expand(trainset_input.size(0), 1, -1)], dim=1)
trainset_target_with_tokens = torch.cat([sos_token.expand(trainset_input.size(0), 1, -1), trainset_target, eos_token.expand(trainset_input.size(0), 1, -1)], dim=1)

valid_input_with_tokens = torch.cat([valid_input, eos_token.expand(valid_input.size(0), 1, -1)], dim=1)
valid_target_with_tokens = torch.cat([sos_token.expand(valid_input.size(0), 1, -1), valid_target, eos_token.expand(valid_input.size(0), 1, -1)], dim=1)

test_input_with_tokens = torch.cat([test_input, eos_token.expand(test_input.size(0), 1, -1)], dim=1)
test_target_with_tokens = torch.cat([sos_token.expand(test_input.size(0), 1, -1), test_target, eos_token.expand(test_input.size(0), 1, -1)], dim=1)

In [8]:
train_dataset=FeatureDataset(data=trainset_input_with_tokens, labels=trainset_target_with_tokens)
valid_dataset=FeatureDataset(data=valid_input_with_tokens, labels=valid_target_with_tokens)
test_dataset=FeatureDataset(data=test_input_with_tokens, labels=test_target_with_tokens)
combined_valid_train = ConcatDataset([train_dataset, valid_dataset]) #Combines validation and training datasets for cross validation

In [9]:
def cv_train_network(net,criterion, valid_criterion,epochs,optimizer,trainloader, validloader,
                  correct_label_index, wandb, early_stop=100):
    j=0
    greatest_acc=0
    min_tot_loss=float('inf')
    tot_loss=0
    tot_items=0
    for i in range(epochs):
        net.train()
        for sequences, labels in trainloader:
            sequences, labels = sequences.to(device), labels.to(device)
            
            labels_input = labels[:,:-1]
            labels_expected = labels[:,1:]
            
            sequence_length = labels_input.size(1)
            tgt_mask = net.get_tgt_mask(sequence_length, device)
            
            optimizer.zero_grad()
            out = net(sequences, labels_input, tgt_mask =tgt_mask)
            #out = out.squeeze()
            loss=criterion(sequences, labels_expected, out)
            tot_loss+=loss.item()
            tot_items+=len(labels)
            loss.backward()
            if torch.isnan(loss):
                raise RuntimeError("NAN!")
            optimizer.step()
        tot_loss/=tot_items
        accuracy, tot_valid_loss = compute_accuracy(device, net, validloader, valid_criterion, 
                                                "VALID", verbose = False, cv=True,
                                                   correct_label_index=correct_label_index) #17th/33+Start token
        if wandb!=None:
            wandb.log({"Training loss": tot_loss,
                       "Validation loss": tot_valid_loss,
                       "Valid Accuracy": accuracy,
            #           "Test loss": test_loss,
            #           "Test Accuracy": test_accuracy,
            #           "Pooled test recall": fake_recall_test,
            #           "Pooled test precision": fake_precision_test,
            #           "Learning rate": optimizer.param_groups[0]['lr'],
            #           "Scheduler": is_scheduler,
                       "Epoch": i})
        if round(accuracy,3)<=round(greatest_acc,3):
            pass
        else:
            greatest_acc=accuracy
        if round(tot_valid_loss,3)>=round(min_tot_loss,3):
            j+=1
            if j>=early_stop and i>100:
                break
        else:
            j=0
            min_tot_loss=tot_valid_loss
        
    return greatest_acc

In [10]:
def train_network(net,criterion, valid_criterion,epochs,optimizer,trainloader, validloader,
                  correct_label_index, greatest_acc_overall, model_path, wandb, early_stop=50):
    j=0
    greatest_acc=0
    tot_loss=0
    tot_items=0
    for i in range(epochs):
        net.train()
        for sequences, labels in trainloader:
            sequences, labels = sequences.to(device), labels.to(device)
            
            labels_input = labels[:,:-1]
            labels_expected = labels[:,1:]
            
            sequence_length = labels_input.size(1)
            tgt_mask = net.get_tgt_mask(sequence_length, device)
            
            optimizer.zero_grad()
            out = net(sequences, labels_input, tgt_mask =tgt_mask)
            #out=out.squeeze()
            loss=criterion(sequences, labels_expected, out)
            tot_loss+=loss.item()
            tot_items+=len(labels)
            loss.backward()
            if torch.isnan(loss):
                raise RuntimeError("NAN!")
            optimizer.step()
        tot_loss/=tot_items
        accuracy, tot_valid_loss = compute_accuracy(device, net, validloader, valid_criterion, "VALID", 
                                                    verbose = False, cv = True, correct_label_index=correct_label_index) #17th/33+1
        
        if wandb!=None:
            wandb.log({"Training loss": tot_loss,
                       "Validation loss": tot_valid_loss,
                       "Valid Accuracy": accuracy,
            #           "Test loss": test_loss,
            #           "Test Accuracy": test_accuracy,
            #           "Pooled test recall": fake_recall_test,
            #           "Pooled test precision": fake_precision_test,
            #           "Learning rate": optimizer.param_groups[0]['lr'],
            #           "Scheduler": is_scheduler,
                       "Epoch": i})
            
        if round(accuracy,4)<=round(greatest_acc,4):
            if early_stop:
                j+=1
                if j>=early_stop and i>150:
                    break
        else:
            if accuracy>greatest_acc_overall:
                torch.save(net.state_dict(), model_path)
                greatest_acc_overall=accuracy
            j=0
            greatest_acc=accuracy
    print("Greatest accuracy on the run: {}".format(greatest_acc))
    return greatest_acc

In [11]:
def get_earlier_accuracy(log_file):
    with open(log_file, 'r') as fr:
        for line in fr:
            if "Accuracy:" in line:
                return float(line.strip().split(' ')[1]) #Accuracy is written as Accuracy: <acc>

In [12]:
greatest_avg_valid_acc = 0
if os.path.isfile(cv_log_file):
    greatest_avg_valid_acc = get_earlier_accuracy(cv_log_file)
print(greatest_avg_valid_acc)

0.3483568324723381


In [13]:
if not os.path.isfile(last_run_log_file):
    with open(last_run_log_file, 'w+') as fw:
        fw.write("Run log.\n\n")

In [14]:
def get_class_weights(device, dataloader, indices_of_interest, norm_type):
    total_amount=0
    class_amounts=collections.Counter()
    bases = {0: 'A', 1: 'C', 2: 'G', 3: 'T'}
    for sequences, labels in dataloader:
        values_of_interest = sequences[:, 16, indices_of_interest]
        _, original_bases = torch.max(values_of_interest, dim=1)

        values_of_interest = labels[:, correct_label_index, indices_of_interest] #
        _, new_bases = torch.max(values_of_interest, dim=1)
        for new_base, original_base in zip(new_bases, original_bases):
            correct_class = classes[bases[original_base.item()]+bases[new_base.item()]]
            class_amounts[correct_class]+=1
            total_amount+=1
            
    return create_class_weights(class_amounts, total_amount, norm_type)

In [15]:
if cv:
    k_folds=5
    epochs=5000
    seq_len=33
    kfold = KFold(n_splits=k_folds, shuffle=True) #batch size affects the size of datasets
    mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
    for i in range(150): #Test with 150 different hyperparameter combinations
        valid_accuracies = list()
        learning_rate=random.sample([0.0001, 0.00001, 0.000001], 1)[0]
        lr_text=str(learning_rate).replace(".","d")
        train_batch_size=random.sample([32, 64, 128, 256], 1)[0]
        norm_type = random.sample(["sum", "max", "None"], 1)[0]
        
        nhead=random.sample([2, 3, 4, 6, 7, 12, 21, 42], 1)[0]
        num_encoder_layers=random.sample([2, 3, 4, 6, 8], 1)[0]
        num_decoder_layers=num_encoder_layers
        
        optimizer_type=random.sample(["Adam","AdamW"], 1)[0] #random.sample(["Adam","SGD"], 1)[0]
        weight_decay=random.sample([0, 0.000001, 0.00000001], 1)[0]
        if weight_decay!=0: weight_decay = round(weight_decay, -int(floor(log10(weight_decay))) + 2)
        decay_text="_wdecay"+str(weight_decay).replace(".","d")
        for (train_ids, test_ids) in kfold.split(combined_valid_train):
            
            train_subsampler = torch.utils.data.SubsetRandomSampler(train_ids)
            test_subsampler = torch.utils.data.SubsetRandomSampler(test_ids)
            trainloader = torch.utils.data.DataLoader(
                          combined_valid_train, 
                          batch_size=train_batch_size, sampler=train_subsampler)
            validloader = torch.utils.data.DataLoader(
                              combined_valid_train,
                              batch_size=valid_batch_size, sampler=test_subsampler)
            train_class_weights = get_class_weights(device, trainloader, indices_of_interest, norm_type)
            valid_class_weights = get_class_weights(device, validloader, indices_of_interest, norm_type)
            
            net = Transformer(nhead=nhead, num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers).to(device)
            if optimizer_type=="Adam":
                optimizer=torch.optim.Adam(net.parameters(), lr=learning_rate, weight_decay=weight_decay)
            elif optimizer_type=="AdamW":
                optimizer=torch.optim.AdamW(net.parameters(), lr=learning_rate, weight_decay=weight_decay)
            else:
                raise RuntimeError("WRONG OPTIMIZER: {}".format(optimizer_type))
            criterion = WeightedMSELoss(device, train_class_weights, classes, correct_label_index, indices_of_interest)
            valid_criterion=WeightedMSELoss(device, valid_class_weights, classes, correct_label_index, indices_of_interest)
            run_name = None
            valid_acc = cv_train_network(net,criterion,valid_criterion,epochs,optimizer,trainloader, validloader, 
                                      correct_label_index, wandb=None)
            valid_accuracies.append(valid_acc)
        avg_valid_acc = sum(valid_accuracies) / len(valid_accuracies)

        print("Optimizer: {}\nLearning Rate: {}".format(optimizer_type, learning_rate))
        print("Weight decay: {}".format(weight_decay))
        print("Average validation accuracy: {}".format(avg_valid_acc))
        for k, accuracy in enumerate(valid_accuracies):
            print("Fold {}: {}.".format(k+1, accuracy))

        with open(last_run_log_file, 'a+') as fw:
            fw.write("Time: {}\n".format(datetime.now().strftime("%d.%m.%Y %H:%M:%S")))
            fw.write("Average validation accuracy: {}\n".format(avg_valid_acc))
            fw.write("\n".join(["Fold {}: {}.".format(k+1, accuracy) for k, accuracy in enumerate(valid_accuracies)]))
            fw.write("\nOptimizer: {}\nLearning Rate: {}\n".format(optimizer_type, learning_rate))
            fw.write("Weight decay: {}\nBatch size: {}\n".format(weight_decay, train_batch_size))
            fw.write("head: {}\nnum_encoder_layers: {}\nnum_decoder_layers: {}\n".format(nhead, num_encoder_layers, num_decoder_layers))
            fw.write("Norm type: {}\n".format(norm_type))
            fw.write("Target mode: {}\n".format(target_mode))
            fw.write('\n\n')

        if avg_valid_acc>greatest_avg_valid_acc:
            now = datetime.now()
            dt_string = now.strftime("%d.%m.%Y %H:%M:%S")
            write_to_log_file(learning_rate, optimizer_type, weight_decay,
                             train_batch_size, valid_batch_size, cv_log_file, model_version, 
                              input_to_matrix, 
                              avg_valid_acc, model_path, many_classes, compute_accuracy_file,
                              model_script, run_name, dt_string, target_mode, 
                              nhead, num_encoder_layers, num_decoder_layers, acc_list=valid_accuracies, norm=norm,
                                 norm_type=norm_type)
            greatest_avg_valid_acc=avg_valid_acc
            os.path.join("saved_models", "transformer", "test_notebook.pth")
            torch.save(net.state_dict(), model_path)

In [16]:
greatest_acc_overall=0
if os.path.isfile(log_file):
    greatest_acc_overall=get_earlier_accuracy(log_file)
print(greatest_acc_overall)
train_batch_size=32

0.49964248632637526


In [17]:
trainloader = torch.utils.data.DataLoader(train_dataset
    ,batch_size=train_batch_size
    ,shuffle=True
    ,drop_last=True
)
validloader = torch.utils.data.DataLoader(valid_dataset
    ,batch_size=5
    ,shuffle=False
)

In [18]:
if train:
    learning_rate=0.0001
    optimizer_type="AdamW"
    weight_decay=0.000001
    nhead=42
    num_encoder_layers=6
    num_decoder_layers=6
    epochs=5000
    norm_type="None"
    criterion = WeightedMSELoss(device, class_weights_train_whole, classes, correct_label_index, indices_of_interest)
    valid_criterion=WeightedMSELoss(device, class_weights_valid_whole, classes, correct_label_index, indices_of_interest)
    lr_text=str(learning_rate).replace(".","d")
    decay_text="_wdecay"+str(weight_decay).replace(".","d")
    for i in range(1000):
        net = Transformer(nhead=nhead, num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers).to(device)
        if optimizer_type=="Adam":
            optimizer=torch.optim.Adam(net.parameters(), lr=learning_rate, weight_decay=weight_decay)
        elif optimizer_type=="AdamW":
            optimizer=torch.optim.AdamW(net.parameters(), lr=learning_rate, weight_decay=weight_decay)
        else:
            raise RuntimeError("WRONG OPTIMIZER.")
        run = wandb.init(project='transformer_final')
        run_name="run_{}".format(i)
        wandb.run.name = run_name
        config = wandb.config
        config.batch_size=train_batch_size
        config.optimizer_type=optimizer_type
        config.learning_rate = learning_rate
        config.weight_decay=weight_decay
        config.is_scheduler=is_scheduler
        if is_scheduler:
            config.T_max=T_max
            config.scheduler="Cosine"
        if optimizer_type=="SGD":
            config.momentum = momentum

        greatest_acc = train_network(net,criterion, valid_criterion,epochs,optimizer,trainloader, validloader,
                          correct_label_index, greatest_acc_overall, model_path, wandb=None, early_stop=100)
                
        with open(last_run_log_file_final_runs, 'a+') as fw:
            fw.write("Time: {}\n".format(datetime.now().strftime("%d.%m.%Y %H:%M:%S")))
            fw.write("Validation accuracy: {}\n".format(greatest_acc))
            fw.write("\nOptimizer: {}\nLearning Rate: {}\n".format(optimizer_type, learning_rate))
            fw.write("Weight decay: {}\nBatch size: {}\n".format(weight_decay, train_batch_size))
            fw.write("head: {}\nnum_encoder_layers: {}\nnum_decoder_layers: {}\n".format(nhead, num_encoder_layers, num_decoder_layers))
            fw.write("Norm type: {}\n".format(norm_type))
            fw.write("Target mode: {}\n".format(target_mode))
            fw.write('\n\n')
       
        
        if greatest_acc>greatest_acc_overall:
            now = datetime.now()
            dt_string = now.strftime("%d.%m.%Y %H:%M:%S")
            write_to_log_file(learning_rate, optimizer_type, weight_decay,
                             train_batch_size, valid_batch_size, log_file, model_version, 
                              input_to_matrix, 
                              greatest_acc, model_path, many_classes, compute_accuracy_file,
                              model_script, run_name, dt_string, target_mode, 
                              nhead, num_encoder_layers, num_decoder_layers, norm=norm,
                                 norm_type=norm_type)
            greatest_acc_overall=greatest_acc

In [19]:
testloader = torch.utils.data.DataLoader(test_dataset
    ,batch_size=5
    ,shuffle=False
)

In [20]:
nhead=42
num_encoder_layers=6
num_decoder_layers=6

model=Transformer(nhead=nhead, num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers)
#model.load_state_dict(torch.load(model_path))
model_path = os.path.join("saved_models", "transformer", "transformer_v10_masters_thesis.pth".format(model_version))
model.load_state_dict(torch.load(model_path))
model.eval()
model.to(device)

Transformer(
  (positional_encoder): PositionalEncoding()
  (transformer): Transformer(
    (encoder): TransformerEncoder(
      (layers): ModuleList(
        (0-5): 6 x TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=84, out_features=84, bias=True)
          )
          (linear1): Linear(in_features=84, out_features=2048, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=2048, out_features=84, bias=True)
          (norm1): LayerNorm((84,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((84,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2): Dropout(p=0.1, inplace=False)
        )
      )
      (norm): LayerNorm((84,), eps=1e-05, elementwise_affine=True)
    )
    (decoder): TransformerDecoder(
      (layers): ModuleList(
        (0-5): 6 x TransformerDecoderLayer(
     

In [21]:
criterion=WeightedMSELoss(device, class_weights_test, classes, correct_label_index, indices_of_interest)
accuracy, tot_valid_loss = compute_accuracy(device, model, testloader, None, "TEST", 
                                                    verbose = True, cv = False, correct_label_index=correct_label_index)
print("Test accuracy:",accuracy)

  return torch._native_multi_head_attention(


Class CA: Correct 10, total 16, acc 0.625
Class CC: Correct 29, total 42, acc 0.6904761904761905
Class CG: Correct 3, total 9, acc 0.3333333333333333
Class CT: Correct 10, total 19, acc 0.5263157894736842
Class TA: Correct 5, total 14, acc 0.35714285714285715
Class TC: Correct 1, total 13, acc 0.07692307692307693
Class TG: Correct 15, total 27, acc 0.5555555555555556
Class TT: Correct 30, total 56, acc 0.5357142857142857

 TEST
TP: 44 . FN: 54 TP/(TP+FN): 0.4489795918367347 TN: 59 FP: 39 TN/(TN+FP): 0.6020408163265306 Wrong positive class predicted: 42 Wrong negative class predicted: 4
Fake F1-score: 0.7853881278538812 . Fake F2-score: 0.8382066276803118
Fake TP/(TP+FN): 0.6142857142857143 Fake TN/(TN+FP) 0.6176470588235294
Fake precision: 0.7107438016528925 Fake recall: 0.8775510204081632
F1-score: 0.48618784530386744
F2-score: 0.4631578947368421
Precision: 0.5301204819277109
Recall: 0.4489795918367347
Fake accuracy: 0.7602040816326531
Test accuracy: 0.4625576360773729


In [22]:
accuracy, tot_valid_loss = compute_accuracy(device, model, validloader, None, "VALID", 
                                                    verbose = True, cv = False, correct_label_index=correct_label_index)
print("Valid accuracy:",accuracy)

Class CA: Correct 35, total 84, acc 0.4166666666666667
Class CC: Correct 325, total 545, acc 0.5963302752293578
Class CG: Correct 10, total 22, acc 0.45454545454545453
Class CT: Correct 63, total 112, acc 0.5625
Class TA: Correct 18, total 31, acc 0.5806451612903226
Class TC: Correct 17, total 53, acc 0.32075471698113206
Class TG: Correct 12, total 21, acc 0.5714285714285714
Class TT: Correct 369, total 758, acc 0.4868073878627968

 VALID
TP: 155 . FN: 168 TP/(TP+FN): 0.47987616099071206 TN: 694 FP: 609 TN/(TN+FP): 0.5326170376055257 Wrong positive class predicted: 115 Wrong negative class predicted: 104
Fake F1-score: 0.4918032786885246 . Fake F2-score: 0.6531204644412192
Fake TP/(TP+FN): 0.6164383561643836 Fake TN/(TN+FP) 0.5671641791044776
Fake precision: 0.34838709677419355 Fake recall: 0.8359133126934984
F1-score: 0.28518859245630174
F2-score: 0.3769455252918288
Precision: 0.20287958115183247
Recall: 0.47987616099071206
Fake accuracy: 0.6568265682656826
Valid accuracy: 0.498709779

In [23]:
test_input, test_target, class_weights_test = parse_matrices(test_file, norm, target_mode, norm_type)

In [24]:
test_dataset=FeatureDataset(data=test_input, labels=test_target)
testloader = DataLoader(test_dataset
                           ,batch_size=5
                            ,shuffle=False
                        )

NameError: name 'DataLoader' is not defined

In [None]:
model_path="/u/77/jarvint12/unix/huslab_timo_dev/masters_thesis/saved_models/transformer/{}.pth".format(model_version)
nhead=42
num_encoder_layers=6
num_decoder_layers=6
model=Transformer(nhead=nhead, num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers)
model.load_state_dict(torch.load(model_path))
model.eval()
model.to(device)
print(model_path)
accuracy, tot_valid_loss = compute_accuracy(device, model, testloader, None, "TEST", 
                                                    verbose = True, cv = False, correct_label_index=correct_label_index)
print("Test accuracy:",accuracy)

In [None]:
def countProbabilityDistributions(net, dataloader):
    distributions=dict()
    x=None
    m = nn.Softmax(dim=1)
    predictions=None
    with torch.no_grad():
        tot_loss=0
        for sequences, labels in dataloader:
            sequences, labels = sequences.to(device), labels.to(device)
            outputs = net(sequences)
            _, predicted = torch.max(outputs.data, 1)
            if predictions==None:
                predictions=predicted
                trueLabels=labels
            else:
                predictions=torch.cat((predictions, predicted), 0)
                trueLabels=torch.cat((trueLabels, labels), 0)
            for i in range(outputs.data.shape[1]):
                if not i in distributions:
                    distributions[i]=m(outputs).data[:,i]
                else:
                    distributions[i]=torch.cat((distributions[i], m(outputs).data[:,i]), 0)
                    #print(m(outputs).data[0,:])
                    #print(predicted)
                    #raise RuntimeError("LOL")
    return distributions, predictions, trueLabels
#distributions, predictions, trueLabels= countProbabilityDistributions(net, trainloader)
distributions, predictions, trueLabels= countProbabilityDistributions(net, trainloader)

In [None]:
print(classes)
plt.hist(predictions.to("cpu").numpy(),bins=[0,1,2,3,4,5,6,7,8,9]);

In [None]:
trueLabels[3]

In [None]:
classes

In [None]:
classNames=['CA', 'CC', 'CG', 'CT', 'TA', 'TC', 'TG', 'TT']
for i in range(8):
    plt.figure(figsize=(20,20))
    condition=(predictions==i) & (trueLabels==i)
    plt.hist(distributions[i][condition].to("cpu").numpy(),bins=100, color='r', alpha=0.5);#, density=True);
    condition=(predictions==i) & (trueLabels!=i)
    plt.hist(distributions[i][condition].to("cpu").numpy(),bins=100, color='b', alpha=0.5);#, density=True);
    plt.title(classNames[i]);
    plt.xlim([0,1])
    plt.savefig('probDistr_{}.pdf'.format(classNames[i]))

In [None]:
print(classes)
plt.figure(figsize=(20,20))
condition=(predictions==1)
plt.hist(distributions[1][condition].to("cpu").numpy(),bins=100, color='r', alpha=0.5, density=True);
other=torch.cat((distributions[0][predictions==0],distributions[2][predictions==2],distributions[3][predictions==3]),dim=0)
plt.hist(other.to("cpu").numpy(),bins=100, color='b', alpha=0.5, density=True);

In [None]:
fig,ax = plt.subplots(2,2,constrained_layout = True) # Instantiate figure and axes object
ax[0][0].hist(distributions[0][predictions==0].to("cpu").numpy(),bins=100, color='r', density=True);
ax[0][0].title.set_text('CA 475')
ax[0][1].hist(distributions[1][predictions==1].to("cpu").numpy(),bins=100, color='r', density=True);
ax[0][1].title.set_text('CC 3048')
ax[1][0].hist(distributions[2][predictions==2].to("cpu").numpy(),bins=100, color='r', density=True);
ax[1][0].title.set_text('CG 127')
ax[1][1].hist(distributions[3][predictions==3].to("cpu").numpy(),bins=100, color='r', density=True);
ax[1][1].title.set_text('CT 639')

In [None]:
fig,ax = plt.subplots(2,2,constrained_layout = True) # Instantiate figure and axes object
ax[0][0].hist(distributions[4][predictions==4].to("cpu").numpy(),bins=100, color='r', density=True);
ax[0][0].title.set_text('TA 181')
ax[0][1].hist(distributions[5][predictions==5].to("cpu").numpy(),bins=100, color='r', density=True);
ax[0][1].title.set_text('TC 303')
ax[1][0].hist(distributions[6][predictions==6].to("cpu").numpy(),bins=100, color='r', density=True);
ax[1][0].title.set_text('TG 122')
ax[1][1].hist(distributions[7][predictions==7].to("cpu").numpy(),bins=100, color='r', density=True);
ax[1][1].title.set_text('TT 4285')

In [None]:
fig,ax = plt.subplots(2,4) # Instantiate figure and axes object
for index in range(8):
    ax[int(index>3)][index-(index>3)*4].hist(distributions[index].to("cpu").numpy(), density=True, bins=20)#,
            #bins=[0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9], density=True, histtype="step")
    ax[int(index>3)][index-(index>3)*4].title.set_text(str(index))
plt.legend()
plt.show()

In [None]:
fig,axes = plt.subplots(2,1) # Instantiate figure and axes object
axes[0].hist(distributions[0].to("cpu").numpy(), label=str(0),
        bins=[0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9], density=True, histtype="step");
axes[1].hist(distributions[1].to("cpu").numpy(), label=str(1),
        bins=[0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9], density=True, histtype="step");
plt.legend()
plt.show()

In [None]:
print(distributions[0][0:11].data)

In [None]:
accuracy, loss, f1, f2, precision, recall, f1_fake, f2_fake, fake_precision, fake_recall = \
compute_accuracy(device, net, testloader, criterion, many_classes, "TEST", True, correct_label_index=1)

print("Accuracy for the test data:",accuracy)
print("Loss for test data:",loss)
print("F1-score:",f1)
print("F2-score:",f2)
print("Precision:",precision)
print("Recall:",recall)
config.test_acc=accuracy
config.test_loss=loss
config.test_f1=f1
config.test_f2=f2
config.test_precis=precision
config.test_recall=recall

In [None]:
for sequences, labels in testloader:
    sequences, labels = sequences.to(device), labels.to(device)
    outputs = net(sequences)
    _, predicted = torch.max(outputs.data, 1)
    for sequence, label, prediction, output in zip(sequences, labels, predicted, outputs):
        if label in [0,5,10,15]:
            if label==prediction:
                continue
            else:
                if prediction in [0,5,10,15]:
                    print(class_types[prediction], class_types[label],sequence, output)

In [None]:
for sequences, labels in testloader:
    sequences, labels = sequences.to(device), labels.to(device)
    outputs = net(sequences)
    _, predicted = torch.max(outputs.data, 1)
    for sequence, label, prediction, output in zip(sequences, labels, predicted, outputs):
        if prediction==label:
            print(sequence,'\n\n',output,'\n\n', label,'\n\n', prediction, class_types[label.item()], 
                  class_types[prediction.item()])
            #print("TRUE:",class_types[prediction.item()])
            raise UserWarning('Exit Early')
        else:
            print("FALSE",sequence,'\n\n',output,'\n\n', label,'\n\n', prediction, class_types[label.item()],
                 class_types[prediction.item()])
            raise UserWarning('Exit Early')

In [None]:
for sequences, labels in trainloader:
    for label in labels:
        if class_types[label.item()]=="TT":
            print("TT")

In [None]:
test_tensor=torch.FloatTensor(np.array([[0,0,1,0,0,0,1,0,0,0,0,1,0,0,1,0,0,0,0,1,0,0,0,0,0,0,0,1,0,0,0,0,0],
                                [0,1,0,1,0,0,0,1,1,1,1,0,0,1,0,0,0,1,1,0,1,1,1,0,1,1,1,0,1,0,1,1,0],
                                [0,0,0,0,0,1,0,0,0,0,0,0,1,0,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,1,0,0,0],
                                [1,0,0,0,1,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1]]))
test_target=torch.FloatTensor(8) #G->A=8
test_tensor=test_tensor.reshape(1,1,4,33)
out_test=net(test_tensor)
print(out_test)

In [None]:
predicted = torch.argmax(out_test, 1)
total =1
correct = (predicted == torch.argmax(test_target)).sum().item()
print(correct)

In [None]:
test=np.zeros(6)
for luku in test:
    print(luku)

In [None]:
print(torch.FloatTensor(np.array([[0,0,1,0,0,0,1,0,0,0,0,1,0,0,1,0,0,0,0,1,0,0,0,0,0,0,0,1,0,0,0,0,0],
                                [0,1,0,1,0,0,0,1,1,1,1,0,0,1,0,0,0,1,1,0,1,1,1,0,1,1,1,0,1,0,1,1,0],
                                [0,0,0,0,0,1,0,0,0,0,0,0,1,0,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,1,0,0,0],
                                [1,0,0,0,1,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1]])))

In [None]:
net.eval()
with torch.no_grad():
    images, labels = iter(testloader).next()
    tests.plot_images(images[:5], n_rows=1)
    
    # Compute predictions
    images = images.to(device)
    y = net(images)

print('Ground truth labels: ', ' '.join('%10s' % classes[labels[j]] for j in range(5)))
print('Predictions:         ', ' '.join('%10s' % classes[j] for j in y.argmax(dim=1)))

In [None]:
# Compute the accuracy on the test set
accuracy = compute_accuracy(net, testloader)
print('Accuracy of the network on the test images: %.3f' % accuracy)
assert accuracy > 0.85, "Poor accuracy {:.3f}".format(accuracy)
print('Success')