In [47]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader


from sklearn.metrics import confusion_matrix
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import os
from scipy.signal import butter, sosfilt, sosfreqz, resample, sosfiltfilt
import scipy.io
import random
import json
from livelossplot import PlotLosses


In [37]:
def butter_bandpass(lowcut, highcut, fs, order=5):
        nyq = 0.5 * fs
        low = lowcut / nyq
        high = highcut / nyq
        sos = butter(order, [low, high], analog=False, btype='band', output='sos')
        return sos

def butter_bandpass_filter(data, lowcut, highcut, fs, order=2):
        sos = butter_bandpass(lowcut, highcut, fs, order=order)
        y = sosfiltfilt(sos, data)
        return y

class MODA_proc(Dataset):
    def __init__(self, input_path = '/scratch/s174411/center_width/1D_MASS_MODA_processed/input/', label_path = '/scratch/s174411/center_width/1D_MASS_MODA_processed/labels/'):
        self.input_path = input_path
        self.label_path = label_path
        self.input_dict = {}
        self.label_dict = {}
        temp_input_list = []
        temp_output_list = []
        for root, dirs, files in os.walk(self.input_path):
            for name in files:
                if name.endswith('npy'):
                    temp_input_list.append(os.path.join(root, name))
                    self.input_dict[int(name[:-4])] = os.path.join(root, name)

        for root, dirs, files in os.walk(self.label_path):
            for name in files:
                if name.endswith('json'):
                    temp_output_list.append(os.path.join(root, name))
                    self.label_dict[int(name[:-5])] = os.path.join(root, name)

        self.master_path_list = []
        
        for in_path in temp_input_list:
            for la_path in temp_output_list:
                if in_path[-16:-3] == la_path[-17:-4]:
                    self.master_path_list.append((in_path,la_path))
                


    def __len__(self):
        return len(self.master_path_list)

    def __getitem__(self, idx):
        #print(self.input_dict[idx])
        model_input, labels = self.master_path_list[idx]
        eeg_input = np.load(model_input)
        eeg_input = resample(eeg_input, 100*115)
        eeg_input = butter_bandpass_filter(eeg_input, 0.3, 30, 100, 10)
        # Standardize
        eeg_input = (eeg_input - np.mean(eeg_input))/np.std(eeg_input)

        eeg_input = torch.FloatTensor(eeg_input)
        eeg_input = eeg_input[None, :]

        #print('dataloader shape')

        #image = np.array(cv2.imread(self.input_dict[idx]))
        #image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB)
        #image = read_image(self.input_dict[idx])
        f = open(labels)
        
        labels = (json.load(f))
        f.close()


        input_length = int(100*115)
        sumo_label_format = np.zeros(input_length)
        for bbox in labels['boxes']:
            
            box_start = bbox[0] - bbox[1]/2
            box_end = bbox[0] + bbox[1]/2
            box_start_scaled = int(box_start * input_length)
            box_end_scaled = int(box_end * input_length)
            sumo_label_format[box_start_scaled:box_end_scaled] = 1

        spindle = False
        for sample in sumo_label_format:
            if sample == 1:
                spindle = True

        if not spindle:
            print("found")
        #print(type(sumo_label_format))
        sumo_label_format = torch.FloatTensor(sumo_label_format)
        #print(sumo_label_format.shape)

        return eeg_input, sumo_label_format

