In [7]:
%load_ext tensorboard

In [1]:
from __future__ import print_function
import os
from skimage import io
from collections import OrderedDict

import pandas as pd

from self_supervised import ResNetUNet

from sklearn.metrics import accuracy_score, f1_score

from copy import copy

import pydicom

import torch
import torchvision
from torch.utils.data import Dataset, DataLoader, Subset, WeightedRandomSampler
from torchvision import transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau

In [2]:
for directory in ["../models/", "../log_dir/"]:
    try:
        os.mkdir(directory)
    except FileExistsError:
        pass

In [3]:
class Config:

    def __init__(self,
                 original_dataset_dir='../data/Dataset/',
                 preprocessed_dataset_dir='../data/Preprocessed_Dataset/',
                 labels_file='../data/labels.csv',
                 dataset_begin_at_slice=6,
                 dataset_num_slice=7,
                 image_crop_size=128,
                 label_type='PVWM',
                 n_classes=4,
                 preprocess=True,
                 self_supervised=True,
                 xavier_initialization=False,
                 self_supervised_model_path='../models/ss_models/feature_less_4_transformation_model3(batch2,rotation45).pth.tar',
                 cuda=True,
                 tensorboard_dir='../log_dir/runs',
                 save_best_model=True,
                 save_dir='../models',
                 batch_size=1,
                 lr = 0.008,
                 epochs = 120,
                 verbose_epoch = 20,
                 regularization = 0.0,
                 patience = 25,
                 lr_decay = 0.8,
                 n_folds=10,
                 test_size=0.1,
                 verbose=True):

        """
        Args:
            param original_dataset_dir (string): Path to the directory with all the original images.
            param preprocessed_dataset_dir (string): Path to the directory with all the preprocessed images,
            param labels_file (string): Path to the csv file with PVWM and DWM labels.
            param dataset_begin_at_slice (int): The first slice number to use. This is because we are interested in middle slices.
            param dataset_num_slices (int): Number of slices from each MRI serie to use.
            param image_crop_size (int): The height and width of a center patch extracted from each original image.
            param label_type (string): Assumes values 'PVWM' or 'DWM'.
            param n_classes (int): Number of target classes of classification. Default is four as each biomarkers has four severity grades.
            param preprocess (boolean): Whether to preprocess samples.
            param self_supervised (boolean): Whether to include the self-supervised component. In fact whether to use self-supervised weights as a training initial point.
            param xavier_initialization (boolean): Whether to initialze the network wights with xavier. Mutually exclusive with param "self_supervised".
            param self_supervised_model_path (string): the path to the pretrianed self-supervised model.
            param cuda (boolean): whether to train on gpu.
            param tensorboard_dir (string): the path to the directory where tensorboard logs are to be stored.
            param save_best_model (boolean): whether to save the best perfoming model.
            param save_dir (string): if "save_best_model" is set to true, this param specifies the directory to save the best model.
            param batch_size (int): Batch size of training and validation data ⚠️⚠️⚠️: Don't change this!
            param lr (float): initial learning rate of the training process.
            param epochs (int): maximum epochs to train the model in each cross-validation fold.
            param verbose_epoch (int): the epoch step to print the train and validation loss and f1-score to console.
            param regularization (float): the l1-regularization weight.
            param patience (int): number of epochs with unchanged validatoin loss to wait. this is used as input to LRReduceOnPlateau instace.
            param lr_decay (float): the amount of decay in learning rate after a plateau in validation loss.
            param n_folds (int): number of CV folds.
            param test_size (float): in no test set available this param specifies the portion of the training data as test data.
            param verbose (boolean): used as input to LRReduceOnPlateau instance. to print the changes to learning rate to console.
        """

        self.original_dataset_dir = original_dataset_dir
        self.preprocessed_dataset_dir = preprocessed_dataset_dir
        self.labels_file = labels_file
        self.dataset_begin_at_slice = dataset_begin_at_slice
        self.dataset_num_slice = dataset_num_slice
        self.image_crop_size = image_crop_size
        self.label_type = label_type
        self.preprocess = preprocess
        self.batch_size = batch_size
        self.n_classes = n_classes
        self.cuda = cuda

        self.train_transform = transforms.Compose([transforms.ToPILImage(),
                                      transforms.Grayscale(num_output_channels=1),
                                      transforms.RandomHorizontalFlip(p=0.5),
                                      transforms.RandomRotation(degrees=20),
                                      transforms.CenterCrop(self.image_crop_size),
                                      transforms.ToTensor()
                                      ])

        self.validation_transform = transforms.Compose([transforms.ToPILImage(),
                                           transforms.Grayscale(num_output_channels=1),
                                           transforms.CenterCrop(self.image_crop_size),
                                           transforms.ToTensor()
                                           ])

        self.self_supervised = self_supervised
        self.self_supervised_model_path = self_supervised_model_path
        self.xavier_initialization = xavier_initialization

        self.tensorboard_dir = tensorboard_dir
        self.save_dir = save_dir
        self.save_best_model = save_best_model

        self.lr = lr
        self.epochs = epochs
        self.verbose_epoch = verbose_epoch
        self.regularization = regularization
        self.reduce_lr_patience = patience
        self.lr_decay = lr_decay

        self.n_folds = n_folds
        self.test_size = test_size
        self.verbose = verbose

    def display(self):
        """Display Configuration values."""
        print("\nConfigurations:")
        for a in dir(self):
            if not a.startswith("__") and not callable(getattr(self, a)):
                print("{:30} {}".format(a, getattr(self, a)))
        print("\n")


