In [205]:
import torch
import torch.nn as nn
import torchvision
import os
from os import path
import copy
import numpy as np
import torch.utils.data as data
from torchvision import transforms
from collections import OrderedDict

In [206]:
batch_size = 128
repeat = 10
epoches = 1
alpha = 0.01
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# DataLoader

In [207]:
class CacheClassLabel(data.Dataset):
    """
    A dataset wrapper that has a quick access to all labels of data.
    """
    def __init__(self, dataset):
        super(CacheClassLabel, self).__init__()
        self.dataset = dataset
        self.labels = torch.LongTensor(len(dataset)).fill_(-1)
        label_cache_filename = path.join(dataset.root, str(type(dataset))+'_'+str(len(dataset))+'.pth')
        if path.exists(label_cache_filename):
            self.labels = torch.load(label_cache_filename)
        else:
            for i, data in enumerate(dataset):
                self.labels[i] = data[1]
            torch.save(self.labels, label_cache_filename)
        self.number_classes = len(torch.unique(self.labels))

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        img,target = self.dataset[index]
        return img, target
    
class AppendName(data.Dataset):
    """
    A dataset wrapper that also return the name of the dataset/task
    """
    def __init__(self, dataset, name, first_class_ind=0):
        super(AppendName,self).__init__()
        self.dataset = dataset
        self.name = name
        self.first_class_ind = first_class_ind  # For remapping the class index

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        img,target = self.dataset[index]
        target = target + self.first_class_ind
        return img, target, self.name
    
class Subclass(data.Dataset):
    """
    A dataset wrapper that return the task name and remove the offset of labels (Let the labels start from 0)
    """
    def __init__(self, dataset, class_list, remap=True):
        '''
        :param dataset: (CacheClassLabel)
        :param class_list: (list) A list of integers
        :param remap: (bool) Ex: remap class [2,4,6 ...] to [0,1,2 ...]
        '''
        super(Subclass,self).__init__()
        assert isinstance(dataset, CacheClassLabel), 'dataset must be wrapped by CacheClassLabel'
        self.dataset = dataset
        self.class_list = class_list
        self.remap = remap
        self.indices = []
        for c in class_list:
            self.indices.extend((dataset.labels==c).nonzero().flatten().tolist())
        if remap:
            self.class_mapping = {c: i for i, c in enumerate(class_list)}

    def __len__(self):
        return len(self.indices)
    def __getitem__(self, index):
        img,target = self.dataset[self.indices[index]]
        if self.remap:
            raw_target = target.item() if isinstance(target,torch.Tensor) else target
            target = self.class_mapping[raw_target]
        return img, target

In [208]:
def SplitGen(train_dataset, val_dataset, first_split_sz=2, other_split_sz=2, rand_split=False, remap_class=False):
    assert train_dataset.number_classes==val_dataset.number_classes,'Train/Val has different number of classes'
    num_classes =  train_dataset.number_classes

    # Calculate the boundary index of classes for splits
    # Ex: [0,2,4,6,8,10] or [0,50,60,70,80,90,100]
    split_boundaries = [0, first_split_sz]
    while split_boundaries[-1]<num_classes:
        split_boundaries.append(split_boundaries[-1]+other_split_sz)
    print('split_boundaries:',split_boundaries)
    assert split_boundaries[-1]==num_classes,'Invalid split size'

    # Assign classes to each splits
    # Create the dict: {split_name1:[2,6,7], split_name2:[0,3,9], ...}
    if not rand_split:
        class_lists = {str(i):list(range(split_boundaries[i-1],split_boundaries[i])) for i in range(1,len(split_boundaries))}
    else:
        randseq = torch.randperm(num_classes)
        class_lists = {str(i):randseq[list(range(split_boundaries[i-1],split_boundaries[i]))].tolist() for i in range(1,len(split_boundaries))}
    print(class_lists)

    # Generate the dicts of splits
    # Ex: {split_name1:dataset_split1, split_name2:dataset_split2, ...}
    train_dataset_splits = {}
    val_dataset_splits = {}
    task_output_space = {}
    for name,class_list in class_lists.items():
        train_dataset_splits[name] = AppendName(Subclass(train_dataset, class_list, remap_class), name)
        val_dataset_splits[name] = AppendName(Subclass(val_dataset, class_list, remap_class), name)
        task_output_space[name] = len(class_list)

    return train_dataset_splits, val_dataset_splits, task_output_space

In [209]:
def MNIST(dataroot, train_aug=False):
    val_transform = transforms.Compose([
        transforms.Pad(2, fill=0, padding_mode='constant'),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5]),
    ])
    train_transform = val_transform
    if train_aug:
        train_transform = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ])

    train_dataset = torchvision.datasets.MNIST(
        root=dataroot,
        train=True,
        download=True,
        transform=train_transform
    )
    train_dataset = CacheClassLabel(train_dataset)

    val_dataset = torchvision.datasets.MNIST(
        dataroot,
        train=False,
        transform=val_transform
    )
    val_dataset = CacheClassLabel(val_dataset)

    return train_dataset, val_dataset

