In [1]:
from torch.optim import SGD
from torch.nn import CrossEntropyLoss
from avalanche.benchmarks.classic import SplitMNIST

from avalanche.evaluation.metrics import forgetting_metrics, \
accuracy_metrics, loss_metrics, timing_metrics, cpu_usage_metrics, \
confusion_matrix_metrics, disk_usage_metrics, bwt_metrics
from avalanche.models import SimpleMLP
from avalanche.logging import InteractiveLogger, TextLogger, TensorboardLogger, WandBLogger
from avalanche.training.plugins import EvaluationPlugin
from avalanche.training import Naive,EWC
from torchvision.transforms import Compose, ToTensor, Normalize, RandomCrop, Resize
import numpy as np
import random
import wandb
import timm
import torch
import avalanche

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
seed = 2023
np.random.seed(seed)
torch.manual_seed(seed)
random.seed(seed)
torch.cuda.manual_seed(seed)


In [3]:

train_transform = Compose([
    Resize(224),
    ToTensor(),
])

test_transform = Compose([
    Resize(224),
    ToTensor(),
])


benchmark = avalanche.benchmarks.SplitCIFAR100(n_experiences=10, return_task_id=False, seed=seed, 
                       train_transform=train_transform,eval_transform = test_transform
                      )

Files already downloaded and verified
Files already downloaded and verified


In [9]:
benchmark.original_classes_in_exp

[{5, 9, 23, 37, 39, 48, 58, 62, 86, 96},
 {7, 18, 19, 28, 33, 41, 67, 69, 78, 99},
 {17, 29, 36, 44, 47, 66, 74, 80, 83, 88},
 {6, 20, 21, 27, 43, 51, 81, 82, 85, 90},
 {1, 4, 14, 63, 65, 73, 84, 91, 93, 95},
 {2, 10, 12, 32, 53, 54, 59, 75, 89, 92},
 {31, 34, 35, 40, 42, 45, 55, 64, 97, 98},
 {0, 8, 24, 38, 49, 52, 68, 71, 79, 87},
 {15, 16, 22, 25, 30, 50, 56, 57, 61, 94},
 {3, 11, 13, 26, 46, 60, 70, 72, 76, 77}]

In [8]:
benchmark.get_reproducibility_data()

