In [37]:
import torch
import torch.nn as nn
import torchvision
import os
from os import path
import copy
import numpy as np
import torch.utils.data as data
from torchvision import transforms
from collections import OrderedDict

In [38]:
batch_size = 128
repeat = 10
epoches = 1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [39]:
class CacheClassLabel(data.Dataset):
    """
    A dataset wrapper that has a quick access to all labels of data.
    """
    def __init__(self, dataset):
        super(CacheClassLabel, self).__init__()
        self.dataset = dataset
        self.labels = torch.LongTensor(len(dataset)).fill_(-1)
        print(dataset.root)
        label_cache_filename = dataset.root + '/' +'_'+str(len(dataset))+'.pth'
        if path.exists(label_cache_filename):
            self.labels = torch.load(label_cache_filename)
        else:
            for i, data in enumerate(dataset):
                self.labels[i] = data[1]
            torch.save(self.labels, label_cache_filename)
        self.number_classes = len(torch.unique(self.labels))

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        img,target = self.dataset[index]
        return img, target
    
class AppendName(data.Dataset):
    """
    A dataset wrapper that also return the name of the dataset/task
    """
    def __init__(self, dataset, name, first_class_ind=0):
        super(AppendName,self).__init__()
        self.dataset = dataset
        self.name = name
        self.first_class_ind = first_class_ind  # For remapping the class index

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        img,target = self.dataset[index]
        target = target + self.first_class_ind
        return img, target, self.name
    
class Subclass(data.Dataset):
    """
    A dataset wrapper that return the task name and remove the offset of labels (Let the labels start from 0)
    """
    def __init__(self, dataset, class_list, remap=True):
        '''
        :param dataset: (CacheClassLabel)
        :param class_list: (list) A list of integers
        :param remap: (bool) Ex: remap class [2,4,6 ...] to [0,1,2 ...]
        '''
        super(Subclass,self).__init__()
        assert isinstance(dataset, CacheClassLabel), 'dataset must be wrapped by CacheClassLabel'
        self.dataset = dataset
        self.class_list = class_list
        self.remap = remap
        self.indices = []
        for c in class_list:
            self.indices.extend((dataset.labels==c).nonzero().flatten().tolist())
        if remap:
            self.class_mapping = {c: i for i, c in enumerate(class_list)}

    def __len__(self):
        return len(self.indices)
    def __getitem__(self, index):
        img,target = self.dataset[self.indices[index]]
        if self.remap:
            raw_target = target.item() if isinstance(target,torch.Tensor) else target
            target = self.class_mapping[raw_target]
        return img, target

In [40]:
def SplitGen(train_dataset, val_dataset, first_split_sz=2, other_split_sz=2, rand_split=False, remap_class=False):
    assert train_dataset.number_classes==val_dataset.number_classes,'Train/Val has different number of classes'
    num_classes =  train_dataset.number_classes

    # Calculate the boundary index of classes for splits
    # Ex: [0,2,4,6,8,10] or [0,50,60,70,80,90,100]
    split_boundaries = [0, first_split_sz]
    while split_boundaries[-1]<num_classes:
        split_boundaries.append(split_boundaries[-1]+other_split_sz)
    print('split_boundaries:',split_boundaries)
    assert split_boundaries[-1]==num_classes,'Invalid split size'

    # Assign classes to each splits
    # Create the dict: {split_name1:[2,6,7], split_name2:[0,3,9], ...}
    if not rand_split:
        class_lists = {str(i):list(range(split_boundaries[i-1],split_boundaries[i])) for i in range(1,len(split_boundaries))}
    else:
        randseq = torch.randperm(num_classes)
        class_lists = {str(i):randseq[list(range(split_boundaries[i-1],split_boundaries[i]))].tolist() for i in range(1,len(split_boundaries))}
    print(class_lists)

    # Generate the dicts of splits
    # Ex: {split_name1:dataset_split1, split_name2:dataset_split2, ...}
    train_dataset_splits = {}
    val_dataset_splits = {}
    task_output_space = {}
    for name,class_list in class_lists.items():
        train_dataset_splits[name] = AppendName(Subclass(train_dataset, class_list, remap_class), name)
        val_dataset_splits[name] = AppendName(Subclass(val_dataset, class_list, remap_class), name)
        task_output_space[name] = len(class_list)

    return train_dataset_splits, val_dataset_splits, task_output_space

