<a href="https://colab.research.google.com/github/gianluigilopardo/Open-World-Recognition/blob/main/baseline/mainLWF.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import logging
import sys

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets
import pandas as pd
import torch.nn as nn
from sklearn.metrics import confusion_matrix
import seaborn as sn


if not os.path.isdir('./owr'):
  !git clone https://github.com/gianluigilopardo/Open-World-Recognition.git
  !mv 'Open-World-Recognition' 'owr'

from owr.baseline import ResNet
from owr.baseline import params
from owr.baseline import utils
from owr.baseline import models
from owr.baseline.dataset import *
from owr.baseline import icarl

#PYCHARM:

# import ResNet
# import params
# import utils
# import dataset

#Here we address the catastrophic forgetting by adding distillation loss during the training that aims to preserve the knowledge on old classes as implemented in iCaRL paper

# preprocessing
train_transformer = transforms.Compose([transforms.RandomCrop(size=32, padding=4),
                                        transforms.RandomHorizontalFlip(),
                                        transforms.ToTensor(),
                                        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                        ])

test_transformer = transforms.Compose([transforms.ToTensor(),
                                       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                       ])

# dataset
cifar = datasets.cifar.CIFAR100
train_dataset = cifar('data', train=True, download=True, transform=train_transformer)
test_dataset = cifar('data', train=False, download=True, transform=test_transformer)

# splits
splits = utils.splitter()
print('splits: ' + str(splits))

# model: simplification: we initialize the network with all the classes
model = ResNet.resnet32(num_classes=params.NUM_CLASSES).to(params.DEVICE) #.to() Performs Tensor dtype and/or device conversion. 
optimizer = torch.optim.SGD(model.parameters(), lr=params.LR, momentum=params.MOMENTUM, 
                            weight_decay=params.WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, params.STEP_SIZE,gamma=params.GAMMA) #change learning rate at epoch 49 and 63 with a factor gamma


# Run
test_indexes=[]
metrics = [None] * params.NUM_TASKS  # se non inizializzo con None da errore

exemplars = [None] * params.NUM_CLASSES #not used in LWF

test_indexes = []
accs = []

for task in range(0, params.NUM_CLASSES, params.TASK_SIZE): #incremental train
  
    train_indexes = utils.get_task_indexes(train_dataset, task) #store the indexes of training images of classes that belong to the current task
    test_indexes = test_indexes + utils.get_task_indexes(test_dataset, task) #same reasoning for test data but in this case we accumulate the indexes since we test on all classes seen so far

    train_subset = Subset(train_dataset, train_indexes, transform=train_transformer) #create incremental training and testing datasets 
    test_subset = Subset(test_dataset, test_indexes, transform=test_transformer)

    train_loader = DataLoader(train_subset, num_workers=params.NUM_WORKERS,
                          batch_size=params.BATCH_SIZE, shuffle=True)
    test_loader = DataLoader(test_subset, num_workers=params.NUM_WORKERS,
                         batch_size=params.BATCH_SIZE, shuffle=True)

    if (task == 0):
      torch.save(model, 'resNet_task{0}.pt'.format(task)) #in the first task we store the model as Resnet_task{0}

    models.trainLWF(task, train_loader, splits)
    loss_accuracy = models.testLWF(task, test_loader, splits)
    metrics[int(task / 10)] = loss_accuracy  # pars_task[i] = (accuracy, loss) at i-th task

Cloning into 'Open-World-Recognition'...
remote: Enumerating objects: 507, done.[K
remote: Counting objects: 100% (250/250), done.[K
remote: Compressing objects: 100% (134/134), done.[K
remote: Total 507 (delta 153), reused 145 (delta 107), pack-reused 257[K
Receiving objects: 100% (507/507), 2.09 MiB | 29.70 MiB/s, done.
Resolving deltas: 100% (296/296), done.
Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to data/cifar-100-python.tar.gz


HBox(children=(FloatProgress(value=0.0, max=169001437.0), HTML(value='')))


Extracting data/cifar-100-python.tar.gz to data
Files already downloaded and verified
splits: [[81, 14, 3, 94, 35, 31, 28, 17, 13, 86], [90, 18, 4, 42, 38, 34, 21, 16, 96, 76], [22, 5, 49, 45, 41, 25, 20, 85, 15, 68], [27, 6, 57, 53, 50, 32, 26, 65, 70, 82], [72, 11, 1, 80, 39, 36, 33, 12, 95, 10], [84, 24, 2, 51, 47, 46, 29, 23, 74, 19], [43, 7, 61, 59, 58, 44, 40, 37, 77, 98], [79, 30, 0, 88, 56, 55, 89, 48, 97, 73], [54, 8, 66, 64, 91, 52, 71, 9, 69, 92], [67, 99, 83, 63, 60, 87, 62, 75, 78, 93]]
task = 0 
train col =  [81 14  3 94 35 31 28 17 13 86]
train col =  [[81 14  3 94 35 31 28 17 13 86]]




At step  0  and at epoch =  0  the loss is =  0.03742843493819237  and accuracy is =  0.0998
