In [14]:
import torch
import torchvision.transforms as TF
from cl_lite.head.dynamic_simple import DynamicSimpleHead
import cl_lite.backbone as B
import os
import numpy as np
import sys
sys.path.append('..')
from rdfcil.datamodule import DataModule
from cl_lite.backbone.resnet_cifar import CifarResNet

class ISCF_ResNet(CifarResNet):
    def __init__(self, n=5, nf=16, channels=3, preact=False, zero_residual=True, pooling_config=..., downsampling="stride", final_layer=False, all_attentions=False, last_relu=False, **kwargs):
        super().__init__(n, nf, channels, preact, zero_residual, pooling_config, downsampling, final_layer, all_attentions, last_relu, **kwargs)        
    def forward_feat(self, x):
        x = self.conv_1_3x3(x)
        x = F.relu(self.bn_1(x), inplace=True)

        feats_s1, x1 = self.stage_1(x)
        feats_s2, x2 = self.stage_2(x1)
        feats_s3, x3 = self.stage_3(x2)
        x4 = self.stage_4(x3)

        return x4,[x1, x2, x3]

dataset = "cifar100"
num_classes = 100
data_root = '/data'
class_order= [23, 8, 11, 7, 48, 13, 1, 91, 94, 54, 16, 63, 52, 41, 80, 2, 47, 87, 78, 66, 19, 6, 24, 10, 59, 30, 22, 29, 83, 37, 93, 81, 43, 99, 86, 28, 34, 88, 44, 14, 84, 70, 4, 20, 15, 21, 31, 76, 57, 67, 73, 50, 69, 25, 98, 46, 96, 0, 72, 35, 58, 92, 3, 95, 56, 90, 26, 40, 55, 89, 75, 71, 60, 42, 9, 82, 39, 18, 77, 68, 32, 79, 12, 85, 36, 17, 64, 27, 74, 45, 61, 38, 51, 62, 65, 33, 5, 53, 97, 49]
# class_order = [53, 37, 65, 51, 4, 20, 38, 9, 10, 81, 44, 36, 84, 50, 96, 90, 66, 16, 80, 33, 24, 52, 91, 99, 64, 5, 58, 76, 39, 79, 23, 94, 30, 73, 25, 47, 31, 45, 19, 87, 42, 68, 95, 21, 7, 67, 46, 82, 11, 6, 41, 86, 88, 70, 18, 78, 71, 59, 43, 61, 22, 14, 35, 93, 56, 28, 98, 54, 27, 89, 1, 69, 74, 2, 85, 40, 13, 75, 29, 34, 92, 0, 77, 55, 49, 3, 62, 12, 26, 48, 83, 60, 57, 63, 15, 32, 8, 97, 72, 17]
# class_order = [0, 76, 61, 63, 1, 71, 2, 6, 16, 19, 13, 24, 49, 12, 75, 9, 83, 72, 5, 41, 99, 45, 89, 53, 79, 18, 52, 92, 14, 42, 68, 44, 38, 84, 36, 17, 31, 15, 70, 88, 25, 97, 51, 73, 66, 37, 78, 33, 80, 26, 82, 28, 60, 35, 43, 57, 23, 58, 91, 8, 62, 93, 98, 86, 29, 30, 22, 95, 67, 54, 48, 40, 59, 96, 3, 87, 34, 64, 56, 69, 47, 65, 50, 81, 55, 20, 74, 4, 90, 27, 77, 32, 39, 85, 94, 21, 46, 10, 11, 7]

num_tasks = 20

# Convert class_order to a tensor for faster indexing
class_order_tensor = torch.tensor(class_order).cuda()

# Create a tensor of zeros with the same length as class_order
# This tensor will be used to create a mapping where the index is the class order position
mapping_tensor = torch.zeros(len(class_order), dtype=torch.long).cuda()

