In [None]:
# Download the files
!wget https://github.com/lacykaltgr/continual-learning-ait/archive/refs/heads/main.zip
!unzip main.zip
!find continual-learning-ait-main -type f ! -name "main.ipynb" ! -name "main.py" -exec cp {} . \;

--2023-04-11 14:57:07--  https://github.com/lacykaltgr/continual-learning-ait/archive/refs/heads/main.zip
Resolving github.com (github.com)... 140.82.113.4
Connecting to github.com (github.com)|140.82.113.4|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://codeload.github.com/lacykaltgr/continual-learning-ait/zip/refs/heads/main [following]
--2023-04-11 14:57:07--  https://codeload.github.com/lacykaltgr/continual-learning-ait/zip/refs/heads/main
Resolving codeload.github.com (codeload.github.com)... 140.82.112.10
Connecting to codeload.github.com (codeload.github.com)|140.82.112.10|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: unspecified [application/zip]
Saving to: ‘main.zip’

main.zip                [ <=>                ] 154.08K  --.-KB/s    in 0.05s   

2023-04-11 14:57:07 (3.06 MB/s) - ‘main.zip’ saved [157773]

Archive:  main.zip
9d267e6b95e57ad02d2e58c6e1d1f16dcffdefcb
   creating: continual-learning-ait-main/
  i

In [None]:
import tensorflow as tf
import numpy as np
from data_preparation import load_dataset
from data_preparation import CLDataLoader

In [None]:
dpt_train, dpt_test = load_dataset('cifar-100')

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz


In [None]:
train_loader = CLDataLoader(dpt_train, 64, train=True)
test_loader = CLDataLoader(dpt_train, 64, train=False)

In [None]:
import os

result_dir = 'results/'

result_path = os.path.join(result_dir)
if not os.path.exists(result_path): os.mkdir(result_path)
sample_path = os.path.join(*[result_dir, 'samples'])
if not os.path.exists(sample_path): os.mkdir(sample_path)
recon_path = os.path.join(*[result_dir, 'reconstructions'])
if not os.path.exists(recon_path): os.mkdir(recon_path)
mir_path = os.path.join(*[result_dir, 'mir'])
if not os.path.exists(mir_path): os.mkdir(mir_path)

In [None]:
params = {
    "device": 'cuda:0' if tf.test.is_gpu_available() else 'cpu',
    "n_runs": 1,
    "n_tasks": 10,
    "n_epochs": 100,
    "n_classes": 10,
    "input_size": (3,32,32),
    #"samples_per_task": 100,
    "batch_size": 64,

    "gen_depth": 6,
    "cls_mir_gen": 1,
    "gen_mir_gen": 1,
    "cls_iters": 100,
    "gen_iters": 10,
    "loss_fn": "mse",

    "lr": 0.001,
    "warmup": 0,
    "max_beta": 1,
    "reuse_samples": True,
    "print_every": 5,
    "mem_coeff": 0.12,
    "n_mem": 10,
    "mir_init_prior": 10,
    "z_size": 10,
    "mir_iters": 100,
    "gen_kl_coeff": 0.5,
    "gen_rec_coeff": 0.5,
    "gen_ent_coeff": 0.5,
    "gen_div_coeff": 0.5,
    "gen_shell_coeff": 0.5
}

In [None]:
class Agent:
  def __init__(self, params):
    self.params = params
    self.state = dict()
    self.prev_cls = None
    self.prev_gen = None

  def set_gen(self, gen):
    self.gen = gen
    self.opt_gen = tf.keras.optimizers.Adam(gen.parameters())
  
  def set_cls(self, cls):
    self.cls = cls
    self.opt = tf.keras.optimizers.SGD(cls.parameters(), lr=params["lr"])

In [None]:
import mir
from utils import get_logger, get_temp_logger, logging_per_task, naive_cross_entropy_loss
from classifier import ResNet18, classifier
from loss import calculate_loss
from stable_diffusion.stable_diffusion import StableDiffusion

