In [1]:
# Download the files
!wget https://github.com/lacykaltgr/continual-learning-ait/archive/refs/heads/main.zip
!unzip main.zip
!find continual-learning-ait-main -type f ! -name "main.ipynb" ! -name "main.py" -exec cp {} . \;

--2023-04-11 14:57:07--  https://github.com/lacykaltgr/continual-learning-ait/archive/refs/heads/main.zip
Resolving github.com (github.com)... 140.82.113.4
Connecting to github.com (github.com)|140.82.113.4|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://codeload.github.com/lacykaltgr/continual-learning-ait/zip/refs/heads/main [following]
--2023-04-11 14:57:07--  https://codeload.github.com/lacykaltgr/continual-learning-ait/zip/refs/heads/main
Resolving codeload.github.com (codeload.github.com)... 140.82.112.10
Connecting to codeload.github.com (codeload.github.com)|140.82.112.10|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: unspecified [application/zip]
Saving to: ‘main.zip’

main.zip                [ <=>                ] 154.08K  --.-KB/s    in 0.05s   

2023-04-11 14:57:07 (3.06 MB/s) - ‘main.zip’ saved [157773]

Archive:  main.zip
9d267e6b95e57ad02d2e58c6e1d1f16dcffdefcb
   creating: continual-learning-ait-main/
  i

In [2]:
import tensorflow as tf
from data_preparation import load_dataset
from data_preparation import CLDataLoader

In [3]:
dpt_train, dpt_test = load_dataset('cifar-100')

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz


In [4]:
train_loader = CLDataLoader(dpt_train, 64, train=True)
test_loader = CLDataLoader(dpt_train, 64, train=False)

In [6]:
import os

result_dir = 'results/'

result_path = os.path.join(result_dir)
if not os.path.exists(result_path): os.mkdir(result_path)
sample_path = os.path.join(*[result_dir, 'samples'])
if not os.path.exists(sample_path): os.mkdir(sample_path)
recon_path = os.path.join(*[result_dir, 'reconstructions'])
if not os.path.exists(recon_path): os.mkdir(recon_path)
mir_path = os.path.join(*[result_dir, 'mir'])
if not os.path.exists(mir_path): os.mkdir(mir_path)

In [None]:
params = {
    "device": 'cuda:0' if tf.test.is_gpu_available() else 'cpu',
    "n_runs": 1,
    "n_tasks": 10,
    "n_epochs": 100,
    "n_classes": 10,
    "input_size": (3,32,32),
    #"samples_per_task": 100,
    "batch_size": 64,

    "gen_depth": 6,
    "cls_mir_gen": 1,
    "gen_mir_gen": 1,
    "cls_iters": 100,
    "gen_iters": 10,

    "lr": 0.001,
    "warmup": 0,
    "max_beta": 1,
    "reuse_samples": True,
    "print_every": 5,
    "mem_coeff": 0.12
}

In [None]:
class Agent:
  def __init__(self, params):
    self.params = params
    self.state = dict()
    self.prev_cls = None
    self.prev_gen = None

  def set_gen(self, gen):
    self.gen = gen
    self.opt_gen = tf.keras.optimizers.Adam(gen.parameters())
  
  def set_cls(self, cls):
    self.cls = cls
    self.opt = tf.keras.optimizers.SGD(cls.parameters(), lr=params["lr"])

In [None]:
import mir
from utils import get_logger, get_temp_logger, logging_per_task, naive_cross_entropy_loss
from classifier import ResNet18, classifier
from loss import calculate_loss
from stable_diffusion.stable_diffusion import StableDiffusion

In [None]:
def train_generator(agent):
  data = agent.state["data"]
  beta = agent.state["beta"]
  task = agent.state["task"]

  for it in range(agent.params["gen_iters"]):

      x_mean, z_mu, z_var, ldj, z0, zk = agent.gen(data)
      gen_loss, rec, kl, _ = calculate_loss(x_mean, data, z_mu, z_var, z0, zk, ldj, args, beta=beta)

      tot_gen_loss = 0 + gen_loss

      if task > 0:

          if it == 0 or not agent.params["reuse_samples"]:
              mem_x, mir_worked = mir.retrieve_gen_for_gen(args, data, agent.gen, agent.prev_gen, agent.prev_cls)

              mir_tries += 1
              if mir_worked:
                  mir_success += 1
                  # keep for logging later
                  gen_x, gen_mem_x = data, mem_x

          mem_x_mean, z_mu, z_var, ldj, z0, zk = agent.gen(mem_x)
          mem_gen_loss, mem_rec, mem_kl, _ = calculate_loss(mem_x_mean, mem_x, z_mu, z_var, z0, zk, ldj, args, beta=beta)

          tot_gen_loss += agent.params["mem_coeff"] * mem_gen_loss

      agent.opt_gen.zero_grad()
      tot_gen_loss.backward()
      agent.opt_gen.step()

  # End Generator Iteration Loop
  #------------------------------

  if i % agent.params["print_every"] == 0:
      print(f'current VAE loss = {gen_loss.item():.4f} (rec: {rec.item():.4f} + beta: {beta:.2f} * kl: {kl.item():.2f}')
      if task > 0:
          print(f'memory VAE loss = {mem_gen_loss.item():.4f} (rec: { mem_rec.item():.4f} + beta: {beta:.2f} * kl: {mem_kl.item():.2f})')


In [None]:
def train_classifier(agent):
  data = agent.state["data"]
  target = agent.state["target"]
  beta = agent.state["beta"]
  task = agent.state["task"]

  for it in range(agent.params["cls_iters"]):

      logits = agent.cls(data)
      loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='mean')
      cls_loss = loss_fn(target, logits)
      tot_cls_loss = 0 + cls_loss

      if task > 0:

          if it == 0 or not agent.params["reuse_samples"]:
              mem_x, mem_y, mir_worked = mir.retrieve_gen_for_cls(args, data, agent.cls, agent.prev_cls, agent.prev_gen)
              mir_tries += 1
              if mir_worked:
                  mir_success += 1
                  # keep for logging later
                  cls_x, cls_mem_x = data, mem_x

          mem_logits = agent.cls(mem_x)

          mem_cls_loss = naive_cross_entropy_loss(mem_logits, mem_y)

          tot_cls_loss += agent.params["mem_coeff"] * mem_cls_loss

      agent.opt.zero_grad()
      tot_cls_loss.backward()
      agent.opt.step()

  # End Classifer Iteration Loop
  #-----------------------------

  if i % agent.params["print_every"] == 0:
      pred = logits.argmax(dim=1, keepdim=True)
      acc = pred.eq(target.view_as(pred)).sum().item() / pred.size(0)
      print(f'current training accuracy: {acc:.2f}')
      if task > 0:
          pred = mem_logits.argmax(dim=1, keepdim=True)
          mem_y = mem_y.argmax(dim=1, keepdim=True)
          acc = pred.eq(mem_y.view_as(pred)).sum().item() / pred.size(0)
          print(f'memory training accuracy: {acc:.2f}')