In [None]:
import os
import os.path
import logging

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Subset, DataLoader, random_split, ConcatDataset
from torch.backends import cudnn

from torchvision.datasets import CIFAR100
from torchvision import transforms
from torchsummary import summary

from PIL import Image
from tqdm import tqdm

import copy
import random
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.neighbors import KNeighborsClassifier

In [None]:
DEVICE = 'cuda' 

NUM_CLASSES = 100  

BATCH_SIZE = 128     
                     
K = 2000              
LR = 2.0            
LR_L2 = 0.25
LR_CE = 0.1
LR_LFC = 0.1
LR_L1 = 0.025        
MOMENTUM = 0.9       
WEIGHT_DECAY = 1e-5  
WEIGHT_DECAY_CE = 5e-4

NUM_EPOCHS = 60      
MILESTONES = [49, 63]       
MILESTONES_L2 = [30, 45]
MILESTONES_CE = [50]
MILESTONES_LFC = [40]  
GAMMA = 0.2          
GAMMA_CE = 0.2
GAMMA_LFC = 0.2

LOG_FREQUENCY = 10
SEED = 1994

In [None]:

train_transform = transforms.Compose([transforms.RandomCrop(32, padding=4),
                                      transforms.RandomHorizontalFlip(0.5),
                                      transforms.ToTensor(),
                                      transforms.Normalize(mean=[0.5071, 0.4865, 0.4409],
                                                           std=[0.2673, 0.2564, 0.2761])])

test_transform = transforms.Compose([transforms.ToTensor(),
                                     transforms.Normalize(mean=[0.5071, 0.4865, 0.4409],
                                                          std=[0.2673, 0.2564, 0.2761])])


In [None]:
#@title
# Clone github repository with data
if not os.path.isdir('./ICARL'):
  !git clone https://github.com/lbonasera1/ICARL.git

from ICARL.ResNet_CIFAR_100.resnet_cifar import resnet32
from ICARL.ResNet_CIFAR_100.resnet_cifar_norelu import resnet32_norelu

train_dataset = CIFAR100(root='./ICARL/data', train=True, download=True, transform=train_transform)
test_dataset = CIFAR100(root='./ICARL/data', train=False, download=True, transform=test_transform)

train_subsets = []
test_subsets = []
batch_classes = []
class_indexes = [None] * NUM_CLASSES  
random.seed(SEED)
L = list(range(100))
random.shuffle(L)

for i in range(10):
  # random extract 10 classes from 100
  batch_classes.append([L.pop() for _ in range(10)])

  # search and collect train and exemplars indices for i-batch
  train_indices = []
  for target in batch_classes[i]:
    tmp = []
    for idx, val in enumerate(train_dataset.targets):
      if val == target:
        train_indices.append(idx)
        tmp.append(idx)
    random.shuffle(tmp)
    class_indexes[target] = tmp

  random.shuffle(train_indices)
  # create subset from indices
  subset = Subset(train_dataset, train_indices)
  train_subsets.append(subset)

  # search and collect train indices for i-batch
  test_indices = []
  for target in batch_classes[i]:
    for idx, val in enumerate(test_dataset.targets):
      if target == val:
         test_indices.append(idx)

  random.shuffle(test_indices)
  # create subset from indices
  subset = Subset(test_dataset, test_indices)
  test_subsets.append(subset)

In [None]:
### function used for reload the model trained over the previous batch of classe,
### in order to calculate the distillation loss
def load_checkpoint(filepath):
  model = torch.load(filepath)
  for parameter in model.parameters():
      parameter.requires_grad = False
  model.eval()
  return model

