# Introduction

Continual learning for semantic segmentation is a field that has emerged recently and is quickly evolving. As a newcomer to the field, I wanted to quickly get in touched with the popular approaches and see by myself some concrete results. However, if one wants to rerun the experiments made in recent papers, which often use medium- to large-scale datasets such as Pascal-VOC, ADE20K or COCO, it would require several hours of training. 

Therefore, I took inspiration from the [Simple Deep Learning project](https://https://awaywithideas.com/mnist-extended-a-dataset-for-semantic-segmentation-and-object-detection/) by Luke Tonin in which he built MNIST-Extended, a semantic segmentation dataset  made from MNIST. While this is obviously a toy dataset, it has the benefit of giving quick feedback when tinkering with models.

This will also enable to explore common challenges of continual semantic segmentation such as Catastrophic forgetting, background shift, and the various combinations of setups where past/future classes are/are not in images and labeled as background. 

# TODO
1. [X]. Clean Evaluater
2. [X]Test Evaluater
3. [X]Design metrics - How to measure catastrophic forgetting? Visualisation
4. [X]Clean confusion metric printing
5. [X]Experiments illustrating catastrophic forgetting
6. [X]Implement Memory
7. Experiments with vs without memory 
8. Explain background shift
9. Experiments illustrating background shift

In [1]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [2]:
# some initial imports
%load_ext autoreload
%autoreload 2
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import os
import sys
from tqdm import tqdm
import pandas as pd

os.chdir("/content/gdrive/MyDrive/Colab Notebooks/mnist_continual_seg")
sys.path.append("/content/gdrive/MyDrive/Colab Notebooks/simple_deep_learning")

In [3]:
import simple_deep_learning
from simple_deep_learning.mnist_extended import semantic_segmentation

Now that the segmentation dataset is loaded, we need to adapt it to reflect the continual setup where classes are seen sequentially. To do so, we will implement a processing function that only keeps the groundtruth masks for the current learning step. For instance, if we divide the 10 classes from MNIST in 5 learning steps of 2 classes, then at step #0 the model is only trained to segment digits 0 and 1, while other digits on the image are labeled as background (i.e. label '0'). Then, at step #1 where the model must learn to segment 2's and 3's, digits 0-1 and 4-9 will be labeled as background, and so on.

This scenario reflects the common experimental setup in which segmentation masks are only available for current classes while objects that belong to past and future classes can still appear in scene images (e.g. from Pascal-VOC) but are labeled as background. This is a challenge specific to Continual Semantic Segmentation known as Background shift, which is addressed in several recent works (MiB, PLOP, SSUL, RECALL, etc). 

In [25]:
class Trainer:
  def __init__(self, model, n_classes, optim, curr_task=0, callbacks=[]):#, evaluater=None):
    self.model = model
    self.criterion = nn.CrossEntropyLoss()
    self.n_classes = n_classes
    self.optim = optim
    self.curr_task = curr_task
    self.callbacks = callbacks
    #self.evaluater = evaluater

  def train(self, cur_epoch, scenario, memory=None, save_memory_flag=True, sample_memory = False):
    """Train for 1 epoch."""
    epoch_loss = 0
    for cur_step, (images, labels) in enumerate(scenario.train_stream):
      images, labels = images.cuda().float(), labels.cuda().long()
      images_mem, labels_mem = self._enrich_with_memory(images, labels, memory, sample_memory)
      epoch_loss += self._train_step(images_mem, labels_mem)
      if memory is not None:
        if save_memory_flag:
          for im, lab in zip(images, labels):
            c = list(np.unique(lab.cpu().numpy()))
            c.remove(0)
            c = np.random.choice(c)
            memory.save_memory(im, lab, c-1)
      self._apply_callbacks(scenario, freq="step")
    self._apply_callbacks(scenario, freq="epoch")
    return epoch_loss / len(scenario.train_stream)

  def _enrich_with_memory(self, images, labels, memory=None, sample_memory=False):
    if memory is None or sample_memory==False:
      return images, labels
    else:
      new_images, new_labels = memory.sample_batch()
      new_images, new_labels = new_images.float().cuda(), new_labels.long().cuda()
      all_images = torch.cat([images, new_images])
      all_labels = torch.cat([labels, new_labels])
      return all_images, all_labels

  def _apply_callbacks(self, scenario, freq):
    for c in self.callbacks:
      c.callbacks(scenario, freq=freq)

  def _train_step(self, images, labels):
    """Perform 1 training iteration."""
    self.optim.zero_grad()
    outputs = self.model(images)
    loss = self._compute_loss(outputs, labels)

    loss.backward()
    self.optim.step()
    
    return loss

  def next_task(self, n_classes_per_task):
    """Switch to next task."""
    self.n_classes_per_task = n_classes_per_task
    new_model = self._load_new_model(self.n_classes_per_task)
    self.model = new_model.cuda()
    self.curr_task += 1

  def _load_new_model(self, n_classes_per_task):
    """
    Helper function to create new model and
     load weights of past training.
    """
    new_model = self._make_model(n_classes_per_task)
    path_weights = self._get_path_weights()
    step_checkpoint = torch.load(path_weights, map_location="cpu")
    new_model.load_state_dict(step_checkpoint, strict=False)
    return new_model

  def _make_model(self, n_classes_per_task):
    """Helper function to create new model."""
    m_constructor = type(self.model)
    new_model = m_constructor(n_classes_per_task=n_classes_per_task)
    return new_model

  def _get_path_weights(self):
    """Helper function to get path to previous model's weights."""
    path = f"checkpoints/task-{self.curr_task}.pth"
    return path

  def _compute_loss(self, outputs, labels):
    return self.criterion(outputs, labels)

  def set_optim(self, optim):
    self.optim = optim

  def set_callbacks(self, callbacks):
    self.callbacks = callbacks

In [26]:
from importlib import reload
import metrics
reload(metrics)
reload(metrics.metrics)
from metrics.metrics import *

class EvaluaterCallback:
  __implemented_metrics = ["iou", "acc", "confusion_matrix"]

  def __init__(self, model, metrics, callback_frequency, **kwargs):
    self.model = model
    self.metrics = [self._convert_metrics(m, **kwargs) for m in metrics]
    self.callback_frequency = callback_frequency
    self.init_values()
    self.kwargs = kwargs

  def _convert_metrics(self, m, **kwargs):
    assert m in self.__implemented_metrics, f"Invalid metric, choose from {self.__implemented_metrics}"
    _c_m_dict = {"iou": IoU(**kwargs), 
                "acc": Acc(**kwargs),
                "confusion_matrix": ConfusionMatrix(**kwargs)
    }
    return _c_m_dict[m]

  def init_values(self):
    for m in self.metrics:
      m.init_values()     

  def test(self, test_loader, verbose=False):
    self.model.eval()
    self.init_values()
    with torch.no_grad():
      for images, labels in test_loader:
        images, labels = images.cuda().float(), labels.float()
        predictions = self.model(images)
        _, class_predictions = torch.max(predictions, dim=1)
        class_predictions = class_predictions.cpu().long()
        self.update(class_predictions, labels)
    
    if verbose:
      self.print_metrics()

  def callbacks(self, scenario, freq, verbose=False):
    if freq == self.callback_frequency:
      self.test(scenario.test_stream, verbose)
      for m in self.metrics:
        m.callbacks() 

  def create_animation(self, filepath, sample_freq=1):
    for m in self.metrics:
      m.create_animation(filepath, sample_freq)
  
  def update(self, predictions, labels):
    for m in self.metrics:
      m.update(predictions, labels)

  def print_metrics(self):
    for m in self.metrics:
      m.print()
      
  def show_metrics(self):
    for m in self.metrics:
      m.show()
      
  def get_metrics(self):
    results = []
    for m in self.metrics:
      results.append(m.get_values())
    return results    

  def set_model(self, model):
    self.model = model

In [93]:
from scenarios import ContinualMnist

In [83]:
def increment_task(scenario, trainer, evaluater, memory):
  scenario.next_task()
  task_id = scenario.train_data.curr_task_id
  trainer.next_task(scenario.n_classes_per_task[:task_id+1])
  optimizer = torch.optim.Adam(lr = 0.0005, params=trainer.model.parameters())
  trainer.set_optim(optimizer)
  evaluater.set_model(trainer.model)
  trainer.set_callbacks([evaluater])
  if memory is not None:
    memory.refactor_memory(scenario.n_classes_per_task[task_id])

def meta_train(n_tasks,
               epochs, 
               scenario,
               trainer,
               evaluater,
               memory=None, 
               pass_first_step=False,
               animation_path="animation"):
  def _print_header(scenario):
    print("*******")
    print(f"Task #{t}")
    print("*******")
    print("Classes to learn:")
    print(*[c-1 for c in scenario.train_data.curr_classes])
    print("*******")
  def _print_results():
    print()
    #evaluater.print_metrics()
    res = evaluater.metrics[-1].get_results()
    print("Overall stats")
    df = pd.DataFrame([res["Overall Acc"], res["Mean Acc"], res["Mean IoU"]]).T
    df.columns = ["Overall Acc", "Mean Acc", "Mean IoU"]
    print(df.to_markdown())

    print("Class IoU")
    df = pd.DataFrame(res["Class IoU"].values()).T
    df.columns = np.arange(-1,len(res["Class IoU"].values())-1)
    print(df.to_markdown())
    #print(tabulate(df))

    print("Class Acc")
    df = pd.DataFrame(res["Class Acc"].values()).T
    df.columns = np.arange(-1,len(res["Class IoU"].values())-1)
    print(df.to_markdown())

  def _print_new_task():
    print()
    print("####################################")
    print("Next Task")
    print("####################################")

  save_memory_flag=True  
  sample_memory=False
  
  for t in range(n_tasks):
    _print_header(scenario)
    if t > 0 :
      sample_memory = True
    
    if t > 0 or not pass_first_step:
      for i in tqdm(range(epochs)):
        trainer.train(i, scenario, memory, save_memory_flag=save_memory_flag, sample_memory=sample_memory)

    _print_results()
    
    torch.save(model.state_dict(), f"checkpoints/task-{trainer.curr_task}.pth")

    if t < n_tasks:
      _print_new_task()    
      increment_task(scenario = continual_mnist, trainer = trainer, evaluater = evaluater, memory = memory)
  evaluater.create_animation(animation_path, sample_freq=4)

  

#Experiments for Fine-tuning model (naive approach)

In [51]:
optimizer.params = model.parameters()

In [91]:
# Initialize dataloader, optimizer and trainer
_tasks = {0: [0,1], 1: [2,3], 2: [4,5], 3: [6,7], 4: [8,9]}
_offline = {0: list(np.arange(10))}

model = simple_seg_model(n_classes_per_task=[len(_tasks[0])+1])
model = model.cuda()
optimizer = torch.optim.Adam(lr = 0.0005, params=model.parameters())

continual_mnist = ContinualMnist(n_train=1000, n_test=500, batch_size=72, tasks=_tasks)

evaluater = EvaluaterCallback(model, ["acc", "iou", "confusion_matrix"], callback_frequency="step", n_classes=11, save_matrices=True)

trainer = Trainer(model,
                  n_classes=[3],
                  optim=optimizer,
                  callbacks=[evaluater])

In [92]:
meta_train(n_tasks = len(_tasks),
           epochs = 100,
           scenario = continual_mnist, 
           trainer=trainer,
           evaluater=evaluater,
           memory=None,
           animation_path='test_trainer_nexttask')

*******
Task #0
*******
Classes to learn:
-1 0 1
*******


100%|██████████| 100/100 [00:56<00:00,  1.77it/s]



Overall stats
|    |   Overall Acc |   Mean Acc |   Mean IoU |
|---:|--------------:|-----------:|-----------:|
|  0 |      0.979959 |   0.852693 |   0.754658 |
Class IoU
|    |       -1 |        0 |       1 | 2   | 3   | 4   | 5   | 6   | 7   | 8   | 9   |
|---:|---------:|---------:|--------:|:----|:----|:----|:----|:----|:----|:----|:----|
|  0 | 0.980048 | 0.624867 | 0.65906 | X   | X   | X   | X   | X   | X   | X   | X   |
Class Acc
|    |       -1 |        0 |        1 | 2   | 3   | 4   | 5   | 6   | 7   | 8   | 9   |
|---:|---------:|---------:|---------:|:----|:----|:----|:----|:----|:----|:----|:----|
|  0 | 0.988778 | 0.838315 | 0.730986 | X   | X   | X   | X   | X   | X   | X   | X   |

####################################
Next Task
####################################
*******
Task #1
*******
Classes to learn:
-1 2 3
*******


100%|██████████| 100/100 [01:10<00:00,  1.42it/s]



Overall stats
|    |   Overall Acc |   Mean Acc |   Mean IoU |
|---:|--------------:|-----------:|-----------:|
|  0 |      0.957154 |   0.500813 |   0.443761 |
Class IoU
|    |       -1 |   0 |   1 |        2 |        3 | 4   | 5   | 6   | 7   | 8   | 9   |
|---:|---------:|----:|----:|---------:|---------:|:----|:----|:----|:----|:----|:----|
|  0 | 0.957342 |   0 |   0 | 0.582069 | 0.679394 | X   | X   | X   | X   | X   | X   |
Class Acc
|    |       -1 |   0 |   1 |        2 |        3 | 4   | 5   | 6   | 7   | 8   | 9   |
|---:|---------:|----:|----:|---------:|---------:|:----|:----|:----|:----|:----|:----|
|  0 | 0.994544 |   0 |   0 | 0.703431 | 0.806088 | X   | X   | X   | X   | X   | X   |

####################################
Next Task
####################################
*******
Task #2
*******
Classes to learn:
-1 4 5
*******


100%|██████████| 100/100 [01:23<00:00,  1.20it/s]



Overall stats
|    |   Overall Acc |   Mean Acc |   Mean IoU |
|---:|--------------:|-----------:|-----------:|
|  0 |      0.936264 |   0.354422 |   0.298564 |
Class IoU
|    |       -1 |   0 |   1 |   2 |   3 |        4 |        5 | 6   | 7   | 8   | 9   |
|---:|---------:|----:|----:|----:|----:|---------:|---------:|:----|:----|:----|:----|
|  0 | 0.937788 |   0 |   0 |   0 |   0 | 0.568641 | 0.583521 | X   | X   | X   | X   |
Class Acc
|    |       -1 |   0 |   1 |   2 |   3 |       4 |        5 | 6   | 7   | 8   | 9   |
|---:|---------:|----:|----:|----:|----:|--------:|---------:|:----|:----|:----|:----|
|  0 | 0.994636 |   0 |   0 |   0 |   0 | 0.75345 | 0.732866 | X   | X   | X   | X   |

####################################
Next Task
####################################
*******
Task #3
*******
Classes to learn:
-1 6 7
*******


100%|██████████| 100/100 [01:29<00:00,  1.11it/s]



Overall stats
|    |   Overall Acc |   Mean Acc |   Mean IoU |
|---:|--------------:|-----------:|-----------:|
|  0 |      0.917565 |   0.277488 |   0.235606 |
Class IoU
|    |       -1 |   0 |   1 |   2 |   3 |   4 |   5 |        6 |        7 | 8   | 9   |
|---:|---------:|----:|----:|----:|----:|----:|----:|---------:|---------:|:----|:----|
|  0 | 0.918912 |   0 |   0 |   0 |   0 |   0 |   0 | 0.626765 | 0.574778 | X   | X   |
Class Acc
|    |       -1 |   0 |   1 |   2 |   3 |   4 |   5 |        6 |        7 | 8   | 9   |
|---:|---------:|----:|----:|----:|----:|----:|----:|---------:|---------:|:----|:----|
|  0 | 0.996954 |   0 |   0 |   0 |   0 |   0 |   0 | 0.770495 | 0.729942 | X   | X   |

####################################
Next Task
####################################
*******
Task #4
*******
Classes to learn:
-1 8 9
*******


100%|██████████| 100/100 [01:33<00:00,  1.07it/s]



Overall stats
|    |   Overall Acc |   Mean Acc |   Mean IoU |
|---:|--------------:|-----------:|-----------:|
|  0 |      0.895753 |   0.230925 |   0.181956 |
Class IoU
|    |       -1 |   0 |   1 |   2 |   3 |   4 |   5 |   6 |   7 |        8 |        9 |
|---:|---------:|----:|----:|----:|----:|----:|----:|----:|----:|---------:|---------:|
|  0 | 0.899947 |   0 |   0 |   0 |   0 |   0 |   0 |   0 |   0 | 0.586972 | 0.514596 |
Class Acc
|    |       -1 |   0 |   1 |   2 |   3 |   4 |   5 |   6 |   7 |        8 |        9 |
|---:|---------:|----:|----:|----:|----:|----:|----:|----:|----:|---------:|---------:|
|  0 | 0.996632 |   0 |   0 |   0 |   0 |   0 |   0 |   0 |   0 | 0.847055 | 0.696492 |

####################################
Next Task
####################################


#Experiments with Memory model

In [40]:
# Initialize dataloader, optimizer and trainer
_tasks = {0: [0,1], 1: [2,3], 2: [4,5], 3: [6,7], 4: [8,9]}
_offline = {0: list(np.arange(10))}

model = simple_seg_model(n_classes_per_task=[len(_tasks[0])+1])
model = model.cuda()
optimizer = torch.optim.Adam(lr = 0.001, params=model.parameters())

continual_mnist = ContinualMnist(n_train=1000, n_test=200, batch_size=32, tasks=_tasks)

evaluater = EvaluaterCallback(model, ["acc", "iou", "confusion_matrix"], callback_frequency="step", n_classes=11, save_matrices=True)

mem = memory(images_shape=(1,60,60),
             masks_shape=(60,60),
             n_classes = 2,
             batch_size=32,
             memory_size=1000)

trainer = Trainer(model,
                  n_classes=[3],
                  optim=optimizer,
                  callbacks=[evaluater])

In [41]:
meta_train(n_tasks = len(_tasks),
           epochs = 100,
           scenario = continual_mnist, 
           trainer=trainer,
           evaluater=evaluater,
           memory=mem,
           animation_path="memory_100e_32b_32bm")

*******
Task #0
*******
Classes to learn:
-1 0 1
*******


100%|██████████| 100/100 [01:08<00:00,  1.46it/s]



Overall stats
|    |   Overall Acc |   Mean Acc |   Mean IoU |
|---:|--------------:|-----------:|-----------:|
|  0 |      0.987487 |   0.915318 |   0.832932 |
Class IoU
|    |       -1 |        0 |       1 | 2   | 3   | 4   | 5   | 6   | 7   | 8   | 9   |
|---:|---------:|---------:|--------:|:----|:----|:----|:----|:----|:----|:----|:----|
|  0 | 0.987403 | 0.775584 | 0.73581 | X   | X   | X   | X   | X   | X   | X   | X   |
Class Acc
|    |       -1 |        0 |        1 | 2   | 3   | 4   | 5   | 6   | 7   | 8   | 9   |
|---:|---------:|---------:|---------:|:----|:----|:----|:----|:----|:----|:----|:----|
|  0 | 0.993115 | 0.868223 | 0.884615 | X   | X   | X   | X   | X   | X   | X   | X   |

####################################
Next Task
####################################
*******
Task #1
*******
Classes to learn:
-1 2 3
*******


100%|██████████| 100/100 [01:40<00:00,  1.00s/it]



Overall stats
|    |   Overall Acc |   Mean Acc |   Mean IoU |
|---:|--------------:|-----------:|-----------:|
|  0 |      0.972908 |     0.7203 |   0.679482 |
Class IoU
|    |       -1 |        0 |        1 |        2 |        3 | 4   | 5   | 6   | 7   | 8   | 9   |
|---:|---------:|---------:|---------:|---------:|---------:|:----|:----|:----|:----|:----|:----|
|  0 | 0.972832 | 0.789231 | 0.688654 | 0.455749 | 0.490946 | X   | X   | X   | X   | X   | X   |
Class Acc
|    |       -1 |        0 |        1 |       2 |        3 | 4   | 5   | 6   | 7   | 8   | 9   |
|---:|---------:|---------:|---------:|--------:|---------:|:----|:----|:----|:----|:----|:----|
|  0 | 0.995598 | 0.877162 | 0.717277 | 0.48735 | 0.524111 | X   | X   | X   | X   | X   | X   |

####################################
Next Task
####################################
*******
Task #2
*******
Classes to learn:
-1 4 5
*******


100%|██████████| 100/100 [01:45<00:00,  1.05s/it]



Overall stats
|    |   Overall Acc |   Mean Acc |   Mean IoU |
|---:|--------------:|-----------:|-----------:|
|  0 |      0.946849 |   0.474389 |   0.443065 |
Class IoU
|    |       -1 |        0 |       1 |        2 |        3 |       4 |        5 | 6   | 7   | 8   | 9   |
|---:|---------:|---------:|--------:|---------:|---------:|--------:|---------:|:----|:----|:----|:----|
|  0 | 0.946967 | 0.324364 | 0.27187 | 0.114358 | 0.318144 | 0.55003 | 0.575723 | X   | X   | X   | X   |
Class Acc
|    |      -1 |        0 |        1 |        2 |       3 |        4 |        5 | 6   | 7   | 8   | 9   |
|---:|--------:|---------:|---------:|---------:|--------:|---------:|---------:|:----|:----|:----|:----|
|  0 | 0.99759 | 0.326297 | 0.272798 | 0.116918 | 0.33294 | 0.623496 | 0.650683 | X   | X   | X   | X   |

####################################
Next Task
####################################
*******
Task #3
*******
Classes to learn:
-1 6 7
*******


100%|██████████| 100/100 [02:04<00:00,  1.24s/it]



Overall stats
|    |   Overall Acc |   Mean Acc |   Mean IoU |
|---:|--------------:|-----------:|-----------:|
|  0 |      0.920554 |   0.293095 |   0.268358 |
Class IoU
|    |       -1 |         0 |         1 |        2 |          3 |         4 |         5 |        6 |        7 | 8   | 9   |
|---:|---------:|----------:|----------:|---------:|-----------:|----------:|----------:|---------:|---------:|:----|:----|
|  0 | 0.921103 | 0.0461882 | 0.0120949 | 0.102899 | 0.00901675 | 0.0114263 | 0.0869496 | 0.669954 | 0.555595 | X   | X   |
Class Acc
|    |       -1 |         0 |         1 |        2 |          3 |         4 |        5 |       6 |        7 | 8   | 9   |
|---:|---------:|----------:|----------:|---------:|-----------:|----------:|---------:|--------:|---------:|:----|:----|
|  0 | 0.998589 | 0.0461882 | 0.0120968 | 0.104945 | 0.00902159 | 0.0114911 | 0.088264 | 0.73129 | 0.635971 | X   | X   |

####################################
Next Task
################################

100%|██████████| 100/100 [02:07<00:00,  1.27s/it]



Overall stats
|    |   Overall Acc |   Mean Acc |   Mean IoU |
|---:|--------------:|-----------:|-----------:|
|  0 |      0.889222 |   0.196276 |   0.179315 |
Class IoU
|    |       -1 |         0 |         1 |   2 |   3 |   4 |         5 |           6 |           7 |        8 |        9 |
|---:|---------:|----------:|----------:|----:|----:|----:|----------:|------------:|------------:|---------:|---------:|
|  0 | 0.888537 | 0.0221469 | 0.0390153 |   0 |   0 |   0 | 0.0104031 | 0.000205023 | 0.000201066 | 0.514574 | 0.497388 |
Class Acc
|    |       -1 |         0 |         1 |   2 |   3 |   4 |         5 |           6 |           7 |        8 |        9 |
|---:|---------:|----------:|----------:|----:|----:|----:|----------:|------------:|------------:|---------:|---------:|
|  0 | 0.998985 | 0.0221836 | 0.0390819 |   0 |   0 |   0 | 0.0104031 | 0.000205044 | 0.000201066 | 0.538153 | 0.549827 |

####################################
Next Task
####################################


In [None]:
evaluater.print_metrics()

In [None]:
increment_task(scenario = continual_mnist, trainer = trainer, evaluater = evaluater)

for i in tqdm(range(50)):
  trainer.train(i, continual_mnist)

torch.save(model.state_dict(), f"checkpoints/task-{trainer.curr_task}.pth")
evaluater.metrics[0].create_animation("conf_task4")
Video("conf_task4.mp4", embed=True)

100%|██████████| 50/50 [00:24<00:00,  2.00it/s]


In [None]:
train_data.next_task()
train_loader = data.DataLoader(train_data, batch_size=36)
trainer.next_task(new_classes = [3,2,2])#,2,2,2])
test_data.next_task()
test_loader = data.DataLoader(test_data, batch_size=8)
optimizer = torch.optim.Adam(lr = 0.0005, params=trainer.model.parameters())
trainer.optim = optimizer
metrics_manager = MetricsManager(["confusion_matrix"], n_classes=7)
evaluater = Evaluater(trainer.model, metrics_manager)

for i in range(100):
  trainer.train(0, train_loader)
evaluater.test(test_loader)
torch.save(model.state_dict(), f"checkpoints/task-{trainer.curr_task}.pth")

In [None]:
plt.imshow(metrics_manager.metrics[0].confusion_matrix/(np.sum(metrics_manager.metrics[0].confusion_matrix, axis=1, keepdims=True)+1e-6), cmap="jet")

In [None]:
train_data.next_task()
train_loader = data.DataLoader(train_data, batch_size=36)
trainer.next_task(new_classes = [3,2,2,2])#,2,2,2])
test_data.next_task()
test_loader = data.DataLoader(test_data, batch_size=8)
optimizer = torch.optim.Adam(lr = 0.0005, params=trainer.model.parameters())
trainer.optim = optimizer
metrics_manager = MetricsManager(["confusion_matrix"], n_classes=9)
evaluater = Evaluater(trainer.model, metrics_manager)

for i in range(100):
  trainer.train(0, train_loader)
evaluater.test(test_loader)
torch.save(model.state_dict(), f"checkpoints/task-{trainer.curr_task}.pth")

In [None]:
plt.imshow(metrics_manager.metrics[0].confusion_matrix/np.sum(metrics_manager.metrics[0].confusion_matrix, axis=1, keepdims=True), cmap="jet")

In [None]:
trainer.model.eval()
for j, (images, labels) in enumerate(test_loader):
  x = trainer.model(images.cuda().float())
  plt.imshow(images[0][0])
  plt.show()
  _, pred = torch.max(x, dim=1)
  print(np.unique(pred[0].detach().cpu().numpy()))
  plt.imshow(pred[0].detach().cpu(), vmax=4))
  plt.show()
  print("****")
  if j == 15:
    break