In [210]:
train_dataset, val_dataset = MNIST('data', False)

In [211]:
train_dataset_splits, val_dataset_splits, task_output_space = SplitGen(train_dataset, val_dataset,
                                                                          first_split_sz=2,
                                                                          other_split_sz=2,
                                                                          rand_split=False,
                                                                          remap_class=False)

split_boundaries: [0, 2, 4, 6, 8, 10]
{'1': [0, 1], '2': [2, 3], '3': [4, 5], '4': [6, 7], '5': [8, 9]}


# Model

In [212]:
class MLP(nn.Module):
    def __init__(self, out_dim=10, in_channel=1, img_sz=32, hidden_dim=256):
        super(MLP, self).__init__()
        self.in_dim = in_channel*img_sz*img_sz
        self.linear = nn.Sequential(
            nn.Linear(self.in_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
        )
        self.last = nn.Linear(hidden_dim, out_dim)

    def features(self, x):
        x = self.linear(x.view(-1,self.in_dim))
        return x

    def logits(self, x):
        x = self.last(x)
        return x

    def forward(self, x):
        x = self.features(x)
        x = self.logits(x)
        return x

In [213]:
def MLP400():
    return MLP(hidden_dim=400)

# Train

In [214]:
class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = float(self.sum) / self.count

In [215]:
def accuracy(output, target):
    with torch.no_grad():
        _, predicted = torch.max(output.data, 1)
        batch_size = target.size(0)
        correct = (predicted == target).sum().item() * 100
    return correct / batch_size

In [216]:
def accumulate_acc(output, target, meter):
    acc = accuracy(output, target)
    meter.update(acc, len(target))
    return meter

In [217]:
def criterion_fn(criterion, preds, targets, valid_out_dim):
    if valid_out_dim != 0:
        pred = preds[:,:valid_out_dim]
    loss = criterion(pred, targets)
    return loss

In [218]:
def train_on_task(model, train_loader, optimizer, criterion, 
                  valid_out_dim, best_model_wts, best_loss, task_num, task_names):
    leader = MLP400().to(device)
    if (best_model_wts):
        leader.load_state_dict(best_model_wts)

    for epoch in range(epoches):
        train_acc = AverageMeter()
        batch_num = 0
        for images, labels, _ in train_loader:
            model.train()
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)

            reg_loss = 0
            for lead_para, follower_para in zip(leader.parameters(), model.parameters()):
                reg_loss += torch.norm(follower_para - lead_para, p = 2)
            
            c_loss = criterion_fn(criterion, outputs, labels, valid_out_dim)
            loss = c_loss + 5 * reg_loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_acc = accumulate_acc(outputs, labels, train_acc)

            model.eval()
            with torch.no_grad():
                val_loss = AverageMeter()

                for task in range(task_num + 1):
                    val_name = task_names[task]
                    val_data = val_dataset_splits[val_name]
                    val_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size, shuffle=False)

                    for i, (input, target, _) in enumerate(val_loader):
                        input, target = input.to(device), target.to(device)
                        output = model(input)
                        loss = criterion(output, target).item()

                        val_loss.update(loss, len(target))

                    if val_loss.avg < best_loss:
                        best_loss = val_loss.avg
                        best_model_wts = copy.deepcopy(model.state_dict())
                        leader.load_state_dict(best_model_wts) 
            print(f"batch_num: {batch_num}, c_loss:{c_loss.item():.4f}, val_loss:{val_loss.avg: .4f}, loss:{loss:.4f}")
    return best_model_wts, best_loss

In [219]:
def train(task_names):
    acc_table = OrderedDict()
    valid_out_dim = 0

    model = MLP400().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), 0.0005)

    best_model_wts = None
    best_loss = float('inf')
    for i in range(len(task_names)):
        valid_out_dim += 2
        train_name = task_names[i]
        train_loader = torch.utils.data.DataLoader(train_dataset_splits[train_name], batch_size=batch_size, shuffle=True)
        
        print(f'=====Task: {train_name}=====')
        best_model_wts, best_loss = train_on_task(model, train_loader, optimizer, criterion, valid_out_dim, best_model_wts, best_loss, i, task_names)
    
        acc_table[train_name] = OrderedDict()

        for j in range(i+1):
            val_name = task_names[j]
            val_data = val_dataset_splits[val_name]
            val_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size, shuffle=False)
            model.eval()
            val_acc = AverageMeter()
            with torch.no_grad():
                for i, (input, target, _) in enumerate(val_loader):
                    input, target = input.to(device), target.to(device)
                    output = model(input)
                    val_acc = accumulate_acc(output, target, val_acc)

            acc_table[val_name][train_name] = val_acc.avg

        print(acc_table)

    avg_acc_history = [0] * len(task_names)
    for i in range(len(task_names)):
        train_name = task_names[i]
        cls_acc_sum = 0
        for j in range(i + 1):
            val_name = task_names[j]
            cls_acc_sum += acc_table[val_name][train_name]

        avg_acc_history[i] = cls_acc_sum / (i + 1)
        print('Task', train_name, 'average acc:', avg_acc_history[i])
    
    return avg_acc_history