In [None]:
def train_generator(agent):
  data = agent.state["data"]
  beta = agent.state["beta"]
  task = agent.state["task"]

  for it in range(agent.params["gen_iters"]):

      x_mean, z_mu, z_var, ldj, z0, zk = agent.gen(data)
      gen_loss, rec, kl, _ = calculate_loss(x_mean, data, z_mu, z_var, z0, zk, ldj, agent.params["input_size"], agent.params["loss_fn"], beta=beta)

      tot_gen_loss = 0 + gen_loss

      if task > 0:

          if it == 0 or not agent.params["reuse_samples"]:
              mem_x, mir_worked = mir.retrieve_gen_for_gen(agent.params, data, agent.gen, agent.prev_gen, agent.prev_cls)

              agent.state["mir_tries"] += 1
              if mir_worked:
                  agent.state["mir_success"] += 1
                  # keep for logging later
                  gen_x, gen_mem_x = data, mem_x

          mem_x_mean, z_mu, z_var, ldj, z0, zk = agent.gen(mem_x)
          mem_gen_loss, mem_rec, mem_kl, _ = calculate_loss(mem_x_mean, mem_x, z_mu, z_var, z0, zk, ldj, agent.params["input_size"], agent.params["loss_fn"], beta=beta)

          tot_gen_loss += agent.params["mem_coeff"] * mem_gen_loss

      agent.opt_gen.zero_grad()
      tot_gen_loss.backward()
      agent.opt_gen.step()

      

  if agent.state["i_example"] % agent.params["print_every"] == 0:
      print(f'current VAE loss = {gen_loss.item():.4f} (rec: {rec.item():.4f} + beta: {beta:.2f} * kl: {kl.item():.2f}')
      if task > 0:
          print(f'memory VAE loss = {mem_gen_loss.item():.4f} (rec: { mem_rec.item():.4f} + beta: {beta:.2f} * kl: {mem_kl.item():.2f})')


In [None]:
def train_classifier(agent):
  data = agent.state["data"]
  target = agent.state["target"]
  beta = agent.state["beta"]
  task = agent.state["task"]

  for it in range(agent.params["cls_iters"]):

      logits = agent.cls(data)
      loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='mean')
      cls_loss = loss_fn(target, logits)
      tot_cls_loss = 0 + cls_loss

      if task > 0:

          if it == 0 or not agent.params["reuse_samples"]:
              mem_x, mem_y, mir_worked = mir.retrieve_gen_for_cls(agent.params, data, agent.cls, agent.prev_cls, agent.prev_gen)
              mir_tries += 1
              if mir_worked:
                  mir_success += 1
                  # keep for logging later
                  cls_x, cls_mem_x = data, mem_x

          mem_logits = agent.cls(mem_x)

          mem_cls_loss = naive_cross_entropy_loss(mem_logits, mem_y)

          tot_cls_loss += agent.params["mem_coeff"] * mem_cls_loss

      agent.opt.zero_grad()
      tot_cls_loss.backward()
      agent.opt.step()



  if agent.state["i_example"] % agent.params["print_every"] == 0:
      pred = logits.argmax(dim=1, keepdim=True)
      acc = pred.eq(target.view_as(pred)).sum().item() / pred.size(0)
      print(f'current training accuracy: {acc:.2f}')
      if agent.state["task"] > 0:
          pred = mem_logits.argmax(dim=1, keepdim=True)
          mem_y = mem_y.argmax(dim=1, keepdim=True)
          acc = pred.eq(mem_y.view_as(pred)).sum().item() / pred.size(0)
          print(f'memory training accuracy: {acc:.2f}')

In [None]:
def run_epoch(agent):
  for i, (data, target) in enumerate(agent.state["tr_loader"]):

    if agent.state["sample_amt"] > agent.params["samples_per_task"] > 0: break
    agent.state["sample_amt"] += data.size(0)

    agent.state["data"] = data.to(agent.params["device"])
    agent.state["target"] = target.to(agent.params["device"])
    agent.state["i_example"] = i

    agent.state["beta"] = min([(agent.state["sample_amt"]) / max([agent.params["warmup"], 1.]), agent.params["max_beta"]])

    train_generator(agent)
    train_classifier(agent)

In [None]:
from copy import deepcopy

