In [None]:
import torchvision
import torch
import os
import random
import numpy as np
from torch.backends import cudnn
import copy

In [None]:
transform = torchvision.transforms.Compose([torchvision.transforms.RandomCrop(32,padding=4),
                                          torchvision.transforms.RandomHorizontalFlip(0.5),
                                          torchvision.transforms.ToTensor()])

transform1 = torchvision.transforms.Compose([torchvision.transforms.RandomCrop(32, padding=4),
                                      torchvision.transforms.RandomHorizontalFlip(0.5),
                                      torchvision.transforms.ToTensor(),
                                      torchvision.transforms.Normalize(mean=[0.5071, 0.4865, 0.4409],
                                                           std=[0.2673, 0.2564, 0.2761])])

transform2 = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
                                     torchvision.transforms.Normalize(mean=[0.5071, 0.4865, 0.4409],
                                                          std=[0.2673, 0.2564, 0.2761])])

In [None]:
data_train = torchvision.datasets.CIFAR100(root='./ICARL/data', download = True, transform = transform1, target_transform = None)
data_test = torchvision.datasets.CIFAR100(root='./ICARL/data',train = False,download = True, transform = transform2, target_transform = None)

In [None]:
#@title
import torch.nn as nn
import math
import torch.utils.model_zoo as model_zoo
import torch.nn.functional as F

"""
Credits to @hshustc
Taken from https://github.com/hshustc/CVPR19_Incremental_Learning/tree/master/cifar100-class-incremental
"""


def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class ResNet(nn.Module):

    def __init__(self, block, layers, num_classes=10):
        self.flag = True
        self.inplanes = 16
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1,
                               bias=False) 
        self.bn1 = nn.BatchNorm2d(16)
        self.relu = nn.ReLU(inplace=True)
        self.layer1 = self._make_layer(block, 16, layers[0])
        self.layer2 = self._make_layer(block, 32, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 64, layers[2], stride=2)
        self.avgpool = nn.AvgPool2d(8, stride=1)
        ###### self.fc = nn.Linear(64 * block.expansion, num_classes)
        self.fc = nn.ModuleList([nn.Linear(64,10)])

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    
    def updatemodel(self):
        self.fc.append(nn.Linear(64,10).cuda())

    def set_flag(self,b):
        self.flag = b
      
              
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)

        if self.flag == False:
          return x
        else:
          t = x
          x = self.fc[0](x)
          for i in range(1,len(self.fc)):
            x = torch.cat((x,self.fc[i](t)),1)
        return x

def resnet32(pretrained=False, **kwargs):
    n = 5
    model = ResNet(BasicBlock, [n, n, n], **kwargs)
    return model

In [None]:
def obtain_list_of_subset(seed,dataset):
    l = list(range(0,100))
    random.Random(seed).shuffle(l)
    dc = {tuple(l[i:i+10]) : [] for i in range(0,100,10)}
    for i,t in enumerate(dataset.targets):
        
        idx = [j for j in range(10) if  t in list(dc.keys())[j]]
        dc[list(dc.keys())[idx[0]]].append(i)
    return list(dc.values())

In [None]:
def accuracy(test_data, model, label):
    classi_finali = []
    dl1 = torch.utils.data.DataLoader(test_data, batch_size = 128)
    for input_data, _  in dl1:
      input_data = input_data.to(DEVICE)
      output = model(input_data)
      del input_data
      torch.cuda.empty_cache()
      _, classes = torch.max(output,1)
      del output
      torch.cuda.empty_cache()
      classes = classes.tolist()
      classi_finali = classi_finali + classes
    corretti = 0
    for i in range(len(classi_finali)):
        if classi_finali[i] == label[i]:
            corretti = corretti + 1

    return corretti/len(classi_finali)

In [None]:
def new_label_exemplar(label, target_t_prec, target_t_old, index): ### label sono le etichette del codice
    new_target = []
    old_class_set = set([item for sublist in target_t_prec for item in sublist])

    if target_t_old is None:
        if  not isinstance(label,list):
            label = label.tolist()
        target_t_old = tuple(set(label)-old_class_set)
    
    for i,el in enumerate(label):
       
        if el in target_t_old: 
            
            new_target.append(target_t_old.index(el) + index*10)

        else: 
            
            for j,t in enumerate(target_t_prec):
                if el in t:
                    new_target.append(t.index(el) + (j*10))

    return target_t_old,new_target



In [None]:
def target_for_test(label, target_t_prec, target_t_old, index):
  new_label = []
  
  for i,el in enumerate(label):
    
       
    if el in target_t_old: 
            
      new_label.append(target_t_old.index(el) + index*10)
      
    else: 
      
      for j,t in enumerate(target_t_prec):
        if el in t:
          new_label.append(t.index(el) + (j*10))
          
  return new_label


