In [4]:
import copy
import time
import torch
import wandb
import atexit
import random
import easydict
import argparse
import numpy as np
import torch.nn as nn
import torch.optim as optim
from datetime import datetime
import torch.nn.functional as F
import utils.data as data_loader
from models.lenet import LeNet

from tqdm import tqdm_notebook as tqdm
from methods.ewc import Fisher_KFAC_reg, Fisher_KFAC_reg_id, Fisher_EKFAC_reg, Fisher_KFAC_reg_uncertainty
from utils.utils import Edl_mse_loss, Edl_digamma_loss, relu_evidence, Edl_log_loss

import os
import numpy as np
from utils.clb.utils.metrics import accuracy, AverageMeter
from utils.dataloaders import splitmnist as dataloader

In [5]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LOADER_KWARGS = {'num_workers': 1, 'pin_memory': True} if torch.cuda.is_available() else {}
print(torch.cuda.is_available())

False


In [6]:
# Prepare data for chosen experiment
(train_datasets, valid_datasets, test_datasets), config, classes_per_task = dataloader.get_multitask_experiment(
    name='splitMNIST', scenario='domain', tasks=2, data_dir='./data',
    verbose=True, exception=False)

  --> mnist28: 'train'-dataset consisting of 60000 samples
  --> mnist28: 'test'-dataset consisting of 10000 samples




In [7]:
class MLP(nn.Module):             
    def __init__(self,class_numb, h_hidden):
        super().__init__()
        self.class_numb = class_numb
        self.h_hidden = h_hidden
        self.layer_in = nn.Linear(28*28, h_hidden)
        self.out = nn.ModuleList([nn.Linear(self.h_hidden, self.class_numb),nn.Linear(self.h_hidden, self.class_numb)])
        
    def forward(self, x, task):
        x = x.view(-1, 28 * 28)
        x = F.relu(self.layer_in(x))
        output = self.out[task](x)
        return output

In [8]:
def train(net, optimizer, epoch, train_loader, task):
    net.train()                                                
    for batch_idx, (data, target) in enumerate(tqdm(train_loader)):
        data, target = data.to(DEVICE), target.to(DEVICE)
        net.zero_grad()
        out = net(data, task)
        loss = F.cross_entropy(out, target)
        loss.backward()
        optimizer.step()

def test(net, n_tasks, datasets, current_task=None, split="test"):
    net.eval()
    correct = 0
    test_loss = 0
    acc_current=None
    acc_tasks = []
    if current_task is None:
        for task in range(n_tasks):  
            with torch.no_grad():   
                loader = torch.utils.data.DataLoader(datasets[task], batch_size=64)
                for data, target in loader:
                    data, target = data.to(DEVICE), target.to(DEVICE)
                    output = net(data, task)
                    pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
                    correct += pred.eq(target.view_as(pred)).sum().item()
                print('Task {:.0f}, {} Accuracy: {}/{} ({:.2f}%)\n'.format(task, split, correct, len(loader.dataset),
                    100. * correct / len(loader.dataset)))
                acc_tasks.append(100. * correct / len(loader.dataset))
            correct = 0
            
        wandb.log({
            f"Avv. {split} Accuracy ": np.mean(acc_tasks)})
    else:
        task = current_task
        with torch.no_grad():   
                loader = torch.utils.data.DataLoader(datasets[current_task], batch_size=64)
                for data, target in loader:
                    data, target = data.to(DEVICE), target.to(DEVICE)
                    output = net(data, task)
                    pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
                    correct += pred.eq(target.view_as(pred)).sum().item()
                print('Task {:.0f}, {} Accuracy: {}/{} ({:.2f}%)\n'.format(task, split, correct, len(loader.dataset),
                    100. * correct / len(loader.dataset)))
    
    acc_current = correct / len(loader.dataset)
    return acc_current

In [9]:
def _get_optimizer(model,lr=None):
    if lr is None: lr=LR
    return torch.optim.Adam(model.parameters(),lr=lr)
def get_model(model):
    return copy.deepcopy(model.state_dict())
def set_model_(model,state_dict):
    model.load_state_dict(copy.deepcopy(state_dict))
    return