In [38]:
class u_net_backbone(nn.Module):
    def __init__(self):
        super().__init__()
        n_groups = 8

        # DOWNSAMPLING

        #self.pad = nn.Conv1d(1, 1, kernel_size = 1, stride = 2)
        
        
        # ENCODER GROUND LEVEL (LEVEL 1)
        self.conv_1_1 = nn.Conv1d(1, 16, kernel_size = 5, dilation = 2, padding = 'same')
        self.batch_1_1 = nn.GroupNorm(n_groups, 16)

        self.conv_1_2 = nn.Conv1d(16, 16, kernel_size = 5, dilation = 2, padding = 'same')
        self.batch_1_2 = nn.GroupNorm(n_groups, 16)

        # LEVEL 2
        self.pool_1 = nn.MaxPool1d(kernel_size = 4)

        self.conv_2_1 = nn.Conv1d(16, 32, kernel_size = 5, dilation = 2, padding = 'same')
        self.batch_2_1 = nn.GroupNorm(n_groups, 32)

        self.conv_2_2 = nn.Conv1d(32, 32, kernel_size = 5, dilation = 2, padding = 'same')
        self.batch_2_2 = nn.GroupNorm(n_groups, 32)


        # ENCODER BOTTOM LEVEL
        self.pool_2 = nn.MaxPool1d(kernel_size = 4)

        self.conv_3_1 = nn.Conv1d(32, 64, kernel_size = 5, dilation = 2, padding = 'same')
        self.batch_3_1 = nn.GroupNorm(n_groups, 64)

        self.conv_3_2 = nn.Conv1d(64, 64, kernel_size = 5, dilation = 2, padding = 'same')
        self.batch_3_2 = nn.GroupNorm(n_groups, 64)

        # DECODER LEVEL 2
        # UPSAMPLING
        self.upsample_2 = nn.Upsample(scale_factor = 4, mode = 'nearest')
        self.conv_2_3 = nn.Conv1d(64, 32, kernel_size = 4, dilation = 1, padding = 'same')

        # 
        self.conv_2_4 = nn.Conv1d(64, 32, kernel_size = 5, dilation = 1, padding = 'same')
        self.batch_2_4 = nn.GroupNorm(n_groups, 32)

        self.conv_2_5 = nn.Conv1d(32, 32, kernel_size = 5, dilation = 1, padding = 'same')
        self.batch_2_5 = nn.GroupNorm(n_groups, 32)

        
        # DECODER GROUND LEVEL (LEVEL 1)
        # UPSAMPLING
        self.upsample_1 = nn.Upsample(scale_factor = 4, mode = 'nearest')
        self.conv_1_3 = nn.Conv1d(32, 16, kernel_size = 4, dilation = 1, padding = 'same')

        # 
        self.conv_1_4 = nn.Conv1d(32, 16, kernel_size = 5, dilation = 1, padding = 'same')
        self.batch_1_4 = nn.GroupNorm(n_groups, 16)

        self.conv_1_5 = nn.Conv1d(16, 16, kernel_size = 5, dilation = 1, padding = 'same')
        self.batch_1_5 = nn.GroupNorm(n_groups, 16)


        self.conv_1_6 = nn.Conv1d(16, 2, kernel_size = 1, dilation = 1)
        
    def forward(self, tensor_list):
        #print(tensor_list.shape)
        # DOWNSAMPLING
        #downsampled_input = self.downsample(tensor_list)
        #extrapolation = int(np.ceil(tensor_list.shape[1] / (4*4*4)) * (4*4*4) - tensor_list.shape[1])
        padded_input = F.pad(tensor_list, (2, 2), mode='reflect')
        #padded_input = tensor_list

        # GROUND LEVEL FORWARD
        level_1 = self.batch_1_1(F.relu(self.conv_1_1(padded_input)))
        level_1 = self.batch_1_2(F.relu(self.conv_1_2(level_1)))

        # POOLING AND LEVEL 2
        level_1_down = self.pool_1(level_1)
        level_2 = self.batch_2_1(F.relu(self.conv_2_1(level_1_down)))
        level_2 = self.batch_2_2(F.relu(self.conv_2_2(level_2)))

        # POOLING AND BOTTOM LEVEL
        level_2_down = self.pool_2(level_2)
        level_3 = self.batch_3_1(F.relu(self.conv_3_1(level_2_down)))
        level_3 = self.batch_3_2(F.relu(self.conv_3_2(level_3)))

        # UPSAMPLING AND FEATURE FUSION (LEVEL 2)
        level_3_upsampled = self.upsample_2(level_3)
        level_2_up = self.conv_2_3(level_3_upsampled)
        
        #print(level_2.shape)
        #print(level_3_upsampled.shape)
        dec_level_2 = torch.cat((level_2, level_2_up), 1)

        dec_level_2 = self.batch_2_4(F.relu(self.conv_2_4(dec_level_2)))
        dec_level_2 = self.batch_2_5(F.relu(self.conv_2_5(dec_level_2)))

        # UPSAMPLING AND FEATURE FUSION (UPPER LEVEL)
        level_2_upsampled = self.upsample_1(dec_level_2)
        level_1_up = self.conv_1_3(level_2_upsampled)
        
        dec_level_1 = torch.cat((level_1, level_1_up), 1)

        dec_level_1 = self.batch_1_4(F.relu(self.conv_1_4(dec_level_1)))
        dec_level_1 = self.batch_1_5(F.relu(self.conv_1_5(dec_level_1)))

        dec_level_1 = self.conv_1_6(dec_level_1)


        diff = dec_level_1.shape[2] - tensor_list.shape[2]
        crop_dims = [diff // 2, diff // 2 + diff % 2]

        if crop_dims[1] == 0:
            dec_level_1 = dec_level_1[:, :, crop_dims[0]:]
        else:
            dec_level_1 = dec_level_1[:, :, crop_dims[0]:-crop_dims[1]]

        # Not used when calculating loss
        #smooth = F.avg_pool1d(dec_level_1, 42, stride=1)

        return dec_level_1


In [39]:
class GeneralizedDiceLoss(nn.Module):
    """
    Compute the generalised Dice loss defined in:
        Sudre, C. et al. (2017) Generalised Dice overlap as a deep learning loss function for highly unbalanced
        segmentations. DLMIA 2017. https://arxiv.org/pdf/1707.03237.pdf
    Adapted from:
        https://github.com/Project-MONAI/MONAI/blob/0.5.2/monai/losses/dice.py#L216

    Parameters
    ----------
    reduction : {'mean', 'sum', 'none'}, optional
        Specifies the reduction to apply to the output. The sum of the output will be divided by the number of
        elements in the output ('mean'), the output will be summed ('sum') or no reduction will be applied ('none').
        Default is 'mean'.
    smooth : float, optional
        A small constant added to the numerator and denominator to avoid zero and nan.
    use_weight: bool, optional
        When true, use class weights as originally proposed by Sudre et al.
    softmax : bool, optional
        When True, apply a softmax function to the prediction.
    """

    def __init__(self, reduction: str = 'mean', smooth: float = 1e-5, use_weight: bool = True, softmax: bool = True):
        super(GeneralizedDiceLoss, self).__init__()

        self.reduction = reduction
        self.smooth = smooth
        self.use_weight = use_weight
        self.softmax = softmax

    @staticmethod
    def get_onehot_encoding(y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
        shp_x = y_pred.shape
        shp_y = y_true.shape

        with torch.no_grad():
            y_true = y_true.long()

            if shp_x == shp_y:
                return y_true  # y_true is already in one hot encoding
            else:
                y_onehot = F.one_hot(y_true, num_classes=shp_x[1])  # one hot encoding in format [N,K,C]
                y_onehot = y_onehot.permute(0, 2, 1)  # transform to format [N,C,K]

                return y_onehot

    def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
        """
        Apply the forward pass of the generalized dice loss function.

        Parameters
        ----------
        y_pred : torch.Tensor
            The predicted logits (or probabilities if `self.softmax` is False) in shape [N,C,K] where N is the batch
            size, C the number of classes and K the number of elements/steps in each observation.
        y_true : torch.Tensor
            The target labels in shape [N,K] or as one hot encoded vector in format [N,C,K].

        Returns
        -------
        loss : torch.Tensor
            The calculated loss as scalar or in shape [N] if `self.reduction` is 'none'.

        Raises
        ------
        ValueError
            If `self.reduction` is not one of {'mean', 'sum', 'none'}.
        """

        if self.softmax:
            y_pred = F.softmax(y_pred, dim=1)

        y_onehot = self.get_onehot_encoding(y_pred, y_true)

        # calculate intersection and union and sum them over the K steps in each observation; shape [N,C]
        intersection = (y_pred * y_onehot).sum(dim=2)
        union = (y_pred + y_onehot).sum(dim=2)

        if self.use_weight:
            # class weights using the inverse of each label volume; shape [N,C]
            w = 1 / y_onehot.sum(dim=2)**2
            # if one class doesn't contain any labels, its weight is set to 1.0 (as if it would contain exactly one
            # label)
            w[torch.isinf(w)] = 1.0

            # apply the weights on intersection and union and sum over the classes; shape [N]
            intersection = (w * intersection).sum(dim=1)
            union = (w * union).sum(dim=1)
        else:
            # sum intersection and union over the classes; shape [N]
            intersection = intersection.sum(dim=1)
            union = union.sum(dim=1)

        # calculate dice coefficient and generalized dice loss using a small number to prevent zero/nan; shape [N]
        dice = (2.0 * intersection + self.smooth) / (union + self.smooth)
        gdl = 1.0 - dice

        if self.reduction == 'mean':
            gdl = gdl.mean()  # average over the batch and channel; scalar
        elif self.reduction == 'sum':
            gdl = gdl.sum()  # sum over the batch and channel; scalar
        elif self.reduction == 'none':
            pass  # unmodified losses per batch; shape [N]
        else:
            raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')

        return gdl

In [40]:
def refine_spindle_list(list_of_spindles, target = False):
    refined_spindle_list = []
    for spindle in list_of_spindles:
        start = spindle[0]/(100)
        end = spindle[1]/(100)
        if ((end-start) < 0.3 and not target):
            continue
        else:
            refined_spindle_list.append((start, end))

    return refined_spindle_list



def iou(out,tar):
    out_box_start = out[0]
    out_box_end = out[1]

    tar_box_start = tar[0]
    tar_box_end = tar[1]

    overlap_start = max(out_box_start, tar_box_start)
    overlap_end = min(out_box_end, tar_box_end)
    union_start = min(out_box_start, tar_box_start)
    union_end = max(out_box_end, tar_box_end)

    return ((overlap_end - overlap_start)/(union_end-union_start))

def overlap(out, tar, threshold):
    out_box_start = out[0]
    out_box_end = out[1]

    tar_box_start = tar[0]
    tar_box_end = tar[1]

    overlap_start = max(out_box_start, tar_box_start)
    overlap_end = min(out_box_end, tar_box_end)
    union_start = min(out_box_start, tar_box_start)
    union_end = max(out_box_end, tar_box_end)

    if (overlap_end - overlap_start) >= (threshold * (tar_box_end-tar_box_start)):
        return True
    else:
        return False


In [41]:
def f1_score(outputs, targets):
    
    # Loop through batches to compute F1 score through training.

    
    F1_list = []
    temp_tp = 0
    total_spindle_count = 0
    total_pred_count = 0

    
    
    for i in range(outputs.shape[0]):

        pred = out_to_vector(outputs[i,:,:].cpu())

        pred_spindles = vector_to_spindle_list(pred)
        #print(len(pred_spindles))

        #pred_spindles = refine_spindle_list(pred_spindles)
        #print(len(pred_spindles))

        
        TP = 0

        target = targets[i]
        t_spindles = vector_to_spindle_list(target.cpu())
        #t_spindles = refine_spindle_list(t_spindles)

        total_spindle_count += len(t_spindles)
        batch_spindle_count = len(t_spindles)

        if len(t_spindles) == 0:
            spindle = False
            for l, sample in enumerate(target):
                if sample == 1:
                    spindle = True
                    print(l)
            if spindle:
                print('not found')
                print(vector_to_spindle_list(target.cpu(), debug = True))
                print(len(target))
        batch_pred_count = len(pred_spindles)
        for k in range(len(t_spindles)):
            tar_box = t_spindles[k]
            #print(tar_box)
            
            best_match = -1

            if len(pred_spindles) == 0:
                continue
            
            for j,out_box in enumerate(pred_spindles):

                if iou(out_box, tar_box) > iou(pred_spindles[best_match], tar_box):
                    best_match = j
            #print(pred_spindles[best_match])
            #print(tar_box)
            if iou(pred_spindles[best_match],tar_box) > 0.2:
                TP +=1
            

        FP = batch_pred_count - TP
        FN = batch_spindle_count - TP
        
        if (TP + FP) == 0:
            PRECISION = TP
        else:
            PRECISION = (TP)/(TP + FP)
        
        RECALL = (TP)/(TP+FN)

        if (PRECISION + RECALL) == 0:
            F1_list.append(0)
        else:
            F1_list.append((2 * PRECISION * RECALL)/(PRECISION + RECALL))
        
        temp_tp += TP


    F1_list = np.asarray(F1_list)
    #print("F1 MEAN:", np.mean(F1_list), " F1 STD:", np.std(F1_list), " TP:", temp_tp, " FP:", FP, " Number of spindles:", total_spindle_count)
    return (np.mean(F1_list), np.std(F1_list), temp_tp, FP, total_spindle_count)

In [42]:
def out_to_vector(output):
    moving_avg = 42
    s = moving_avg - 1
    vector = F.pad(output, (s // 2, s // 2 + s % 2), mode='constant', value=0)

    vector_smoothed =  F.avg_pool1d(vector, moving_avg, stride=1)
    vector_softmax = F.softmax(vector_smoothed, dim=1)
    top_p, top_class = vector_softmax.topk(1, dim = 0)

    return top_class[0]


def vector_to_spindle_list(vector, debug = False):
    
    prev_class = 0
    list_of_spindles = []
    vector = vector.numpy()
    for i, instance_class in enumerate(vector):
        if (instance_class == 1 and prev_class == 1):
            prev_class = 1
            if (i+1 == len(vector)):
                spindle.append(i)
                list_of_spindles.append(spindle)
            continue

        if (instance_class == 1 and prev_class == 0):
            spindle = []
            spindle.append(i)
            prev_class = 1
            continue

        if (instance_class == 0 and prev_class == 1):
            spindle.append(i)
            list_of_spindles.append(spindle)
            prev_class = 0
            continue

        prev_class = 0
    return list_of_spindles
        
    #print(vector)

In [None]:
dataloaders = {
    "train": trainloader,
    "validation": testloader
}

In [None]:
dataset_train = MODA_proc(input_path = '/scratch/s174411/full_segments/TRAIN/input/', label_path = '/scratch/s174411/full_segments/TRAIN/labels/')
dataset_val = MODA_proc(input_path = '/scratch/s174411/full_segments/VAL/input/', label_path = '/scratch/s174411/full_segments/VAL/labels/')

data_loader_train = DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
data_loader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

In [None]:
def train_model(model, criterion, optimizer, num_epochs=10):
    liveloss = PlotLosses()
    model = model.to(device)
    
    prev_f1 = 0
    
    for epoch in range(num_epochs):
        logs = {}
        f1_mean_run = []
        f1_std_run = []
        TP_run = []
        FP_run = []
        total_spindle_run = []
        for phase in ['train', 'validation']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                outputs = model(inputs)
                loss = criterion(outputs, labels)

                if phase == 'train':
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                _, preds = torch.max(outputs, 1)
                running_loss += loss.detach() * inputs.size(0)
                if (epoch % 100 == 0):
                    f1_mean, f1_std, TP, FP, total_spindle_count = f1_score(outputs, labels)
                    f1_mean_run.append(f1_mean)
                    f1_std_run.append(f1_std)
                    TP_run.append(TP)
                    FP_run.append(FP)
                    total_spindle_run.append(total_spindle_count)
                    prev_f1 = sum(f1_mean_run)/len(f1_mean_run)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.float() / len(dataloaders[phase].dataset)
            
            prefix = ''
            if phase == 'validation':
                prefix = 'val_'

            logs[prefix + 'log loss'] = epoch_loss.item()
            logs[prefix + 'accuracy'] = epoch_acc.item()
        
        liveloss.update(logs)
        liveloss.send()

In [None]:
model = u_net_backbone()
criterion = GeneralizedDiceLoss()
optimizer = optim.Adam(net.parameters(), lr=0.005)
train_model(model, criterion, optimizer, num_epochs=20)

In [45]:


# Loading data, setting up GPU use, setting up variables for model training
def main(BATCH_SIZE = 12, EPOCHS = 801):
    




   

    

    training_loss = []
    validation_loss = []
    for j, epoch in enumerate(range(EPOCHS)):  # loop over the dataset multiple times
        net.train()

        running_loss = []
        
        for i, batch in enumerate(data_loader_train):
            
                
            model_input, labels = batch
            model_input = model_input.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = net(model_input)

            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss.append(loss.item())

            

        print(f"EPOCH:{epoch}")
        print("TRAINING")
        print("Loss: ", round(sum(running_loss)/len(running_loss), 6))
        training_loss.append(sum(running_loss)/len(running_loss))
        
        if (epoch % 100 == 0):
            print("F1 MEAN:", round(sum(f1_mean_run)/len(f1_mean_run), 6), " F1 STD:", round(sum(f1_std_run)/len(f1_std_run), 6), " TP:", sum(TP_run), " FP:", sum(FP_run),
                " Number of spindles:", sum(total_spindle_run))

        net.eval()
        
        running_loss = []
        print("VALIDATION")
        running_loss = []

        for i, batch in enumerate(data_loader_val): 
            model_input, labels = batch
            model_input = model_input.to(device)
            labels = labels.to(device)

            outputs = net(model_input)

            loss = criterion(outputs, labels)

            running_loss.append(loss.item())

            if (epoch % 100 == 0):
                f1_mean, f1_std, TP, FP, total_spindle_count = f1_score(outputs, labels)
                f1_mean_run.append(f1_mean)
                f1_std_run.append(f1_std)
                TP_run.append(TP)
                FP_run.append(FP)
                total_spindle_run.append(total_spindle_count)


        print("Loss: ", round(sum(running_loss)/len(running_loss), 6))
        if (epoch % 100 == 0):
            print(loss.item())
            print("F1 MEAN:", round(sum(f1_mean_run)/len(f1_mean_run), 6), " F1 STD:", round(sum(f1_std_run)/len(f1_std_run), 6), " TP:", sum(TP_run), " FP:", sum(FP_run),
                " Number of spindles:", sum(total_spindle_run))
        print("")
        
        validation_loss.append(sum(running_loss)/len(running_loss))



main()
#torch.save(net, '/home/marius/Documents/OneDrive/MSc/StartUP/Code/m1_stats_features.pt')
print('Finished Training')

EPOCH:0
TRAINING
Loss:  0.744606
F1 MEAN: 0.290685  F1 STD: 0.256486  TP: 1912  FP: 618  Number of spindles: 4172
VALIDATION
Loss:  0.741345
0.7907912731170654
F1 MEAN: 0.261787  F1 STD: 0.253552  TP: 347  FP: 130  Number of spindles: 842

EPOCH:1
TRAINING
Loss:  0.692779
VALIDATION
Loss:  0.729722

EPOCH:2
TRAINING
Loss:  0.691995
VALIDATION
Loss:  0.726978

EPOCH:3
TRAINING
Loss:  0.69353
VALIDATION
Loss:  0.696772

EPOCH:4
TRAINING
Loss:  0.688948
VALIDATION
Loss:  0.748856

EPOCH:5
TRAINING
Loss:  0.685562
VALIDATION
Loss:  0.737462

EPOCH:6
TRAINING
Loss:  0.683325
VALIDATION
Loss:  0.732291

EPOCH:7
TRAINING
Loss:  0.680534
VALIDATION
Loss:  0.692674

EPOCH:8
TRAINING
Loss:  0.679072
VALIDATION
Loss:  0.702669

EPOCH:9
TRAINING
Loss:  0.674556
VALIDATION
Loss:  0.72547

EPOCH:10
TRAINING
Loss:  0.674944
VALIDATION
Loss:  0.721605

EPOCH:11
TRAINING
Loss:  0.673518
VALIDATION
Loss:  0.737867

EPOCH:12
TRAINING
Loss:  0.674038
VALIDATION
Loss:  0.726156

EPOCH:13
TRAINING
Loss:  0.

KeyboardInterrupt: 