In [None]:
def crea_target_classification(BatchDati,Target,i):
    N = BatchDati.size(0)
    C = i*10+10
    NewTarget = torch.zeros(N,C)
    for j in range(N):
        if Target[j] >= i*10 and Target[j] < (i+1)*10:  
            NewTarget[j,Target[j]] = 1 
    return NewTarget

In [None]:
def crea_target_distillation(output_distillation):
    output_distillation = torch.sigmoid(output_distillation)
    N = output_distillation.size(0)
    zero_tensor = torch.zeros(N,10).to(DEVICE)
    output_distillation = torch.cat((output_distillation, zero_tensor),1)
    return output_distillation

In [None]:
def crea_label_classification_distillation(label_distillation, label_classification):
    return label_distillation+label_classification

In [None]:
def take_index_for_class(data):
    target = [el[1] for el in data]
    target = tuple(set(target))
    
    dc = {k:[] for k in target}
    for i,el in enumerate(data):
        (dc[el[1]]).append(i)
    return dc

In [None]:
def reduceExemplars(exemplars, classes, m):
    for c in classes:
      exemplars[c] = exemplars[c][:m]

In [None]:
def constructExemplar(exemplars,data,classi,classindex,model,m):
    class_means = torch.empty((0, 64)).cuda()
    pdist = torch.nn.PairwiseDistance(p=2)
    model.train(False)
    model.set_flag(False)
    with torch.no_grad():
        model.set_flag(False)
        print("classi:",classi)
        for c in classi:
            print("classe: ",c)
            features = torch.empty(0,64).to('cuda')
            indexes = copy.deepcopy(classindex[c])
            subset = torch.utils.data.Subset(data,indexes) 
            dataloader = torch.utils.data.DataLoader(subset,batch_size = 128)

            for image, label in dataloader:
                image = image.to(DEVICE)
                output = model(image)
                output = nn.functional.normalize(output, p=2, dim=1)
                features = torch.cat((features,output))
                
            
            class_mean = torch.mean(features, 0)
            class_mean = nn.functional.normalize(class_mean, p=2, dim=0)
            class_mean = class_mean.view(-1, 64)
            class_means = torch.cat((class_means, class_mean))
            current_features = torch.empty((0, 64)).cuda()
            exemplars_indexes = []
            for k in range(m):
              current_sum = torch.sum(current_features, 0)
              current_sum = torch.add(features, current_sum.repeat(features.size(0), 1))
              current_mean = current_sum * (1.0/(k+1))
              current_mean = nn.functional.normalize(current_mean, p=2, dim=1)
              distances = pdist(current_mean, class_mean)
              # collecting chosen features
              index = torch.argmin(distances).item()
              phi = features[index].view(-1, 64)
              current_features = torch.cat((current_features, phi))
              # removing chosen features
              features = torch.cat((features[:index], features[index+1:]))
              exemplars_indexes.append(indexes.pop(index)) 
            exemplars[c] = exemplars_indexes
    model.set_flag(True)
    return class_means

In [None]:
loss_function = torch.nn.BCEWithLogitsLoss()

In [None]:
def classifierNCM(train_dataset,input_image, exemplars, labels_old, labels_new, class_means, model,indexofit):
  pdist = nn.PairwiseDistance(p=2)
  lbls_old = [item for sublist in labels_old for item in sublist]
  model.train(False)
  model.set_flag(False)
  with torch.no_grad():
    tensor = torch.zeros((input_image.size(0), (indexofit+1)*10), device="cuda:0")
    exemplars_mean = torch.empty((0, 64)).cuda()
    for c in lbls_old:
      l1 = list(exemplars[c])
      subset = torch.utils.data.Subset(train_dataset, l1)
      dataLoader = torch.utils.data.DataLoader(subset, batch_size=128)
      mean = torch.empty((0, 64)).cuda()
      for image, label in dataLoader:
        image = image.to(DEVICE)
        output = model(image)
        # L2 normalization of feature vector
        output = nn.functional.normalize(output, p=2, dim=1)
        mean = torch.cat((mean, output))
      mean = torch.mean(mean, 0)
      mean = nn.functional.normalize(mean, p=2, dim=0)
      mean = mean.view(-1, 64)
      exemplars_mean = torch.cat((exemplars_mean, mean))

    exemplars_mean = torch.cat((exemplars_mean, class_means))
    classes = list(lbls_old) + list(labels_new)

    output = model(input_image)
    output = nn.functional.normalize(output, p=2, dim=1)
    for n in range((output.size(0))):
      image = output[n]
      distances = pdist(image, exemplars_mean)
      index = torch.argmin(distances).item()
      index = target_for_test([classes[index]], labels_old, labels_new, indexofit)
      index = index[0]
      tensor[n][index] = 1

  model.set_flag(True)
  return tensor

In [None]:
model = None
model_distillation = None
DEVICE = 'cuda'
test = None

In [None]:
torch.set_printoptions(threshold=5000)
import random
if model is not None:
  del model
  torch.cuda.empty_cache()

