# Code to do 10-fold cross validation on the 3-class and 5-class data subsets from Aliper et al. 

## requires fastai version 0.7

#### Install instructions here: https://forums.fast.ai/t/fastai-v0-7-install-issues-thread/24652

In [1]:
import matplotlib.pyplot as plt
%reload_ext autoreload
%autoreload 2
%matplotlib inline
from fastai.imports import *
from fastai.torch_imports import *
from fastai.transforms import *
from fastai.conv_learner import *
from fastai.model import *
from fastai.dataset import *
from fastai.sgdr import *
from fastai.plots import *
from random import sample
from itertools import chain
from sklearn.metrics import confusion_matrix, roc_curve, roc_auc_score, average_precision_score
from sklearn.metrics import matthews_corrcoef, balanced_accuracy_score, accuracy_score
from sklearn.preprocessing import OneHotEncoder, LabelBinarizer, MultiLabelBinarizer
from sklearn.preprocessing import label_binarize
from sklearn.model_selection import StratifiedKFold
import shutil
import csv
import time

#### set some variables and set your GPU if you have multiple

In [2]:
torch.cuda.set_device(1)   ### use first GPU if you have many
PATH = "data/data_alipermodel_copy/pics/" ### path to where your pictures are downloaded and the .csv files with val sets
sz = 150  ### resize images to this px by px
arch = resnext101_64  ### pre-trained network choice

## functions


In [3]:
def get_val_idx_fromfile(validx_csv):
    validx_df =pd.read_csv(validx_csv, header=None)
    return validx_df[0].tolist()
    
def get_data(sz, bs, val_idxs, label_csv): # sz: image size, bs: batch size
    tfms = tfms_from_model(arch, sz, aug_tfms=transforms_top_down, max_zoom=1.1)
    data = ImageClassifierData.from_csv(PATH, 'train', label_csv, val_idxs=val_idxs, 
                                        suffix='.png', 
                                        tfms=tfms, 
                                        num_workers=1,
                                        bs=bs)
    return data if sz > 300 else data.resize(150, 'tmp')

### Function to compute multiclass ROC AUC score from model predictions
def multiclass_roc_auc_score(y_test, y_probs, AVERAGE="weighted"):
    oh = OneHotEncoder(sparse=False, categories='auto')
    yt = oh.fit_transform(y_test.reshape(-1,1))
    return roc_auc_score(yt, y_probs, average=AVERAGE)

### File-based validation sets wont work for nested CV
#### Make the new subfiles with nested CV label csv and adjusted indexes

In [4]:
label_csv = f'{PATH}12cls_aliper.csv'
valididx_base = '12cls_aliper_10fold'

In [5]:
class MyFilter(object):
    def __init__(self, level):
        self.__level = level
    def filter(self, logRecord):
        return logRecord.levelno <= self.__level

In [10]:
import logging
LOG_FILENAME = '/home/jgmeyer2/fastai/courses/dl1/data/data_alipermodel_copy/pics/10foldCV_3cls5cls_unfrozen_bs25.log'
logger = logging.getLogger()
fhandler = logging.FileHandler(filename=LOG_FILENAME, mode='a')
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
fhandler.setFormatter(formatter)
logger.addHandler(fhandler)
logger.setLevel(logging.INFO)
fhandler.addFilter(MyFilter(logging.INFO))
#logging.info('This message should go to the log file')

In [11]:
st = time.time()
### build nested cv loop
n_classes = 3

label_csv = f'{PATH}3cls_aliper.csv'
outer_valididx_base = '3cls_aliper_10fold'

allfolds = [x for x in list(range(10))]
acc = []
roc = []
mcc = []
bac = []
avp = []