def run_task(agent):
  agent.cls = agent.cls.train()
  agent.gen = agent.gen.train()

  agent.state["sample_amt"] = 0

  for epoch in range(agent.params["n_epochs"]):
    run_epoch(agent)


  #evaluation
  with tf.GradientTape(persistent=True) as tape:

    agent.cls = agent.cls.eval()
    agent.prev_cls = deepcopy(agent.cls)

    agent.gen = agent.gen.eval()
    agent.prev_gen = deepcopy(agent.gen)


    '''

    # save some training reconstructions:
    recon_path_ = os.path.join(recon_path, f'task{agent.state["task"]}.png')
    # TODO: data
    recons = tf.concat([data.to('cpu'), x_mean.to('cpu')])
    save_image(recons, recon_path_, nrow=agent.params["batch_size"])

    # save some pretty images:
    gen_images = agent.gen.generate(25).to('cpu')
    sample_path_ = os.path.join(sample_path, f'task{agent.state["task"]}.png')
    save_image(gen_images, sample_path_, nrow=5)


    # save some MIR samples:
    if agent.state["task"] > 0:
        mir_images = tf.concat([cls_x.to('cpu'), cls_mem_x.to('cpu')])
        mir_path_ = os.path.join(mir_path, f'cls_task{task}.png')
        save_image(mir_images, mir_path_, nrow=10)

        mir_images = tf.concat([gen_x.to('cpu'), gen_mem_x.to('cpu')])
        mir_path_ = os.path.join(mir_path, f'gen_task{task}.png')
        save_image(mir_images, mir_path_, nrow=10)

    '''

    for task_t, te_loader in enumerate(test_loader):
        if task_t > agent.state["task"]: break
        LOG_temp = get_temp_logger(None, ['gen_loss', 'cls_loss', 'acc'])

 
        for i, (data, target) in enumerate(test_loader):

            data, target = data.to(agent.params["device"]), target.to(agent.params["device"])

            logits = agent.cls(data)

            loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='mean')(target, logits)
            pred = logits.argmax(dim=1, keepdim=True)

            LOG_temp['acc'] += [pred.eq(target.view_as(pred)).sum().item() / pred.size(0)]
            LOG_temp['cls_loss'] += [loss.item()]

            x_mean, z_mu, z_var, ldj, z0, zk = agent.gen(data)
            gen_loss, rec, kl, bpd = calculate_loss(x_mean, data, z_mu, z_var, z0,
                    zk, ldj, agent.params["input_size"], agent.params["loss_fn"], beta=agent.state["beta"])
            LOG_temp['gen_loss'] += [gen_loss.item()]


    '''

        logging_per_task(wandb, LOG, run, mode, 'acc', task, task_t,
                  np.round(np.mean(LOG_temp['acc']), 2))
        logging_per_task(wandb, LOG, run, mode, 'cls_loss', task, task_t,
                  np.round(np.mean(LOG_temp['cls_loss']), 2))
        logging_per_task(wandb, LOG, run, mode, 'gen_loss', task, task_t,
                  np.round(np.mean(LOG_temp['gen_loss']), 2))

    print(f'\n{mode}:')
    print(LOG[run][mode]['acc'])

    '''

In [None]:
def run(agent):
  agent.state["mir_tries"], agent.state["mir_success"] = 0, 0

  agent.set_cls(ResNet18(agent.params["n_classes"], nf=20, input_size=agent.params["input_size"]))
  agent.set_gen(StableDiffusion())

  for task, tr_loader in enumerate(train_loader):
    agent.state["task"] = task
    agent.state["tr_loader"] = tr_loader
    run_task(agent)

'''

  # accuracy
  final_accs = LOG[agent.state["run"]]['acc'][:, agent.state["task"]]
  logging_per_task(wandb, LOG, run, mode, 'final_acc', task, value=np.round(np.mean(final_accs),2))

  # forgetting
  best_acc = np.max(LOG[run][mode]['acc'], 1)
  final_forgets = best_acc - LOG[run][mode]['acc'][:, task]
  logging_per_task(wandb, LOG, run, mode, 'final_forget', task, value=np.round(np.mean(final_forgets[:-1]),2))

  # VAE loss
  final_elbos = LOG[run][mode]['gen_loss'][:, task]
  logging_per_task(wandb, LOG, run, mode, 'final_elbo', task, value=np.round(np.mean(final_elbos), 2))

  print(f'\n{mode}:')
  print(f'final accuracy: {final_accs}')
  print(f'average: {LOG[run][mode]["final_acc"]}')
  print(f'final forgetting: {final_forgets}')
  print(f'average: {LOG[run][mode]["final_forget"]}')
  print(f'final VAE loss: {final_elbos}')
  print(f'average: {LOG[run][mode]["final_elbo"]}\n')

  try:
      mir_worked_frac = mir_success/ (mir_tries)
      logging_per_task(wandb, LOG, run, mode, 'final_mir_worked_frac', task, mir_worked_frac)
      print('mir worked \n', mir_worked_frac)
  except:
      pass

'''

In [None]:
agent = Agent(params)
for r in range(agent.params["n_runs"]):
  agent.state["run"] = r
  run(agent)

In [None]:
'''
n_runs = agent.params["n_runs"]

final_accs = [LOG[x]['final_acc'] for x in range(n_runs)]
final_acc_avg = np.mean(final_accs)
final_acc_se = np.std(final_accs) / np.sqrt(n_runs)

# forgetting
final_forgets = [LOG[x]['final_forget'] for x in range(n_runs)]
final_forget_avg = np.mean(final_forgets)
final_forget_se = np.std(final_forgets) / np.sqrt(n_runs)

# VAE loss
final_elbos = [LOG[x]['final_elbo'] for x in range(n_runs)]
final_elbo_avg = np.mean(final_elbos)
final_elbo_se = np.std(final_elbos) / np.sqrt(n_runs)

# MIR worked
try:
    final_mir_worked_frac = [LOG[x]['final_mir_worked_frac'] for x in range(n_runs)]
    final_mir_worked_avg = np.mean(final_mir_worked_frac)
except:
    pass

print(f'\nFinal Accuracy: {final_acc_avg:.3f} +/- {final_acc_se:.3f}')
print(f'\nFinal Forget: {final_forget_avg:.3f} +/- {final_forget_se:.3f}')
print(f'\nFinal ELBO: {final_elbo_avg:.3f} +/- {final_elbo_se:.3f}')
'''