In [None]:
import torchvision
import torch
import os
import random
import copy
from torch.backends import cudnn

In [None]:
transform = torchvision.transforms.Compose([torchvision.transforms.RandomCrop(32,padding=4),
                                          torchvision.transforms.RandomHorizontalFlip(0.5),
                                          torchvision.transforms.ToTensor()])

In [None]:
data_train = torchvision.datasets.CIFAR100(root='./ICARL/data', download = True, transform = transform, target_transform = None)
data_test = torchvision.datasets.CIFAR100(root='./ICARL/data',train = False,download = True, transform = torchvision.transforms.ToTensor(), target_transform = None)

In [None]:
import torch.nn as nn
import math
import torch.utils.model_zoo as model_zoo
import torch.nn.functional as F

"""
Credits to @hshustc
Taken from https://github.com/hshustc/CVPR19_Incremental_Learning/tree/master/cifar100-class-incremental
"""


def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class ResNet(nn.Module):

    def __init__(self, block, layers, num_classes=10):
        self.inplanes = 16
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1,
                               bias=False) 
        self.bn1 = nn.BatchNorm2d(16)
        self.relu = nn.ReLU(inplace=True)
        self.layer1 = self._make_layer(block, 16, layers[0])
        self.layer2 = self._make_layer(block, 32, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 64, layers[2], stride=2)
        self.avgpool = nn.AvgPool2d(8, stride=1)
        ###### self.fc = nn.Linear(64 * block.expansion, num_classes)
        self.fc = nn.ModuleList([nn.Linear(64,10)])

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    
    def updatemodel(self):
        self.fc.append(nn.Linear(64,10).cuda())
        
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        t = x
        x = self.fc[0](x)
        for i in range(1,len(self.fc)):
            x = torch.cat((x,self.fc[i](t)),1)
        return x

def resnet32(pretrained=False, **kwargs):
    n = 5
    model = ResNet(BasicBlock, [n, n, n], **kwargs)
    return model

In [None]:
def obtain_list_of_subset(seed,dataset):
    l = list(range(0,100))
    random.Random(seed).shuffle(l)
    dc = {tuple(l[i:i+10]) : [] for i in range(0,100,10)}
    for i,t in enumerate(dataset.targets):
        idx = [j for j in range(10) if  t in list(dc.keys())[j]]
        dc[list(dc.keys())[idx[0]]].append(i)
    return list(dc.values())

In [None]:
### this function create a mapping beetwen the real label of the class and their adaptation based 
### on the batch their are introduced
def new_label2(label, index, target_t_old = None):
    if target_t_old is None:
        if  not isinstance(label,list):
            label = label.tolist()
        target_t = tuple(set(label))
        return target_t, [(target_t.index(el) + index) for el in label]
    else:  
        return target_t_old, [(target_t_old.index(el) + index) for el in label]

In [None]:
## this function create a tensor that is used in order to correctly compute the loss
def crea_target_finetuning(BatchDati,Target,C):
    N = BatchDati.size(0)
    C = C
    NewTarget = torch.zeros(N,C)
    for i in range(N):
        NewTarget[i,Target[i]] = 1
    return NewTarget

In [None]:
loss_function = torch.nn.BCEWithLogitsLoss()

In [None]:
### return the value of the accuracy over the test dataset
def accuracy(test_data, model, label):
    classi_finali = []
    dl1 = torch.utils.data.DataLoader(test_data, batch_size = 128)
    for input_data, _  in dl1:
      input_data = input_data.to(DEVICE)
      output = model(input_data)
      del input_data
      torch.cuda.empty_cache()
      _, classes = torch.max(output,1)
      del output
      torch.cuda.empty_cache()
      classes = classes.tolist()
      classi_finali = classi_finali + classes
    corretti = 0
    for i in range(len(classi_finali)):
        if classi_finali[i] == label[i]:
            corretti = corretti + 1

    return corretti/len(classi_finali)

In [None]:
model = None
model_distillation = None

In [None]:
if model is not None:
  del model
  torch.cuda.empty_cache()

from torch.backends import cudnn
DEVICE = 'cuda'

model = resnet32()
model = model.to(DEVICE)
cudnn.benchmark

list_of_image = obtain_list_of_subset(1992,data_train)
list_of_image_test = obtain_list_of_subset(1992,data_test)
test = None
test_label = []
accuracy_test = []
for i in range(10):
    data = torch.utils.data.Subset(data_train,list_of_image[i])
    if i == 0:
      optimizer = torch.optim.SGD(model.parameters(),lr=2)
    else:
      optimizer = torch.optim.SGD(model.fc[i].parameters(),lr=2)

    scheduler =torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones =  [49,63], gamma=0.2, last_epoch=-1)
    
    dl = torch.utils.data.DataLoader(data,batch_size=128, num_workers = 4)
        
    LR = 2
    EPOCHS = 70
    target_t_old = None
    model.train()
    for epoch in range(EPOCHS):
      counter = 0
      for   images, label in (dl):
            
            images = images.to(DEVICE)
            
            labels = label.to(DEVICE)
            
            output = model(images)
            output = output.to(DEVICE)
            target_t_old, new_target = new_label2(label, i*10, target_t_old)
            target_loss = crea_target_finetuning(images, new_target,i*10 + 10)
            target_loss = target_loss.to(DEVICE)
            loss_value = loss_function(output,target_loss)
            if counter % 39 == 0 and counter != 0:
              print("epoch:",epoch,"loss",loss_value,"lr",scheduler.get_last_lr())
            counter = counter + 1
            loss_value.backward()
            optimizer.step()
            optimizer.zero_grad()
            del output
            del images
            del label
            torch.cuda.empty_cache()
      scheduler.step()
    
    print("entro in modalità test")
    model.eval()
    if test is None:
        test = torch.utils.data.Subset(data_test,list_of_image_test[i])
        target = [label for _, label in test]
        _, label = new_label2(target, i*10, target_t_old)
        test_label = test_label + label
        accuracy_test.append(accuracy(test, model, test_label))
        
    else:
        
        subset = torch.utils.data.Subset(data_test,list_of_image_test[i])
        test = torch.utils.data.ConcatDataset([test,subset])
        target = [label for _, label in subset]
        _, label = new_label2(target, i*10, target_t_old)
        test_label = test_label + label
        accuracy_test.append(accuracy(test, model, test_label))
    print(target_loss.size())
    print(loss_value, epoch)
    del target_loss
    torch.cuda.empty_cache()
    print("accuracy:",accuracy_test[i])    
    model.updatemodel()
    print("model update")