config = Config()
config.display()


Configurations:
batch_size                     1
cuda                           False
dataset_begin_at_slice         6
dataset_num_slice              7
epochs                         120
image_crop_size                128
label_type                     PVWM
labels_file                    ../data/labels.csv
lr                             0.008
lr_decay                       0.8
n_classes                      4
n_folds                        10
original_dataset_dir           ../data/Dataset/
preprocess                     True
preprocessed_dataset_dir       ../data/Preprocessed_Dataset/
reduce_lr_patience             25
regularization                 0.0
save_best_model                True
save_dir                       ../models
self_supervised                True
self_supervised_model_path     ../models/ss_models/feature_less_4_transformation_model3(batch2,rotation45).pth.tar
tensorboard_dir                ../log_dir/runs
test_size                      0.1
verbose                     

In [4]:
class DementiaDataset(Dataset):
    """Dementia Dataset"""

    def __init__(self, original_dataset_dir, preprocessed_dataset_dir, labels_file, begin_at_slice=6, num_slices=7, image_crop_size=128,
                 label_type='PVWM', preprocess=True, transform=None):

        """
        Args:
            param original_dataset_dir (string): Path to the directory with all the original images.
            param preprocessedd_dataset_dir (string): Path to the directory with all the preprocessed images,
            param labels_file (string): Path to the csv file with PVWM and DWM labels.
            begin_at_slice (int): The first slice number to use. This is because we are interested in middle slices.
            num_slices (int): Number of slices from each MRI serie to use.
            image_crop_size (int): The height and width of a center patch extracted from each original image.
            label_type (string): Assumes values 'PVWM' or 'DWM'.
            preprocess (boolean): Whether to preprocess samples.
            transform (callable, optional): Optional transform to be applied on a sample.
        """

        assert label_type in ['PVWM', 'DWM'], "Invalid label type {}, label type must be one of ['PVWM', 'DWM']".format(label_type)

        self.original_dataset_dir = original_dataset_dir
        self.preprocessed_dataset_dir = preprocessed_dataset_dir
        self.labels = pd.read_csv(labels_file, dtype={"ID": str})
        self.begin_at_slice = begin_at_slice
        self.num_slices = num_slices
        self.image_size = image_crop_size
        self.label_type = label_type
        self.preprocess = preprocess
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        label = self.labels.loc[idx, self.label_type]

        try:
            slices = self.get_original_slices(idx)
        except Exception:
            slices = self.get_original_modified_slices(idx)

        if self.preprocess:
            preprocessed_slices = self.get_preprocessed_slices(idx)
            slices = torch.cat([slices, preprocessed_slices], dim=0)

        sample = {'slices': slices, 'label': label}
        return sample

    def get_original_slices(self, idx):
        slices = torch.zeros((self.num_slices, 1, self.image_size, self.image_size))

        patient_id = self.labels.loc[idx, 'ID'] # patient id in the dataset
        patient_dir = os.path.join(self.original_dataset_dir, patient_id)

        all_slices = sorted(os.listdir(patient_dir))
        middle_slices = all_slices[self.begin_at_slice: self.begin_at_slice + self.num_slices]

        for i, slice_file in enumerate(middle_slices):
            slice_file_path = os.path.join(patient_dir, slice_file)
            slice_data = pydicom.dcmread(slice_file_path).pixel_array.astype(float)
            slice_data *= 255.0 / np.max(slice_data) # normalize all images to [0-255]
            slice_data = slice_data.astype(np.uint8)

            if self.transform:
                slice_data = self.transform(np.expand_dims(slice_data, 2))

            slices[i] = slice_data
        return slices

    def get_preprocessed_slices(self, idx):
        slices = torch.zeros((self.num_slices, 1, self.image_size, self.image_size))

        patient_id = self.labels.loc[idx, 'ID'] # patient id in the dataset
        patient_dir = os.path.join(self.preprocessed_dataset_dir, patient_id)

        all_slices = sorted(os.listdir(patient_dir))
        middle_slices = all_slices[self.begin_at_slice: self.begin_at_slice + self.num_slices]

        for i, slice_file in enumerate(middle_slices):
            slice_file_path = os.path.join(patient_dir, slice_file)

            slice_data = io.imread(slice_file_path)
            slice_data = (slice_data * (255.0 / np.max(slice_data))) # normalize all images to [0-255]
            slice_data = slice_data.astype(np.uint8)

            if self.transform:
                slice_data = self.transform(np.expand_dims(slice_data, 2))

            slices[i] = slice_data
        return slices

    def get_original_modified_slices(self, idx):
        slices = torch.zeros((self.num_slices, 1, self.image_size, self.image_size))

        patient_id = self.labels.loc[idx, 'ID'] # patient id in the dataset
        patient_dir = os.path.join(self.original_dataset_dir, patient_id)

        all_slices = sorted(os.listdir(patient_dir))
        middle_slices = all_slices[self.begin_at_slice: self.begin_at_slice + self.num_slices]

        for i, slice_file in enumerate(middle_slices):
            slice_file_path = os.path.join(patient_dir, slice_file)

            slice_data = io.imread(slice_file_path)
            slice_data = (slice_data * (255.0 / np.max(slice_data))) # normalize all images to [0-255]
            slice_data = slice_data.astype(np.uint8)

            if self.transform:
                slice_data = self.transform(np.expand_dims(slice_data, 2))

            slices[i] = slice_data
        return slices

