In [1]:
import pandas as pd
import numpy as np
import glob
import os
from pathlib import Path
from PIL import Image, ExifTags

from sklearn.model_selection import train_test_split
from sklearn import preprocessing
from sklearn.utils import compute_class_weight
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.model_selection import KFold
from sklearn.metrics import mean_squared_error, accuracy_score
import scipy

import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping
from tensorflow.keras.preprocessing.image import img_to_array, load_img, ImageDataGenerator
from tensorflow.keras.layers import Activation, Dense, GlobalAveragePooling2D, GlobalMaxPooling2D
from tensorflow.keras import layers
from tensorflow.keras.models import Model
from tensorflow.keras import optimizers
import tensorflow.keras.backend as K
from tensorflow.keras.callbacks import LambdaCallback

#salmon-scales
#from train_util import read_images, load_xy, get_checkpoint_tensorboard, create_model_grayscale, get_fresh_weights, base_output, dense1_linear_output, train_validate_test_split


In [15]:
#!pip install plotly
#!pip install torch
#!pip install loguru
#!pip install timm #PyTorch Image Models
#!pip install albumentations #augmentation
#!pip install colorama #color terminal

Collecting colorama
  Downloading colorama-0.4.4-py2.py3-none-any.whl (16 kB)
Installing collected packages: colorama
Successfully installed colorama-0.4.4


In [16]:
import os
import gc
import copy
import time
import random

import numpy as np
import pandas as pd
import plotly.graph_objects as go

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader
from torch.cuda import amp

from tqdm import tqdm
from collections import defaultdict

from loguru import logger

from sklearn.metrics import roc_auc_score
from sklearn.model_selection import StratifiedKFold

import timm
import albumentations as A
from albumentations.pytorch import ToTensorV2

from colorama import Fore
b_ = Fore.BLUE

### Train Configuration

In [18]:
class CONFIG:
    seed = 42
    model_name = 'tf_efficientnetv2_s_in21k' 
    train_batch_size = 32
    valid_batch_size = 64
    img_size = 512
    epochs = 5
    learning_rate = 1e-4
    min_lr = 1e-6
    weight_decay = 1e-6
    T_max = 10
    scheduler = 'CosineAnnealingLR'
    n_accumulate = 1
    n_fold = 5
    target_size = 1
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
def set_seed(seed = 42):
    '''Sets the seed of the entire notebook so results are the same every time we run.
    This is for REPRODUCIBILITY.'''
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ['PYTHONHASHSEED'] = str(seed)
    
set_seed(CONFIG.seed)    

### Cod Model

In [29]:
class codModel(nn.Module):
    def __init__(self, model_name, pretrained=True):
        super(codModel, self).__init__()
        ## not 3 channels: https://fastai.github.io/timmdocs/models
        self.model = timm.create_model(model_name, pretrained=pretrained, in_chans=3)
        self.n_features = self.model.classifier.in_features
        self.model.classifier = nn.Linear(self.n_features, CONFIG.target_size)

    def forward(self, x):
        output = self.model(x)
        return output
    
model = codModel(CONFIG.model_name)
model.to(CONFIG.device);

Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s_21k-6337ad01.pth" to /root/.cache/torch/hub/checkpoints/tf_efficientnetv2_s_21k-6337ad01.pth


RTX A6000 with CUDA capability sm_86 is not compatible with the current PyTorch installation.
The current PyTorch install supports CUDA capabilities sm_37 sm_50 sm_60 sm_70.
If you want to use the RTX A6000 GPU with PyTorch, please check the instructions at https://pytorch.org/get-started/locally/




