<a href="https://colab.research.google.com/github/luciainnocenti/IncrementalLearning/blob/newICaRL/ICaRLMain.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Import GitHub repository

In [0]:
import os
import logging
import sys

In [2]:
if not os.path.isdir('./DatasetCIFAR'):
  !git clone -b newICaRL https://github.com/luciainnocenti/IncrementalLearning.git
  !mv 'IncrementalLearning' 'DatasetCIFAR'

Cloning into 'IncrementalLearning'...
remote: Enumerating objects: 89, done.[K
remote: Counting objects: 100% (89/89), done.[K
remote: Compressing objects: 100% (89/89), done.[K
remote: Total 1363 (delta 56), reused 0 (delta 0), pack-reused 1274[K
Receiving objects: 100% (1363/1363), 759.88 KiB | 17.27 MiB/s, done.
Resolving deltas: 100% (870/870), done.


# Import packages

In [0]:
from DatasetCIFAR.data_set import Dataset 
from DatasetCIFAR import ResNet
from DatasetCIFAR import utils
from DatasetCIFAR import params
from DatasetCIFAR import ICaRLModel
from torchvision import models
import torch.nn as nn
import torch
import torch.optim as optim
import torchvision
import numpy as np

from torchvision import transforms
from torch.utils.data import Subset, DataLoader
from torch.nn import functional as F
import random
random.seed(params.SEED)

# Define Datasets

In [0]:
train_transformer = transforms.Compose([transforms.RandomCrop(size = 32, padding=4),
                                         transforms.RandomHorizontalFlip(),
                                         transforms.ToTensor(),
                                         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

test_transformer = transforms.Compose([transforms.ToTensor(),
                                         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

In [5]:
trainDS = Dataset(train=True, transform = train_transformer)
testDS = Dataset(train=False, transform = test_transformer)

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to data/cifar-100-python.tar.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting data/cifar-100-python.tar.gz to data
Files already downloaded and verified


In [0]:
train_splits = trainDS.splits
test_splits = testDS.splits

# Define Network

In [0]:
ICaRL = ResNet.resnet32(num_classes=100)
ICaRL =  ICaRL.to(params.DEVICE)

In [0]:
exemplars = [None]*100

test_indexes =  []

In [0]:
for task in range(0, 100, params.TASK_SIZE):
  train_indexes = trainDS.__getIndexesGroups__(task)
  test_indexes = test_indexes + testDS.__getIndexesGroups__(task)

  train_dataset = Subset(trainDS, train_indexes)
  test_dataset = Subset(testDS, test_indexes)

  train_loader = DataLoader( train_dataset, num_workers=params.NUM_WORKERS, batch_size=params.BATCH_SIZE)
  test_loader = DataLoader( test_dataset, num_workers=params.NUM_WORKERS, batch_size=params.BATCH_SIZE )

  ICaRL, exemplars = ICaRLModel.incrementalTrain(task, trainDS, ICaRL, exemplars)

  col = []
  for i,x in enumerate( train_splits[ :int(task/10) + 1]) : 
    v = np.array(x)
    col = np.concatenate( (col,v), axis = None)
    col = col.astype(int)

  total = 0.0
  running_corrects = 0.0
  for img, lbl, _ in train_loader:
      img = img.float().to(params.DEVICE)
      preds = ICaRLModel.classify(img, exemplars, ICaRL, task)
      preds = preds.to(params.DEVICE)
      labels = utils.mapFunction(lbl, col).to(params.DEVICE)
      #print("preds: ", preds.data)
      #print("mapped labels: ", labels)
      #print("labels: ", lbl)
      total += len(lbl)
      running_corrects += torch.sum(preds == labels.data).data.item()
      #print(running_corrects)
  accuracy = float(running_corrects/total)
  print(f'task: {task}', f'train accuracy = {accuracy}')

  total = 0.0
  running_corrects = 0.0
  for img, lbl, _ in test_loader:
      img = img.float().to(params.DEVICE)
      preds = ICaRLModel.classify(img, exemplars, ICaRL, task)
      preds = preds.to(params.DEVICE)
      labels = utils.mapFunction(lbl, col).to(params.DEVICE)
      #print("preds: ", preds.data)
      #print("mapped labels: ", labels)
      #print("labels: ", lbl)
      total += len(lbl)
      running_corrects += torch.sum(preds == labels.data).data.item()
      #print(running_corrects)
  accuracy = float(running_corrects/total)
  print(f'task: {task}', f'test accuracy = {accuracy}')

col =  [94 63 74 21 35 56 91 96 87 48]
col[:10] [94 63 74 21 35 56 91 96 87 48]
At step  0  and at epoch =  0  the loss is =  0.03136654198169708  and accuracy is =  0.1816
At step  0  and at epoch =  1  the loss is =  0.027404727414250374  and accuracy is =  0.387
At step  0  and at epoch =  2  the loss is =  0.015392475761473179  and accuracy is =  0.5378
At step  0  and at epoch =  3  the loss is =  0.016766458749771118  and accuracy is =  0.6192
At step  0  and at epoch =  4  the loss is =  0.02472279965877533  and accuracy is =  0.6684
At step  0  and at epoch =  5  the loss is =  0.010973197408020496  and accuracy is =  0.7034
At step  0  and at epoch =  6  the loss is =  0.012941603548824787  and accuracy is =  0.7548
At step  0  and at epoch =  7  the loss is =  0.013217004016041756  and accuracy is =  0.7886
At step  0  and at epoch =  8  the loss is =  0.01770365983247757  and accuracy is =  0.8026
At step  0  and at epoch =  9  the loss is =  0.014674614183604717  and accura

In [0]:
col = np.concatenate( (train_splits[0],train_splits[1]) )
print(col)

In [0]:
col[10:]