if model_distillation is not None: 
  del model_distillation
  torch.cuda.empty_cache()
  model_distillation = None

from torch.backends import cudnn
DEVICE = 'cuda'
model = resnet32()
model = model.to(DEVICE)
cudnn.benchmark

list_of_image = obtain_list_of_subset(1994,data_train)
list_of_image_test = obtain_list_of_subset(1994,data_test)
dc = take_index_for_class(data_train)
test = None
test_label = []
accuracy_test = []
exemplars = [[] for x in range(100)]
target_prec = []
M = 2000
for i in range(10):
    a = set(([item for sublist in exemplars for item in sublist]))
    somma= sum([1 for el in  a])
    print("num exemplari:",somma)
    l = list_of_image[i]+[item for sublist in exemplars for item in sublist]
    random.Random(35).shuffle(l)

    data = torch.utils.data.Subset(data_train,l)
    
    if i == 0:
      
      optimizer = torch.optim.SGD(model.parameters(),lr=2,weight_decay=1e-5)
    
    else:
      optimizer = torch.optim.SGD(model.parameters(),lr=2,weight_decay = 1e-5)
    

    scheduler =torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones = [49,63], gamma=0.2, last_epoch=-1)
    
    dl = torch.utils.data.DataLoader(data,batch_size=128,shuffle=True)
        
    LR = 2.0
    EPOCHS = 70
    target_t_old = None
    model.train()
    for epoch in range(EPOCHS):
      model.train()
      counter = 0
        
      for   images, label in (dl):
            
            
            images = images.to(DEVICE)
            
            labels = label.to(DEVICE)
            
            optimizer.zero_grad()
            output = model(images)
            output = output.to(DEVICE)
            
            target_t_old, new_target = new_label_exemplar(label, target_prec, target_t_old, i)
            STAMPA = False
            
            target_loss = crea_target_classification(images,new_target,i)
            target_loss = target_loss.to(DEVICE)

            if model_distillation is not None:
                model_distillation.eval() 
                output_distillation = model_distillation(images)
                
                label_distillation = crea_target_distillation(output_distillation)
                
                target_loss = crea_label_classification_distillation(label_distillation, target_loss)

            loss_value = loss_function(output,target_loss)
            if counter % 39 == 0 and counter != 0:
              print("epoch:",epoch,"loss",loss_value,"lr",scheduler.get_last_lr())
            counter = counter + 1
            loss_value.backward()
            optimizer.step()
            del output
            del images
            del label

            if model_distillation is not None:
              del output_distillation

            torch.cuda.empty_cache()
      scheduler.step()

    
    print("entro in modalità test")
    if model_distillation is not None:
        del model_distillation
        torch.cuda.empty_cache()
        model_distillation = None
    model.eval()
    if test is None:
        test = torch.utils.data.Subset(data_test,list_of_image_test[i])
        m = M//(i*10+10)
        print("m:",m)
        
        class_means = constructExemplar(exemplars,data_train,target_t_old, dc,model,m)

        dltest = torch.utils.data.DataLoader(test,batch_size=128, shuffle=False, num_workers=4)
        
        running_corrects = 0
        STAMPA = True
        for images, labels in dltest:
          with torch.no_grad():
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)
            outputs = classifierNCM(data_train,images, exemplars, target_prec, target_t_old, class_means, model,i)
            STAMPA = False
            _, preds = torch.max(outputs, 1)
            labels = target_for_test(labels, target_prec, target_t_old, i)
            for idx in range(len(labels)):
              if preds[idx] == labels[idx]:
                running_corrects = running_corrects + 1
        score = running_corrects / float(len(test))
        accuracy_test.append(score)
    else:
        
        subset = torch.utils.data.Subset(data_test,list_of_image_test[i])
        test = torch.utils.data.ConcatDataset([test,subset])
        m = M//(i*10+10)
        print("m:",m)
        reduceExemplars(exemplars,dc.keys(),m)
        class_means = constructExemplar(exemplars,data_train,target_t_old,dc,model,m)
        dltest = torch.utils.data.DataLoader(test,batch_size=128, shuffle=False, num_workers=4)
        running_corrects = 0
        STAMPA = True
        for images, labels in dltest:
          with torch.no_grad():
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)
            outputs = classifierNCM(data_train,images, exemplars, target_prec, target_t_old, class_means, model,i)
            STAMPA = False
            _, preds = torch.max(outputs, 1)
            labels = target_for_test(labels, target_prec, target_t_old, i)
            for idx in range(len(labels)):
              if preds[idx] == labels[idx]:
                running_corrects = running_corrects + 1

        score = running_corrects / float(len(test))
        accuracy_test.append(score)


       
    print(target_loss.size())
    print(loss_value, epoch) 
    del target_loss
    torch.cuda.empty_cache()
    print("accuracy:",accuracy_test[i])
    model_distillation = copy.deepcopy(model) 
    target_prec.append(target_t_old)
    model.updatemodel()
    print("model update")