<a href="https://colab.research.google.com/github/gianluigilopardo/Open-World-Recognition/blob/main/baseline/ICARL_Main.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Import GitHub repository

In [1]:
import os
import logging
import sys

In [2]:
import torch
torch.cuda.is_available()

True

In [3]:
if not os.path.isdir('./owr'):
  !git clone https://github.com/gianluigilopardo/Open-World-Recognition.git
  !mv 'Open-World-Recognition' 'owr'

Cloning into 'Open-World-Recognition'...
remote: Enumerating objects: 443, done.[K
remote: Counting objects: 100% (178/178), done.[K
remote: Compressing objects: 100% (73/73), done.[K
remote: Total 443 (delta 118), reused 143 (delta 103), pack-reused 265[K
Receiving objects: 100% (443/443), 2.06 MiB | 5.54 MiB/s, done.
Resolving deltas: 100% (269/269), done.


# Import packages

In [5]:
from owr.baseline import ResNet
from owr.baseline.dataset import Subset
from owr.baseline.icarl import classify
from owr.baseline.icarl import incremental_train
from owr.baseline.icarl import update_representation
from owr.baseline.icarl import construct_exemplar_set
from owr.baseline.icarl import reduce_exemplars
from owr.baseline.icarl import generate_new_exemplars
from owr.baseline.models import compute_loss
from owr.baseline.models import train_network
from owr.baseline import params
from owr.baseline import utils
from owr.baseline.utils import get_classes_names
from owr.baseline.utils import get_task_indexes
from owr.baseline.utils import splitter
from owr.baseline.utils import map_splits
from owr.baseline.utils import get_classes
from owr.baseline.utils import get_indexes
import owr.baseline.models

In [None]:
from torchvision import models
import torch.nn as nn
import torch
import torch.optim as optim
import torchvision
import numpy as np
import seaborn as sn
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import GridSearchCV
from sklearn.neighbors import KNeighborsClassifier
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.nn import functional as F
import random

In [None]:
print(params.SEED)
print(params.NUM_WORKERS)

In [None]:
print(params.NUM_EPOCHS)

# Define Datasets

In [None]:
train_transformer = transforms.Compose([transforms.RandomCrop(size = 32, padding=4),
                                         transforms.RandomHorizontalFlip(),
                                         transforms.ToTensor(),
                                         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

test_transformer = transforms.Compose([transforms.ToTensor(),
                                         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

In [None]:
from torchvision import datasets
trainDS = datasets.cifar.CIFAR100( 'data', train=True, download=True, transform=train_transformer)
testDS = datasets.cifar.CIFAR100( 'data', train=False, download=True, transform=test_transformer)

In [None]:
splits = splitter()

# Define Network

In [None]:
ICaRL = ResNet.resnet32(num_classes=100)
ICaRL =  ICaRL.to(params.DEVICE)

In [None]:
exemplars = [None]*params.NUM_CLASSES

test_indexes =  []
accs_train = []
accs_test = []
classifier = 'nme'

In [None]:
#classifier = KNeighborsClassifier()
#clf_params = {'n_neighbors' : np.arange(3,13,2)}

In [None]:
for task in range(0, params.NUM_TASKS*params.TASK_SIZE, params.TASK_SIZE):
    train_indexes = get_task_indexes(trainDS, task)
    test_indexes = test_indexes + get_task_indexes(testDS, task)

    train_dataset = Subset(trainDS, train_indexes, transform = train_transformer)
    test_dataset = Subset(testDS, test_indexes, transform = test_transformer)

    train_loader = DataLoader( train_dataset, num_workers=2, batch_size=params.BATCH_SIZE, shuffle=True)
    test_loader = DataLoader( test_dataset, num_workers=2, batch_size=params.BATCH_SIZE , shuffle=True )
    print(task)
    ICaRL, exemplars = incremental_train(trainDS, ICaRL, exemplars, task, train_transformer)
    col = []
    for i,x in enumerate( splits[ :int(task/10) + 1]) : 
        v = np.array(x)
        col = np.concatenate( (col,v), axis = None)
        col = col.astype(int)
    mean = None
    total = 0.0
    running_corrects = 0.0
    
    for img, lbl, _ in train_loader:
        img = img.float().to(params.DEVICE)        
        
        preds, mean = classify(img, exemplars, ICaRL, task, trainDS, mean)
        preds = preds.to(params.DEVICE)
        if classifier == 'nme':
            labels = map_splits(lbl, col).to(params.DEVICE)
        else:
            labels = lbl.to(params.DEVICE)
        total += len(lbl)
        running_corrects += torch.sum(preds == labels.data).data.item()
        '''
        #use the fc layer:
        outputs = ICaRL(img)
        _, preds = torch.max(outputs.data, 1)
        running_corrects += torch.sum(preds.to('cpu') == lbl.data).data.item()
        total += len(lbl)
        '''
    accuracy = float(running_corrects/total)
    print(f'task: {task}', f'train accuracy = {accuracy}')
    accs_train.append(accuracy)

    total = 0.0
    running_corrects = 0.0
    tot_preds = []
    tot_lab = []
    for img, lbl, _ in test_loader:
        img = img.float().to(params.DEVICE)

        preds, _ = classify(img, exemplars, ICaRL, task, trainDS, mean)
        preds = preds.to(params.DEVICE)
        if classifier == 'nme':
            labels = map_splits(lbl, col).to(params.DEVICE)
        else:
            labels = lbl.to(params.DEVICE)
        tot_preds = np.concatenate( ( tot_preds, preds.data.cpu().numpy() ) )
        tot_lab = np.concatenate( (tot_lab, labels.data.cpu().numpy()  ) )
        total += len(lbl)
        running_corrects += torch.sum(preds == labels.data).data.item()
        '''
        #use the fc layer:
        outputs = ICaRL(img)
        _, preds = torch.max(outputs.data, 1)
        running_corrects += torch.sum(preds.to('cpu') == lbl.data).data.item()
        total += len(lbl)
        '''        
    accuracy = float(running_corrects/total)
    print(f'task: {task}', f'test accuracy = {accuracy}')
    accs_test.append(accuracy)