In [33]:
def read_jpg_cods2(B4_input_shape = (380, 380, 3), max_dataset_size = 5180, whichExposure='min'):
    #    '''
    #    reads one .jpg file in each folder in structure of folders
    #    returns tensor with images, and 1-1 correspondence with age
    #    '''

    #max_dataset_size = 5156
    #B4_input_shape = (380, 380, 3)
    df_cod = pd.DataFrame(columns=['age', 'path', 'ExposureTime'])
    image_tensor1 = np.empty(shape=(max_dataset_size,)+B4_input_shape)
    image_tensor2 = np.empty(shape=(max_dataset_size,)+B4_input_shape)
    image_tensor3 = np.empty(shape=(max_dataset_size,)+B4_input_shape)

    base_dir = '/gpfs/gpfs0/deep/data/Savannah_Professional_Practice2021_06_10_21/CodOtholiths-MachineLearning/Savannah_Professional_Practice'
    df_cod = pd.DataFrame(columns=['age', 'path', 'ExposureTime'])
    base_dirs_posix = Path(base_dir)

    error_count=0
    add_count = 0
    for some_year_dir in base_dirs_posix.iterdir():
        count = 0
        ## terminate quickly for testing
        if count > 0: 
            break

        if not os.path.isdir( some_year_dir ) or "Extra" in str(some_year_dir):
            continue

        #dir structure: /year/station_number/cod_img_by_age/6 jpeg images of one fish
        stat_nos = [name for name in os.listdir( some_year_dir ) if os.path.isdir(os.path.join(some_year_dir , name))]
        for i in range(0, len(stat_nos)):
            cod_path = os.path.join( some_year_dir, stat_nos[i] )
            yr_station_codage_path = [os.path.join(cod_path , n) for n in os.listdir( cod_path ) 
                            if os.path.isdir(os.path.join(cod_path , n))]
            cod_age = [n for n in os.listdir( cod_path ) 
                            if os.path.isdir(os.path.join(cod_path , n))]

            assert len(yr_station_codage_path) == len(cod_age)
            for j in range(0, len(yr_station_codage_path)):
                #print(onlyfiles)
                onlyfiles = [f for f in os.listdir( yr_station_codage_path[j] ) 
                             if os.path.isfile(os.path.join(yr_station_codage_path[j] , f))]

                #2013/70028/Nr01_age05/Thumbs.db
                #2016/70008/Nr01_age07/Thumbs.db
                if len(onlyfiles) != 6: 
                    #print(str(len(onlyfiles)) + '\t' + str( yr_station_codage_path[j] ) + "\t" +'\t'.join(map(str,onlyfiles)))
                    error_count +=1
                else: 
                    full_path = [os.path.join(yr_station_codage_path[j] , f) 
                             for f in os.listdir( yr_station_codage_path[j] ) 
                         if os.path.isfile(os.path.join(yr_station_codage_path[j] , f))]

                    begin_age = cod_age[j].lower().find('age')
                    #print(cod_age[j])
                    age = cod_age[j][begin_age+3:begin_age+5]
                    try:
                        age = int(age)
                    except ValueError:
                        #print(yr_station_codage_path[j])
                        #print(cod_age[j])
                        #print(age)
                        #print(begin_age)
                        age = 0
                        continue

                    #print(age)

                    full_path.sort()
                    exposures_set = set()
                    exposures_list = []
                    for k in range(0, len(full_path)): #len(full_path) == 6
                        img = Image.open(full_path[k])
                        exif = {ExifTags.TAGS[k]: v for k, v in img._getexif().items() if k in ExifTags.TAGS}
                        #print(exif['ExposureTime'])
                        exposures_set.add( exif['ExposureTime'] )
                        exposures_list.append( exif['ExposureTime'] )


                    #if len(exposures_set) != 3:
                        #print("\t"+str (yr_station_codage_path[j] ) + '\t' + str(exposures_list) ) 
                    #    continue
                    #else:
                    if len(exposures_list) == 6 and len(exposures_set) == 3:

                        expo_args = np.argsort(exposures_list).tolist()
                        #print( "exposures_list"+str(exposures_list) )
                        #print(" argsort: "+str(expo_args) )

                        numpy_images = [0,0,0]
                        file_paths = [0,0,0]
                        imgs_added = 0
                        
                        #use if loading to memory
                        """
                        for k in [0,2,4]:
                            img = Image.open( full_path[ expo_args[k] ] ) 
                            pil_img = load_img(full_path[ expo_args[k] ], target_size=B4_input_shape, grayscale=False)
                            array_img = img_to_array(pil_img, data_format='channels_last')

                            numpy_images[imgs_added] = array_img
                            file_paths[imgs_added] = full_path[ expo_args[k] ]
                            imgs_added += 1
                        """
                        

                        if expo_args != [1, 4, 0, 3, 2, 5]:
                            print( "exposures_list"+str(exposures_list) )
                            print(" argsort: "+str(expo_args) )
                            #print(file_paths)

                        if whichExposure == 'min':
                            #use if loading to memory
                            ##pil_img = load_img(full_path[ expo_args[0] ], target_size=B4_input_shape, grayscale=False)
                            ##array_img = img_to_array(pil_img, data_format='channels_last')
                            ##image_tensor1[add_count] = array_img
                            add_count += 1
                            
                            df_cod = df_cod.append({'age':age, 'path':full_path[expo_args[0]], 'light': 1, 'ExposureTime':exposures_list[expo_args[0]]}, ignore_index=True)
                        if whichExposure == 'middle':
                            #use if loading to memory
                            ##pil_img = load_img(full_path[ expo_args[2] ], target_size=B4_input_shape, grayscale=False)
                            ##array_img = img_to_array(pil_img, data_format='channels_last')
                            ##image_tensor1[add_count] = array_img
                            add_count += 1
                            
                            df_cod = df_cod.append({'age':age, 'path':full_path[expo_args[2]], 'light': 2, 'ExposureTime':exposures_list[expo_args[0]]}, ignore_index=True)
                        if whichExposure == 'max':
                            #use if loading to memory
                            ##pil_img = load_img(full_path[ expo_args[4] ], target_size=B4_input_shape, grayscale=False)
                            ##array_img = img_to_array(pil_img, data_format='channels_last')
                            ##image_tensor1[add_count] = array_img
                            add_count += 1
                            
                            df_cod = df_cod.append({'age':age, 'path':full_path[expo_args[4]], 'light': 3, 'ExposureTime':exposures_list[expo_args[0]]}, ignore_index=True)

                                        

    print("error_count:"+str(error_count))
    print("add_count:"+str(add_count))

    '''
    if whichExposure == 'min':
        return image_tensor1, df_cod.age        
    if whichExposure == 'middle':
        return image_tensor2, df_cod.age
    if whichExposure == 'max':
        return image_tensor3, df_cod.age
    '''    
    return df_cod
    #return None, None



