<a href="https://colab.research.google.com/github/gianluigilopardo/Open-World-Recognition/blob/main/ablation_study/ICARL_Main.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Import GitHub repository

In [8]:
import os
import logging
import sys

In [9]:
import torch
torch.cuda.is_available()

True

In [10]:
if not os.path.isdir('./owr'):
  !git clone https://github.com/gianluigilopardo/Open-World-Recognition.git
  !mv 'Open-World-Recognition' 'owr'

# Import packages

In [11]:
from owr.ablation_study import ResNet
from owr.ablation_study.dataset import Subset
from owr.ablation_study.icarl import classify
from owr.ablation_study.icarl import incremental_train
from owr.ablation_study.icarl import update_representation
from owr.ablation_study.icarl import construct_exemplar_set
from owr.ablation_study.icarl import reduce_exemplars
from owr.ablation_study.icarl import generate_new_exemplars
from owr.ablation_study.models import compute_loss
from owr.ablation_study.models import train_network
from owr.ablation_study import params
from owr.ablation_study import utils
from owr.ablation_study.utils import get_classes_names
from owr.ablation_study.utils import get_task_indexes
from owr.ablation_study.utils import splitter
from owr.ablation_study.utils import map_splits
from owr.ablation_study.utils import get_classes
from owr.ablation_study.utils import get_indexes
import owr.ablation_study.models

In [12]:
from torchvision import models
import torch.nn as nn
import torch
import torch.optim as optim
import torchvision
import numpy as np
import seaborn as sn
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import GridSearchCV
from sklearn.neighbors import KNeighborsClassifier
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.nn import functional as F
import random

In [13]:
print(params.SEED)
print(params.NUM_WORKERS)

42
2


In [14]:
print(params.NUM_EPOCHS)

70


# Define Datasets

In [15]:
train_transformer = transforms.Compose([transforms.RandomCrop(size = 32, padding=4),
                                         transforms.RandomHorizontalFlip(),
                                         transforms.ToTensor(),
                                         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

test_transformer = transforms.Compose([transforms.ToTensor(),
                                         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

In [16]:
from torchvision import datasets
trainDS = datasets.cifar.CIFAR100( 'data', train=True, download=True, transform=train_transformer)
testDS = datasets.cifar.CIFAR100( 'data', train=False, download=True, transform=test_transformer)

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to data/cifar-100-python.tar.gz


HBox(children=(FloatProgress(value=0.0, max=169001437.0), HTML(value='')))


Extracting data/cifar-100-python.tar.gz to data
Files already downloaded and verified


In [17]:
splits = splitter()

# Define Network

In [18]:
ICaRL = ResNet.resnet32(num_classes=100)
ICaRL =  ICaRL.to(params.DEVICE)

In [19]:
exemplars = [None]*params.NUM_CLASSES

test_indexes =  []
accs_train = []
accs_test = []
classifier = 'nme'

In [20]:
#classifier = KNeighborsClassifier()
#clf_params = {'n_neighbors' : np.arange(3,13,2)}

In [None]:
for task in range(0, params.NUM_TASKS*params.TASK_SIZE, params.TASK_SIZE):
    train_indexes = get_task_indexes(trainDS, task)
    test_indexes = test_indexes + get_task_indexes(testDS, task)

    train_dataset = Subset(trainDS, train_indexes, transform = train_transformer)
    test_dataset = Subset(testDS, test_indexes, transform = test_transformer)

    train_loader = DataLoader( train_dataset, num_workers=2, batch_size=params.BATCH_SIZE, shuffle=True)
    test_loader = DataLoader( test_dataset, num_workers=2, batch_size=params.BATCH_SIZE , shuffle=True )
    print(task)
    ICaRL, exemplars = incremental_train(trainDS, ICaRL, exemplars, task, train_transformer)
    col = []
    for i,x in enumerate( splits[ :int(task/10) + 1]) : 
        v = np.array(x)
        col = np.concatenate( (col,v), axis = None)
        col = col.astype(int)
    mean = None
    total = 0.0
    running_corrects = 0.0
    
    for img, lbl, _ in train_loader:
        img = img.float().to(params.DEVICE)        
        
        preds, mean = classify(img, exemplars, ICaRL, task, trainDS, mean)
        preds = preds.to(params.DEVICE)
        if classifier == 'nme':
            labels = map_splits(lbl, col).to(params.DEVICE)
        else:
            labels = lbl.to(params.DEVICE)
        total += len(lbl)
        running_corrects += torch.sum(preds == labels.data).data.item()
        '''
        #use the fc layer:
        outputs = ICaRL(img)
        _, preds = torch.max(outputs.data, 1)
        running_corrects += torch.sum(preds.to('cpu') == lbl.data).data.item()
        total += len(lbl)
        '''
    accuracy = float(running_corrects/total)
    print(f'task: {task}', f'train accuracy = {accuracy}')
    accs_train.append(accuracy)

    total = 0.0
    running_corrects = 0.0
    tot_preds = []
    tot_lab = []
    for img, lbl, _ in test_loader:
        img = img.float().to(params.DEVICE)

        preds, _ = classify(img, exemplars, ICaRL, task, trainDS, mean)
        preds = preds.to(params.DEVICE)
        if classifier == 'nme':
            labels = map_splits(lbl, col).to(params.DEVICE)
        else:
            labels = lbl.to(params.DEVICE)
        tot_preds = np.concatenate( ( tot_preds, preds.data.cpu().numpy() ) )
        tot_lab = np.concatenate( (tot_lab, labels.data.cpu().numpy()  ) )
        total += len(lbl)
        running_corrects += torch.sum(preds == labels.data).data.item()
        '''
        #use the fc layer:
        outputs = ICaRL(img)
        _, preds = torch.max(outputs.data, 1)
        running_corrects += torch.sum(preds.to('cpu') == lbl.data).data.item()
        total += len(lbl)
        '''        
    accuracy = float(running_corrects/total)
    print(f'task: {task}', f'test accuracy = {accuracy}')
    accs_test.append(accuracy)

0
Step: 0, Epoch: 0, Loss: 0.03613166883587837, Accuracy: 0.2172
Step: 0, Epoch: 1, Loss: 0.026905639097094536, Accuracy: 0.4188
Step: 0, Epoch: 2, Loss: 0.024034203961491585, Accuracy: 0.5212
Step: 0, Epoch: 3, Loss: 0.020660076290369034, Accuracy: 0.5336
Step: 0, Epoch: 4, Loss: 0.02451731078326702, Accuracy: 0.593
Step: 0, Epoch: 5, Loss: 0.026991836726665497, Accuracy: 0.6308
Step: 0, Epoch: 6, Loss: 0.01978614740073681, Accuracy: 0.6138
Step: 0, Epoch: 7, Loss: 0.02819369174540043, Accuracy: 0.6452
Step: 0, Epoch: 8, Loss: 0.017216330394148827, Accuracy: 0.654
Step: 0, Epoch: 9, Loss: 0.01488550379872322, Accuracy: 0.697
Step: 0, Epoch: 10, Loss: 0.017645662650465965, Accuracy: 0.7152
Step: 0, Epoch: 11, Loss: 0.014202043414115906, Accuracy: 0.7252
