In [1]:
import torch
import torch.nn as nn
import torchvision
from os import path
import numpy as np
import torch.utils.data as data
from torchvision import transforms
from collections import OrderedDict

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
batch_size = 128
repeat = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# DataLoader

In [3]:
class CacheClassLabel(data.Dataset):
    """
    A dataset wrapper that has a quick access to all labels of data.
    """
    def __init__(self, dataset):
        super(CacheClassLabel, self).__init__()
        self.dataset = dataset
        self.labels = torch.LongTensor(len(dataset)).fill_(-1)
        label_cache_filename = path.join(dataset.root, str(type(dataset))+'_'+str(len(dataset))+'.pth')
        if path.exists(label_cache_filename):
            self.labels = torch.load(label_cache_filename)
        else:
            for i, data in enumerate(dataset):
                self.labels[i] = data[1]
            torch.save(self.labels, label_cache_filename)
        self.number_classes = len(torch.unique(self.labels))

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        img,target = self.dataset[index]
        return img, target
    
class AppendName(data.Dataset):
    """
    A dataset wrapper that also return the name of the dataset/task
    """
    def __init__(self, dataset, name, first_class_ind=0):
        super(AppendName,self).__init__()
        self.dataset = dataset
        self.name = name
        self.first_class_ind = first_class_ind  # For remapping the class index

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        img,target = self.dataset[index]
        target = target + self.first_class_ind
        return img, target, self.name
    
class Subclass(data.Dataset):
    """
    A dataset wrapper that return the task name and remove the offset of labels (Let the labels start from 0)
    """
    def __init__(self, dataset, class_list, remap=True):
        '''
        :param dataset: (CacheClassLabel)
        :param class_list: (list) A list of integers
        :param remap: (bool) Ex: remap class [2,4,6 ...] to [0,1,2 ...]
        '''
        super(Subclass,self).__init__()
        assert isinstance(dataset, CacheClassLabel), 'dataset must be wrapped by CacheClassLabel'
        self.dataset = dataset
        self.class_list = class_list
        self.remap = remap
        self.indices = []
        for c in class_list:
            self.indices.extend((dataset.labels==c).nonzero().flatten().tolist())
        if remap:
            self.class_mapping = {c: i for i, c in enumerate(class_list)}

    def __len__(self):
        return len(self.indices)
    def __getitem__(self, index):
        img,target = self.dataset[self.indices[index]]
        if self.remap:
            raw_target = target.item() if isinstance(target,torch.Tensor) else target
            target = self.class_mapping[raw_target]
        return img, target

In [4]:
def SplitGen(train_dataset, val_dataset, first_split_sz=2, other_split_sz=2, rand_split=False, remap_class=False):
    assert train_dataset.number_classes==val_dataset.number_classes,'Train/Val has different number of classes'
    num_classes =  train_dataset.number_classes

    # Calculate the boundary index of classes for splits
    # Ex: [0,2,4,6,8,10] or [0,50,60,70,80,90,100]
    split_boundaries = [0, first_split_sz]
    while split_boundaries[-1]<num_classes:
        split_boundaries.append(split_boundaries[-1]+other_split_sz)
    print('split_boundaries:',split_boundaries)
    assert split_boundaries[-1]==num_classes,'Invalid split size'

    # Assign classes to each splits
    # Create the dict: {split_name1:[2,6,7], split_name2:[0,3,9], ...}
    if not rand_split:
        class_lists = {str(i):list(range(split_boundaries[i-1],split_boundaries[i])) for i in range(1,len(split_boundaries))}
    else:
        randseq = torch.randperm(num_classes)
        class_lists = {str(i):randseq[list(range(split_boundaries[i-1],split_boundaries[i]))].tolist() for i in range(1,len(split_boundaries))}
    print(class_lists)

    # Generate the dicts of splits
    # Ex: {split_name1:dataset_split1, split_name2:dataset_split2, ...}
    train_dataset_splits = {}
    val_dataset_splits = {}
    task_output_space = {}
    for name,class_list in class_lists.items():
        train_dataset_splits[name] = AppendName(Subclass(train_dataset, class_list, remap_class), name)
        val_dataset_splits[name] = AppendName(Subclass(val_dataset, class_list, remap_class), name)
        task_output_space[name] = len(class_list)

    return train_dataset_splits, val_dataset_splits, task_output_space

In [5]:
def MNIST(dataroot, train_aug=False):
    # Add padding to make 32x32
    #normalize = transforms.Normalize(mean=(0.1307,), std=(0.3081,))  # for 28x28
    normalize = transforms.Normalize(mean=(0.1000,), std=(0.2752,))  # for 32x32

    val_transform = transforms.Compose([
        transforms.Pad(2, fill=0, padding_mode='constant'),
        transforms.ToTensor(),
        normalize,
    ])
    train_transform = val_transform
    if train_aug:
        train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.ToTensor(),
            normalize,
        ])

    train_dataset = torchvision.datasets.MNIST(
        root=dataroot,
        train=True,
        download=True,
        transform=train_transform
    )
    train_dataset = CacheClassLabel(train_dataset)

    val_dataset = torchvision.datasets.MNIST(
        dataroot,
        train=False,
        transform=val_transform
    )
    val_dataset = CacheClassLabel(val_dataset)

    return train_dataset, val_dataset

In [6]:
train_dataset, val_dataset = MNIST('data', False)