In [95]:
pil_img = load_img(df.path[0], target_size=B4_input_shape, grayscale=False)
array_img = img_to_array(pil_img, data_format='channels_last')
image = array_img

#image = np.load(df.path[0],allow_pickle=True).astype(np.float32)
print(image.shape)
image = (image - image.mean(axis=(1,2), keepdims=True)) / image.std(axis=(1,2), keepdims=True)
print(image.shape)
#image = np.vstack(image).transpose((1, 0))
print(image.shape)
label = torch.tensor(df.age[0]).float()
print(image.shape) #(3, 144400)
print(label)

(380, 380, 3)
(380, 380, 3)
(380, 380, 3)
(380, 380, 3)
tensor(2.)


In [96]:
class codDataset(Dataset):
    def __init__(self, df, transforms=None):
        self.df = df
        self.file_names = df['path'].values
        self.labels = df['age'].values
        self.transforms = transforms
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        print("loading filename:"+str(self.file_names[index]))
        #image = np.load(self.file_names[index]).astype(np.float32)
        B4_input_shape = (380, 380, 3)
        pil_img = load_img(df.path[0], target_size=B4_input_shape, grayscale=False)
        image = img_to_array(pil_img, data_format='channels_last')
        
        image = (image - image.mean(axis=(1,2), keepdims=True)) / image.std(axis=(1,2), keepdims=True)
        #image = np.vstack(image).transpose((1, 0))
        label = torch.tensor(self.labels[index]).float()
        
        if self.transforms:
            image = self.transforms(image=image)["image"]
            
        return image, label

In [97]:
B4_input_shape = (380, 380, 3)
B5_input_shape = (456, 456, 3)

df = read_jpg_cods2(B4_input_shape, max_dataset_size = 9180) #5316 #5110

print("len age:"+str( len(df.age) ) )


exposures_list[0.4, 0.2, 0.1, 0.4, 0.2, 0.1]
 argsort: [2, 5, 1, 4, 0, 3]
exposures_list[0.4, 0.2, 0.1, 0.4, 0.2, 0.1]
 argsort: [2, 5, 1, 4, 0, 3]
