# Introduction

Continual learning for semantic segmentation is a field that has emerged recently and is quickly evolving. As a newcomer to the field, I wanted to quickly get in touched with the popular approaches and see by myself some concrete results. However, if one wants to rerun the experiments made in recent papers, which often use medium- to large-scale datasets such as Pascal-VOC, ADE20K or COCO, it would require several hours of training. 

Therefore, I took inspiration from the [Simple Deep Learning project](https://https://awaywithideas.com/mnist-extended-a-dataset-for-semantic-segmentation-and-object-detection/) by Luke Tonin in which he built MNIST-Extended, a semantic segmentation dataset  made from MNIST. While this is obviously a toy dataset, it has the benefit of giving quick feedback when tinkering with models.

This will also enable to explore common challenges of continual semantic segmentation such as Catastrophic forgetting, background shift, and the various combinations of setups where past/future classes are/are not in images and labeled as background. 

# TODO
1. [X]. Clean Evaluater
2. [X]Test Evaluater
3. [X]Design metrics - How to measure catastrophic forgetting? Visualisation
4. [X]Clean confusion metric printing
5. [X]Experiments illustrating catastrophic forgetting
6. [X]Implement Memory
7. Experiments with vs without memory 
8. Explain background shift
9. Experiments illustrating background shift

In [76]:
from google.colab import drive
drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [77]:
# some initial imports
%load_ext autoreload
%autoreload 2
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import os
import sys
from tqdm import tqdm
import pandas as pd

os.chdir("/content/gdrive/MyDrive/Colab Notebooks/mnist_continual_seg")
sys.path.append("/content/gdrive/MyDrive/Colab Notebooks/simple_deep_learning")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


Now that the segmentation dataset is loaded, we need to adapt it to reflect the continual setup where classes are seen sequentially. To do so, we will implement a processing function that only keeps the groundtruth masks for the current learning step. For instance, if we divide the 10 classes from MNIST in 5 learning steps of 2 classes, then at step #0 the model is only trained to segment digits 0 and 1, while other digits on the image are labeled as background (i.e. label '0'). Then, at step #1 where the model must learn to segment 2's and 3's, digits 0-1 and 4-9 will be labeled as background, and so on.

This scenario reflects the common experimental setup in which segmentation masks are only available for current classes while objects that belong to past and future classes can still appear in scene images (e.g. from Pascal-VOC) but are labeled as background. This is a challenge specific to Continual Semantic Segmentation known as Background shift, which is addressed in several recent works (MiB, PLOP, SSUL, RECALL, etc). 

In [78]:
from trainer import Trainer, GeneticTrainer
from metrics import EvaluaterCallback
from scenarios import ContinualMnist

In [79]:
def increment_task(scenario, trainer, evaluater, memory):
  scenario.next_task()
  task_id = scenario.train_data.curr_task_id
  trainer.next_task(scenario.n_classes_per_task[:task_id+1])
  optimizer = torch.optim.Adam(lr = 0.0005, params=trainer.model.parameters())
  trainer.set_optim(optimizer)
  evaluater.set_model(trainer.model)
  trainer.set_callbacks([evaluater])
  if memory is not None:
    memory.refactor_memory(scenario.n_classes_per_task[task_id])

def meta_train(n_tasks,
               epochs, 
               scenario,
               trainer,
               evaluater,
               memory=None, 
               pass_first_step=False,
               animation_path="animation"):
  def _print_header(scenario):
    print("*******")
    print(f"Task #{t}")
    print("*******")
    print("Classes to learn:")
    print(*[c-1 for c in scenario.train_data.curr_classes])
    print("*******")
  def _print_results():
    print()
    #evaluater.print_metrics()
    res = evaluater.metrics[-1].get_results()
    print("Overall stats")
    df = pd.DataFrame([res["Overall Acc"], res["Mean Acc"], res["Mean IoU"]]).T
    df.columns = ["Overall Acc", "Mean Acc", "Mean IoU"]
    print(df.to_markdown())

    print("Class IoU")
    df = pd.DataFrame(res["Class IoU"].values()).T
    df.columns = np.arange(-1,len(res["Class IoU"].values())-1)
    print(df.to_markdown())
    #print(tabulate(df))

    print("Class Acc")
    df = pd.DataFrame(res["Class Acc"].values()).T
    df.columns = np.arange(-1,len(res["Class IoU"].values())-1)
    print(df.to_markdown())

  def _print_new_task():
    print()
    print("####################################")
    print("Next Task")
    print("####################################")

  sample_memory=False
  
  for t in range(n_tasks):
    _print_header(scenario)
    if t > 0 :
      sample_memory = True
    
    if t > 0 or not pass_first_step:
      for i in tqdm(range(epochs)):
        trainer.train(i, scenario, memory, sample_memory=sample_memory)

    _print_results()
    
    torch.save(model.state_dict(), f"checkpoints/task-{trainer.curr_task}.pth")

    if t < n_tasks:
      _print_new_task()    
      increment_task(scenario = continual_mnist, trainer = trainer, evaluater = evaluater, memory = memory)
  evaluater.create_animation(animation_path, sample_freq=4)

  

#Experiments for Fine-tuning model (naive approach)

In [82]:
from models import simple_seg_model
# Initialize dataloader, optimizer and trainer
_tasks = {0: [0,1], 1: [2,3], 2: [4,5], 3: [6,7], 4: [8,9]}
_offline = {0: list(np.arange(10))}

model = simple_seg_model(n_classes_per_task=[len(_tasks[0])+1])
model = model.cuda()
optimizer = torch.optim.Adam(lr = 0.0005, params=model.parameters())

continual_mnist = ContinualMnist(n_train=1000, n_test=500, batch_size=72, tasks=_tasks)

evaluater = EvaluaterCallback(model, ["acc", "iou", "confusion_matrix"], callback_frequency="step", n_classes=11, save_matrices=True)

trainer = Trainer(model,
                  n_classes=[3],
                  optim=optimizer,
                  callbacks=[evaluater])

In [83]:
meta_train(n_tasks = len(_tasks),
           epochs = 200,
           scenario = continual_mnist, 
           trainer=trainer,
           evaluater=evaluater,
           memory=None,
           animation_path='naive')

*******
Task #0
*******
Classes to learn:
-1 0 1
*******


100%|██████████| 200/200 [01:36<00:00,  2.07it/s]



Overall stats
|    |   Overall Acc |   Mean Acc |   Mean IoU |
|---:|--------------:|-----------:|-----------:|
|  0 |      0.988447 |    0.91122 |   0.848985 |
Class IoU
|    |       -1 |        0 |        1 | 2   | 3   | 4   | 5   | 6   | 7   | 8   | 9   |
|---:|---------:|---------:|---------:|:----|:----|:----|:----|:----|:----|:----|:----|
|  0 | 0.988105 | 0.785866 | 0.772983 | X   | X   | X   | X   | X   | X   | X   | X   |
Class Acc
|    |       -1 |        0 |        1 | 2   | 3   | 4   | 5   | 6   | 7   | 8   | 9   |
|---:|---------:|---------:|---------:|:----|:----|:----|:----|:----|:----|:----|:----|
|  0 | 0.994384 | 0.873579 | 0.865696 | X   | X   | X   | X   | X   | X   | X   | X   |

####################################
Next Task
####################################
*******
Task #1
*******
Classes to learn:
-1 2 3
*******


100%|██████████| 200/200 [02:19<00:00,  1.43it/s]



Overall stats
|    |   Overall Acc |   Mean Acc |   Mean IoU |
|---:|--------------:|-----------:|-----------:|
|  0 |      0.956647 |   0.522564 |   0.456441 |
Class IoU
|    |       -1 |   0 |   1 |        2 |       3 | 4   | 5   | 6   | 7   | 8   | 9   |
|---:|---------:|----:|----:|---------:|--------:|:----|:----|:----|:----|:----|:----|
|  0 | 0.956668 |   0 |   0 | 0.627787 | 0.69775 | X   | X   | X   | X   | X   | X   |
Class Acc
|    |       -1 |   0 |   1 |        2 |        3 | 4   | 5   | 6   | 7   | 8   | 9   |
|---:|---------:|----:|----:|---------:|---------:|:----|:----|:----|:----|:----|:----|
|  0 | 0.992808 |   0 |   0 | 0.776251 | 0.843761 | X   | X   | X   | X   | X   | X   |

####################################
Next Task
####################################
*******
Task #2
*******
Classes to learn:
-1 4 5
*******


100%|██████████| 200/200 [02:39<00:00,  1.26it/s]



Overall stats
|    |   Overall Acc |   Mean Acc |   Mean IoU |
|---:|--------------:|-----------:|-----------:|
|  0 |      0.935113 |   0.377915 |   0.320879 |
Class IoU
|    |       -1 |   0 |   1 |   2 |   3 |        4 |        5 | 6   | 7   | 8   | 9   |
|---:|---------:|----:|----:|----:|----:|---------:|---------:|:----|:----|:----|:----|
|  0 | 0.936402 |   0 |   0 |   0 |   0 | 0.663342 | 0.646408 | X   | X   | X   | X   |
Class Acc
|    |       -1 |   0 |   1 |   2 |   3 |        4 |        5 | 6   | 7   | 8   | 9   |
|---:|---------:|----:|----:|----:|----:|---------:|---------:|:----|:----|:----|:----|
|  0 | 0.996012 |   0 |   0 |   0 |   0 | 0.860293 | 0.789098 | X   | X   | X   | X   |

####################################
Next Task
####################################
*******
Task #3
*******
Classes to learn:
-1 6 7
*******


100%|██████████| 200/200 [02:56<00:00,  1.13it/s]



Overall stats
|    |   Overall Acc |   Mean Acc |   Mean IoU |
|---:|--------------:|-----------:|-----------:|
|  0 |      0.916286 |     0.2608 |   0.238806 |
Class IoU
|    |       -1 |   0 |   1 |   2 |   3 |   4 |   5 |        6 |        7 | 8   | 9   |
|---:|---------:|----:|----:|----:|----:|----:|----:|---------:|---------:|:----|:----|
|  0 | 0.915858 |   0 |   0 |   0 |   0 |   0 |   0 | 0.676448 | 0.556944 | X   | X   |
Class Acc
|    |      -1 |   0 |   1 |   2 |   3 |   4 |   5 |        6 |        7 | 8   | 9   |
|---:|--------:|----:|----:|----:|----:|----:|----:|---------:|---------:|:----|:----|
|  0 | 0.99859 |   0 |   0 |   0 |   0 |   0 |   0 | 0.736691 | 0.611921 | X   | X   |

####################################
Next Task
####################################
*******
Task #4
*******
Classes to learn:
-1 8 9
*******


100%|██████████| 200/200 [03:00<00:00,  1.11it/s]



Overall stats
|    |   Overall Acc |   Mean Acc |   Mean IoU |
|---:|--------------:|-----------:|-----------:|
|  0 |       0.89443 |   0.241402 |   0.178254 |
Class IoU
|    |       -1 |   0 |   1 |   2 |   3 |   4 |   5 |   6 |   7 |        8 |        9 |
|---:|---------:|----:|----:|----:|----:|----:|----:|----:|----:|---------:|---------:|
|  0 | 0.902804 |   0 |   0 |   0 |   0 |   0 |   0 |   0 |   0 | 0.499783 | 0.558208 |
Class Acc
|    |       -1 |   0 |   1 |   2 |   3 |   4 |   5 |   6 |   7 |        8 |       9 |
|---:|---------:|----:|----:|----:|----:|----:|----:|----:|----:|---------:|--------:|
|  0 | 0.996378 |   0 |   0 |   0 |   0 |   0 |   0 |   0 |   0 | 0.916309 | 0.74274 |

####################################
Next Task
####################################


#Experiments with Memory model

In [84]:
from models.memory import memory

In [85]:
from models import simple_seg_model
# Initialize dataloader, optimizer and trainer
_tasks = {0: [0,1], 1: [2,3], 2: [4,5], 3: [6,7], 4: [8,9]}
_offline = {0: list(np.arange(10))}

model = simple_seg_model(n_classes_per_task=[len(_tasks[0])+1])
model = model.cuda()
optimizer = torch.optim.Adam(lr = 0.0005, params=model.parameters())

#continual_mnist = ContinualMnist(n_train=1000, n_test=500, batch_size=72, tasks=_tasks)

evaluater = EvaluaterCallback(model, ["acc", "iou", "confusion_matrix"], callback_frequency="step", n_classes=11, save_matrices=True)

mem = memory(images_shape=(1,60,60),
             masks_shape=(60,60),
             n_classes = 2,
             batch_size=72,
             memory_size=1000)

trainer = Trainer(model,
                  n_classes=[3],
                  optim=optimizer,
                  callbacks=[evaluater])

In [None]:
meta_train(n_tasks = len(_tasks),
           epochs = 200,
           scenario = continual_mnist, 
           trainer=trainer,
           evaluater=evaluater,
           memory=mem,
           animation_path="memory_100e_72b_72bm")

*******
Task #0
*******
Classes to learn:
-1 0 1
*******


100%|██████████| 200/200 [03:13<00:00,  1.04it/s]



Overall stats
|    |   Overall Acc |   Mean Acc |   Mean IoU |
|---:|--------------:|-----------:|-----------:|
|  0 |      0.894772 |   0.245599 |   0.212156 |
Class IoU
|    |      -1 |        0 |        1 |   2 |   3 |   4 |   5 |   6 |   7 |   8 |   9 |
|---:|--------:|---------:|---------:|----:|----:|----:|----:|----:|----:|----:|----:|
|  0 | 0.89533 | 0.731385 | 0.706998 |   0 |   0 |   0 |   0 |   0 |   0 |   0 |   0 |
Class Acc
|    |       -1 |        0 |        1 |   2 |   3 |   4 |   5 |   6 |   7 |   8 |   9 |
|---:|---------:|---------:|---------:|----:|----:|----:|----:|----:|----:|----:|----:|
|  0 | 0.998337 | 0.888966 | 0.814283 |   0 |   0 |   0 |   0 |   0 |   0 |   0 |   0 |

####################################
Next Task
####################################
*******
Task #1
*******
Classes to learn:
-1 2 3
*******


100%|██████████| 200/200 [03:48<00:00,  1.14s/it]



Overall stats
|    |   Overall Acc |   Mean Acc |   Mean IoU |
|---:|--------------:|-----------:|-----------:|
|  0 |      0.904283 |   0.294001 |   0.275561 |
Class IoU
|    |       -1 |        0 |        1 |        2 |        3 |   4 |   5 |   6 |   7 |   8 |   9 |
|---:|---------:|---------:|---------:|---------:|---------:|----:|----:|----:|----:|----:|----:|
|  0 | 0.903418 | 0.663515 | 0.547239 | 0.432992 | 0.484005 |   0 |   0 |   0 |   0 |   0 |   0 |
Class Acc
|    |       -1 |        0 |        1 |       2 |        3 |   4 |   5 |   6 |   7 |   8 |   9 |
|---:|---------:|---------:|---------:|--------:|---------:|----:|----:|----:|----:|----:|----:|
|  0 | 0.999185 | 0.686422 | 0.564824 | 0.46302 | 0.520563 |   0 |   0 |   0 |   0 |   0 |   0 |

####################################
Next Task
####################################
*******
Task #2
*******
Classes to learn:
-1 4 5
*******


100%|██████████| 200/200 [03:51<00:00,  1.16s/it]



Overall stats
|    |   Overall Acc |   Mean Acc |   Mean IoU |
|---:|--------------:|-----------:|-----------:|
|  0 |      0.903197 |   0.302314 |   0.270565 |
Class IoU
|    |      -1 |        0 |        1 |        2 |        3 |      4 |        5 |   6 |   7 |   8 |   9 |
|---:|--------:|---------:|---------:|---------:|---------:|-------:|---------:|----:|----:|----:|----:|
|  0 | 0.90572 | 0.410376 | 0.140271 | 0.276468 | 0.180908 | 0.5391 | 0.523378 |   0 |   0 |   0 |   0 |
Class Acc
|    |       -1 |        0 |        1 |        2 |        3 |       4 |        5 |   6 |   7 |   8 |   9 |
|---:|---------:|---------:|---------:|---------:|---------:|--------:|---------:|----:|----:|----:|----:|
|  0 | 0.998396 | 0.413438 | 0.140731 | 0.315535 | 0.192333 | 0.59892 | 0.666106 |   0 |   0 |   0 |   0 |

####################################
Next Task
####################################
*******
Task #3
*******
Classes to learn:
-1 6 7
*******


 86%|████████▋ | 173/200 [03:26<00:32,  1.19s/it]

#Distillation
##Output-level distillation loss

In [None]:
from trainer import Trainer_distillation
from models import simple_seg_model
_tasks = {0: [0,1], 1: [2,3], 2: [4,5], 3: [6,7], 4: [8,9]}
_offline = {0: list(np.arange(10))}

model = simple_seg_model(n_classes_per_task=[len(_tasks[0])+1])
model = model.cuda()
optimizer = torch.optim.Adam(lr = 0.0005, params=model.parameters())

#continual_mnist = ContinualMnist(n_train=1000, n_test=500, batch_size=72, tasks=_tasks)

evaluater = EvaluaterCallback(model, ["acc", "iou", "confusion_matrix"], callback_frequency="step", n_classes=11, save_matrices=True)

trainer = Trainer_distillation(model,
                  n_classes=[3],
                  optim=optimizer,
                  from_new_class = 0,
                  lambda_distill=10,
                  callbacks=[evaluater])

In [None]:
meta_train(n_tasks = len(_tasks),
           epochs = 200,
           scenario = continual_mnist, 
           trainer=trainer,
           evaluater=evaluater,
           animation_path="distillation")

*******
Task #0
*******
Classes to learn:
-1 0 1
*******


100%|██████████| 200/200 [00:39<00:00,  5.00it/s]



Overall stats
|    |   Overall Acc |   Mean Acc |   Mean IoU |
|---:|--------------:|-----------:|-----------:|
|  0 |       0.98498 |   0.913485 |   0.815281 |
Class IoU
|    |      -1 |        0 |        1 | 2   | 3   | 4   | 5   | 6   | 7   | 8   | 9   |
|---:|--------:|---------:|---------:|:----|:----|:----|:----|:----|:----|:----|:----|
|  0 | 0.98449 | 0.758311 | 0.703042 | X   | X   | X   | X   | X   | X   | X   | X   |
Class Acc
|    |       -1 |        0 |        1 | 2   | 3   | 4   | 5   | 6   | 7   | 8   | 9   |
|---:|---------:|---------:|---------:|:----|:----|:----|:----|:----|:----|:----|:----|
|  0 | 0.990254 | 0.896697 | 0.853504 | X   | X   | X   | X   | X   | X   | X   | X   |

####################################
Next Task
####################################
UPawfeDDD
3
*******
Task #1
*******
Classes to learn:
-1 2 3
*******


100%|██████████| 200/200 [01:11<00:00,  2.78it/s]



Overall stats
|    |   Overall Acc |   Mean Acc |   Mean IoU |
|---:|--------------:|-----------:|-----------:|
|  0 |      0.965461 |   0.872689 |   0.662218 |
Class IoU
|    |       -1 |       0 |       1 |       2 |        3 | 4   | 5   | 6   | 7   | 8   | 9   |
|---:|---------:|--------:|--------:|--------:|---------:|:----|:----|:----|:----|:----|:----|
|  0 | 0.968091 | 0.67949 | 0.52815 | 0.54181 | 0.593548 | X   | X   | X   | X   | X   | X   |
Class Acc
|    |       -1 |        0 |        1 |        2 |        3 | 4   | 5   | 6   | 7   | 8   | 9   |
|---:|---------:|---------:|---------:|---------:|---------:|:----|:----|:----|:----|:----|:----|
|  0 | 0.974173 | 0.946052 | 0.940143 | 0.727951 | 0.775125 | X   | X   | X   | X   | X   | X   |

####################################
Next Task
####################################
UPawfeDDD
6
*******
Task #2
*******
Classes to learn:
-1 4 5
*******


100%|██████████| 200/200 [01:18<00:00,  2.56it/s]



Overall stats
|    |   Overall Acc |   Mean Acc |   Mean IoU |
|---:|--------------:|-----------:|-----------:|
|  0 |      0.926176 |   0.802642 |   0.492761 |
Class IoU
|    |       -1 |        0 |        1 |       2 |        3 |        4 |       5 | 6   | 7   | 8   | 9   |
|---:|---------:|---------:|---------:|--------:|---------:|---------:|--------:|:----|:----|:----|:----|
|  0 | 0.935961 | 0.514151 | 0.350042 | 0.40077 | 0.441187 | 0.407444 | 0.39977 | X   | X   | X   | X   |
Class Acc
|    |       -1 |        0 |        1 |        2 |        3 |       4 |        5 | 6   | 7   | 8   | 9   |
|---:|---------:|---------:|---------:|---------:|---------:|--------:|---------:|:----|:----|:----|:----|
|  0 | 0.939084 | 0.958495 | 0.916475 | 0.723172 | 0.733594 | 0.71685 | 0.630822 | X   | X   | X   | X   |

####################################
Next Task
####################################
UPawfeDDD
9
*******
Task #3
*******
Classes to learn:
-1 6 7
*******


100%|██████████| 200/200 [01:27<00:00,  2.29it/s]



Overall stats
|    |   Overall Acc |   Mean Acc |   Mean IoU |
|---:|--------------:|-----------:|-----------:|
|  0 |      0.887402 |   0.719497 |   0.400004 |
Class IoU
|    |       -1 |        0 |        1 |        2 |        3 |        4 |        5 |        6 |        7 | 8   | 9   |
|---:|---------:|---------:|---------:|---------:|---------:|---------:|---------:|---------:|---------:|:----|:----|
|  0 | 0.907854 | 0.461622 | 0.219697 | 0.328574 | 0.359816 | 0.371722 | 0.344034 | 0.309327 | 0.297388 | X   | X   |
Class Acc
|    |      -1 |        0 |        1 |        2 |        3 |        4 |        5 |        6 |        7 | 8   | 9   |
|---:|--------:|---------:|---------:|---------:|---------:|---------:|---------:|---------:|---------:|:----|:----|
|  0 | 0.90888 | 0.921671 | 0.882031 | 0.549024 | 0.716758 | 0.616203 | 0.572393 | 0.717624 | 0.590892 | X   | X   |

####################################
Next Task
####################################
UPawfeDDD
12
*******
Task #4

100%|██████████| 200/200 [01:27<00:00,  2.28it/s]



Overall stats
|    |   Overall Acc |   Mean Acc |   Mean IoU |
|---:|--------------:|-----------:|-----------:|
|  0 |      0.824249 |   0.588334 |   0.316859 |
Class IoU
|    |       -1 |        0 |         1 |        2 |        3 |        4 |        5 |        6 |        7 |       8 |        9 |
|---:|---------:|---------:|----------:|---------:|---------:|---------:|---------:|---------:|---------:|--------:|---------:|
|  0 | 0.859879 | 0.415765 | 0.0954831 | 0.281709 | 0.300046 | 0.324381 | 0.254974 | 0.193044 | 0.265875 | 0.26465 | 0.229645 |
Class Acc
|    |       -1 |        0 |        1 |        2 |        3 |        4 |        5 |        6 |        7 |        8 |       9 |
|---:|---------:|---------:|---------:|---------:|---------:|---------:|---------:|---------:|---------:|---------:|--------:|
|  0 | 0.860122 | 0.903007 | 0.861227 | 0.489431 | 0.706751 | 0.492804 | 0.422249 | 0.242344 | 0.411125 | 0.639549 | 0.44307 |

####################################
Next Task
#####

## Intermediate-level distillation

In [None]:
trainer.from_new_class

6