In [1]:
import torch
from copy import deepcopy

from torch import nn
from tqdm import tqdm
from torch.nn import functional as F
import torch.optim as optim
import torch.utils.data

from torchvision import datasets, transforms
import random
import PIL.Image as Image


In [83]:
epochs = 10
lr = 1e-3
batch_size = 64
sample_size = 200
hidden_size = 256
num_task = 5
epochs_interval = 1
seed = 42

In [12]:
class SplitMNIST(datasets.MNIST):
    tasks = {
        0: [0,1],
        1: [2,3],
        2: [4,5],
        3: [6,7],
        4: [8,9],
    }
    
    def __init__(self, root="/vandal/datasets", train=True, task=0, cum=True):
        super().__init__(root, train, download=True)
        if not train and cum:
            classes = [i for t in range(task + 1) for i in SplitMNIST.tasks[t]]
        else:
            classes = [i for i in SplitMNIST.tasks[task]]
        self.idx = [i for i in range(len(self.targets)) if self.targets[i] in classes]
        self.transform = transforms.ToTensor()
        self.task = task
        self.train = train
    
    def __len__(self):
        return len(self.idx)

    def __getitem__(self, index):
        img, target = self.data[self.idx[index]], self.targets[self.idx[index]]
        img = Image.fromarray(img.numpy(), mode='L')
        img = self.transform(img)
        
        if self.train:
            target = target - task*2
        
        return img.view(-1), target