exposures_list[0.4, 0.2, 0.1, 0.4, 0.2, 0.1]
 argsort: [2, 5, 1, 4, 0, 3]
exposures_list[0.4, 0.2, 0.1, 0.4, 0.2, 0.1]
 argsort: [2, 5, 1, 4, 0, 3]
exposures_list[0.4, 0.2, 0.1, 0.4, 0.2, 0.1]
 argsort: [2, 5, 1, 4, 0, 3]
exposures_list[0.1, 0.4, 0.2, 0.1, 0.4, 0.2]
 argsort: [0, 3, 2, 5, 1, 4]
exposures_list[0.2, 0.1, 0.4, 0.1, 0.4, 0.2]
 argsort: [1, 3, 0, 5, 2, 4]
exposures_list[0.1, 0.4, 0.2, 0.1, 0.4, 0.2]
 argsort: [0, 3, 2, 5, 1, 4]
exposures_list[0.1, 0.4, 0.2, 0.1, 0.4, 0.2]
 argsort: [0, 3, 2, 5, 1, 4]
exposures_list[0.1, 0.4, 0.2, 0.1, 0.4, 0.2]
 argsort: [0, 3, 2, 5, 1, 4]
exposures_list[0.1, 0.4, 0.2, 0.1, 0.4, 0.2]
 argsort: [0, 3, 2, 5, 1, 4]
exposures_list[0.1, 0.4, 0.2, 0.1, 0.4, 0.2]
 argsort: [0, 3, 2, 5, 1, 4]
exposures_list[0.1, 0.4, 0.2, 0.1, 0.4, 0.2]
 argsort: [0, 3, 2, 5, 1, 4]
exposures_list[0.1, 0.4, 0.2, 0.1, 0.4

exposures_list[0.16666666666666666, 0.6, 0.3, 0.16666666666666666, 0.6, 0.3]
 argsort: [0, 3, 2, 5, 1, 4]
exposures_list[0.16666666666666666, 0.6, 0.3, 0.16666666666666666, 0.6, 0.3]
 argsort: [0, 3, 2, 5, 1, 4]
exposures_list[0.16666666666666666, 0.6, 0.3, 0.16666666666666666, 0.6, 0.3]
 argsort: [0, 3, 2, 5, 1, 4]
exposures_list[0.16666666666666666, 0.6, 0.3, 0.16666666666666666, 0.6, 0.3]
 argsort: [0, 3, 2, 5, 1, 4]
exposures_list[0.16666666666666666, 0.6, 0.3, 0.16666666666666666, 0.6, 0.3]
 argsort: [0, 3, 2, 5, 1, 4]
exposures_list[0.06666666666666667, 0.03333333333333333, 0.016666666666666666, 0.06666666666666667, 0.03333333333333333, 0.016666666666666666]
 argsort: [2, 5, 1, 4, 0, 3]
exposures_list[0.125, 0.06666666666666667, 0.03333333333333333, 0.125, 0.06666666666666667, 0.03333333333333333]
 argsort: [2, 5, 1, 4, 0, 3]
exposures_list[0.125, 0.06666666666666667, 0.03333333333333333, 0.125, 0.06666666666666667, 0.03333333333333333]
 argsort: [2, 5, 1, 4, 0, 3]
error_count:20

### Augmentation

In [35]:
data_transforms = {
    "train": A.Compose([
        A.Resize(CONFIG.img_size, CONFIG.img_size),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.ShiftScaleRotate(shift_limit=0.1, 
                           scale_limit=0.15, 
                           rotate_limit=60, 
                           p=0.5),
        A.CoarseDropout(p=0.5),
        A.Cutout(p=0.5),
        ToTensorV2()], p=1.),
    
    "valid": A.Compose([
        A.Resize(CONFIG.img_size, CONFIG.img_size),
        ToTensorV2()], p=1.)
}


This class has been deprecated. Please use CoarseDropout



In [37]:
df.age


0       2
1       7
2       5
3       1
4       8
       ..
5105    4
5106    8
5107    2
5108    4
5109    5
Name: age, Length: 5110, dtype: object

### Kfolds

In [60]:
#df = pd.read_csv(f"{ROOT_DIR}/train_labels.csv")
skf = StratifiedKFold(n_splits=CONFIG.n_fold, shuffle=True, random_state=CONFIG.seed)
y_train = df.age.values

for fold, ( train_idx, val_idx) in enumerate( skf.split( X=df, y=df.age.values.tolist() ) ):
    df.loc[val_idx , "kfold"] = int(fold)
    #df.loc[train_idx , "kfold"] = int(fold)
    
df['kfold'] = df['kfold'].astype(int)
print(df.kfold)

0       0
1       1
2       1
3       2
4       3
       ..
5105    1
5106    1
5107    1
5108    4
5109    0
Name: kfold, Length: 5110, dtype: int64



The least populated class in y has only 1 members, which is less than n_splits=5.



### Prepare data

In [66]:
def prepare_data(fold):
    df_train = df[df.kfold != fold].reset_index(drop=True)
    df_valid = df[df.kfold == fold].reset_index(drop=True)
    
    train_dataset = codDataset(df_train, transforms=data_transforms['train'])
    valid_dataset = codDataset(df_valid, transforms=data_transforms['valid'])

    train_loader = DataLoader(train_dataset, batch_size=CONFIG.train_batch_size, 
                              num_workers=4, shuffle=True, pin_memory=True)
    valid_loader = DataLoader(valid_dataset, batch_size=CONFIG.valid_batch_size, 
                              num_workers=4, shuffle=False, pin_memory=True)
    
    return train_loader, valid_loader

### Prepare dataloader

In [68]:
train_loader, valid_loader = prepare_data(fold=0)
print(train_loader)

<torch.utils.data.dataloader.DataLoader object at 0x7fe75c255748>


### Loss function

In [70]:
#def binary_accuracy_for_regression(y_true, y_pred):
#    return K.mean(K.equal(y_true, K.round(y_pred)), axis=-1)

loss = nn.MSELoss()
loss_fn = torch.nn.MSELoss(reduction='sum')
learning_rate = 1e-3
optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate)

