# Incremental learning on image classification

## Libraries and packages


In [0]:
!pip3 install 'torch==1.4.0'
!pip3 install 'torchvision==0.5.0'
!pip3 install 'Pillow-SIMD'
!pip3 install 'tqdm'

In [0]:
import os
import urllib
import logging

import numpy as np

import torch
import torch.nn as nn
import torch.nn.init as init
import torch.optim as optim
from torch.utils.data import Subset, DataLoader, ConcatDataset
from torch.backends import cudnn

import torchvision
from torchvision import transforms
from torchvision.models import resnet34

from PIL import Image
from tqdm import tqdm

from copy import deepcopy

In [3]:
# GitHub credentials for cloning private repository
username = ''
password = ''

# Download packages from repository
password = urllib.parse.quote(password)
!git clone https://$username:$password@github.com/manuelemacchia/incremental-learning-image-classification.git
password = ''

!mv -v incremental-learning-image-classification/* .
!rm -rf incremental-learning-image-classification README.md

Cloning into 'incremental-learning-image-classification'...
remote: Enumerating objects: 138, done.[K
remote: Counting objects: 100% (138/138), done.[K
remote: Compressing objects: 100% (118/118), done.[K
remote: Total 369 (delta 66), reused 49 (delta 17), pack-reused 231[K
Receiving objects: 100% (369/369), 2.29 MiB | 12.08 MiB/s, done.
Resolving deltas: 100% (179/179), done.
renamed 'incremental-learning-image-classification/data' -> './data'
renamed 'incremental-learning-image-classification/model' -> './model'
renamed 'incremental-learning-image-classification/notebook.ipynb' -> './notebook.ipynb'
renamed 'incremental-learning-image-classification/README.md' -> './README.md'
renamed 'incremental-learning-image-classification/report' -> './report'
renamed 'incremental-learning-image-classification/utils' -> './utils'


In [4]:
from data.cifar100 import Cifar100
from model.resnet_cifar import resnet32
from model.manager import Manager
from utils import plot

  import pandas.util.testing as tm


## Arguments

In [0]:
# Directories
DATA_DIR = 'data'       # Directory where the dataset will be downloaded

# Settings
DEVICE = 'cuda'

# Dataset
RANDOM_STATE = 420      # For reproducibility of results                        
                        # Note: different random states give very different
                        # splits and therefore very different results.

NUM_CLASSES = 100       # Total number of classes
NUM_BATCHES = 10
CLASS_BATCH_SIZE = 10   # Size of batch of classes for incremental learning

VAL_SIZE = 0.1          # Proportion of validation set with respect to training set (between 0 and 1)

# Training
BATCH_SIZE = 64         # Batch size (iCaRL sets this to 128)
LR = 2                  # Initial learning rate
                        # iCaRL sets LR = 2. Since they use BinaryCrossEntropy loss it is feasible,
                        # in our case it would diverge as we use CrossEntropy loss.
MOMENTUM = 0.9          # Momentum for stochastic gradient descent (SGD)
WEIGHT_DECAY = 1e-5     # Weight decay from iCaRL

NUM_RUNS = 3            # Number of runs of every method
                        # Note: this should be at least 3 to have a fair benchmark

NUM_EPOCHS = 70         # Total number of training epochs
MILESTONES = [49, 63]   # Step down policy from iCaRL (MultiStepLR)
                        # Decrease the learning rate by gamma at each milestone
GAMMA = 0.2             # Gamma factor from iCaRL

## Data preparation

In [0]:
# Define transformations for training
train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4),
                                      transforms.RandomHorizontalFlip(),
                                      transforms.ToTensor(), # Turn PIL Image to torch.Tensor
                                      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Define transformations for evaluation
test_transform = transforms.Compose([transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))                                    
])

In [0]:
train_dataloaders = [[] for i in range(NUM_RUNS)]
val_dataloaders = [[] for i in range(NUM_RUNS)]
test_dataloaders = [[] for i in range(NUM_RUNS)]

for run_i in range(NUM_RUNS):

    test_subsets = []

    for split_i in range(CLASS_BATCH_SIZE):

        # Download dataset only at first instantiation
        if(run_i+split_i == 0):
            download = True
        else:
            download = False

        # Create CIFAR100 dataset
        train_dataset = Cifar100(DATA_DIR, train=True, download=download,
                                 random_state=RANDOM_STATE+run_i, transform=train_transform)
        test_dataset = Cifar100(DATA_DIR, train=False, download=False,
                                random_state=RANDOM_STATE+run_i, transform=test_transform)

        # Subspace of CIFAR100 of 10 classes
        train_dataset.set_classes_batch(train_dataset.batch_splits[split_i])
        test_dataset.set_classes_batch(
            [test_dataset.batch_splits[i] for i in range(0, split_i+1)])

        # Define train and validation indices
        train_indices, val_indices = train_dataset.train_val_split(
            VAL_SIZE, RANDOM_STATE)

        train_dataloaders[run_i].append(DataLoader(Subset(train_dataset, train_indices),
                                                   batch_size=BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True))

        val_dataloaders[run_i].append(DataLoader(Subset(train_dataset, val_indices),
                                                 batch_size=BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True))

        # Dataset with all seen class
        test_dataloaders[run_i].append(DataLoader(test_dataset,
                                                  batch_size=BATCH_SIZE, shuffle=True, num_workers=4))

In [0]:
# Sanity check: visualize a batch of images
dataiter = iter(test_dataloaders[0][5])
images, labels = dataiter.next()

plot.image_grid(images, one_channel=False)
unique_labels = np.unique(labels, return_counts=True)
unique_labels

## Fine tuning

In [0]:
# @todo try xavier initialization 

In [0]:
train_loss_history = []
train_accuracy_history = []
val_loss_history = []
val_accuracy_history = []
test_accuracy_history = []

# Iterate over runs
for train_dataloader, val_dataloader, test_dataloader in zip(train_dataloaders,
                                                             val_dataloaders, test_dataloaders):
  
    
    train_loss_history.append({})
    train_accuracy_history.append({})
    val_loss_history.append({})
    val_accuracy_history.append({})
    test_accuracy_history.append({})

    net = resnet32()  # Define the net
    
    criterion = nn.BCEWithLogitsLoss()  # Define the loss      
    
    i = 0
    for train_split, val_split, test_split in zip(train_dataloader,
                                                  val_dataloader, test_dataloader):
      
        
        current_split = "Split %i"%(i)
        print(current_split)

        parameters_to_optimize = net.parameters()
        optimizer = optim.SGD(parameters_to_optimize, lr=LR,
                          momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
        scheduler = optim.lr_scheduler.MultiStepLR(optimizer, 
                                               milestones=MILESTONES, gamma=GAMMA)

        # Define Manager Object
        manager = Manager(DEVICE, net, criterion, optimizer, scheduler,
                          train_split, val_split, test_split)

        scores = manager.train(NUM_EPOCHS)  # train the model

        # score[i] = dictionary with key:epoch, value: score
        train_loss_history[-1][current_split] = scores[0]
        train_accuracy_history[-1][current_split] = scores[1]
        val_loss_history[-1][current_split] = scores[2]
        val_accuracy_history[-1][current_split] = scores[3]

        # Test the model on classes seen until now
        test_accuracy, all_preds = manager.test()

        test_accuracy_history[-1][current_split] = test_accuracy

        # Uncomment if default resnet has 10 node at last FC layer
        manager.increment_classes(n=10)  # add 10 nodes to last FC layer

        i+=1
        

In [0]:
# Confusion matrix over last run test predictions
targets = test_dataset.targets
preds = all_preds.to('cpu').numpy()

plot.heatmap_cm(targets, preds)

In [0]:
def mean_std_scores(train_loss_history, train_accuracy_history,
                   val_loss_history, val_accuracy_history, test_accuracy_history):
  '''
      Average the scores of runs different splits
  '''
  # keys = 'Split i-esim'
  keys = train_loss_history[0].keys()

  # Containers for average scores
  avg_train_loss = {k:[] for k in keys}
  avg_train_accuracy = {k:[] for k in keys}
  avg_val_loss = {k:[] for k in keys}
  avg_val_accuracy = {k:[] for k in keys}
  avg_test_accuracy = {k:[] for k in keys}
  
  train_loss = []
  train_accuracy = []
  val_loss = []
  val_accuracy = []
  test_accuracy = []

  for key in keys:
    for run in range(NUM_RUNS):

      # Append all i-th scores (split i-esim) for the different runs
      avg_train_loss[key].append(train_loss_history[run][key])
      avg_train_accuracy[key].append(train_accuracy_history[run][key])
      avg_val_loss[key].append(val_loss_history[run][key])
      avg_val_accuracy[key].append(val_accuracy_history[run][key])
      avg_test_accuracy[key].append(test_accuracy_history[run][key])

    # Define (mean, std) of the i-th score for each split
    train_loss.append([np.array(avg_train_loss[key]).mean(), np.array(avg_train_loss[key]).std()])
    train_accuracy.append([np.array(avg_train_accuracy[key]).mean(), np.array(avg_train_accuracy[key]).std()])
    val_loss.append([np.array(avg_val_loss[key]).mean(), np.array(avg_val_loss[key]).std()])
    val_accuracy.append([np.array(avg_val_accuracy[key]).mean(), np.array(avg_val_accuracy[key]).std()])
    test_accuracy.append([np.array(avg_test_accuracy[key]).mean(), np.array(avg_test_accuracy[key]).std()])

  train_loss = np.array(train_loss)
  train_accuracy = np.array(train_accuracy)
  val_loss = np.array(val_loss)
  val_accuracy = np.array(val_accuracy)
  test_accuracy = np.array(test_accuracy)

  # Return averaged scores
  return(train_loss, train_accuracy, val_loss, val_accuracy, test_accuracy)

In [0]:
# Get the average scores
train_loss, train_accuracy, val_loss, val_accuracy,\
test_accuracy = mean_std_scores(train_loss_history, train_accuracy_history,
                                   val_loss_history, val_accuracy_history, test_accuracy_history)

In [0]:
plot.train_val_scores(train_loss, train_accuracy, val_loss, val_accuracy, None)

In [0]:
plot.test_scores(test_accuracy, None)

In [0]:
# @todo: create utils package for functions

import ast

def load_json_scores(root):

  with open(os.path.join(root, 'train_accuracy_history.json')) as f:
      train_accuracy_history = ast.literal_eval(f.read())

  with open(os.path.join(root, 'train_loss_history.json')) as f:
      train_loss_history = ast.literal_eval(f.read())

  with open(os.path.join(root, 'val_accuracy_history.json')) as f:
      val_accuracy_history = ast.literal_eval(f.read())

  with open(os.path.join(root, 'val_loss_history.json')) as f:
      val_loss_history = ast.literal_eval(f.read())

  with open(os.path.join(root, 'test_accuracy_history.json')) as f:
      test_accuracy_history = ast.literal_eval(f.read())

  return(train_loss_history, train_accuracy_history, val_loss_history,
         val_accuracy_history, test_accuracy_history)

In [0]:
# @todo: create utils package for functions
import json

def save_json_scores(root, train_loss_history, train_accuracy_history,
                   val_loss_history, val_accuracy_history, test_accuracy_history):

    with open(os.path.join(root, 'train_loss_history.json'), 'w') as fout:
        json.dump(train_loss_history, fout)

    with open(os.path.join(root, 'train_accuracy_history.json'), 'w') as fout:
        json.dump(train_accuracy_history, fout)

    with open(os.path.join(root, 'val_loss_history.json'), 'w') as fout:
        json.dump(val_loss_history, fout)

    with open(os.path.join(root, 'val_accuracy_history.json'), 'w') as fout:
        json.dump(val_accuracy_history, fout)

    with open(os.path.join(root, 'test_accuracy_history.json'), 'w') as fout:
        json.dump(test_accuracy_history, fout)

In [0]:
save_json_scores('scores', train_loss_history, train_accuracy_history,
                   val_loss_history, val_accuracy_history, test_accuracy_history)

In [0]:
from google.colab import files
!zip -r scores.zip scores
files.download("scores.zip")

## Learning Without Forgetting

### Arguments

In [0]:
# Training settings for Learning Without Forgetting
RANDOM_STATE = 420
BATCH_SIZE = 128
LR = 2

### Data preparation

In [0]:
# Transformations for Learning Without Forgetting
train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4),
                                      transforms.RandomHorizontalFlip(),
                                      transforms.ToTensor(), # Turn PIL Image to torch.Tensor
                                      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

test_transform = transforms.Compose([transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))                                    
])

In [0]:
train_dataloaders = [[] for i in range(NUM_RUNS)]
val_dataloaders = [[] for i in range(NUM_RUNS)]
test_dataloaders = [[] for i in range(NUM_RUNS)]

for run_i in range(NUM_RUNS):

  test_subsets = []

  for split_i in range(CLASS_BATCH_SIZE):

    # Download dataset only at first instantiation
    if(run_i+split_i == 0):
      download = True
    else:
      download = False

    # Create CIFAR100 dataset
    train_dataset = Cifar100(DATA_DIR, train = True, download = download, random_state = RANDOM_STATE+run_i, transform=train_transform)
    test_dataset = Cifar100(DATA_DIR, train = False, download = False, random_state = RANDOM_STATE+run_i, transform=test_transform)
   
    # Subspace of CIFAR100 of 10 classes
    train_dataset.set_classes_batch(train_dataset.batch_splits[split_i]) 
    test_dataset.set_classes_batch([test_dataset.batch_splits[i] for i in range(0, split_i+1)])

    # Define train and validation indices
    train_indices, val_indices = train_dataset.train_val_split(VAL_SIZE, RANDOM_STATE)
    
    train_dataloaders[run_i].append(DataLoader(Subset(train_dataset, train_indices), 
                               batch_size=BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True))
    
    val_dataloaders[run_i].append(DataLoader(Subset(train_dataset, val_indices), 
                                batch_size=BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True))
    
    # Dataset with all seen class
    test_dataloaders[run_i].append(DataLoader(test_dataset, 
                               batch_size=BATCH_SIZE, shuffle=True, num_workers=4))           

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to data/cifar-100-python.tar.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting data/cifar-100-python.tar.gz to data


In [0]:
# Sanity check: visualize a batch of images
dataiter = iter(test_dataloaders[0][5])
images, labels = dataiter.next()

plot.image_grid(images, one_channel=False)
unique_labels = np.unique(labels, return_counts=True)
unique_labels

In [0]:
from torch.nn import BCEWithLogitsLoss
from copy import deepcopy

'''BCE formulation:
 let x = logits, z = labels. The logistic loss is

  z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
'''

   
CLASS_BATCH_SIZE = 10


class LWF():
  def __init__(self, device, net, old_net, criterion, optimizer, scheduler,
               train_dataloader, val_dataloader, test_dataloader, num_classes=10):
    
    self.device = device

    self.net = net
    self.best_net = self.net
    self.old_net = old_net # None for first ten classes

    self.criterion = BCEWithLogitsLoss() # Classifier criterion 
    self.optimizer = optimizer
    self.scheduler = scheduler

    self.train_dataloader = train_dataloader
    self.val_dataloader = val_dataloader
    self.test_dataloader = test_dataloader

    self.num_classes = num_classes # can be incremented ouitside methods in the main, or inside methods
    self.order = np.arange(100)

    self.sigmoid = nn.Sigmoid()


  def warm_up():
    pass

  def increment_classes(self, n=10):
    """Add n classes in the final fully connected layer."""

    in_features = self.net.fc.in_features  # size of each input sample
    out_features = self.net.fc.out_features  # size of each output sample
    weight = self.net.fc.weight.data

    self.net.fc = nn.Linear(in_features, out_features+n)
    self.net.fc.weight.data[:out_features] = weight

  def to_onehot(self, targets): 
    '''
    Args:
    targets : dataloader.dataset.targets of the new task images
    '''
    one_hot_targets = torch.eye(self.num_classes)[targets]

    return one_hot_targets.to(self.device)

  def do_first_batch(self, batch, labels):

    batch = batch.to(self.device)
    labels = labels.to(self.device) # new classes labels

    # Zero-ing the gradients
    self.optimizer.zero_grad()

    # One hot encoding of new task labels 
    one_hot_labels = self.to_onehot(labels) # Size = [128, 10]

    # New net forward pass
    outputs = self.net(batch)  
    
    loss = self.criterion(outputs, one_hot_labels) # BCE Loss with sigmoids over outputs

    # Get predictions
    _, preds = torch.max(outputs.data, 1)

    # Accuracy over NEW IMAGES, not over all images
    running_corrects = \
        torch.sum(preds == labels.data).data.item() # Può essere che debba usare targets e non labels

    # Backward pass: computes gradients
    loss.backward()

    self.optimizer.step()

    return loss, running_corrects


  def do_batch(self, batch, labels):

    batch = batch.to(self.device)
    labels = labels.to(self.device) # new classes labels

    # Zero-ing the gradients
    self.optimizer.zero_grad()

    # One hot encoding of new task labels 
    one_hot_labels = self.to_onehot(labels) # Size = [128, n_classes] will be sliced as [:, :self.num_classes-10]
    new_classes = (self.order[range(self.num_classes-10, self.num_classes)]).astype(np.int32)
    one_hot_labels = torch.stack([one_hot_labels[:, i] for i in new_classes], axis=1)

    # Old net forward pass
    old_outputs = self.sigmoid(self.old_net(batch)) # Size = [128, 100]
    old_classes = (self.order[range(self.num_classes-10)]).astype(np.int32)
    old_outputs = torch.stack([old_outputs[:, i] for i in old_classes], axis =1)
    
    # Combine new and old class targets
    targets = torch.cat((old_outputs, one_hot_labels), 1)

    # New net forward pass
    outputs = self.net(batch) # Size = [128, 100] comparable with the define targets
    out_classes = (self.order[range(self.num_classes)]).astype(np.int32)
    outputs = torch.stack([outputs[:, i] for i in out_classes], axis=1)
  
    
    loss = self.criterion(outputs, targets) # BCE Loss with sigmoids over outputs (over targets must be done manually)

    # Get predictions
    _, preds = torch.max(outputs.data, 1)

    # Accuracy over NEW IMAGES, not over all images
    running_corrects = \
        torch.sum(preds == labels.data).data.item() 

    # Backward pass: computes gradients
    loss.backward()

    self.optimizer.step()

    return loss, running_corrects


  def do_epoch(self, current_epoch):

    self.net.train()

    running_train_loss = 0
    running_corrects = 0
    total = 0
    batch_idx = 0

    print(f"Epoch: {current_epoch}, LR: {self.scheduler.get_last_lr()}")

    for images, labels in self.train_dataloader:

      if self.num_classes == CLASS_BATCH_SIZE:
        loss, corrects = self.do_first_batch(images, labels)
      else:
        loss, corrects = self.do_batch(images, labels)

      running_train_loss += loss.item()
      running_corrects += corrects
      total += labels.size(0)
      batch_idx += 1

    self.scheduler.step()

    # Calculate average scores
    train_loss = running_train_loss / batch_idx # Average over all batches
    train_accuracy = running_corrects / float(total) # Average over all samples

    print(f"Train loss: {train_loss}, Train accuracy: {train_accuracy}")

    return (train_loss, train_accuracy)


  def train(self, num_epochs):
    """Train the network for a specified number of epochs, and save
    the best performing model on the validation set.
    
    Args:
        num_epochs (int): number of epochs for training the network.
    Returns:
        train_loss: loss computed on the last epoch
        train_accuracy: accuracy computed on the last epoch
        val_loss: average loss on the validation set of the last epoch
        val_accuracy: accuracy on the validation set of the last epoch
    """

    # @todo: is the return behaviour intended? (scores of the last epoch)

    self.net = self.net.to(self.device)
    if self.old_net != None:
      self.old_net = self.old_net.to(self.device)
      self.old_net.train(False)

    cudnn.benchmark  # Calling this optimizes runtime

    self.best_loss = float("inf")
    self.best_epoch = 0

    for epoch in range(num_epochs):
        # Run an epoch (start counting form 1)
        train_loss, train_accuracy = self.do_epoch(epoch+1)
    
        # Validate after each epoch 
        val_loss, val_accuracy = self.validate()    

        # Best validation model
        if val_loss < self.best_loss:
            self.best_loss = val_loss
            self.best_net = deepcopy(self.net)
            self.best_epoch = epoch
            print("Best model updated")

        print("")

    return (train_loss, train_accuracy,
            val_loss, val_accuracy)


  def validate(self):
    """Validate the model.
    
    Returns:
        val_loss: average loss function computed on the network outputs
            of the validation set (val_dataloader).
        val_accuracy: accuracy computed on the validation set.
    """

    self.net.train(False)

    running_val_loss = 0
    running_corrects = 0
    total = 0
    batch_idx = 0


    for batch, labels in self.val_dataloader:
      batch = batch.to(self.device)
      labels = labels.to(self.device)
      total += labels.size(0)

      # One hot encoding of new task labels 
      one_hot_labels = self.to_onehot(labels) # Size = [128, 100] will be sliced as [:, :self.num_classes-10]
      new_classes = (self.order[range(self.num_classes-10, self.num_classes)]).astype(np.int32)
      one_hot_labels = torch.stack([one_hot_labels[:, i] for i in new_classes], axis=1)

      if self.num_classes > 10:
        # Old net forward pass
        old_outputs = self.sigmoid(self.old_net(batch)) # Size = [128, 100]
        old_classes = (self.order[range(self.num_classes-10)]).astype(np.int32)
        old_outputs = torch.stack([old_outputs[:, i] for i in old_classes], axis =1)

        # Combine new and old class targets
        targets = torch.cat((old_outputs, one_hot_labels), 1)

      else:
        targets = one_hot_labels

      # New net forward pass
      outputs = self.net(batch) # Size = [128, 100] comparable with the define targets
      out_classes = (self.order[range(self.num_classes)]).astype(np.int32)
      outputs = torch.stack([outputs[:, i] for i in out_classes], axis=1)

      
      loss = self.criterion(outputs, targets) # BCE Loss with sigmoids over outputs (over targets must be done manually)

      # Get predictions
      _, preds = torch.max(outputs.data, 1)

      # Update the number of correctly classified validation samples
      running_corrects += torch.sum(preds == labels.data).data.item()
      running_val_loss += loss.item()

      batch_idx += 1

    # Calcuate scores
    val_loss = running_val_loss / batch_idx
    val_accuracy = running_corrects / float(total)

    print(f"Validation loss: {val_loss}, Validation accuracy: {val_accuracy}")

    return (val_loss, val_accuracy)


  def test(self):
    """Test the model.
    Returns:
        accuracy (float): accuracy of the model on the test set
    """

    self.best_net.train(False)  # Set Network to evaluation mode

    running_corrects = 0
    total = 0

    all_preds = torch.tensor([]) # to store all predictions
    all_preds = all_preds.type(torch.LongTensor)
    
    for images, labels in self.test_dataloader:
      images = images.to(self.device)
      labels = labels.to(self.device)
      total += labels.size(0)

      # Forward Pass
      outputs = self.best_net(images)

      # Get predictions
      _, preds = torch.max(outputs.data, 1)

      # Update Corrects
      running_corrects += torch.sum(preds == labels.data).data.item()

      # Append batch predictions
      all_preds = torch.cat(
          (all_preds.to(self.device), preds.to(self.device)), dim=0
      )

    # Calculate accuracy
    accuracy = running_corrects / float(total)  

    print(f"Test accuracy: {accuracy}")

    return (accuracy, all_preds)

In [0]:
train_loss_history = []
train_accuracy_history = []
val_loss_history = []
val_accuracy_history = []
test_accuracy_history = []



# Iterate over runs
for train_dataloader, val_dataloader, test_dataloader in zip(train_dataloaders,
                                                             val_dataloaders, test_dataloaders):
  
    
    train_loss_history.append({})
    train_accuracy_history.append({})
    val_loss_history.append({})
    val_accuracy_history.append({})
    test_accuracy_history.append({})

    net = resnet32()  # Define the net
    
    criterion = nn.BCEWithLogitsLoss()  # Define the loss
        
    
    i = 0
    for train_split, val_split, test_split in zip(train_dataloader,
                                                  val_dataloader, test_dataloader):
      
      # Redefine optimizer at each split (pass by reference issue)
      parameters_to_optimize = net.parameters()
      optimizer = optim.SGD(parameters_to_optimize, lr=LR,
                            momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
      scheduler = optim.lr_scheduler.MultiStepLR(optimizer, 
                                                milestones=MILESTONES, gamma=GAMMA)
        
      current_split = "Split %i"%(i)
      print(current_split)

      num_classes = CLASS_BATCH_SIZE*(i+1)

      if num_classes == CLASS_BATCH_SIZE:
        # Old Network = None
        lwf = LWF(DEVICE, net, None, criterion, optimizer, scheduler,
                          train_split, val_split, test_split, num_classes)
      else:
        lwf = LWF(DEVICE, net, old_net, criterion, optimizer, scheduler,
                        train_split, val_split, test_split, num_classes)
        

      scores = lwf.train(NUM_EPOCHS)  # train the model

      # score[i] = dictionary with key:epoch, value: score
      train_loss_history[-1][current_split] = scores[0]
      train_accuracy_history[-1][current_split] = scores[1]
      val_loss_history[-1][current_split] = scores[2]
      val_accuracy_history[-1][current_split] = scores[3]

      # Test the model on classes seen until now
      test_accuracy, all_preds = lwf.test()

      test_accuracy_history[-1][current_split] = test_accuracy

      # Uncomment if default resnet has 10 node at last FC layer
      old_net = deepcopy(lwf.net)
      lwf.increment_classes()

      i =i+1

## iCaRL

In [0]:
# @todo: move to package once finished
# @todo: provide an uncommented .py for easier editing of class for ablation studies

from math import floor
from copy import deepcopy

sigmoid = nn.Sigmoid() # Sigmoid function

class Exemplars(torch.utils.data.Dataset):
    dataset = []
    targets = []

    def __init__(self, transform=None):
        self.transform = transform

    def add_exemplar(self, sample, target):
        self.dataset.append(sample)
        self.targets.append(target)
    
    def __getitem__(self, index):
        image = self.dataset[index]
        target = self.targets[index]

        if self.transform is not None:
            image = self.transform(image)

        return image, target

    def __len__(self):
        return len(self.targets)

class iCaRL:
    """Implement iCaRL, a strategy for simultaneously learning classifiers and a
    feature representation in the class-incremental setting.

    Args:
        ...
    """

    # Maximum number of exemplars
    memory_size = 2000

    # List of exemplar sets. Each set contains memory_size/num_classes exemplars
    # with num_classes the number of classes seen until now by the network.
    exemplars = []

    # Initialize the copy of the old network, used to compute outputs of the
    # previous network for the distillation loss, to None. This is useful to
    # correctly apply the first function when training the network for the first
    # time.
    old_net = None

    # Loss function
    criterion = nn.BCEWithLogitsLoss() # @todo: should this be reduction='sum'?

    def __init__(self, device, net, lr, momentum, weight_decay, milestones, gamma, num_epochs, batch_size, train_transform):
        self.device = device
        self.net = net

        # Set hyper-parameters
        self.LR = lr
        self.MOMENTUM = momentum
        self.WEIGHT_DECAY = weight_decay
        self.MILESTONES = milestones
        self.GAMMA = gamma
        self.NUM_EPOCHS = num_epochs
        self.BATCH_SIZE = batch_size
        
        self.train_transform = train_transform

    def classify(self, image):
        """Mean-of-exemplars classifier used to classify images into the set of
        classes observed so far.

        Args:
            image: image to classify (three channels, RGB)
        Returns:
            label (int): class label assigned to the image
        """

        features = self.extract_features(image)

        # @todo (?): post di Cermelli su forum:
        # 3) Compute the means for the testing with all the data available.
        # That means: don't compute the means on the new exemplars only,
        # but on all the images you have available for that step (e.g in
        # the last step you have 2000/90 images for each previous class and
        # all the images for the 10 novel ones.  Don't compute the means on
        # the novel exemplars which are only 20 per class, but on all the
        # data you have - this will remove noise from the mean estimate,
        # improving the results).

        # Compute the mean of exemplars for all available classes.
        # means is a 2-dimensional tensor, the i-th row containing the mean of 
        # exemplars in the i-th class.
        means = torch.stack([self.mean_of_exemplars(y) for y in range(len(self.exemplars))])
        
        f_arg = torch.norm(features - means, dim=1)

        label = torch.argmin(f_arg)
        
        return label
    
    def mean_of_exemplars(self, class_id):
        """

        Args:
            ...
        """
        
        cardinality = len(self.exemplars[class_id])

        # Initialize the sum over exemplar to a tensor of zeros
        exemplar_sum = torch.tensor([0 for _ in range(cardinality)])

        # See iCaRL algorithm 1
        for i in range(cardinality):
            exemplar_features = self.extract_features(self.exemplars[class_id][i]).to(self.device)
            exemplar_sum += exemplar_features

        mean = 1/cardinality * exemplar_sum

        return mean

    def incremental_train(self, split, train_dataset, val_dataset):
        """Adjust internal knowledge based on the additional information
        available in the new observations.

        Args:
            ...
        Returns:
            ...
        """

        if split is not 0:
            # Increment the number of output nodes for the new network by 10
            self.increment_classes(10)

        # Improve network parameters upon receiving new classes. Effectively
        # train a new network starting from the current network parameters.
        self.update_representation(train_dataset, val_dataset)

        # Compute the number of exemplars per class
        num_classes = self.output_neurons_count()
        m = floor(self.memory_size / num_classes)

        print(f"Target number of exemplars per class: {m}")
        print(f"Total number of exemplars: {m*num_classes}")

        # Reduce pre-existing exemplar sets in order to fit new exemplars
        for y in range(len(self.exemplars)):
            self.exemplars[y] = self.reduce_exemplar_set(self.exemplars[y], m)

        # Construct exemplar set for new classes
        new_exemplars = self.construct_exemplar_set(train_dataset, m)
        self.exemplars.extend(new_exemplars)

    def update_representation(self, train_dataset, val_dataset):
        """

        Args:
            ...
        """

        # Define the optimization algorithm
        parameters_to_optimize = self.net.parameters()
        self.optimizer = optim.SGD(parameters_to_optimize, 
                                   lr=self.LR,
                                   momentum=self.MOMENTUM,
                                   weight_decay=self.WEIGHT_DECAY)
        
        # Define the learning rate decaying policy
        self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer,
                                                        milestones=self.MILESTONES,
                                                        gamma=self.GAMMA)

        # Combine the new training data with existing exemplars.
        # create exemplars dataset
        exemplars_dataset = self.create_exemplars_dataset()
        train_dataset_with_exemplars = ConcatDataset([train_dataset, exemplars_dataset])

        # Train the network on combined dataset
        self.train(train_dataset_with_exemplars, val_dataset) # @todo: include exemplars in validation set?

        # Keep a copy of the current network in order to compute its outputs for
        # the distillation loss while the new network is being trained.
        self.old_net = deepcopy(self.net)
    
    def construct_exemplar_set(self, dataset, m):
        """

        Args:
            dataset: dataset containing a split (samples from 10 classes) from
                which to take exemplars.
            m (int): target number of exemplars per class.
        """

        # Disable transformations when getting items. We will need original PIL
        # images to store in our exemplar set.
        dataset.dataset.disable_transform()

        # Create a list that stores all images and labels in the dataset by
        # their class index in an ordered fashion, structured as follows:
        #
        # samples = [
        #     [image0, image1, ...] # Class 0
        #     [image0, image1, ...] # Class 1
        #     ...
        #     [image0, image1, ...] # Class 9
        # ]
        samples = [[] for _ in range(10)]
        for image, label in dataset:
            label = label % 10 # Map labels to 0-9 range
            samples[label].append(image)

        # Re-enable transformations on the dataset
        dataset.dataset.enable_transform()

        # Initialize exemplar set
        exemplars = [[] for _ in range(10)]
        
        # Iterate over classes
        for y in range(10):
            print(f"Extracting exemplars from class {y} of current split...")

            # Extract features from samples of the current class
            samples_features = [self.extract_features(samples[y][i]).to(self.device) for i in range(len(samples[y]))]

            # Compute the feature mean of the current class
            features_mean = torch.stack(samples_features).mean(dim=0)

            # See iCaRL algorithm 4
            for k in range(1, m+1): # k = 1, ..., m -- Choose m exemplars
                if k == 1: # No exemplars chosen yet, sum to 0 vector
                    f_sum = torch.tensor([0 for _ in range(samples_features[0].size(0))])
                else: # Sum of features of all exemplars chosen until now (j = 1, ..., k-1)
                    exemplars_features = torch.stack([self.extract_features(exemplars[y][j]) for j in range(len(exemplars[y]))]).to(self.device)
                    f_sum = exemplars_features.sum(dim=0)
                
                f_sum = f_sum.to(self.device)

                f_arg = torch.norm(features_mean - 1/k * (torch.stack(samples_features) + f_sum), dim=1)

                exemplar_idx = torch.argmin(f_arg)

                exemplars[y].append(samples[y][exemplar_idx])

                # Remove sample from list after adding it to the exemplar set.
                # This is done to avoid creating duplicate exemplars.
                del samples[y][exemplar_idx]
                del samples_features[exemplar_idx]

            print(f"Extracted {len(exemplars[y])} exemplars.")
        
        return exemplars

    def extract_features(self, sample):
        """Extract features from PIL image.

        Args:
            sample: PIL image from which to extract features.
        Returns:
            tensor: features extracted from the sample.
        """

        self.best_net.train(False)

        # Transform sample and send it to device to feed it to the network
        sample = self.train_transform(sample)
        sample = sample.unsqueeze(0) # https://stackoverflow.com/a/59566009/6486336
        sample = sample.to(self.device)

        return self.best_net.features(sample)[0]

    def reduce_exemplar_set(self, exemplar_set, m):
        """Procedure for removing exemplars from a given set.

        Args:
            exemplar_set (set): set of exemplars belonging to a certain class.
            m (int): target number of exemplars.
        """

        return exemplar_set[:m]

    def train(self, train_dataset, val_dataset):
        """Train the network for a specified number of epochs, and save
        the best performing model on the validation set.
        
        Args:
            ...
        Returns:
            train_loss: loss computed on the last epoch
            train_accuracy: accuracy computed on the last epoch
            val_loss: average loss on the validation set of the last epoch
            val_accuracy: accuracy on the validation set of the last epoch
        """

        # Define the optimization algorithm
        parameters_to_optimize = self.net.parameters()
        self.optimizer = optim.SGD(parameters_to_optimize, 
                                   lr=self.LR,
                                   momentum=self.MOMENTUM,
                                   weight_decay=self.WEIGHT_DECAY)
        
        # Define the learning rate decaying policy
        self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer,
                                                        milestones=self.MILESTONES,
                                                        gamma=self.GAMMA)

        # Create DataLoaders for training and validation
        self.train_dataloader = DataLoader(train_dataset, batch_size=self.BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True)
        self.val_dataloader = DataLoader(val_dataset, batch_size=self.BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True)

        # @todo: is the return behaviour intended? (scores of the last epoch)

        # Send the old network to the chosen device and set it to evaluation mode
        if self.old_net is not None:
            self.old_net = self.old_net.to(self.device)
            self.old_net.train(False)

        # Send the current network to the chosen device
        self.net.to(self.device)
        cudnn.benchmark  # Calling this optimizes runtime

        self.best_accuracy = 0 # @todo: should we use best_loss instead?
        self.best_epoch = 0

        for epoch in range(self.NUM_EPOCHS):
            # Run an epoch (start counting form 1)
            train_loss, train_accuracy = self.do_epoch(epoch+1)
        
            # Validate after each epoch 
            val_loss, val_accuracy = self.validate()    

            # Best validation model
            if val_accuracy > self.best_accuracy:
                self.best_accuracy = val_accuracy
                self.best_net = deepcopy(self.net)
                self.best_epoch = epoch
                print("Best model updated")

        print(f"Best model found at epoch {self.best_epoch}")

        return (train_loss, train_accuracy,
                val_loss, val_accuracy)
    
    def do_epoch(self, current_epoch):
        """Trains model for one epoch.
        
        Args:
            current_epoch (int): current epoch number (begins from 1)
        Returns:
            train_loss: average training loss over all batches of the
                current epoch.
            train_accuracy: training accuracy of the current epoch over
                all samples.
        """

        self.net.train()  # Set network in training mode

        running_train_loss = 0
        running_corrects = 0
        total = 0
        batch_idx = 0

        print(f"Epoch: {current_epoch}, LR: {self.scheduler.get_last_lr()}")

        for images, labels in self.train_dataloader:
            loss, corrects = self.do_batch(images, labels)

            running_train_loss += loss.item()
            running_corrects += corrects
            total += labels.size(0)
            batch_idx += 1

        self.scheduler.step()

        # Calculate average scores
        train_loss = running_train_loss / batch_idx # Average over all batches
        train_accuracy = running_corrects / float(total) # Average over all samples

        print(f"Train loss: {train_loss}, Train accuracy: {train_accuracy}")

        return (train_loss, train_accuracy)

    def do_batch(self, batch, labels):
        batch = batch.to(self.device)
        labels = labels.to(self.device)

        # Zero-ing the gradients
        self.optimizer.zero_grad()
        
        # One-hot encoding of labels of the new training data (new classes)
        # Size: batch size (rows) by number of classes seen until now (columns)
        #
        # e.g., suppose we have four images in a batch, and each incremental
        #   step adds three new classes. At the second step, the one-hot
        #   encoding may return the following tensor:
        #
        #       tensor([[0., 0., 0., 1., 0., 0.],   # image 0 (label 3)
        #               [0., 0., 0., 0., 1., 0.],   # image 1 (label 4)
        #               [0., 0., 0., 0., 0., 1.],   # image 2 (label 5)
        #               [0., 0., 0., 0., 1., 0.]])  # image 3 (label 4)
        #
        #   The first three elements of each vector will always be 0, as the
        #   new training batch does not contain images belonging to classes
        #   already seen in previous steps.
        #
        #   The last three elements of each vector will contain the actual
        #   information about the class of each image (one-hot encoding of the
        #   label). Therefore, we slice the tensor and remove the columns 
        #   related to old classes (all zeros).# Number of classes seen until now, including new classes
        num_classes = self.output_neurons_count() 
        one_hot_labels = self.to_onehot(labels)[:, num_classes-10:num_classes]

        if self.old_net is None: # Network is training for the first time
            # We only apply the classification loss
            targets = one_hot_labels

            # Forward pass
            outputs = self.net(batch)

            loss = self.criterion(outputs, targets)

        else:
            # Old net forward pass. We compute the outputs of the old network
            # and apply a sigmoid function. These are used in the distillation
            # loss. We discard the output of the new neurons, as they are not
            # considered in the distillation loss.
            old_net_outputs = sigmoid(self.old_net(batch))[:, :num_classes-10]

            # Concatenate the outputs of the old network and the one-hot encoded
            # labels along dimension 1 (columns).
            # 
            # Each row refers to an image in the training set, and contains:
            # - the output of the old network for that image, used by the
            #   distillation loss
            # - the one-hot label of the image, used by the classification loss
            targets = torch.cat((old_net_outputs, one_hot_labels), dim=1)

            # New net forward pass
            # Size: batch size (rows) by number of classes seen until now, 
            # including new ones (columns)
            outputs = self.net(batch)

            loss = self.criterion(outputs, targets)

        # Get predictions
        _, preds = torch.max(outputs.data, 1)

        # Accuracy over NEW IMAGES, not over all images
        running_corrects = \
            torch.sum(preds == labels.data).data.item() 

        # Backward pass: computes gradients
        loss.backward()

        self.optimizer.step()

        return loss, running_corrects

    def validate(self):
        """Validate the model.
        
        Returns:
            val_loss: average loss function computed on the network outputs
                of the validation set (val_dataloader).
            val_accuracy: accuracy computed on the validation set.
        """

        self.net.train(False)

        running_val_loss = 0
        running_corrects = 0
        total = 0
        batch_idx = 0

        for images, labels in self.val_dataloader:
            images = images.to(self.device)
            labels = labels.to(self.device)
            total += labels.size(0)

            # One hot encoding of new task labels 
            one_hot_labels = self.to_onehot(labels)

            # New net forward pass
            outputs = self.net(images)  
            loss = self.criterion(outputs, one_hot_labels) # BCE Loss with sigmoids over outputs

            running_val_loss += loss.item()

            # Get predictions
            _, preds = torch.max(outputs.data, 1)

            # Update the number of correctly classified validation samples
            running_corrects += torch.sum(preds == labels.data).data.item()

            batch_idx += 1

        # Calculate scores
        val_loss = running_val_loss / batch_idx
        val_accuracy = running_corrects / float(total)

        print(f"Validation loss: {val_loss}, Validation accuracy: {val_accuracy}")

        return val_loss, val_accuracy

    def test(self, test_dataset):
        """Test the model.

        Returns:
            accuracy (float): accuracy of the model on the test set
        """

        self.test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)

        self.best_net.train(False)  # Set Network to evaluation mode

        running_corrects = 0
        total = 0

        all_preds = torch.tensor([]) # to store all predictions
        all_preds = all_preds.type(torch.LongTensor)
        
        for images, labels in self.test_dataloader:
            images = images.to(self.device)
            labels = labels.to(self.device)
            total += labels.size(0)
            
            # Classify each image of the batch and update corrects
            for i, image in enumerate(images):
                pred = self.classify(image)

                if pred == labels[i]:
                    running_corrects += 1

        # Calculate accuracy
        accuracy = running_corrects / float(total)  

        print(f"Test accuracy: {accuracy}")

        return accuracy
    
    def increment_classes(self, n=10):
        """Add n classes in the final fully connected layer."""

        in_features = self.net.fc.in_features  # size of each input sample
        out_features = self.net.fc.out_features  # size of each output sample
        weight = self.net.fc.weight.data

        self.net.fc = nn.Linear(in_features, out_features+n)
        self.net.fc.weight.data[:out_features] = weight
    
    def output_neurons_count(self):
        """Return the number of output neurons of the current network."""

        return self.net.fc.out_features
    
    def to_onehot(self, targets): 
        """

        Args:
            targets: dataloader.dataset.targets of the new task images
        """
        num_classes = self.net.fc.out_features
        one_hot_targets = torch.eye(num_classes)[targets]

        return one_hot_targets.to(self.device)
    
    def create_exemplars_dataset(self):
        exemplars_dataset = Exemplars(transform=self.train_transform)

        for y in range(len(self.exemplars)): # Iterate over classes
            for i in range(len(self.exemplars[y])): # Iterate over elements
                exemplars_dataset.add_exemplar(self.exemplars[y][i], y)
        
        return exemplars_dataset

### Data preparation

In [0]:
# Transformations for Learning Without Forgetting
train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4),
                                      transforms.RandomHorizontalFlip(),
                                      transforms.ToTensor(), # Turn PIL Image to torch.Tensor
                                      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

test_transform = transforms.Compose([transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))                                    
])

In [8]:
train_subsets = [[] for i in range(NUM_RUNS)]
val_subsets = [[] for i in range(NUM_RUNS)]
test_subsets = [[] for i in range(NUM_RUNS)]

for run_i in range(NUM_RUNS):
    for split_i in range(CLASS_BATCH_SIZE):
        if run_i+split_i == 0: # Download dataset only at first instantiation
            download = True
        else:
            download = False

        # Create CIFAR100 dataset
        train_dataset = Cifar100(DATA_DIR, train=True, download=download, random_state=RANDOM_STATE+run_i, transform=train_transform)
        test_dataset = Cifar100(DATA_DIR, train=False, download=False, random_state=RANDOM_STATE+run_i, transform=test_transform)
    
        # Subspace of CIFAR100 of 10 classes
        train_dataset.set_classes_batch(train_dataset.batch_splits[split_i]) 
        test_dataset.set_classes_batch([test_dataset.batch_splits[i] for i in range(0, split_i+1)])

        # Define train and validation indices
        train_indices, val_indices = train_dataset.train_val_split(VAL_SIZE, RANDOM_STATE)

        # Define subsets
        train_subsets[run_i].append(Subset(train_dataset, train_indices))
        val_subsets[run_i].append(Subset(train_dataset, val_indices))

        test_subsets[run_i].append(test_dataset)
        
        # train_dataloaders[run_i].append(DataLoader(Subset(train_dataset, train_indices), 
        #                         batch_size=BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True))
        
        # val_dataloaders[run_i].append(DataLoader(Subset(train_dataset, val_indices), 
        #                             batch_size=BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True))
        
        # # Dataset with all seen class
        # test_dataloaders[run_i].append(DataLoader(test_dataset, 
        #                         batch_size=BATCH_SIZE, shuffle=True, num_workers=4))           

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to data/cifar-100-python.tar.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting data/cifar-100-python.tar.gz to data


In [0]:
# train_subsets[0][0].dataset.enable_transform()
# print(train_subsets[0][0][0][0][0].size())
# train_subsets[0][0].dataset.disable_transform()
# print(train_subsets[0][0][0])

# dataloader = DataLoader(train_subsets[0][0], batch_size=BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True)
# dataloader = DataLoader(ConcatDataset([train_subsets[0][0], train_subsets[0][1]]), batch_size=BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True)

# train_subsets[0][0].dataset.disable_transform()
# imgs = [train_subsets[0][0][i][0] for i in range(0, 100)]
# train_subsets[0][0].dataset.enable_transform()

### Execution

In [0]:
# iCaRL hyperparameters
LR = 2
MOMENTUM = 0.9
WEIGHT_DECAY = 0.00001
MILESTONES = [49, 63]
GAMMA = 0.2
NUM_EPOCHS = 70
BATCH_SIZE = 128

In [10]:
# Initialize logs
logs = []

# Define loss function
criterion = nn.BCEWithLogitsLoss()

for run_i in range(NUM_RUNS):
    net = resnet32()
    icarl = iCaRL(DEVICE, net, LR, MOMENTUM, WEIGHT_DECAY, MILESTONES, GAMMA, NUM_EPOCHS, BATCH_SIZE, train_transform)

    for split_i in range(10):
        print(f"## Split {split_i} of run {run_i} ##")
        
        icarl.incremental_train(split_i, train_subsets[run_i][split_i], val_subsets[run_i][split_i])
        icarl.test(test_subsets[run_i][split_i])

        # @todo: implement logging and plots

## Split 0 of run 0 ##
Epoch: 1, LR: [2]
Train loss: 0.3998215079307556, Train accuracy: 0.13258928571428572
Validation loss: 0.5350239177544912, Validation accuracy: 0.125
Best model updated

Epoch: 2, LR: [2]
Train loss: 0.3167587152549199, Train accuracy: 0.15290178571428573
Validation loss: 0.32000892361005145, Validation accuracy: 0.1328125
Best model updated

Epoch: 3, LR: [2]
Train loss: 0.31612454652786254, Train accuracy: 0.15803571428571428
Validation loss: 0.31432344516118366, Validation accuracy: 0.1484375
Best model updated

Epoch: 4, LR: [2]
Train loss: 0.3119645374161856, Train accuracy: 0.17254464285714285
Validation loss: 0.3255065679550171, Validation accuracy: 0.15625
Best model updated

Epoch: 5, LR: [2]
Train loss: 0.3040930552142007, Train accuracy: 0.19620535714285714
Validation loss: 0.3026558856169383, Validation accuracy: 0.234375
Best model updated

Epoch: 6, LR: [2]
Train loss: 0.29576962079320634, Train accuracy: 0.22834821428571428
Validation loss: 0.31097

KeyboardInterrupt: ignored