In [220]:
task_names = sorted(list(task_output_space.keys()), key=int)
print('Task order:',task_names)

Task order: ['1', '2', '3', '4', '5']


In [221]:
avg_acc_history = train(task_names)

=====Task: 1=====
batch_num: 0, c_loss:0.7457, val_loss: 2.2492, loss:2.2799
batch_num: 0, c_loss:0.3090, val_loss: 2.1597, loss:2.1512
batch_num: 0, c_loss:0.1448, val_loss: 2.0643, loss:2.0230
batch_num: 0, c_loss:0.0816, val_loss: 1.9663, loss:1.8981
batch_num: 0, c_loss:0.0644, val_loss: 1.8712, loss:1.7748
batch_num: 0, c_loss:0.0495, val_loss: 1.7770, loss:1.6600
batch_num: 0, c_loss:0.0477, val_loss: 1.6845, loss:1.5532
batch_num: 0, c_loss:0.0347, val_loss: 1.5941, loss:1.4521
batch_num: 0, c_loss:0.0181, val_loss: 1.5052, loss:1.3543
batch_num: 0, c_loss:0.0210, val_loss: 1.4181, loss:1.2602
batch_num: 0, c_loss:0.0187, val_loss: 1.3311, loss:1.1771
batch_num: 0, c_loss:0.0165, val_loss: 1.2499, loss:1.0941
batch_num: 0, c_loss:0.0312, val_loss: 1.1748, loss:1.0136
batch_num: 0, c_loss:0.0129, val_loss: 1.1004, loss:0.9402
batch_num: 0, c_loss:0.0204, val_loss: 1.0268, loss:0.8822
batch_num: 0, c_loss:0.0100, val_loss: 0.9599, loss:0.8242
batch_num: 0, c_loss:0.0203, val_loss:

In [222]:
'''avg_final_acc = np.zeros(repeat)

for r in range (repeat):
    acc_table = OrderedDict()
    valid_out_dim = 0
    for i in range(len(task_names)):
        valid_out_dim += 2
        train_name = task_names[i]
        train_loader = torch.utils.data.DataLoader(train_dataset_splits[train_name],
                                                            batch_size=batch_size, shuffle=True)
        val_loader = torch.utils.data.DataLoader(val_dataset_splits[train_name],
                                                        batch_size=batch_size, shuffle=False)
        # Train
        for epoch in range(4):
            train_acc = AverageMeter()
            for (input, target, task) in train_loader:
                agent.train()
                input, target = input.to(device), target.to(device)

                output = agent(input)
                loss = criterion_fn(criterion, output, target, valid_out_dim)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                train_acc = accumulate_acc(output, target, train_acc)

        # Eval
        acc_table[train_name] = OrderedDict()

        for j in range(i+1):
            val_name = task_names[j]
            val_data = val_dataset_splits[val_name]
            val_loader = torch.utils.data.DataLoader(val_data, batch_size=128, shuffle=False,)

            agent.eval()
            val_acc = AverageMeter()
            with torch.no_grad():
                for i, (input, target, task) in enumerate(val_loader):
                    input, target = input.to(device), target.to(device)
                    output = agent(input)
                    val_acc = accumulate_acc(output, target, val_acc)

            acc_table[val_name][train_name] = val_acc.avg

    print(acc_table)

    avg_acc_history = [0] * len(task_names)
    for i in range(len(task_names)):
        train_name = task_names[i]
        cls_acc_sum = 0
        for j in range(i + 1):
            val_name = task_names[j]
            cls_acc_sum += acc_table[val_name][train_name]

        avg_acc_history[i] = cls_acc_sum / (i + 1)
        print('Task', train_name, 'average acc:', avg_acc_history[i])
    
    avg_final_acc[r] = avg_acc_history[-1]
    print('===Summary of experiment repeats:',r+1,'/',repeat,'===')
    print(avg_final_acc)
    print('mean:', avg_final_acc.mean(), 'std:', avg_final_acc.std())'''

"avg_final_acc = np.zeros(repeat)\n\nfor r in range (repeat):\n    acc_table = OrderedDict()\n    valid_out_dim = 0\n    for i in range(len(task_names)):\n        valid_out_dim += 2\n        train_name = task_names[i]\n        train_loader = torch.utils.data.DataLoader(train_dataset_splits[train_name],\n                                                            batch_size=batch_size, shuffle=True)\n        val_loader = torch.utils.data.DataLoader(val_dataset_splits[train_name],\n                                                        batch_size=batch_size, shuffle=False)\n        # Train\n        for epoch in range(4):\n            train_acc = AverageMeter()\n            for (input, target, task) in train_loader:\n                agent.train()\n                input, target = input.to(device), target.to(device)\n\n                output = agent(input)\n                loss = criterion_fn(criterion, output, target, valid_out_dim)\n\n                optimizer.zero_grad()\n              