In [7]:
train_dataset_splits, val_dataset_splits, task_output_space = SplitGen(train_dataset, val_dataset,
                                                                          first_split_sz=2,
                                                                          other_split_sz=2,
                                                                          rand_split=False,
                                                                          remap_class=False)

split_boundaries: [0, 2, 4, 6, 8, 10]
{'1': [0, 1], '2': [2, 3], '3': [4, 5], '4': [6, 7], '5': [8, 9]}


# Model

In [8]:
class MLP(nn.Module):
    def __init__(self, out_dim=10, in_channel=1, img_sz=32, hidden_dim=256):
        super(MLP, self).__init__()
        self.in_dim = in_channel*img_sz*img_sz
        self.linear = nn.Sequential(
            nn.Linear(self.in_dim, hidden_dim),
            #nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, hidden_dim),
            #nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
        )
        self.last = nn.Linear(hidden_dim, out_dim)  # Subject to be replaced dependent on task

    def features(self, x):
        x = self.linear(x.view(-1,self.in_dim))
        return x

    def logits(self, x):
        x = self.last(x)
        return x

    def forward(self, x):
        x = self.features(x)
        x = self.logits(x)
        return x

In [9]:
def MLP400():
    return MLP(hidden_dim=400)

# Train

In [10]:
class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = float(self.sum) / self.count

In [11]:
agent = MLP400()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(agent.parameters(), 0.001)

n_feat = agent.last.in_features

In [12]:
task_names = sorted(list(task_output_space.keys()), key=int)
print('Task order:',task_names)

Task order: ['1', '2', '3', '4', '5']


In [13]:
def criterion_fn(preds, targets, valid_out_dim):
    if valid_out_dim != 0:
        pred = preds[:,:valid_out_dim]
    loss = criterion(pred, targets)
    return loss

In [14]:
def accuracy(output, target, topk=(1,)):
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum().item()
            res.append(correct_k*100.0 / batch_size)

        if len(res)==1:
            return res[0]
        else:
            return res

In [15]:
def accumulate_acc(output, target, meter):
    meter.update(accuracy(output, target), len(target))
    return meter

In [16]:
avg_final_acc = np.zeros(repeat)

for r in range (repeat):
    acc_table = OrderedDict()
    valid_out_dim = 0
    for i in range(len(task_names)):
        valid_out_dim += 2
        train_name = task_names[i]
        train_loader = torch.utils.data.DataLoader(train_dataset_splits[train_name],
                                                            batch_size=batch_size, shuffle=True)
        val_loader = torch.utils.data.DataLoader(val_dataset_splits[train_name],
                                                        batch_size=batch_size, shuffle=False)
        # Train
        for epoch in range(4):
            train_acc = AverageMeter()
            for (input, target, task) in train_loader:
                agent.train()
                input, target = input.to(device), target.to(device)

                output = agent(input)
                loss = criterion_fn(output, target, valid_out_dim)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                train_acc = accumulate_acc(output, target, train_acc)

        # Eval
        acc_table[train_name] = OrderedDict()

        for j in range(i+1):
            val_name = task_names[j]
            val_data = val_dataset_splits[val_name]
            val_loader = torch.utils.data.DataLoader(val_data, batch_size=128, shuffle=False,)

            agent.eval()
            val_acc = AverageMeter()
            with torch.no_grad():
                for i, (input, target, task) in enumerate(val_loader):
                    input, target = input.to(device), target.to(device)
                    output = agent(input)
                    val_acc = accumulate_acc(output, target, val_acc)

            acc_table[val_name][train_name] = val_acc.avg

    print(acc_table)

    avg_acc_history = [0] * len(task_names)
    for i in range(len(task_names)):
        train_name = task_names[i]
        cls_acc_sum = 0
        for j in range(i + 1):
            val_name = task_names[j]
            cls_acc_sum += acc_table[val_name][train_name]

        avg_acc_history[i] = cls_acc_sum / (i + 1)
        print('Task', train_name, 'average acc:', avg_acc_history[i])
    
    avg_final_acc[r] = avg_acc_history[-1]
    print('===Summary of experiment repeats:',r+1,'/',repeat,'===')
    print(avg_final_acc)
    print('mean:', avg_final_acc.mean(), 'std:', avg_final_acc.std())

OrderedDict([('1', OrderedDict([('1', 99.66903073286052), ('2', 0.0), ('3', 0.0), ('4', 0.0), ('5', 0.0)])), ('2', OrderedDict([('2', 99.70617042115573), ('3', 0.0), ('4', 0.0), ('5', 0.0)])), ('3', OrderedDict([('3', 99.46638207043756), ('4', 0.0), ('5', 0.0)])), ('4', OrderedDict([('4', 99.59718026183283), ('5', 0.0)])), ('5', OrderedDict([('5', 98.73928391326274)]))])
Task 1 average acc: 99.66903073286052
Task 2 average acc: 49.853085210577866
Task 3 average acc: 33.15546069014585
Task 4 average acc: 24.899295065458208
Task 5 average acc: 19.747856782652548
===Summary of experiment repeats: 1 / 10 ===
[19.74785678  0.          0.          0.          0.          0.
  0.          0.          0.          0.        ]
mean: 1.9747856782652549 std: 5.924357034795765
OrderedDict([('1', OrderedDict([('1', 0.0), ('2', 0.0), ('3', 0.0), ('4', 0.0), ('5', 0.0)])), ('2', OrderedDict([('2', 0.0), ('3', 0.0), ('4', 0.0), ('5', 0.0)])), ('3', OrderedDict([('3', 13.180362860192103), ('4', 0.0), ('