learning_rate = 1e-3
optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate)

MSELoss()


### Training function

In [75]:
def train_one_epoch(model, optimizer, scheduler, dataloader, device, epoch):
    model.train()
    scaler = amp.GradScaler()
    
    dataset_size = 0
    running_loss = 0.0
    
    bar = tqdm(enumerate(dataloader), total=len(dataloader))
    for step, (images, labels) in bar:         
        images = images.to(device, dtype=torch.float)
        labels = labels.to(device, dtype=torch.float)
        
        inputs, targets_a, targets_b, lam = mixup_data(images, labels.view(-1, 1))
        
        inputs = inputs.to(device, dtype=torch.float)
        targets_a = targets_a.to(device, dtype=torch.float)
        targets_b = targets_b.to(device, dtype=torch.float)
        
        batch_size = images.size(0)
        
        with amp.autocast(enabled=True):
            outputs = model(inputs)
            loss = mixup_criterion(criterion, outputs, targets_a, targets_b, lam)
            loss = loss / CONFIG.n_accumulate
            
        scaler.scale(loss).backward()
        
        if (step + 1) % CONFIG.n_accumulate == 0:
            scaler.step(optimizer)
            scaler.update()
            
            # zero the parameter gradients
            optimizer.zero_grad()
                
        running_loss += (loss.item() * batch_size)
        dataset_size += batch_size
        
        epoch_loss = running_loss/dataset_size
        
        bar.set_postfix(Epoch=epoch, Train_Loss=epoch_loss,
                        LR=optimizer.param_groups[0]['lr'])
    gc.collect()
    
    return epoch_loss

### Validation function

In [76]:
@torch.no_grad()
def valid_one_epoch(model, optimizer, scheduler, dataloader, device, epoch):
    model.eval()
    
    dataset_size = 0
    running_loss = 0.0
    
    TARGETS = []
    PREDS = []
    
    bar = tqdm(enumerate(dataloader), total=len(dataloader))
    for step, (images, labels) in bar:        
        images = images.to(device)
        labels = labels.to(device)
        
        batch_size = images.size(0)
        
        outputs = model(images)
        loss = criterion(outputs.view(-1), labels)
        
        running_loss += (loss.item() * batch_size)
        dataset_size += batch_size
        
        epoch_loss = running_loss/dataset_size
        
        PREDS.append(outputs.sigmoid().cpu().detach().numpy())
        TARGETS.append(labels.view(-1).cpu().detach().numpy())
        
        bar.set_postfix(Epoch=epoch, Valid_Loss=epoch_loss,
                        LR=optimizer.param_groups[0]['lr'])   
    
    TARGETS = np.concatenate(TARGETS)
    PREDS = np.concatenate(PREDS)
    val_auc = roc_auc_score(TARGETS, PREDS)
    gc.collect()
    
    return epoch_loss, val_auc