logging.info('starting 10-fold CV on '+str(n_classes)+' class aliper data')
for tmptest in allfolds:  ### outer loop setting each single fold as holdout
    shutil.rmtree(f'{PATH}tmp', ignore_errors=True) # delete the tmp folder every cycle to ensure no weird behavior
    val_idxs = get_val_idx_fromfile(f'{PATH}'+outer_valididx_base+str(tmptest)+'.csv')
    data = get_data(sz, 25, val_idxs, label_csv)
    learn = ConvLearner.pretrained(arch, data, precompute=False, ps= 0.4)
    learn.unfreeze()
    val_loss, val_acc = learn.fit([1e-4,1e-3,1e-2], 7, cycle_len=1, cycle_mult=2)
    ## make predictions with test-time augmentation
    log_preds, y = learn.TTA()
    probs = np.mean(np.exp(log_preds),0)
    preds = np.argmax(probs, axis=1)
    #record stuff
    acc.append(accuracy_score(y, preds))
    bac.append(balanced_accuracy_score(y, preds)) # balanced accuracy
    mcc.append(matthews_corrcoef(y, preds)) #MCC
    if(tmptest <12):
        roc.append(multiclass_roc_auc_score(y, probs, "weighted"))
        Y = label_binarize(y, classes = [x for x in range(0,n_classes)])
        avp.append(average_precision_score(Y, probs, average="weighted"))
    #total_models_trained +=1
    logging.info('accuracy on fold #' +str(tmptest)+' is '+ str(acc[tmptest]))
    logging.info('balanced accuracy on fold #' + str(tmptest)+ ' is '+ str(bac[tmptest]))
    logging.info('MCC on fold #' + str(tmptest)+ ' is '+ str(roc[tmptest]))
    logging.info('ROC AUC on fold #' + str(tmptest)+ ' is '+ str(roc[tmptest]))
    logging.info('ave. prec on fold #' + str(tmptest)+ ' is '+ str(avp[tmptest]))
    logging.info('####################################')
    
logging.info('mean accuracy = '+ str(np.mean(acc)))
logging.info('mean balanced accuracy = '+ str(np.mean(bac)))
logging.info('mean MCC is '+ str(np.mean(mcc)))
logging.info('mean ROC AUC = '+ str(np.mean(roc)))
logging.info('mean ave. prec =' + str(np.mean(avp)))
et = time.time()
print('total time ='+str(et-st))

HBox(children=(IntProgress(value=0, max=6), HTML(value='')))

                                                  

