In [1]:
import random
import argparse
import numpy as np
import datetime as dt

import torch
import torchvision
import avalanche
import torch.backends.cudnn as cudnn

from avalanche.logging import InteractiveLogger, TensorboardLogger
from avalanche.evaluation.metrics import ExperienceAccuracy, EpochAccuracy, StreamAccuracy

from avalanche.training.plugins import ReplayPlugin, EvaluationPlugin
from avalanche.training.storage_policy import ClassBalancedBuffer
from avalanche.training.storage_policy import ReservoirSamplingBuffer

### Argument

In [2]:
parser = argparse.ArgumentParser()
parser.add_argument('--num_class', type=int, default=10)
parser.add_argument('--device', type=str, default='0')
parser.add_argument('--lr', '--learning_rate', type=float, default=0.001)
parser.add_argument('--alpha', type=float, default=1.)
parser.add_argument('--temperature', type=float, default=2.)
parser.add_argument('--train_mb', type=int, default=512)
parser.add_argument('--eval_mb', type=int, default=256)
parser.add_argument('--epoch', type=int, default=10)
parser.add_argument('--memory_size', type=int, default=500)
parser.add_argument('--buffer_weights', type=float, default=0, help='random uniform value in [0, 1]')

args = parser.parse_args(args=[])

### Set Seed & Device

In [3]:
seed = 0
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
cudnn.enabled = False
cudnn.deterministic = True

device = torch.device('cuda:' + args.device if torch.cuda.is_available() else 'cpu')
args.device = torch.device(device)

### Load Dataset

In [4]:
#cifar10 train and eval used default transform
benchmark = avalanche.benchmarks.SplitCIFAR10(n_experiences=10,
                                              return_task_id=False,
                                              seed=0,
                                              shuffle=True)

Files already downloaded and verified
Files already downloaded and verified


### Evaulation

In [5]:
date = dt.datetime.now()
date = date.strftime("%Y_%m_%d_%H_%M_%S")

interactive_logger = InteractiveLogger()
tensor_logger = TensorboardLogger("logs/" + date)

eval_plugin = EvaluationPlugin(
    EpochAccuracy(),
    ExperienceAccuracy(),
    StreamAccuracy(),
    loggers=[interactive_logger, tensor_logger]
)

### Backbone Model

In [6]:
model = torchvision.models.resnet18(pretrained=False, num_classes=args.num_class)
model.to(args.device)

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)

# storage_policy = ClassBalancedBuffer(args.memory_size, adaptive_size=True)
storage_policy = ReservoirSamplingBuffer(args.buffer_weights) # init = 0
replay_plugin = ReplayPlugin(args.memory_size, storage_policy=storage_policy)



In [8]:
cl_strategy = avalanche.training.LwF(
    model=model,
    optimizer=optimizer,
    criterion=criterion,
    alpha=args.alpha,
    temperature=args.temperature,
    train_mb_size=args.train_mb,
    train_epochs=args.epoch,
    eval_mb_size=args.eval_mb,
    device=args.device,
    plugins=[replay_plugin],
    evaluator=eval_plugin)

In [9]:
res = None
for experience in benchmark.train_stream:
    cl_strategy.train(experience)
    res = cl_strategy.eval(benchmark.test_stream)

-- >> Start of training phase << --
0it [00:00, ?it/s]

100%|██████████| 10/10 [00:06<00:00,  1.64it/s]
Epoch 0 ended.
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.4822
100%|██████████| 10/10 [00:05<00:00,  1.89it/s]
Epoch 1 ended.
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.9976
100%|██████████| 10/10 [00:05<00:00,  1.90it/s]
Epoch 2 ended.
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 1.0000
100%|██████████| 10/10 [00:05<00:00,  1.87it/s]
Epoch 3 ended.
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 1.0000
100%|██████████| 10/10 [00:05<00:00,  1.88it/s]
Epoch 4 ended.
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 1.0000
100%|██████████| 10/10 [00:05<00:00,  1.89it/s]
Epoch 5 ended.
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 1.0000
100%|██████████| 10/10 [00:05<00:00,  1.89it/s]
Epoch 6 ended.
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 1.0000
100%|██████████| 10/10 [00:05<00:00,  1.90it/s]
Epoch 7 ended.
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 1.0000
100%|██████████| 10/10 [00:05<00