In [5]:
dataset = DementiaDataset(original_dataset_dir=config.original_dataset_dir,
                          preprocessed_dataset_dir=config.preprocessed_dataset_dir,
                          labels_file=config.labels_file,
                          label_type=config.label_type,
                          preprocess=config.preprocess,
                          begin_at_slice=config.dataset_begin_at_slice,
                          num_slices=config.dataset_num_slice, 
                          image_crop_size=config.image_crop_size
                          )

In [6]:
def get_data_loaders(dataset, train_indices, val_indices, test_indices, batch_size=config.batch_size,
                     train_transform=config.train_transform, validation_transform=config.validation_transform, test_transform=config.validation_transform):
    
    train_set = Subset(dataset, train_indices)
    val_set = Subset(dataset, val_indices)
    test_set = Subset(dataset, test_indices)

    train_set.dataset = copy(dataset)

    train_set.dataset.transform = train_transform
    val_set.dataset.transform = validation_transform
    test_set.dataset.transform = test_transform

    train_targets = torch.from_numpy(dataset.labels[dataset.label_type].values[train_set.indices])
    _, train_class_sample_counts = torch.unique(train_targets, sorted=True, return_counts=True)
    train_weights = 1. / train_class_sample_counts.float()
    train_samples_weights = train_weights[train_targets]
    train_sampler = WeightedRandomSampler(weights=train_samples_weights, num_samples=len(train_samples_weights), replacement=True)

    train_loader = DataLoader(train_set, batch_size=batch_size, sampler=train_sampler)
    validation_loader = DataLoader(val_set, batch_size=batch_size)
    test_loader = DataLoader(test_set, batch_size=batch_size)

    return train_loader, validation_loader, test_loader