In [None]:
def constrExemplars(exemplars, classes, class_indexes, model, m):
  pdist = nn.PairwiseDistance(p=2)
  model.train(False)
  model.set_flag(False)
  class_means = torch.empty((0, 64)).cuda()
  with torch.no_grad():
    # compute mean for each class
    for c in classes:
      indexes = copy.deepcopy(class_indexes[c])
      features = torch.empty((0, 64)).cuda()
      # image set of class c
      subset = Subset(train_dataset, indexes)
      dataLoader = DataLoader(subset, batch_size=BATCH_SIZE)
      for image, label in dataLoader:
        image = image.to(DEVICE)
        # extract features
        output = model(image)
        # L2 normalization of feature vector
        output = nn.functional.normalize(output, p=2, dim=1)
        features = torch.cat((features, output))
      
      class_mean = torch.mean(features, 0)
      class_mean = nn.functional.normalize(class_mean, p=2, dim=0)
      class_mean = class_mean.view(-1, 64)
      class_means = torch.cat((class_means, class_mean))
      current_features = torch.empty((0, 64)).cuda()
      exemplars_indexes = []
      for k in range(m):
        current_sum = torch.sum(current_features, 0)
        current_sum = torch.add(features, current_sum.repeat(features.size(0), 1))
        current_mean = current_sum * (1.0/(k+1))
        current_mean = nn.functional.normalize(current_mean, p=2, dim=1)
        distances = pdist(current_mean, class_mean)
        index = torch.argmin(distances).item()   
        phi = features[index].view(-1, 64)
        # collecting chosen features
        current_features = torch.cat((current_features, phi))
        # removing chosen features
        features = torch.cat((features[:index], features[index+1:]))
        exemplars_indexes.append(indexes.pop(index))  
      exemplars[c] = exemplars_indexes
  model.set_flag(True)
  model.train()
  return class_means

In [None]:
def reduceExemplars(exemplars, classes, m):
  for c in classes:
    exemplars[c] = exemplars[c][:m]

In [None]:
def distillationLossCE(outputs, outputs_old, labels_old):
  loss = torch.empty([0, 1]).cuda()
  weights = torch.nn.functional.softmax(outputs_old, dim=1)
  logs = torch.nn.functional.log_softmax(outputs, dim=1)
  results = torch.mul(weights, -logs)
  mean = torch.mean(results, dim=0)
  for k in labels_old:
    loss = torch.cat((loss, mean[k].view(-1, 1)))
  loss = torch.sum(loss)
  return loss

In [None]:
### return a Knn classifier trained over the exemplars
def create_classifier(labels_old,labels_new,exemplars,net, num):
    KNN = KNeighborsClassifier(n_neighbors = num)
    classi = labels_old + labels_new
    exemplars_index = [exemplars[i] for i in classi]
    exemplars_index = [item for sublist in exemplars_index for item in sublist]
    subsetforclassifier = Subset(train_dataset, exemplars_index)
    net.set_flag(False)
    net.eval()
    dlfc = DataLoader(subsetforclassifier, batch_size = 128)
    X_train = torch.empty((0,64))
    Y_train = torch.empty((0),dtype = torch.long)
    for images, labels in dlfc:
        images = images.to(DEVICE)
        out = net(images)
        out = out.to('cpu')
        X_train = torch.cat((X_train,out))
        Y_train = torch.cat((Y_train, labels))
    X_train = X_train.detach().numpy()
    Y_train = Y_train.detach().numpy()
    print("X_train:",X_train.shape)
    print("Y_train:",Y_train.shape)
    
    KNN.fit(X_train,Y_train)
    del images
    del out
    del labels
    torch.cuda.empty_cache()
    net.set_flag(False)
    return KNN
    


In [None]:

def icarl_ablation_ce_ce(train_subsets, test_subsets, batch_classes, class_indexes, criterion_dist):
  net = resnet32()
  net = net.to(DEVICE)
  cudnn.benchmark
  batches_accuracy = []
  labels_old = []
  test_subList = []
  exemplars = [None] * NUM_CLASSES

  # iterate over class batches
  for i in range(10):
    train_clf_loss = []
    train_dist_loss = []
    criterion_clf = nn.CrossEntropyLoss()
    # concatenate test classes
    test_subList.append(test_subsets[i])
    test_subset = ConcatDataset(test_subList)
    # adding exemplars to train subset
    train_subset = train_subsets[i]
    if i > 0:
      # get old labels
      for j in batch_classes[i-1]:
        labels_old.append(j)
      train_subList = []
      for k in labels_old:
        train_subList = train_subList + exemplars[k]
      random.shuffle(train_subList)  
      subset = Subset(train_dataset, train_subList)
      train_subset = ConcatDataset([train_subset, subset])
    # initializate dataloader and variables
    train_dataloader = DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
    test_dataloader = DataLoader(test_subset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

    parameters_to_optimize = net.parameters()
    optimizer = optim.SGD(parameters_to_optimize, lr=LR_CE, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY_CE)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=MILESTONES_CE, gamma=GAMMA_CE)
      
    current_step = 0
    # Start iterating over the epochs
    for epoch in range(NUM_EPOCHS):
      print('Starting epoch {}/{}, LR = {}, Batch {}'.format(epoch+1, NUM_EPOCHS, scheduler.get_last_lr(), i+1))
      # Iterate over the train dataset
      tmp = []
      tmp_dist = []
      for images, labels in train_dataloader:
        # Bring data over the device of choice
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)

        net.train() 
        optimizer.zero_grad() 

        outputs = net(images)

        loss = criterion_clf(outputs, labels)
        tmp.append(loss.item())

        if i > 0:
          with torch.no_grad():
            # loading pre-update parameters for distillation
            prev_net = load_checkpoint('./ICARL/prev_net.pt')
            outputs_old = prev_net(images)
          # distillation loss
          loss_dist = criterion_dist(outputs, outputs_old, labels_old)
          # Log loss
          if current_step % LOG_FREQUENCY == 0:
            print('Step {}, Classification Loss {}, Distillation Loss {}'.format(current_step, loss.item(), loss_dist.item())) 
          tmp_dist.append(loss_dist.item())
          loss = loss + loss_dist 

        loss.backward()
        optimizer.step()

        current_step += 1

      train_clf_loss.append(np.mean(tmp))
      if i > 0:
        train_dist_loss.append(np.mean(tmp_dist))
      scheduler.step()

    # plot train loss
    fig, ax = plt.subplots(figsize=(8, 5))
    ax.plot(np.arange(0, NUM_EPOCHS), train_clf_loss, c='blue', linestyle='-', label='Classification loss')
    if i > 0:
      ax.plot(np.arange(0, NUM_EPOCHS), train_dist_loss, c='red', linestyle='-', label='Distillation loss')
    plt.title('Train loss')
    plt.grid()
    plt.legend()
    plt.tight_layout()
    plt.show()
    plt.close()  

    # reducing previous exemplars
    m = K // (10*(i+1))
    if i > 0:
      reduceExemplars(exemplars=exemplars, classes=labels_old, m=m) 

    # construct exemplars with current classes
    class_means = constrExemplars(exemplars, batch_classes[i], class_indexes, net, m)

    net.set_flag(False)
    net.train(False)
    
    classifier = create_classifier(labels_old,batch_classes[i],exemplars,net,m//2)

    # test
    running_corrects = 0
    for images, labels in tqdm(test_dataloader):
      with torch.no_grad():
        images = images.to(DEVICE)
        
        out = net(images)
        out = out.cpu().detach().numpy()
        pred = classifier.predict(out)
        for h in range(out.shape[0]):
            if labels[h].item() == pred[h]:
                running_corrects = running_corrects + 1
        
    # Calculate Accuracy
    score = running_corrects / float(len(test_subset))
    net.set_flag(True)
    # saving i-batch model parameters (distillation)
    torch.save(net, './ICARL/prev_net.pt')
    # accuracy of last epoch model
    print("Test accuracy of batch {} equal to: {}".format(i+1, score))
    batches_accuracy.append(score)
  
  # plot accuracy graph
  fig, ax = plt.subplots(figsize=(8, 5))
  ax.plot(np.arange(0, 100, 10), batches_accuracy, c='blue', linestyle='-', marker='.')
  plt.title('Accuracy graph vs Number of classes')
  plt.tight_layout()
  plt.grid()
  plt.show()
  plt.close()

  return batches_accuracy, classifier

In [None]:
scores, classificatore = icarl_ablation_ce_ce(train_subsets, test_subsets, batch_classes, class_indexes, distillationLossCE)