def wandb_init(net, group, name, config_dict):    
    run = wandb.init(
        project='Growing Net',
        name=name,
        group=group,
        config=config_dict
        #reinit=True,
        #resume="allow",
    )
    wandb.watch(net) 
    print(f"Using wandb. Group name: {group} run name: {name}")
    return run

In [10]:
# Training settings
config = easydict.EasyDict({
    'n_tasks': 2,
    'lr': 0.001,
    'epochs': 5,
    #'lr_factor': 0.1,
    #lr_min=0.00001
    #LR_PATIENCE = 3
})

In [13]:
def main():
    for h_hidden in [10,100,1000,10000, 100000]:
        best_acc = 0
        net = MLP(5,h_hidden)
        optimizer = _get_optimizer(net, config.lr)
        wandb_init(net, f'hs_{h_hidden}', 'run', config)
        try:
            for task in range(config.n_tasks):
                print("Training on task: " + str(task))
                train_loader = torch.utils.data.DataLoader(train_datasets[task], batch_size=64)
                for epoch in range(config.epochs):
                    print("Epoch: "+str(epoch))
                    train(net, optimizer, epoch, train_loader, task)
                    valid_acc = test(net, 2, valid_datasets, current_task=task, split="valid")

                    if valid_acc>best_acc:
                            best_acc=valid_acc
                            best_model=get_model(net)
                set_model_(net, best_model)
                _ = test(net, task+1, test_datasets, current_task=None, split="test")
        except KeyboardInterrupt:
            pass

        # set prior to posterior


In [14]:
main()

[34m[1mwandb[0m: Wandb version 0.8.36 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


Using wandb. Group name: hs_10 run name: run
Training on task: 0
Epoch: 0


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for batch_idx, (data, target) in enumerate(tqdm(train_loader)):


HBox(children=(FloatProgress(value=0.0, max=378.0), HTML(value='')))


Task 0, valid Accuracy: 5613/5982 (93.83%)

Epoch: 1


HBox(children=(FloatProgress(value=0.0, max=378.0), HTML(value='')))


Task 0, valid Accuracy: 5665/5982 (94.70%)

Epoch: 2


HBox(children=(FloatProgress(value=0.0, max=378.0), HTML(value='')))


Task 0, valid Accuracy: 5690/5982 (95.12%)

Epoch: 3


HBox(children=(FloatProgress(value=0.0, max=378.0), HTML(value='')))


Task 0, valid Accuracy: 5702/5982 (95.32%)

Epoch: 4


HBox(children=(FloatProgress(value=0.0, max=378.0), HTML(value='')))


Task 0, valid Accuracy: 5710/5982 (95.45%)

Task 0, test Accuracy: 4748/4979 (95.36%)

Training on task: 1
Epoch: 0


HBox(children=(FloatProgress(value=0.0, max=373.0), HTML(value='')))


Task 1, valid Accuracy: 5747/6018 (95.50%)

Epoch: 1


HBox(children=(FloatProgress(value=0.0, max=373.0), HTML(value='')))


Task 1, valid Accuracy: 5810/6018 (96.54%)

Epoch: 2


HBox(children=(FloatProgress(value=0.0, max=373.0), HTML(value='')))


Task 1, valid Accuracy: 5832/6018 (96.91%)

Epoch: 3


HBox(children=(FloatProgress(value=0.0, max=373.0), HTML(value='')))


Task 1, valid Accuracy: 5842/6018 (97.08%)

Epoch: 4


HBox(children=(FloatProgress(value=0.0, max=373.0), HTML(value='')))


Task 1, valid Accuracy: 5853/6018 (97.26%)

Task 0, test Accuracy: 2137/4979 (42.92%)

Task 1, test Accuracy: 4895/5021 (97.49%)



[34m[1mwandb[0m: Wandb version 0.8.36 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


Using wandb. Group name: hs_100 run name: run
Training on task: 0
Epoch: 0


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for batch_idx, (data, target) in enumerate(tqdm(train_loader)):


HBox(children=(FloatProgress(value=0.0, max=378.0), HTML(value='')))




[34m[1mwandb[0m: Wandb version 0.8.36 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


Using wandb. Group name: hs_1000 run name: run
Training on task: 0
Epoch: 0


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for batch_idx, (data, target) in enumerate(tqdm(train_loader)):


HBox(children=(FloatProgress(value=0.0, max=378.0), HTML(value='')))

KeyboardInterrupt: 