In [7]:
class Model(nn.Module):
    def __init__(self, n_classes=config.n_classes, resnet_unet_path=None, L=64, D=16, K=1):
        super(Model, self).__init__()

        self.L = L
        self.D = D
        self.K = K

        self.n_classes = n_classes

        self.attention = nn.Sequential(
            nn.Linear(self.L, self.D),
            nn.Tanh(),
            nn.Linear(self.D, self.K)
        )


        self.classifier = nn.Sequential(
            nn.Linear(self.L * self.K, self.n_classes),
            nn.LogSoftmax(dim=1)
        )

        self.resnet_unet = None


        if resnet_unet_path:
            map_location = torch.device('cuda')
            if not config.cuda:
                map_location = torch.device('cpu')
            checkpoint = torch.load(resnet_unet_path, map_location=map_location)
            resnet_unet = ResNetUNet(n_class=1)
            resnet_unet.load_state_dict(checkpoint['state_dict'])
            self.resnet_unet = resnet_unet


        if self.resnet_unet:
            self.init_with_self_supervised_model()

        else: 
            self.init_basic_model()


    def init_basic_model(self):
        
        self.feature_extractor_part1 = nn.Sequential(
            OrderedDict([
                ('conv1', nn.Conv2d(1, 2, kernel_size=3, stride=1)),
                ('relu1', nn.ReLU()),
                ('pool1', nn.MaxPool2d(2)),

                ('conv2', nn.Conv2d(2, 4, kernel_size=3, stride=1)),
                ('relu2', nn.ReLU()),
                ('pool2', nn.MaxPool2d(2)),

                ('conv3', nn.Conv2d(4, 6, kernel_size=3, stride=1)),
                ('relu3', nn.ReLU()),
                ('pool3', nn.MaxPool2d(2)),

                ('conv4', nn.Conv2d(6, 8, kernel_size=3, stride=1)),
                ('relu4', nn.ReLU()),
                ('pool4', nn.MaxPool2d(2))])
        )

        self.feature_extractor_part2 = nn.Sequential(
            OrderedDict([
                ('linear', nn.Linear(8 * 6 * 6, self.L)),
                ('relu', nn.ReLU()),
            ])
        )


  
    def init_with_self_supervised_model(self):

        self.feature_extractor_part1 = nn.Sequential(self.resnet_unet)

        self.feature_extractor_part2 = nn.Sequential(
            OrderedDict([
                ('linear', nn.Linear(16 * 4 * 4, self.L)),
                ('relu', nn.ReLU()),
            ])
        )
        

    def forward(self, x):
        x = x.squeeze(0)

        H = self.feature_extractor_part1(x)
        H = H.view(-1, self.feature_extractor_part2.linear.in_features)
        H = self.feature_extractor_part2(H)  # NxL

        A = self.attention(H)  # NxK
        A = torch.transpose(A, 1, 0)  # KxN
        A = F.softmax(A, dim=1)


        M = torch.mm(A, H)  # KxL

        Y_prob = self.classifier(M)
        Y_hat = torch.argmax(Y_prob)
        
        return Y_prob, Y_hat, A