### Run

In [72]:
@logger.catch
def run(model, optimizer, scheduler, device, num_epochs):    
    start = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_epoch_auc = 0
    history = defaultdict(list)
    
    for epoch in range(1, num_epochs + 1): 
        gc.collect()
        train_epoch_loss = train_one_epoch(model, optimizer, scheduler, 
                                           dataloader=train_loader, 
                                           device=CONFIG.device, epoch=epoch)
        
        valid_epoch_loss, valid_epoch_auc = valid_one_epoch(model, optimizer, scheduler,
                                                            dataloader=valid_loader, 
                                                            device=CONFIG.device, epoch=epoch)
    
        history['Train Loss'].append(train_epoch_loss)
        history['Valid Loss'].append(valid_epoch_loss)
        history['Valid AUC'].append(valid_epoch_auc)
        
        print(f'Valid AUC: {valid_epoch_auc}')
        
        if scheduler is not None:
            scheduler.step()
        
        # deep copy the model
        if valid_epoch_auc >= best_epoch_auc:
            print(f"{b_}Validation AUC Improved ({best_epoch_auc} ---> {valid_epoch_auc})")
            best_epoch_auc = valid_epoch_auc
            best_model_wts = copy.deepcopy(model.state_dict())
            PATH = "AUC{:.4f}_epoch{:.0f}.bin".format(best_epoch_auc, epoch)
            torch.save(model.state_dict(), PATH)
            print("Model Saved")
            
        print()
    
    end = time.time()
    time_elapsed = end - start
    print('Training complete in {:.0f}h {:.0f}m {:.0f}s'.format(
        time_elapsed // 3600, (time_elapsed % 3600) // 60, (time_elapsed % 3600) % 60))
    print("Best AUC: {:.4f}".format(best_epoch_auc))
    
    # load best model weights
    model.load_state_dict(best_model_wts)
    
    return model, history

### Train fold 0

In [79]:
#torch.cuda.is_available() 

True

In [98]:
def fetch_scheduler(optimizer):
    if CONFIG.scheduler == 'CosineAnnealingLR':
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=CONFIG.T_max, eta_min=CONFIG.min_lr)
    elif CONFIG.scheduler == 'CosineAnnealingWarmRestarts':
        scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=CONFIG.T_0, T_mult=1, eta_min=CONFIG.min_lr)
    elif CONFIG.scheduler == None:
        return None
        
    return scheduler

optimizer = optim.Adam(model.parameters(), lr=CONFIG.learning_rate, weight_decay=CONFIG.weight_decay)
scheduler = fetch_scheduler(optimizer)

print(CONFIG.device)
print(CONFIG.epochs)
model, history = run(model, optimizer, scheduler=scheduler, device=CONFIG.device, num_epochs=CONFIG.epochs)

cuda:0
5


  0%|          | 0/128 [00:00<?, ?it/s]
2021-11-11 14:35:53.201 | ERROR    | __main__:<module>:16 - An error has been caught in function '<module>', process 'MainProcess' (2576), thread 'MainThread' (140636972394304):
Traceback (most recent call last):

  File "/usr/lib/python3.6/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
                └ ModuleSpec(name='ipykernel_launcher', loader=<_frozen_importlib_external.SourceFileLoader object at 0x7fe897647d30>, origin='...
  File "/usr/lib/python3.6/runpy.py", line 85, in _run_code
    exec(code, run_globals)
         │     └ {'__name__': '__main__', '__doc__': 'Entry point for launching an IPython kernel.\n\nThis is separate from the ipykernel pack...
         └ <code object <module> at 0x7fe8976a04b0, file "/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py", line 5>

  File "/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py", line 16, in <module>
    app.launch_new_instance()
    │   └ <bound

TypeError: 'NoneType' object is not iterable