In [13]:
class MLP(nn.Module):
    def __init__(self, hidden_size=256, tasks=5, task_size=2):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(28 * 28, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.classifier = nn.ModuleList(
            [nn.Linear(hidden_size, task_size) for _ in range(tasks)])

    def forward(self, input, task):
        x = F.relu(self.fc1(input))
        x = F.relu(self.fc2(x))
        x = self.classifier[task](x)
        return x
    
    def predict(self, input):
        x = F.relu(self.fc1(input))
        x = F.relu(self.fc2(x))
        x = torch.cat([fc(x) for fc in self.classifier], dim=-1)
        return x


In [96]:
class OnlineEWC(object):
    def __init__(self, model, model_old, device, alpha=0.01, fisher=None):

        self.model = model
        self.model_old = model_old
        self.model_old_dict = self.model_old.state_dict()

        self.device = device
        
        self.fisher = {}
        if fisher is not None: # initialize as old Fisher Matrix
            self.fisher_old = fisher
            for key in self.fisher_old:
                self.fisher_old[key].requires_grad = False
                self.fisher_old[key] = self.fisher_old[key].to(device)
                self.fisher[key] = torch.zeros_like(fisher[key], device=device) 
        else: # initialize a new Fisher Matrix
            self.fisher_old = None
            self.fisher = {n:torch.zeros_like(p, device=device, requires_grad=False) 
                           for n, p in self.model.named_parameters() if p.requires_grad} 
            
    def update(self, dataloader, task):
        self.model.eval()
        for input, target in dataloader:
            self.model.zero_grad()
            input = input.to(self.device)
            target = target.to(self.device)
            
            output = self.model(input, task)
                        
            loss = F.cross_entropy(output, target) # Why they use entropy loss?
            loss.backward()

            for n, p in self.model.named_parameters():
                if p.grad is not None:
                    self.fisher[n] += p.grad.data.clone().pow(2) / len(dataloader)
    
    def get_fisher(self):
        return self.fisher # return the new Fisher matrix

    def penalty(self):
        loss = 0
        if self.fisher_old is None:
            return 0.
        for n, p in self.model.named_parameters():
            loss += (self.fisher_old[n] * (p - self.model_old_dict[n]).pow(2)).sum()
        return loss

In [59]:
EPS = 1e-20
def normalize_fn(fisher):
    return (fisher - fisher.min()) / (fisher.max() - fisher.min() + EPS)

class EWCpp(object):
    def __init__(self, model, model_old, device, alpha=0.9, fisher=None, normalize=True):

        self.model = model
        self.model_old = model_old
        self.model_old_dict = self.model_old.state_dict()

        self.device = device
        self.alpha = alpha
        self.normalize = normalize
        
        if fisher is not None: # initialize as old Fisher Matrix
            self.fisher_old = fisher
            for key in self.fisher_old:
                self.fisher_old[key].requires_grad = False
                self.fisher_old[key] = self.fisher_old[key].to(device)
            self.fisher = deepcopy(fisher)
            if normalize:
                self.fisher_old = {n: normalize_fn(self.fisher_old[n]) for n in self.fisher_old}

        else: # initialize a new Fisher Matrix
            self.fisher_old = None
            self.fisher = {n:torch.zeros_like(p, device=device, requires_grad=False) 
                           for n, p in self.model.named_parameters() if p.requires_grad} 

    def update(self):
        # suppose model have already grad computed, so we can directly update the fisher by getting model.parameters
        for n, p in self.model.named_parameters():
            if p.grad is not None:
                self.fisher[n] = (self.alpha * p.grad.data.pow(2)) + ((1-self.alpha)*self.fisher[n])

    def get_fisher(self):
        return self.fisher # return the new Fisher matrix

    def penalty(self):
        loss = 0
        if self.fisher_old is None:
            return 0.
        for n, p in self.model.named_parameters():
            loss += (self.fisher_old[n] * (p - self.model_old_dict[n]).pow(2)).sum()
        return loss


# start training procedure

In [61]:
def get_mnist():
    train_loader = {}
    test_loader_no_cum = {}
    test_loader = {}

    for i in range(num_task):
        train_loader[i] = torch.utils.data.DataLoader(SplitMNIST(train=True, task=i),
                                                      batch_size=batch_size,
                                                      num_workers=4)
        test_loader[i] = torch.utils.data.DataLoader(SplitMNIST(train=False, task=i),
                                                     batch_size=batch_size)
        test_loader_no_cum[i] = torch.utils.data.DataLoader(SplitMNIST(train=False, task=i, cum=False),
                                                     batch_size=batch_size)
    return train_loader, test_loader, test_loader_no_cum

def test(model: nn.Module, data_loader: torch.utils.data.DataLoader):
    model.eval()
    correct = 0.
    size = float(0.)
    for input, target in data_loader:
        input, target = input.cuda(), target.cuda()
        output = model.predict(input)
        _, prediction = output.max(1)
        correct += torch.sum(prediction.eq(target)).float()
        size += len(target)
    return correct / size

train_loader, test_loader, test_loader_no_cum = get_mnist()

# Vanilla

In [88]:
def normal_train(model, optimizer, data_loader, task):
    model.train()
    epoch_loss = 0
    for input, target in data_loader:
        input, target = input.cuda(), target.cuda()
        optimizer.zero_grad()
        output = model(input, task)
        loss = F.cross_entropy(output, target)
        epoch_loss += loss
        loss.backward()
        optimizer.step()
    return epoch_loss / len(data_loader)

def standard_process(model, epochs, task):
    optimizer = optim.Adam(params=model.parameters(), lr=lr)
    
    for epoch in range(epochs):
        loss = normal_train(model, optimizer, train_loader[task], task)
        if epoch % epochs_interval == 0:
            print(f"Epoch {epoch + 1}: Loss {loss}")

    print(f"Acc task {task} is {test(model, test_loader[task])}")      
    return model


In [91]:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
model = MLP().cuda()
EPS = 1e-20

for task in range(num_task):
    print("")
    model = standard_process(model, epochs, task=task)
    


Epoch 1: Loss 0.026549091562628746
Epoch 2: Loss 0.0031973535660654306
Epoch 3: Loss 0.0021754407789558172
Epoch 4: Loss 0.002085939748212695
Epoch 5: Loss 0.0027679249178618193
Epoch 6: Loss 0.0017462905962020159
Epoch 7: Loss 0.0006506486097350717
Epoch 8: Loss 0.0009288053261116147
Epoch 9: Loss 0.00026676274137571454
Epoch 10: Loss 0.0009421079885214567
Acc task 0 is 0.9985815286636353

Epoch 1: Loss 0.11990812420845032
Epoch 2: Loss 0.03745390102267265
Epoch 3: Loss 0.018878968432545662
Epoch 4: Loss 0.011726373806595802
Epoch 5: Loss 0.008084713481366634
Epoch 6: Loss 0.006686045788228512
Epoch 7: Loss 0.005309024825692177
Epoch 8: Loss 0.004606361500918865
Epoch 9: Loss 0.004533804953098297
Epoch 10: Loss 0.0033910039346665144
Acc task 1 is 0.7825354337692261

Epoch 1: Loss 0.05751190707087517
Epoch 2: Loss 0.0064352150075137615
Epoch 3: Loss 0.0018309748265892267
Epoch 4: Loss 0.0004091352748218924
Epoch 5: Loss 0.00011812253796961159
Epoch 6: Loss 5.470484757097438e-05
Epoch 

In [92]:
# PER TASK ACCURACY
print("Per task Acc")
for t in range(num_task):
    print(f"{t} : {test(model, test_loader_no_cum[t]).item() :.3f}")
    
# TOTAL ACCURACY
print("Cumulative Acc")
print(f"{test(model, test_loader[4]).item() :.3f}")

Per task Acc
0 : 0.283
1 : 0.356
2 : 0.585
3 : 0.418
4 : 0.366
Cumulative Acc
0.397


# Online EWC
From: https://arxiv.org/pdf/1805.06370.pdf

In [93]:
def online_ewc_train(model, optimizer, data_loader, ewc, importance, task):
    model.train()
    epoch_loss = 0
    for input, target in data_loader:
        input, target = input.cuda(), target.cuda()
        optimizer.zero_grad()
        output = model(input, task)
        
        loss = F.cross_entropy(output, target)
        loss.backward()
        
        loss_ewc = importance * ewc.penalty()
        if loss_ewc != 0.:
            loss_ewc.backward()
        
        epoch_loss += loss
        
        optimizer.step()
    return epoch_loss / len(data_loader)


def online_ewc_process(model, ewc, epochs, importance, task):

    optimizer = optim.Adam(params=model.parameters(), lr=lr)
    
    for epoch in range(epochs):
        loss = online_ewc_train(model, optimizer, train_loader[task], ewc, importance, task)
        if epoch % epochs_interval == 0:
            print(f"Epoch {epoch + 1}: Loss {loss}")
    
    ewc.update(train_loader[task], task)
    # print({k:v.mean().item() for k,v in ewc.get_fisher().items()})
    # print({k:v.max().item() for k,v in ewc.get_fisher().items()})

    print(f"Acc task {task} is {test(model, test_loader[task])}")
       
    return model, ewc

In [97]:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
model = MLP().cuda()
fisher = None
importance = 75000
EPS = 1e-20
phi = 0.95

for task in range(num_task):
    model_old = deepcopy(model)
    for p in model_old.parameters():
        p.requires_grad = False

    print("")
    ewc = OnlineEWC(model, model_old, "cuda", fisher=fisher)
    model, ewc = online_ewc_process(model, ewc, epochs, task=task, importance=importance)
    
    if fisher is None:
        fisher = deepcopy(ewc.get_fisher())
       
        fisher = {n: (fisher[n] - fisher[n].min()) / (fisher[n].max() - fisher[n].min() + EPS) for n in fisher}
        
        print("\n New fisher (normalized):")
        print({n:(p.min().item(), p.median().item(), p.max().item()) for n,p in fisher.items()})
    else:
        new_fisher = ewc.get_fisher()
        for n in fisher:
            new_fisher[n] = (new_fisher[n] - new_fisher[n].min()) / (new_fisher[n].max() - new_fisher[n].min() + EPS)
            fisher[n] = phi*fisher[n] + new_fisher[n]
        print("\n New fisher (normalized):")
        print({n:(p.min().item(), p.median().item(), p.max().item()) for n,p in fisher.items()})



Epoch 1: Loss 0.026549091562628746
Epoch 2: Loss 0.0031973535660654306
Epoch 3: Loss 0.0021754407789558172
Epoch 4: Loss 0.002085939748212695
Epoch 5: Loss 0.0027679249178618193
Epoch 6: Loss 0.0017462905962020159
Epoch 7: Loss 0.0006506486097350717
Epoch 8: Loss 0.0009288053261116147
Epoch 9: Loss 0.00026676274137571454
Epoch 10: Loss 0.0009421079885214567
Acc task 0 is 0.9985815286636353

 New fisher (normalized):
{'fc1.weight': (0.0, 2.158017391697066e-15, 1.0), 'fc1.bias': (0.0, 0.013740774244070053, 1.0), 'fc2.weight': (0.0, 3.183357193847769e-06, 1.0), 'fc2.bias': (0.0, 0.015340087935328484, 1.0), 'classifier.0.weight': (0.0, 0.0058732349425554276, 1.0), 'classifier.0.bias': (0.0, 0.0, 1.0), 'classifier.1.weight': (0.0, 0.0, 0.0), 'classifier.1.bias': (0.0, 0.0, 0.0), 'classifier.2.weight': (0.0, 0.0, 0.0), 'classifier.2.bias': (0.0, 0.0, 0.0), 'classifier.3.weight': (0.0, 0.0, 0.0), 'classifier.3.bias': (0.0, 0.0, 0.0), 'classifier.4.weight': (0.0, 0.0, 0.0), 'classifier.4.bias

In [98]:
# PER TASK ACCURACY
print("Per task Acc")
for t in range(num_task):
    print(f"{t} : {test(model, test_loader_no_cum[t]).item() :.3f}")
    
# TOTAL ACCURACY
print("Cumulative Acc")
print(f"{test(model, test_loader[4]).item() :.3f}")

Per task Acc
0 : 0.908
1 : 0.663
2 : 0.737
3 : 0.266
4 : 0.102
Cumulative Acc
0.539


# EWC ++
From: http://openaccess.thecvf.com/content_ECCV_2018/papers/Arslan_Chaudhry__Riemannian_Walk_ECCV_2018_paper.pdf

In [99]:
def ewcpp_train(model, optimizer, data_loader, ewc, importance, task):
    model.train()
    epoch_loss = 0
    for input, target in data_loader:
        input, target = input.cuda(), target.cuda()
        optimizer.zero_grad()
        output = model(input, task)
        loss = F.cross_entropy(output, target)
        epoch_loss += loss
        loss.backward()
        ewc.update()
        loss_ewc = importance * ewc.penalty()
        if loss_ewc != 0:
            loss_ewc.backward()
        optimizer.step()
    return epoch_loss / len(data_loader)

def ewcpp_process(model, ewc, epochs, importance, task):

    optimizer = optim.Adam(params=model.parameters(), lr=lr)
    
    for epoch in range(epochs):
        loss = ewcpp_train(model, optimizer, train_loader[task], ewc, importance, task)
        if epoch % epochs_interval == 0:
            print(f"Epoch {epoch + 1}: Loss {loss}")

    print(f"Acc task {task} is {test(model, test_loader[task])}")

    return model, ewc

In [100]:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

model = MLP().cuda()
fisher = None
importance = 75000
EPS = 1e-20
alpha = 0.9
normalize = True

for task in range(num_task):
    model_old = deepcopy(model)
    for p in model_old.parameters():
        p.requires_grad = False

    print("")
    ewc = EWCpp(model, model_old, "cuda", fisher=fisher, alpha=alpha, normalize=normalize)
    model, ewc = ewcpp_process(model, ewc, epochs, task=task, importance=importance)

    fisher = deepcopy(ewc.get_fisher())

    print("\n New fisher (not normalized):")
    print({n:(p.min().item(), p.median().item(), p.max().item()) for n,p in fisher.items()})



Epoch 1: Loss 0.026549091562628746
Epoch 2: Loss 0.0031973535660654306
Epoch 3: Loss 0.0021754407789558172
Epoch 4: Loss 0.002085939748212695
Epoch 5: Loss 0.0027679249178618193
Epoch 6: Loss 0.0017462905962020159
Epoch 7: Loss 0.0006506486097350717
Epoch 8: Loss 0.0009288053261116147
Epoch 9: Loss 0.00026676274137571454
Epoch 10: Loss 0.0009421079885214567
Acc task 0 is 0.9985815286636353

 New fisher (not normalized):
{'fc1.weight': (0.0, 0.0, 1.955451356394633e-09), 'fc1.bias': (0.0, 1.7514060016312174e-15, 1.9745394208570133e-09), 'fc2.weight': (0.0, 1.1023583941414215e-22, 5.509139255899242e-10), 'fc2.bias': (0.0, 1.504856352377293e-14, 1.5372128570056987e-10), 'classifier.0.weight': (0.0, 2.1230128657671876e-13, 1.3325110792550277e-08), 'classifier.0.bias': (3.8616163600124764e-09, 3.8616163600124764e-09, 3.863255493286033e-09), 'classifier.1.weight': (0.0, 0.0, 0.0), 'classifier.1.bias': (0.0, 0.0, 0.0), 'classifier.2.weight': (0.0, 0.0, 0.0), 'classifier.2.bias': (0.0, 0.0, 0.

In [102]:
# PER TASK ACCURACY
print("Per task Acc")
for t in range(num_task):
    print(f"{t} : {test(model, test_loader_no_cum[t]).item() :.3f}")
    
# TOTAL ACCURACY
print("Cumulative Acc")
print(f"{test(model, test_loader[4]).item() :.3f}")

Per task Acc
0 : 0.468
1 : 0.667
2 : 0.747
3 : 0.364
4 : 0.380
Cumulative Acc
0.523