# Assign the new class indices (which are just the indices of class_order_tensor) to the corresponding positions in mapping_tensor
mapping_tensor[class_order_tensor] = torch.arange(len(class_order_tensor)).cuda()

# Use the mapping tensor to map the labels
# fast_mapped_labels = mapping_tensor[labels]

In [18]:
# get forgetting results
total_task_acc=[]
for t in range(1,num_tasks+1):
    print(f"Task {t}")
    
    # get the model
    if dataset.startswith("imagenet"):
        backbone = B.resnet.resnet18()
    else:
        backbone = ISCF_ResNet()

    # prefix = './ImageNet-100/imnet100_version_675_rdfcil_5task_49.44/task_{}'.format(t-1)
    # prefix = './ImageNet-100/version_508_imnet_5task_54.64/task_{}'.format(t-1)

    #10t
    # prefix = '/home/minsoo/workspace/AlwaysBeDreaming-DFCIL/rdfcil/lightning_logs/ijcv/version_816_cifar_10t_43.57/task_{}'.format(t-1)
    # prefix = '/home/minsoo/workspace/AlwaysBeDreaming-DFCIL/rdfcil/lightning_logs/ijcv/version_808_cifar_10t_43.29/task_{}'.format(t-1)
    # prefix = '/home/minsoo/workspace/AlwaysBeDreaming-DFCIL/rdfcil/lightning_logs/ijcv/version_809_cifar_10t_43.26/task_{}'.format(t-1)
    # 5t
    # prefix = '/home/minsoo/workspace/AlwaysBeDreaming-DFCIL/rdfcil/lightning_logs/version_734_cifar100_5t_51.09/task_{}'.format(t-1)
    # prefix = '/home/minsoo/workspace/AlwaysBeDreaming-DFCIL/rdfcil/lightning_logs/ijcv/version_706_cifar_5t_50.22/task_{}'.format(t-1)
    # prefix = '/home/minsoo/workspace/AlwaysBeDreaming-DFCIL/rdfcil/lightning_logs/ijcv/version_726_cifar_5t_51.08/task_{}'.format(t-1)
    # 20t
    prefix = '/home/minsoo/workspace/AlwaysBeDreaming-DFCIL/rdfcil/lightning_logs/ijcv/version_705_cifar_20t_32.54/task_{}'.format(t-1)
    # prefix = '/home/minsoo/workspace/AlwaysBeDreaming-DFCIL/rdfcil/lightning_logs/ijcv/version_708_cifar_20t_32.59/task_{}'.format(t-1)
    # prefix = '/home/minsoo/workspace/AlwaysBeDreaming-DFCIL/rdfcil/lightning_logs/ijcv/version_705_cifar_20t_32.54/task_{}'.format(t-1)
    state_dict = torch.load(os.path.join(prefix,"checkpoints/best_acc.ckpt"))['state_dict']

    # dataload
    data_module = DataModule(root=data_root, 
                            dataset=dataset, 
                            batch_size=128, 
                            num_workers=4,
                            num_tasks=num_tasks,
                            class_order=class_order,
                            current_task=t-1,
                            )
    data_module.setup()
    # head = DynamicSimpleHead(num_classes=data_module.num_classes, num_features=backbone.num_features, bias=False)
    head = DynamicSimpleHead(num_classes=data_module.num_classes, num_features=backbone.num_features, bias=True)
    
    backbone_state= {}
    head_state = {}
    for _ in range(t-1):
        head.append(num_classes//num_tasks)
    for k,v in state_dict.items():
        if k.startswith('backbone'):
            backbone_state[k[9:]] = v
        elif k.startswith('head'):
            head_state[k[5:]] = v
            # head_state[k[17:]] = v
            
        
    backbone.load_state_dict(backbone_state)
    backbone.eval()
    head.load_state_dict(head_state)


    # train_dataloader = data_module.train_dataloader()
    val_dataloader = data_module.val_dataloader()

    backbone.cuda()
    head.cuda()

    task_correct= [0 for _ in range(t)]
    task_total = [0 for _ in range(t)]
    idx=0
    for batch in val_dataloader:
        images, labels = batch
        images = images.cuda()
        labels = labels.cuda()
        with torch.no_grad():
            output = backbone(images)
            output = head(output)
            # print(output[0],labels[0])
            labels = mapping_tensor[labels]
            for i in range(t):
                t_indices = torch.nonzero(torch.bitwise_and(num_classes//num_tasks*(i+1) >= labels, num_classes//num_tasks*(i) < labels) ).view(-1)
                # print(t_indices.view(-1))
                # task accuracy
                labels_t = labels[t_indices] # - i*num_classes//num_tasks
                output_t = output[t_indices]
                # if i==0:
                #     output_t = output_t[:,:num_classes//num_tasks*(i+1)]
                # else: output_t = output_t[:,num_classes//num_tasks*i:num_classes//num_tasks*(i+1)]
                task_correct[i] += (output_t.argmax(dim=1) == labels_t).sum().item()
                task_total[i] += len(labels_t)
        idx+=1
        print('\r idx: {}'.format(idx), end='')
    print()
    task_acc = [float(cc)/ct for cc,ct in zip(task_correct,task_total)]
    print(task_acc)

    for j in range(num_tasks-t):
        task_acc.append(0)
    total_task_acc.append(task_acc)
total_task_acc = np.array(total_task_acc)
print(total_task_acc)
result = []
for i in range(num_tasks):
    if i == 0:
        result.append(0)
    else:
        res = 0
        for j in range(i + 1):
            res += (np.max(total_task_acc[:, j]) - total_task_acc[i][j])
        res = res / i
        result.append(100 * res)

        
print('Forgetting result:')
print(result)
print(sum(result)/len(result))

Task 1
Files already downloaded and verified
Files already downloaded and verified
[23, 8, 11, 7, 48, 13, 1, 91, 94, 54, 16, 63, 52, 41, 80, 2, 47, 87, 78, 66, 19, 6, 24, 10, 59, 30, 22, 29, 83, 37, 93, 81, 43, 99, 86, 28, 34, 88, 44, 14, 84, 70, 4, 20, 15, 21, 31, 76, 57, 67, 73, 50, 69, 25, 98, 46, 96, 0, 72, 35, 58, 92, 3, 95, 56, 90, 26, 40, 55, 89, 75, 71, 60, 42, 9, 82, 39, 18, 77, 68, 32, 79, 12, 85, 36, 17, 64, 27, 74, 45, 61, 38, 51, 62, 65, 33, 5, 53, 97, 49]
 idx: 4
[0.845]
Task 2
Files already downloaded and verified
Files already downloaded and verified
[23, 8, 11, 7, 48, 13, 1, 91, 94, 54, 16, 63, 52, 41, 80, 2, 47, 87, 78, 66, 19, 6, 24, 10, 59, 30, 22, 29, 83, 37, 93, 81, 43, 99, 86, 28, 34, 88, 44, 14, 84, 70, 4, 20, 15, 21, 31, 76, 57, 67, 73, 50, 69, 25, 98, 46, 96, 0, 72, 35, 58, 92, 3, 95, 56, 90, 26, 40, 55, 89, 75, 71, 60, 42, 9, 82, 39, 18, 77, 68, 32, 79, 12, 85, 36, 17, 64, 27, 74, 45, 61, 38, 51, 62, 65, 33, 5, 53, 97, 49]
 idx: 8
[0.794, 0.7025]
Task 3
Files

In [5]:

forgetting = np.max(total_task_acc,axis=0) - np.min(total_task_acc,axis=0)
print(forgetting)
avg_forgetting = np.mean(forgetting)
print(avg_forgetting)

[0.03373684 0.01894737 0.00263158 0.077      0.00578947]
0.027621052631578952