{'class_ids_from_zero_from_first_exp': False,
 'class_ids_from_zero_in_each_exp': False,
 'class_mapping': [0,
  1,
  2,
  3,
  4,
  5,
  6,
  7,
  8,
  9,
  10,
  11,
  12,
  13,
  14,
  15,
  16,
  17,
  18,
  19,
  20,
  21,
  22,
  23,
  24,
  25,
  26,
  27,
  28,
  29,
  30,
  31,
  32,
  33,
  34,
  35,
  36,
  37,
  38,
  39,
  40,
  41,
  42,
  43,
  44,
  45,
  46,
  47,
  48,
  49,
  50,
  51,
  52,
  53,
  54,
  55,
  56,
  57,
  58,
  59,
  60,
  61,
  62,
  63,
  64,
  65,
  66,
  67,
  68,
  69,
  70,
  71,
  72,
  73,
  74,
  75,
  76,
  77,
  78,
  79,
  80,
  81,
  82,
  83,
  84,
  85,
  86,
  87,
  88,
  89,
  90,
  91,
  92,
  93,
  94,
  95,
  96,
  97,
  98,
  99],
 'classes_order': [39,
  48,
  62,
  58,
  37,
  9,
  5,
  86,
  96,
  23,
  41,
  69,
  18,
  7,
  67,
  78,
  99,
  28,
  19,
  33,
  17,
  36,
  29,
  80,
  88,
  44,
  66,
  47,
  83,
  74,
  51,
  81,
  90,
  20,
  82,
  27,
  85,
  6,
  43,
  21,
  95,
  14,
  84,
  1,
  73,
  91,
  4,
  93,
  63

In [10]:
model = timm.models.vit_tiny_patch16_224(pretrained=True, num_classes=benchmark.n_classes)
model = torch.nn.DataParallel(model)

In [11]:
model

DataParallel(
  (module): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 192, kernel_size=(16, 16), stride=(16, 16))
      (norm): Identity()
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (blocks): Sequential(
      (0): Block(
        (norm1): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=192, out_features=576, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=192, out_features=192, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (ls1): Identity()
        (drop_path1): Identity()
        (norm2): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=192, out_features=768, bias=True)
          (act): GELU()
          (drop1): Dropout(p=0.0, inplace=False)
          (fc2): Linear(in_features=768, out_features=192, bias=True)
          (drop2)

In [12]:
# loggers 
loggers = []

# log to Tensorboard
loggers.append(TensorboardLogger())

# log to text file
loggers.append(TextLogger(open('log.txt', 'a')))

# print to stdout
loggers.append(InteractiveLogger())

# W&B logger - comment this if you don't have a W&B account
loggers.append(WandBLogger(project_name="avalanche", run_name="ewc-SplitCifar100-Vit"))

eval_plugin = EvaluationPlugin(
    accuracy_metrics(minibatch=True, epoch=True, experience=True, stream=True),
    loss_metrics(minibatch=True, epoch=True, experience=True, stream=True),
#     timing_metrics(epoch=True, epoch_running=True),
#     cpu_usage_metrics(experience=True),
    forgetting_metrics(experience=True, stream=True),
    bwt_metrics(experience=True, stream=True),
    confusion_matrix_metrics(num_classes=benchmark.n_classes, save_image=True,
                             stream=True),
#     disk_usage_metrics(minibatch=True, epoch=True, experience=True, stream=True),
    loggers=loggers
)


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mgrbagwe[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [13]:
# CREATE THE STRATEGY INSTANCE (EWC)
cl_strategy = EWC(
    model,
    optimizer=SGD(model.parameters(), lr=0.0001),
    criterion=CrossEntropyLoss(),
    ewc_lambda=0.7,
    train_mb_size=50, train_epochs=10,
    eval_mb_size=50,device='cuda',
    evaluator=eval_plugin)

In [None]:
# TRAINING LOOP
print('Starting experiment...')
results = []
for experience in benchmark.train_stream:
    print("Start training on experience ", experience.current_experience)

    cl_strategy.train(experience)
    print("End training on experience", experience.current_experience)
    print("Computing accuracy on the test set")
    results.append(cl_strategy.eval(benchmark.test_stream[:]))
        
        
# for experience in benchmark.train_stream:
#     # train returns a dictionary which contains all the metric values
#     res = cl_strategy.train(experience)
#     print('Training completed')

#     print('Computing accuracy on the whole test set')
    # test also returns a dictionary which contains all the metric values
#     results.append(cl_strategy.eval(benchmark.test_stream))
    


Starting experiment...
Start training on experience  0
-- >> Start of training phase << --
100%|██████████████████████████████████████████████████████████████████| 100/100 [00:28<00:00,  3.51it/s]
Epoch 0 ended.
	DiskUsage_Epoch/train_phase/train_stream/Task000 = 205938.8428
	DiskUsage_MB/train_phase/train_stream/Task000 = 205938.8428
	Loss_Epoch/train_phase/train_stream/Task000 = 3.6011
	Loss_MB/train_phase/train_stream/Task000 = 2.7669
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.1400
	Top1_Acc_MB/train_phase/train_stream/Task000 = 0.2200
100%|██████████████████████████████████████████████████████████████████| 100/100 [00:17<00:00,  5.81it/s]
Epoch 1 ended.
	DiskUsage_Epoch/train_phase/train_stream/Task000 = 206113.6230
	DiskUsage_MB/train_phase/train_stream/Task000 = 206113.6230
	Loss_Epoch/train_phase/train_stream/Task000 = 2.0657
	Loss_MB/train_phase/train_stream/Task000 = 1.5154
	Top1_Acc_Epoch/train_phase/train_stream/Task000 = 0.5086
	Top1_Acc_MB/train_phase/train_strea

100%|████████████████████████████████████████████████████████████████████| 20/20 [00:02<00:00,  7.39it/s]
> Eval on experience 8 (Task 0) from test stream ended.
	DiskUsage_Exp/eval_phase/test_stream/Task000/Exp008 = 207629.7979
	Loss_Exp/eval_phase/test_stream/Task000/Exp008 = 6.9010
	Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp008 = 0.0000
-- Starting eval on experience 9 (Task 0) from test stream --
100%|████████████████████████████████████████████████████████████████████| 20/20 [00:02<00:00,  7.05it/s]
> Eval on experience 9 (Task 0) from test stream ended.
	DiskUsage_Exp/eval_phase/test_stream/Task000/Exp009 = 207640.0049
	Loss_Exp/eval_phase/test_stream/Task000/Exp009 = 6.9569
	Top1_Acc_Exp/eval_phase/test_stream/Task000/Exp009 = 0.0000
-- >> End of eval phase << --
	ConfusionMatrix_Stream/eval_phase/test_stream = <avalanche.evaluation.metric_results.AlternativeValues object at 0x7fd8f3f80df0>
	DiskUsage_Stream/eval_phase/test_stream/Task000 = 207640.0049
	Loss_Stream/eval_pha

In [None]:
import wandb
wandb.finish()