In [8]:
# Function for Xavier initialization of network weights.

def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
        print(m)
        nn.init.xavier_uniform_(m.weight)
        nn.init.constant_(m.bias, 0)

In [9]:
# Test everything is fine with the forward pass of our model. Also view number of total parameters as well as trainable parameters of the model.

model = Model(resnet_unet_path=config.self_supervised_model_path)

# model.apply(weights_init)

total_params = sum(p.numel() for p in model.parameters())
print(f'{total_params:,} total parameters.')
total_trainable_params = sum(
    p.numel() for p in model.parameters() if p.requires_grad)
print(f'{total_trainable_params:,} training parameters.')
inp = torch.rand([8, 1, 128, 128])
print(model(inp)[0].shape)

63,912 total parameters.
63,912 training parameters.
torch.Size([1, 4])


In [10]:
resnet_unet_path = config.self_supervised_model_path
def get_model(cuda=config.cuda, self_supervised=config.self_supervised, resnet_unet_path=resnet_unet_path, xavier=config.xavier_initialization):

    print('Init Model')
    
    if self_supervised:
        model = Model(resnet_unet_path=resnet_unet_path)
        
    else:
        model = Model()

    if xavier:
        model.apply(weights_init)


    if cuda:
        model.cuda()

    total_params = sum(p.numel() for p in model.parameters())
    print(f'{total_params:,} total parameters.')
    total_trainable_params = sum(
    p.numel() for p in model.parameters() if p.requires_grad)
    print(f'{total_trainable_params:,} training parameters.')

    return model

In [11]:
import shutil
try:
    shutil.rmtree(config.tensorboard_dir)
except FileNotFoundError:
    pass

In [12]:
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter('{}/Dementia_{}_Biomarker_Prediction'.format(config.tensorboard_dir, dataset.label_type))

In [20]:
from textwrap import wrap
import itertools
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import confusion_matrix



def plot_confusion_matrix(y_true, y_pred, labels, title='Confusion matrix'):
    ''' 
    Parameters:
        y_true: These are your true classification categories.
        y_pred: These are you predicted classification categories
        labels: This is a lit of labels which will be used to display the axis labels
        title='Confusion matrix': Title for your matrix

    Returns:
        fig: the figure of confusion matrix 
    '''
    cm = confusion_matrix(y_true, y_pred, labels=labels)

    np.set_printoptions(precision=2)

    fig = plt.figure(figsize=(7, 7), dpi=320, facecolor='w', edgecolor='k')
    ax = fig.add_subplot(1, 1, 1)
    im = ax.imshow(cm, cmap='Oranges')

    classes = labels
    classes = ['\n'.join(wrap(str(l), 40)) for l in classes]

    tick_marks = np.arange(len(classes))

    ax.set_xlabel('Predicted', fontsize=7)
    ax.set_xticks(tick_marks)
    c = ax.set_xticklabels(classes, fontsize=4, rotation=-90,  ha='center')
    ax.xaxis.set_label_position('bottom')
    ax.xaxis.tick_bottom()

    ax.set_ylabel('True Label', fontsize=7)
    ax.set_yticks(tick_marks)
    ax.set_yticklabels(classes, fontsize=4, va ='center')
    ax.yaxis.set_label_position('left')
    ax.yaxis.tick_left()

    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        ax.text(j, i, format(cm[i, j], 'd') if cm[i,j]!=0 else '.', horizontalalignment="center", fontsize=6, verticalalignment='center', color= "black")

    fig.set_tight_layout(True)
    
    return fig