In [41]:
def MNIST(dataroot, train_aug=False):
    val_transform = transforms.Compose([
        transforms.Pad(2, fill=0, padding_mode='constant'),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5]),
    ])
    train_transform = val_transform
    if train_aug:
        train_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ])

    train_dataset = torchvision.datasets.MNIST(
        root=dataroot,
        train=True,
        download=True,
        transform=train_transform
    )
    train_dataset = CacheClassLabel(train_dataset)

    val_dataset = torchvision.datasets.MNIST(
        dataroot,
        train=False,
        transform=val_transform
    )
    val_dataset = CacheClassLabel(val_dataset)

    return train_dataset, val_dataset

In [42]:
train_dataset, val_dataset = MNIST('./data', False)

./data
./data


In [43]:
train_dataset_splits, val_dataset_splits, task_output_space = SplitGen(train_dataset, val_dataset,
                                                                          first_split_sz=2,
                                                                          other_split_sz=2,
                                                                          rand_split=False,
                                                                          remap_class=False)

split_boundaries: [0, 2, 4, 6, 8, 10]
{'1': [0, 1], '2': [2, 3], '3': [4, 5], '4': [6, 7], '5': [8, 9]}


In [44]:
class MLP(nn.Module):
    def __init__(self, out_dim=10, in_channel=1, img_sz=32, hidden_dim=256):
        super(MLP, self).__init__()
        self.in_dim = in_channel*img_sz*img_sz
        self.linear = nn.Sequential(
            nn.Linear(self.in_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
        )
        self.last = nn.Linear(hidden_dim, out_dim)

    def features(self, x):
        x = self.linear(x.view(-1,self.in_dim))
        return x

    def logits(self, x):
        x = self.last(x)
        return x

    def forward(self, x):
        x = self.features(x)
        x = self.logits(x)
        return x

In [45]:
def MLP400():
    return MLP(hidden_dim=400)

In [46]:
class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = float(self.sum) / self.count

In [47]:
def accuracy(output, target):
    with torch.no_grad():
        _, predicted = torch.max(output.data, 1)
        batch_size = target.size(0)
        correct = (predicted == target).sum().item() * 100
    return correct / batch_size

In [48]:
def accumulate_acc(output, target, meter):
    acc = accuracy(output, target)
    meter.update(acc, len(target))
    return meter

In [49]:
def criterion_fn(criterion, preds, targets, valid_out_dim):
    if valid_out_dim != 0:
        pred = preds[:,:valid_out_dim]
    loss = criterion(pred, targets)
    return loss

In [50]:
def train_on_task(follower, train_loader, optimizer_F, optimizer_L, criterion, 
                  valid_out_dim, best_model_wts, task_num, task_names):
    leader = MLP400().to(device)
    optimizer_L = torch.optim.Adam(leader.parameters(), 0.0005)
    best_loss = float('inf')
    with torch.no_grad(): 
        if (best_model_wts):
            leader.load_state_dict(best_model_wts)
        else:
            leader.load_state_dict(follower.state_dict())

    for epoch in range(epoches):
        print(f'Epoch: [ {epoch} / {epoches} ]')
        train_acc = AverageMeter()
        batch_num = 0
        for images, labels, _ in train_loader:
            images, labels = images.to(device), labels.to(device)
            follower.train()
            with torch.no_grad():
                leader_outputs = leader(images)

            follower_outputs = follower(images)
            
            c_loss_F = criterion_fn(criterion, follower_outputs, labels, valid_out_dim)
            reg_F = torch.mean(torch.abs(follower_outputs - leader_outputs))
            loss_F = c_loss_F + reg_F

            optimizer_F.zero_grad()
            loss_F.backward()
            optimizer_F.step()

            leader.train()
            optimizer_L.zero_grad()
            leader_outputs = leader(images)

            c_loss_L = criterion_fn(criterion, leader_outputs, labels, valid_out_dim)
            reg_L = torch.mean(torch.abs(follower_outputs.detach() - leader_outputs))
            loss_L = c_loss_L + reg_L

            optimizer_L.zero_grad()
            loss_L.backward()
            optimizer_L.step()

            train_acc = accumulate_acc(follower_outputs, labels, train_acc)


            batch_num += 1
    '''
    with torch.no_grad():
        follower.load_state_dict(best_model_wts) 
    '''
        
    return leader.state_dict(), best_loss

In [51]:
def eval(acc_table, model, train_name, task_names, task_index):
    acc_table[train_name] = OrderedDict()

    for j in range(task_index+1):
        val_name = task_names[j]
        val_data = val_dataset_splits[val_name]
        val_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size, shuffle=False)
        model.eval()
        val_acc = AverageMeter()
        with torch.no_grad():
            for i, (input, target, _) in enumerate(val_loader):
                    input, target = input.to(device), target.to(device)
                    output = model(input)
                    val_acc = accumulate_acc(output, target, val_acc)

        acc_table[val_name][train_name] = val_acc.avg
    
    return acc_table