HBox(children=(IntProgress(value=0, description='Epoch', max=127), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                 
    0      1.33237    1.182533   0.325581  
    1      1.242234   1.036044   0.44186                  
    2      1.16532    1.03958    0.55814                  
    3      1.128431   1.062678   0.44186                  
    4      1.108025   0.980189   0.581395                 
    5      1.03884    0.836923   0.627907                 
    6      0.969663   0.780915   0.72093                   
    7      0.98726    1.089871   0.465116                  
    8      0.965957   0.936695   0.581395                  
    9      0.906036   0.707208   0.674419                  
    10     0.848988   0.775071   0.604651                  
    11     0.797699   0.641664   0.72093                   
    12     0.74682    0.663385   0.744186                  
    13     0.699122   0.697147   0.674419                  
    14     0.655766   0.679011   0.697674                  
    15     0.648683   0.761751   0.72093                   
  

HBox(children=(IntProgress(value=0, max=6), HTML(value='')))

                                                  

HBox(children=(IntProgress(value=0, description='Epoch', max=127), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                 
    0      1.272197   0.982015   0.428571  
    1      1.247568   1.447014   0.47619                  
    2      1.195653   1.110241   0.47619                  
    3      1.129914   1.071004   0.5                      
    4      1.037555   1.024305   0.595238                 
    5      1.010983   1.07072    0.619048                 
    6      0.952051   0.997591   0.619048                  
    7      0.895588   0.95772    0.595238                  
    8      0.880688   1.026903   0.595238                  
    9      0.83905    0.926534   0.642857                  
    10     0.794788   0.61429    0.714286                  
    11     0.748639   0.649726   0.714286                  
    12     0.700538   0.739416   0.714286                  
    13     0.659772   0.680547   0.738095                  
    14     0.621283   0.684346   0.738095                  
    15     0.588622   0.733603   0.714286                  
  

HBox(children=(IntProgress(value=0, max=6), HTML(value='')))

                                                  

HBox(children=(IntProgress(value=0, description='Epoch', max=127), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                 
    0      1.331405   1.153692   0.341463  
    1      1.274372   0.959704   0.585366                 
    2      1.170138   0.838468   0.536585                 
    3      1.112892   1.332605   0.439024                 
    4      1.039823   0.731319   0.804878                 
    5      0.954956   0.754532   0.682927                  
    6      0.91918    0.844608   0.585366                  
    7      0.891968   0.783617   0.634146                  
    8      0.870408   0.862247   0.658537                  
    9      0.812559   1.383646   0.634146                  
    10     0.78409    1.107737   0.609756                  
    11     0.725189   1.076337   0.609756                  
    12     0.657428   0.927373   0.682927                  
    13     0.620236   0.902969   0.658537                  
    14     0.564885   0.941991   0.682927                  
    15     0.578276   0.840792   0.658537                  
 

HBox(children=(IntProgress(value=0, max=6), HTML(value='')))

                                                  

HBox(children=(IntProgress(value=0, description='Epoch', max=127), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                 
    0      1.35128    1.529121   0.268293  
    1      1.282343   1.334745   0.268293                 
    2      1.189425   0.979046   0.560976                 
    3      1.134534   1.118577   0.560976                 
    4      1.087612   0.946067   0.658537                 
    5      1.018673   1.059555   0.536585                 
    6      0.953649   0.989595   0.560976                  
    7      0.910536   1.416091   0.536585                  
    8      0.868756   1.257787   0.512195                  
    9      0.838897   1.186432   0.463415                  
    10     0.798923   1.074739   0.463415                  
    11     0.737663   0.940311   0.634146                  
    12     0.686051   0.97045    0.560976                  
    13     0.650702   1.002582   0.560976                  
    14     0.605689   1.015134   0.585366                  
    15     0.602867   1.345366   0.536585                  
  

HBox(children=(IntProgress(value=0, max=6), HTML(value='')))

                                                  

HBox(children=(IntProgress(value=0, description='Epoch', max=127), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                 
    0      1.376916   1.219812   0.317073  
    1      1.279938   1.211913   0.512195                 
    2      1.178951   1.105218   0.487805                 
    3      1.148831   1.477042   0.317073                 
    4      1.095165   0.885359   0.560976                 
    5      1.007372   0.822445   0.658537                 
    6      0.964884   0.771457   0.634146                  
    7      0.926314   1.472352   0.439024                  
    8      0.880149   0.925962   0.585366                  
    9      0.87011    1.520108   0.512195                  
    10     0.839501   1.037363   0.585366                  
    11     0.779589   1.277415   0.560976                  
    12     0.714335   1.062787   0.585366                  
    13     0.665606   1.039107   0.585366                  
    14     0.61782    1.04952    0.560976                  
    15     0.593324   1.141425   0.634146                  
  

HBox(children=(IntProgress(value=0, max=6), HTML(value='')))

                                                  

HBox(children=(IntProgress(value=0, description='Epoch', max=127), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                 
    0      1.330154   1.189337   0.35      
    1      1.259653   0.90024    0.475                    
    2      1.160742   1.152867   0.5                      
    3      1.109053   0.856335   0.6                      
    4      1.045416   0.871251   0.55                     
    5      0.983177   0.99913    0.5                       
    6      0.931654   0.976025   0.525                     
    7      0.867177   1.097114   0.5                       
    8      0.842021   1.232904   0.55                      
    9      0.81394    1.11697    0.575                     
    10     0.760466   1.031366   0.65                      
    11     0.71099    1.024508   0.575                     
    12     0.693658   0.990557   0.575                     
    13     0.645877   0.995818   0.625                     
    14     0.597016   0.954173   0.625                     
    15     0.581032   0.928811   0.6                       
 

HBox(children=(IntProgress(value=0, max=6), HTML(value='')))

                                                   

HBox(children=(IntProgress(value=0, description='Epoch', max=127), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                 
    0      1.42226    1.059665   0.45      
    1      1.291303   1.067947   0.625                    
    2      1.184924   1.143858   0.475                    
    3      1.124593   1.138874   0.475                    
    4      1.066962   1.354706   0.525                    
    5      1.001871   1.16106    0.5                      
    6      0.937518   1.025459   0.5                       
    7      0.890791   0.911921   0.575                     
    8      0.862496   1.09017    0.6                       
    9      0.84628    0.934873   0.625                     
    10     0.786914   0.886677   0.675                     
    11     0.723104   0.846469   0.7                       
    12     0.667805   0.84297    0.725                     
    13     0.636577   0.860721   0.7                       
    14     0.599994   0.874938   0.7                       
    15     0.589089   0.810378   0.725                     
  

HBox(children=(IntProgress(value=0, max=6), HTML(value='')))

                                                  

HBox(children=(IntProgress(value=0, description='Epoch', max=127), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                 
    0      1.239882   1.030478   0.525     
    1      1.203588   1.134891   0.45                     
    2      1.13132    1.050904   0.625                    
    3      1.070853   1.49429    0.3                      
    4      1.013703   1.227464   0.425                    
    5      0.927811   0.907901   0.6                       
    6      0.859723   0.913956   0.625                     
    7      0.831176   1.093435   0.55                      
    8      0.802332   1.034066   0.525                     
    9      0.781756   0.953111   0.625                     
    10     0.721426   1.095825   0.65                      
    11     0.666483   1.208193   0.65                      
    12     0.644823   1.205048   0.55                      
    13     0.585405   1.121635   0.6                       
    14     0.536028   1.088688   0.575                     
    15     0.529231   1.221359   0.575                     
 

HBox(children=(IntProgress(value=0, max=6), HTML(value='')))

                                                   

HBox(children=(IntProgress(value=0, description='Epoch', max=127), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                 
    0      1.358997   1.176124   0.45      
    1      1.26569    1.411962   0.475                    
    2      1.211435   1.257789   0.55                     
    3      1.135868   1.487721   0.375                    
    4      1.08575    1.125984   0.4                      
    5      1.01996    0.942169   0.575                    
    6      0.956706   0.888723   0.675                     
    7      0.91005    1.283418   0.475                     
    8      0.883685   1.208238   0.525                     
    9      0.870152   1.240116   0.575                     
    10     0.820728   1.046472   0.625                     
    11     0.77081    0.950528   0.7                       
    12     0.718158   1.068197   0.675                     
    13     0.674082   1.028266   0.725                     
    14     0.634993   1.035295   0.7                       
    15     0.610477   1.339106   0.575                     
  

HBox(children=(IntProgress(value=0, max=6), HTML(value='')))

                                                   

HBox(children=(IntProgress(value=0, description='Epoch', max=127), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                 
    0      1.421583   1.318831   0.375     
    1      1.287296   1.116651   0.375                    
    2      1.181603   1.728942   0.325                    
    3      1.121976   1.218252   0.425                    
    4      1.046076   1.297382   0.325                    
    5      0.982551   1.252276   0.5                       
    6      0.910741   1.272672   0.5                       
    7      0.874247   1.355644   0.325                     
    8      0.834355   1.276115   0.5                       
    9      0.795609   1.225092   0.55                      
    10     0.762797   1.242181   0.4                       
    11     0.676562   1.504611   0.425                     
    12     0.628928   1.425775   0.475                     
    13     0.581008   1.352081   0.525                     
    14     0.560269   1.37573    0.55                      
    15     0.555911   1.299652   0.575                     
 

In [13]:
logging.info('stdev accuracy = '+ str(np.std(acc)))
logging.info('stdev balanced accuracy = '+ str(np.std(bac)))
logging.info('stdev MCC is '+ str(np.std(mcc)))
logging.info('stdev ROC AUC = '+ str(np.std(roc)))
logging.info('stdev ave. prec =' + str(np.std(avp)))

In [14]:
st = time.time()
### build nested cv loop
n_classes = 5

label_csv = f'{PATH}5cls_aliper.csv'
outer_valididx_base = '5cls_aliper_10fold'

allfolds = [x for x in list(range(10))]
acc = []
roc = []
mcc = []
bac = []
avp = []

logging.info('starting 10-fold CV on '+str(n_classes)+' class aliper data')
for tmptest in allfolds:  ### outer loop setting each single fold as holdout
    shutil.rmtree(f'{PATH}tmp', ignore_errors=True) # delete the tmp folder every cycle to ensure no weird behavior
    val_idxs = get_val_idx_fromfile(f'{PATH}'+outer_valididx_base+str(tmptest)+'.csv')
    data = get_data(sz, 25, val_idxs, label_csv)
    learn = ConvLearner.pretrained(arch, data, precompute=False, ps= 0.4)
    learn.unfreeze()
    val_loss, val_acc = learn.fit([1e-4,1e-3,1e-2], 7, cycle_len=1, cycle_mult=2)
    ## make predictions with test-time augmentation
    log_preds, y = learn.TTA()
    probs = np.mean(np.exp(log_preds),0)
    preds = np.argmax(probs, axis=1)
    #record stuff
    acc.append(accuracy_score(y, preds))
    bac.append(balanced_accuracy_score(y, preds)) # balanced accuracy
    mcc.append(matthews_corrcoef(y, preds)) #MCC
    if(tmptest <12):
        roc.append(multiclass_roc_auc_score(y, probs, "weighted"))
        Y = label_binarize(y, classes = [x for x in range(0,n_classes)])
        avp.append(average_precision_score(Y, probs, average="weighted"))
    #total_models_trained +=1
    logging.info('accuracy on fold #' +str(tmptest)+' is '+ str(acc[tmptest]))
    logging.info('balanced accuracy on fold #' + str(tmptest)+ ' is '+ str(bac[tmptest]))
    logging.info('MCC on fold #' + str(tmptest)+ ' is '+ str(roc[tmptest]))
    logging.info('ROC AUC on fold #' + str(tmptest)+ ' is '+ str(roc[tmptest]))
    logging.info('ave. prec on fold #' + str(tmptest)+ ' is '+ str(avp[tmptest]))
    logging.info('####################################')
    
logging.info('mean accuracy = '+ str(np.mean(acc)))
logging.info('mean balanced accuracy = '+ str(np.mean(bac)))
logging.info('mean MCC is '+ str(np.mean(mcc)))
logging.info('mean ROC AUC = '+ str(np.mean(roc)))
logging.info('mean ave. prec =' + str(np.mean(avp)))
et = time.time()
print('total time ='+str(et-st))
logging.info('stdev accuracy = '+ str(np.std(acc)))
logging.info('stdev balanced accuracy = '+ str(np.std(bac)))
logging.info('stdev MCC is '+ str(np.std(mcc)))
logging.info('stdev ROC AUC = '+ str(np.std(roc)))
logging.info('stdev ave. prec =' + str(np.std(avp)))

HBox(children=(IntProgress(value=0, max=6), HTML(value='')))

                                                  

HBox(children=(IntProgress(value=0, description='Epoch', max=127), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                 
    0      1.889645   1.637799   0.278689  
    1      1.76813    1.405448   0.409836                 
    2      1.609396   1.315846   0.459016                 
    3      1.509758   1.42346    0.491803                 
    4      1.421716   1.284302   0.491803                 
    5      1.340893   1.237087   0.557377                 
    6      1.251067   1.208587   0.57377                  
    7      1.215145   1.382799   0.508197                 
    8      1.188842   1.360063   0.491803                 
    9      1.127268   1.272044   0.557377                 
    10     1.064204   1.266763   0.57377                  
    11     0.998076   1.218138   0.639344                  
    12     0.920691   1.148157   0.590164                  
    13     0.867362   1.114258   0.590164                  
    14     0.803558   1.109565   0.590164                  
    15     0.790104   1.315387   0.57377                   
    16 

HBox(children=(IntProgress(value=0, max=6), HTML(value='')))

                                                  

HBox(children=(IntProgress(value=0, description='Epoch', max=127), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                 
    0      1.86224    1.663197   0.288136  
    1      1.821454   1.963716   0.355932                 
    2      1.664876   1.61977    0.440678                 
    3      1.590297   1.593113   0.389831                 
    4      1.500241   1.421755   0.474576                 
    5      1.39116    1.402345   0.355932                 
    6      1.290175   1.366235   0.389831                 
    7      1.256614   1.540705   0.40678                  
    8      1.238264   1.547811   0.40678                  
    9      1.205781   1.517795   0.440678                 
    10     1.145049   1.422203   0.542373                 
    11     1.05601    1.300098   0.525424                 
    12     0.98113    1.336449   0.525424                  
    13     0.917456   1.323893   0.542373                  
    14     0.858866   1.300095   0.542373                  
    15     0.860163   1.569393   0.542373                  
    16  

HBox(children=(IntProgress(value=0, max=6), HTML(value='')))

                                                  

HBox(children=(IntProgress(value=0, description='Epoch', max=127), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                 
    0      1.832946   1.436189   0.310345  
    1      1.722388   1.347572   0.5                      
    2      1.59294    1.294645   0.448276                 
    3      1.532321   1.435608   0.465517                 
    4      1.466869   1.845451   0.37931                  
    5      1.371083   1.322512   0.517241                 
    6      1.273347   1.329843   0.551724                 
    7      1.2213     1.2436     0.465517                 
    8      1.191385   1.107704   0.603448                 
    9      1.158168   1.264614   0.568966                 
    10     1.094347   1.002343   0.603448                 
    11     1.008053   1.087994   0.534483                 
    12     0.939765   1.026714   0.603448                  
    13     0.873602   1.006333   0.586207                  
    14     0.829485   1.008974   0.603448                  
    15     0.818616   1.243298   0.586207                  
    16  

HBox(children=(IntProgress(value=0, max=6), HTML(value='')))

                                                  

HBox(children=(IntProgress(value=0, description='Epoch', max=127), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                 
    0      1.886982   1.459103   0.37931   
    1      1.730763   1.79106    0.396552                 
    2      1.591658   1.575943   0.413793                 
    3      1.529785   1.779967   0.344828                 
    4      1.456125   1.67585    0.413793                 
    5      1.368394   1.551543   0.482759                 
    6      1.289537   1.507804   0.465517                 
    7      1.255648   1.668459   0.482759                 
    8      1.231481   1.634015   0.448276                 
    9      1.17415    1.40565    0.5                      
    10     1.126247   1.379995   0.5                      
    11     1.048039   1.443368   0.534483                 
    12     0.942483   1.374192   0.5                       
    13     0.874931   1.407154   0.534483                  
    14     0.82572    1.335848   0.534483                  
    15     0.791775   1.431308   0.482759                  
    16  

HBox(children=(IntProgress(value=0, max=6), HTML(value='')))

                                                  

HBox(children=(IntProgress(value=0, description='Epoch', max=127), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                 
    0      1.856968   1.706166   0.258621  
    1      1.766647   1.734242   0.293103                 
    2      1.666246   1.802086   0.327586                 
    3      1.59432    1.619456   0.275862                 
    4      1.490131   1.375058   0.517241                 
    5      1.376918   1.310986   0.396552                 
    6      1.288555   1.291183   0.413793                 
    7      1.258652   1.434204   0.5                      
    8      1.234799   1.38731    0.448276                 
    9      1.156071   1.429748   0.534483                 
    10     1.106765   1.228397   0.551724                 
    11     1.022862   1.222299   0.603448                 
    12     0.933354   1.323115   0.534483                  
    13     0.871791   1.28578    0.586207                  
    14     0.834364   1.272107   0.551724                  
    15     0.826736   1.356907   0.431034                  
    16  

HBox(children=(IntProgress(value=0, max=6), HTML(value='')))

                                                   

HBox(children=(IntProgress(value=0, description='Epoch', max=127), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                 
    0      1.85146    1.459736   0.473684  
    1      1.779969   1.893673   0.280702                 
    2      1.625536   1.322558   0.491228                 
    3      1.521959   1.490816   0.438596                 
    4      1.446566   1.237725   0.45614                  
    5      1.351388   1.246888   0.508772                 
    6      1.242457   1.255843   0.526316                 
    7      1.199439   1.251838   0.54386                  
    8      1.176239   1.519659   0.45614                  
    9      1.143298   1.079318   0.578947                 
    10     1.089635   1.416586   0.526316                 
    11     0.980762   1.323693   0.54386                   
    12     0.894956   1.275846   0.596491                  
    13     0.834575   1.195875   0.578947                  
    14     0.790185   1.221633   0.526316                  
    15     0.778406   1.403723   0.508772                  
    16 

HBox(children=(IntProgress(value=0, max=6), HTML(value='')))

                                                  

HBox(children=(IntProgress(value=0, description='Epoch', max=127), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                 
    0      1.840385   1.686902   0.333333  
    1      1.720663   1.627987   0.315789                 
    2      1.597795   1.411192   0.45614                  
    3      1.549256   1.596984   0.368421                 
    4      1.454822   1.464786   0.421053                 
    5      1.351838   1.32896    0.438596                 
    6      1.260128   1.219371   0.491228                 
    7      1.228867   1.304281   0.54386                  
    8      1.177685   1.239626   0.508772                 
    9      1.172891   1.096175   0.526316                 
    10     1.121852   1.294653   0.491228                 
    11     1.03207    1.089663   0.631579                 
    12     0.959958   1.016267   0.649123                  
    13     0.896323   0.996335   0.631579                  
    14     0.828995   1.006946   0.631579                  
    15     0.842208   1.596009   0.421053                  
    16  

HBox(children=(IntProgress(value=0, max=6), HTML(value='')))

                                                   

HBox(children=(IntProgress(value=0, description='Epoch', max=127), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                 
    0      1.887508   1.784997   0.263158  
    1      1.780017   1.478469   0.368421                 
    2      1.632637   1.441889   0.333333                 
    3      1.525198   1.650553   0.368421                 
    4      1.399484   1.452921   0.45614                  
    5      1.284086   1.382843   0.45614                  
    6      1.186512   1.376789   0.438596                 
    7      1.155203   1.497422   0.438597                 
    8      1.11151    1.509397   0.438596                 
    9      1.093637   1.316698   0.54386                  
    10     1.012071   1.326837   0.491228                 
    11     0.949062   1.312884   0.45614                   
    12     0.87529    1.316353   0.473684                  
    13     0.830964   1.378023   0.473684                  
    14     0.776814   1.395588   0.45614                   
    15     0.771694   1.649031   0.508772                  
    16 

HBox(children=(IntProgress(value=0, max=6), HTML(value='')))

                                                   

HBox(children=(IntProgress(value=0, description='Epoch', max=127), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                 
    0      1.870314   1.680985   0.315789  
    1      1.771994   1.888845   0.333333                 
    2      1.602143   1.617894   0.45614                  
    3      1.494313   1.674847   0.368421                 
    4      1.445841   1.635487   0.491228                 
    5      1.363792   1.296807   0.561404                 
    6      1.276425   1.235157   0.578947                 
    7      1.236872   1.201212   0.596491                 
    8      1.21025    1.094104   0.578947                 
    9      1.159604   1.034381   0.561404                 
    10     1.094035   1.055413   0.578947                 
    11     0.999053   1.310094   0.526316                  
    12     0.92314    1.148141   0.578947                  
    13     0.855987   1.128658   0.561404                  
    14     0.798165   1.090423   0.561404                  
    15     0.819711   1.391928   0.438596                  
    16 

HBox(children=(IntProgress(value=0, max=6), HTML(value='')))

                                                  

HBox(children=(IntProgress(value=0, description='Epoch', max=127), HTML(value='')))

epoch      trn_loss   val_loss   accuracy                 
    0      1.887723   1.739378   0.298246  
    1      1.721952   1.666502   0.350877                 
    2      1.614175   1.881739   0.368421                 
    3      1.532763   1.765835   0.385965                 
    4      1.430803   1.593021   0.368421                 
    5      1.328106   1.551855   0.403509                 
    6      1.234365   1.536317   0.45614                  
    7      1.169341   1.48802    0.438596                 
    8      1.125476   1.76174    0.45614                  
    9      1.091227   1.548561   0.508772                 
    10     1.021756   1.42324    0.473684                 
    11     0.927802   1.400114   0.45614                   
    12     0.854283   1.416968   0.508772                  
    13     0.773117   1.428227   0.473684                  
    14     0.71949    1.42803    0.473684                  
    15     0.699908   1.716186   0.421053                  
    16 