In [21]:
def plot_instance_bag(split, data_loader):
    data_iter = iter(data_loader)
    sample = data_iter.next()

    images = sample['slices'][0]
    label = sample['label']

    img_grid = torchvision.utils.make_grid(images)

    writer.add_image('A bag of class {} instances/{}'.format(label, split), img_grid)

In [22]:
def save_checkpoint(state, split, epoch, save_path=config.save_dir):
    filename = os.path.join(save_path, 'best_{}_model_epoch_{}_split_{}.pth.tar'.format(dataset.label_type, epoch, split))
    torch.save(state, filename)

In [23]:
def report(phase, loss, y_true, y_pred, split, epoch, save_model=config.save_best_model):

    global best_score

    accuracy = accuracy_score(y_true, y_pred)

    writer.add_scalar('Loss/{}-{}'.format(phase, split), loss.item(), epoch)
    writer.add_scalar('Accuracy/{}-{}'.format(phase, split), accuracy, epoch)

    matrix = confusion_matrix(y_true, y_pred)
    class_accuracy = matrix.diagonal()/matrix.sum(axis=1)
    class_accuracy_dict = dict(enumerate(class_accuracy))
    class_accuracy_dict = {str(k):v for k,v in class_accuracy_dict.items()}
    writer.add_scalars('Class Accuracy/{}-{}'.format(phase, split), class_accuracy_dict, epoch)


    macro_f1_score = f1_score(y_true, y_pred, average='macro')
    micro_f1_score = f1_score(y_true, y_pred, average='micro')

    writer.add_scalar('Macro F-1 Score/{}-{}'.format(phase, split), macro_f1_score, epoch)
    writer.add_scalar('Micro F-1 Score/{}-{}'.format(phase, split), micro_f1_score, epoch)

    if epoch % config.verbose_epoch == 0:
        print('--------------------------------')
        print('Split: {}, Epoch: {}, Phase: {}, Loss: {:.4f}, accuracy: {:.4f}'.format(split, epoch, phase.upper(), loss.item(), accuracy))
        print('{} class accuracy: {}'.format(phase, class_accuracy))
        
        if phase == 'Validation':
            print('\n\n')


    if phase == 'Validation':
        is_best = macro_f1_score >= best_score
        best_score = max(best_score, macro_f1_score)
        
        if is_best:
            if save_model and best_score > 0.8:
                save_checkpoint({
                    'epoch': epoch,
                    'state_dict': model.state_dict(),
                    'best_score': best_score,
                    'optimizer' : optimizer.state_dict(),
                }, split, epoch)
            conf_matrix_plot = plot_confusion_matrix(y_true, y_pred, np.arange(n_classes))
            writer.add_figure('Confusion Matrix/{}-{}'.format(phase, split), conf_matrix_plot, epoch)
            writer.add_scalar('Best Score/{}-{}'.format(phase, split), best_score, epoch)
            writer.add_scalar('Best Acc/{}-{}'.format(phase, split), accuracy, epoch)

In [24]:
def train(split, epoch):
    model.train()

    y_true = []
    y_pred = []

    train_loss = 0.0

    optimizer.zero_grad()
    
    for batch_idx, sample in enumerate(train_loader):
        data = sample['slices']
        bag_label = sample['label']

        if cuda:
            data, bag_label = data.cuda(), bag_label.cuda()

        y_prob, y_hat, _ = model(data)

        loss = criterion(y_prob, bag_label)
        train_loss += loss

        y_true.append(bag_label[0].item())
        y_pred.append(y_hat.item())

    train_loss.backward()
    optimizer.step()


    train_loss /= len(y_true)
    
    report(phase='Training', loss=train_loss, y_true=y_true, y_pred=y_pred, split=split, epoch=epoch)

    validate(split, epoch)