In [52]:
def train(task_names):
    leader_acc_table = OrderedDict()
    follower_acc_table = OrderedDict()
    valid_out_dim = 0

    model = MLP400().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer_F = torch.optim.Adam(model.parameters(), 0.0005)

    best_model_wts = None
    best_loss = float('inf')
    for i in range(len(task_names)):
        valid_out_dim += 2
        train_name = task_names[i]
        train_loader = torch.utils.data.DataLoader(train_dataset_splits[train_name], batch_size=batch_size, shuffle=True)
        
        print(f'=====Task: {train_name}=====')
        best_model_wts, best_loss = train_on_task(model, train_loader, optimizer_F, None, criterion, valid_out_dim, best_model_wts, i, task_names)
    
        follower_acc_table[train_name] = OrderedDict()

        leader = MLP400().to(device)
        leader.load_state_dict(best_model_wts)
        eval(follower_acc_table, model, train_name, task_names, i)
        eval(leader_acc_table, leader, train_name, task_names, i)

        print(follower_acc_table)
        print(leader_acc_table)

    avg_acc_history = [0] * len(task_names)
    for i in range(len(task_names)):
        train_name = task_names[i]
        cls_acc_sum = 0
        for j in range(i + 1):
            val_name = task_names[j]
            cls_acc_sum += follower_acc_table[val_name][train_name]

        avg_acc_history[i] = cls_acc_sum / (i + 1)
        print('follower Task', train_name, 'average acc:', avg_acc_history[i])
    
    leader_avg_acc_history = [0] * len(task_names)
    for i in range(len(task_names)):
        train_name = task_names[i]
        cls_acc_sum = 0
        for j in range(i + 1):
            val_name = task_names[j]
            cls_acc_sum += leader_acc_table[val_name][train_name]

        leader_avg_acc_history[i] = cls_acc_sum / (i + 1)
        print('leader Task', train_name, 'average acc:', leader_avg_acc_history[i])
    
    return avg_acc_history, leader_avg_acc_history

In [53]:
task_names = sorted(list(task_output_space.keys()), key=int)
print('Task order:',task_names)

Task order: ['1', '2', '3', '4', '5']


In [54]:
avg_acc_history,leader_avg_acc_history = train(task_names)

=====Task: 1=====
Epoch: [ 0 / 1 ]


OrderedDict([('1', OrderedDict([('1', 99.90543735224587)]))])
OrderedDict([('1', OrderedDict([('1', 99.90543735224587)]))])
=====Task: 2=====
Epoch: [ 0 / 1 ]
OrderedDict([('1', OrderedDict([('1', 99.90543735224587), ('2', 0.0)])), ('2', OrderedDict([('2', 98.92262487757101)]))])
OrderedDict([('1', OrderedDict([('1', 99.90543735224587), ('2', 0.0)])), ('2', OrderedDict([('2', 99.06953966699314)]))])
=====Task: 3=====
Epoch: [ 0 / 1 ]
OrderedDict([('1', OrderedDict([('1', 99.90543735224587), ('2', 0.0), ('3', 0.0)])), ('2', OrderedDict([('2', 98.92262487757101), ('3', 0.0)])), ('3', OrderedDict([('3', 99.73319103521878)]))])
OrderedDict([('1', OrderedDict([('1', 99.90543735224587), ('2', 0.0), ('3', 0.0)])), ('2', OrderedDict([('2', 99.06953966699314), ('3', 0.0)])), ('3', OrderedDict([('3', 99.73319103521878)]))])
=====Task: 4=====
Epoch: [ 0 / 1 ]
OrderedDict([('1', OrderedDict([('1', 99.90543735224587), ('2', 0.0), ('3', 0.0), ('4', 0.0)])), ('2', OrderedDict([('2', 98.92262487757101