In [25]:
def validate(split, epoch):

    model.eval()

    y_true = []
    y_pred = []

    val_loss = 0.0

    with torch.no_grad():
        for batch_idx, sample in enumerate(validation_loader):
            data = sample['slices']
            bag_label = sample['label']
            if cuda:
                data, bag_label = data.cuda(), bag_label.cuda()
        
            y_prob, y_hat, _ = model(data)

            loss = criterion(y_prob, bag_label)
            val_loss += loss

            y_true.append(bag_label[0].item())
            y_pred.append(y_hat.item())

    val_loss /= len(y_true)

    lr_scheduler.step(val_loss)

    report(phase='Validation', loss=val_loss, y_true=y_true, y_pred=y_pred, split=split, epoch=epoch)

In [50]:
%tensorboard --logdir=log_dir/runs

UsageError: Line magic function `%tensorboard` not found.


In [26]:
from sklearn.model_selection import KFold
from sklearn.model_selection import train_test_split


kf = KFold(n_splits=config.n_folds, shuffle=True)
dataset_size = len(dataset)

best_scores = []

cuda = False
if cuda:
    print('\nGPU is ON!')

indices = np.arange(dataset_size)
train_indices, test_indices = train_test_split(indices, test_size=config.test_size)

for split, (train_indices, val_indices) in enumerate(kf.split(indices)):
    
    print('\n\n\n-----------SPLIT {} ------------\n'.format(split))

    batch_size = config.batch_size
    lr = config.lr
    epochs = config.epochs
    verbose_epoch = config.verbose_epoch
    weight_decay = config.regularization
    patience = config.reduce_lr_patience
    lr_decay = config.lr_decay

    n_classes = np.unique(dataset.labels[dataset.label_type].values).shape[0]

    train_loader, validation_loader, test_loader = get_data_loaders(dataset=dataset, batch_size=batch_size,
                                                                    train_indices=train_indices, val_indices=val_indices, test_indices=val_indices)

    plot_instance_bag(split, train_loader)

    model = get_model(cuda=cuda, self_supervised=config.self_supervised)
    optimizer = optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999), weight_decay=weight_decay)
    criterion = nn.NLLLoss()
    lr_scheduler = ReduceLROnPlateau(optimizer, "min", factor=lr_decay, patience=patience, verbose=config.verbose)

    best_score = 0.0

    print('Start Training')

    for epoch in range(1, epochs):
        train(split, epoch)
    
    best_scores.append(best_score)






-----------SPLIT 0 ------------

Init Model
63,912 total parameters.
63,912 training parameters.
Start Training


  # This is added back by InteractiveShellApp.init_path()


--------------------------------
Split: 0, Epoch: 20, Phase: TRAINING, Loss: 0.5737, accuracy: 0.7778
Training class accuracy: [1.  0.5 1.  0. ]
--------------------------------
Split: 0, Epoch: 20, Phase: VALIDATION, Loss: 0.1162, accuracy: 1.0000
Validation class accuracy: [1.]



--------------------------------
Split: 0, Epoch: 40, Phase: TRAINING, Loss: 0.2535, accuracy: 0.8889
Training class accuracy: [0.67 1.   1.  ]
--------------------------------
Split: 0, Epoch: 40, Phase: VALIDATION, Loss: 3.4138, accuracy: 0.0000
Validation class accuracy: [ 0. nan]



Epoch    45: reducing learning rate of group 0 to 6.4000e-03.
--------------------------------
Split: 0, Epoch: 60, Phase: TRAINING, Loss: 0.4022, accuracy: 0.7778
Training class accuracy: [0. 1. 1. 1.]
--------------------------------
Split: 0, Epoch: 60, Phase: VALIDATION, Loss: 2.8272, accuracy: 0.0000
Validation class accuracy: [ 0. nan]



Epoch    71: reducing learning rate of group 0 to 5.1200e-03.


KeyboardInterrupt: 

In [27]:
print(np.mean(best_scores))

nan


  out=out, **kwargs)
  ret = ret.dtype.type(ret / rcount)
