In [1]:
import torch
from torch import optim, nn
from functools import partial
from initialize import *
from classes import *
from data_proc import *
from tqdm import tqdm
import timeit

In [2]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = "cpu"
print("Device:", device, torch.cuda.get_device_name(0))

Device: cuda NVIDIA GeForce RTX 3090


In [3]:
dataloaders = create_loaders(replay=True)

100%|██████████| 10/10 [00:35<00:00,  3.57s/it]
100%|██████████| 10/10 [00:08<00:00,  1.14it/s]
100%|██████████| 10/10 [00:09<00:00,  1.05it/s]


In [4]:
model = create_model("vit").to(device)
if torch.cuda.get_device_name(0) not in ["Tesla P100-PCIE-16GB", "NVIDIA GeForce GTX 1080 Ti"]:
    model = torch.compile(model)

In [5]:
optim = setup_optimizer(model.parameters(), lr=1e-4, weight_decay=1e-5)
scheduler = setup_scheduler(optim)

trainer = ExperimentTrainer(dataloaders,
                            model, 
                            optim,
                            scheduler,
                            nn.NLLLoss(),
                            device,
                            10
                           )

In [None]:
loss, acc_train, acc_val = trainer.train_class_inc(15)

  7%|▋         | 1/15 [09:40<2:15:27, 580.55s/it]

Epoch 0: loss =  2.64, train_acc = 0.22, val_acc = 0.32


 13%|█▎        | 2/15 [14:52<1:31:33, 422.59s/it]

Epoch 1: loss =  2.03, train_acc = 0.28, val_acc = 0.41


 20%|██        | 3/15 [20:04<1:14:24, 372.02s/it]

Epoch 2: loss =  1.55, train_acc = 0.34, val_acc = 0.52


 27%|██▋       | 4/15 [25:16<1:03:53, 348.45s/it]

Epoch 3: loss =  1.93, train_acc = 0.39, val_acc = 0.57


 33%|███▎      | 5/15 [30:28<55:53, 335.34s/it]  

Epoch 4: loss =  2.13, train_acc = 0.41, val_acc = 0.54


 40%|████      | 6/15 [35:40<49:05, 327.32s/it]

Epoch 5: loss =  1.30, train_acc = 0.44, val_acc = 0.62


 47%|████▋     | 7/15 [40:52<42:58, 322.32s/it]

Epoch 6: loss =  1.77, train_acc = 0.47, val_acc = 0.62


 53%|█████▎    | 8/15 [46:04<37:13, 319.02s/it]

Epoch 7: loss =  1.71, train_acc = 0.49, val_acc = 0.66


 60%|██████    | 9/15 [51:15<31:38, 316.36s/it]

Epoch 8: loss =  1.79, train_acc = 0.50, val_acc = 0.68


 67%|██████▋   | 10/15 [56:25<26:12, 314.53s/it]

Epoch 9: loss =  1.55, train_acc = 0.50, val_acc = 0.65


 73%|███████▎  | 11/15 [1:01:35<20:52, 313.17s/it]

Epoch 10: loss =  1.16, train_acc = 0.52, val_acc = 0.66


 80%|████████  | 12/15 [1:06:46<15:37, 312.38s/it]

Epoch 11: loss =  1.61, train_acc = 0.54, val_acc = 0.71


 87%|████████▋ | 13/15 [1:11:57<10:23, 311.96s/it]

Epoch 12: loss =  1.71, train_acc = 0.56, val_acc = 0.70


 93%|█████████▎| 14/15 [1:17:06<05:11, 311.18s/it]

Epoch 13: loss =  2.35, train_acc = 0.57, val_acc = 0.70


100%|██████████| 15/15 [1:22:18<00:00, 329.25s/it]


Epoch 14: loss =  1.73, train_acc = 0.58, val_acc = 0.72
Finished training on task 0... val_accuracy on task 0 = 0.716


  7%|▋         | 1/15 [05:55<1:22:56, 355.46s/it]

Epoch 0: loss =  2.35, train_acc = 0.35, val_acc = 0.47


 13%|█▎        | 2/15 [11:20<1:13:08, 337.58s/it]

Epoch 1: loss =  1.78, train_acc = 0.38, val_acc = 0.51


 20%|██        | 3/15 [16:56<1:07:20, 336.71s/it]

Epoch 2: loss =  2.32, train_acc = 0.41, val_acc = 0.56


 27%|██▋       | 4/15 [22:20<1:00:49, 331.75s/it]

Epoch 3: loss =  1.77, train_acc = 0.42, val_acc = 0.57


 33%|███▎      | 5/15 [31:21<1:07:53, 407.35s/it]

Epoch 4: loss =  2.13, train_acc = 0.46, val_acc = 0.61


 40%|████      | 6/15 [40:42<1:08:55, 459.55s/it]

Epoch 5: loss =  1.46, train_acc = 0.47, val_acc = 0.61


 47%|████▋     | 7/15 [50:04<1:05:44, 493.10s/it]

Epoch 6: loss =  1.34, train_acc = 0.49, val_acc = 0.66


 53%|█████▎    | 8/15 [59:27<1:00:06, 515.21s/it]

Epoch 7: loss =  1.42, train_acc = 0.50, val_acc = 0.65


 60%|██████    | 9/15 [1:08:49<52:58, 529.80s/it]

Epoch 8: loss =  1.12, train_acc = 0.51, val_acc = 0.67


 67%|██████▋   | 10/15 [1:18:10<44:58, 539.67s/it]

Epoch 9: loss =  1.67, train_acc = 0.53, val_acc = 0.69


 73%|███████▎  | 11/15 [1:27:31<36:24, 546.02s/it]

Epoch 10: loss =  1.37, train_acc = 0.53, val_acc = 0.68


 80%|████████  | 12/15 [1:36:53<27:32, 550.97s/it]

Epoch 11: loss =  1.78, train_acc = 0.55, val_acc = 0.72


 87%|████████▋ | 13/15 [1:46:14<18:28, 554.00s/it]

Epoch 12: loss =  1.87, train_acc = 0.56, val_acc = 0.72


 93%|█████████▎| 14/15 [1:55:36<09:16, 556.32s/it]

Epoch 13: loss =  1.37, train_acc = 0.56, val_acc = 0.72


100%|██████████| 15/15 [2:03:25<00:00, 493.67s/it]


Epoch 14: loss =  1.49, train_acc = 0.57, val_acc = 0.72
Finished training on task 1... val_accuracy on task 0 = 0.773


  7%|▋         | 1/15 [05:42<1:19:57, 342.65s/it]

Epoch 0: loss =  2.04, train_acc = 0.48, val_acc = 0.63


 13%|█▎        | 2/15 [11:24<1:14:11, 342.41s/it]

Epoch 1: loss =  1.84, train_acc = 0.49, val_acc = 0.63


 20%|██        | 3/15 [17:07<1:08:32, 342.71s/it]

Epoch 2: loss =  1.78, train_acc = 0.51, val_acc = 0.67


 27%|██▋       | 4/15 [22:50<1:02:50, 342.74s/it]

Epoch 3: loss =  1.16, train_acc = 0.52, val_acc = 0.67


 33%|███▎      | 5/15 [28:34<57:09, 342.94s/it]  

Epoch 4: loss =  1.77, train_acc = 0.54, val_acc = 0.68


 40%|████      | 6/15 [34:17<51:29, 343.23s/it]

Epoch 5: loss =  2.00, train_acc = 0.55, val_acc = 0.70


 47%|████▋     | 7/15 [40:01<45:47, 343.43s/it]

Epoch 6: loss =  1.47, train_acc = 0.56, val_acc = 0.71


 53%|█████▎    | 8/15 [45:45<40:04, 343.55s/it]

Epoch 7: loss =  1.64, train_acc = 0.57, val_acc = 0.71


 60%|██████    | 9/15 [51:28<34:21, 343.51s/it]

Epoch 8: loss =  1.80, train_acc = 0.58, val_acc = 0.72


 67%|██████▋   | 10/15 [57:12<28:37, 343.59s/it]

Epoch 9: loss =  1.53, train_acc = 0.59, val_acc = 0.74


 73%|███████▎  | 11/15 [1:02:56<22:54, 343.60s/it]

Epoch 10: loss =  1.91, train_acc = 0.59, val_acc = 0.73


 80%|████████  | 12/15 [1:08:39<17:10, 343.54s/it]

Epoch 11: loss =  1.34, train_acc = 0.60, val_acc = 0.73


 87%|████████▋ | 13/15 [1:14:23<11:27, 343.52s/it]

Epoch 12: loss =  1.63, train_acc = 0.61, val_acc = 0.75


 93%|█████████▎| 14/15 [1:20:06<05:43, 343.54s/it]

Epoch 13: loss =  1.57, train_acc = 0.62, val_acc = 0.74


100%|██████████| 15/15 [1:25:50<00:00, 343.34s/it]


Epoch 14: loss =  1.48, train_acc = 0.62, val_acc = 0.75
Finished training on task 2... val_accuracy on task 0 = 0.764


  7%|▋         | 1/15 [05:59<1:23:56, 359.75s/it]

Epoch 0: loss =  1.98, train_acc = 0.52, val_acc = 0.63


 13%|█▎        | 2/15 [12:00<1:18:05, 360.42s/it]

Epoch 1: loss =  1.69, train_acc = 0.54, val_acc = 0.67


 20%|██        | 3/15 [18:00<1:12:03, 360.33s/it]

Epoch 2: loss =  1.98, train_acc = 0.55, val_acc = 0.69


 27%|██▋       | 4/15 [24:01<1:06:02, 360.27s/it]

Epoch 3: loss =  1.59, train_acc = 0.57, val_acc = 0.70


 33%|███▎      | 5/15 [30:01<1:00:04, 360.46s/it]

Epoch 4: loss =  1.61, train_acc = 0.58, val_acc = 0.69


 40%|████      | 6/15 [36:03<54:07, 360.84s/it]  

Epoch 5: loss =  1.79, train_acc = 0.62, val_acc = 0.72


 47%|████▋     | 7/15 [42:04<48:07, 360.94s/it]

Epoch 6: loss =  1.22, train_acc = 0.62, val_acc = 0.73


 53%|█████▎    | 8/15 [48:05<42:05, 360.80s/it]

Epoch 7: loss =  1.31, train_acc = 0.63, val_acc = 0.74


 60%|██████    | 9/15 [54:05<36:04, 360.73s/it]

Epoch 8: loss =  1.39, train_acc = 0.64, val_acc = 0.74


 67%|██████▋   | 10/15 [1:00:06<30:04, 360.88s/it]

Epoch 9: loss =  1.06, train_acc = 0.65, val_acc = 0.75


 73%|███████▎  | 11/15 [1:06:07<24:03, 360.79s/it]

Epoch 10: loss =  1.51, train_acc = 0.64, val_acc = 0.74


 80%|████████  | 12/15 [1:12:07<18:01, 360.49s/it]

Epoch 11: loss =  1.32, train_acc = 0.65, val_acc = 0.75


 87%|████████▋ | 13/15 [1:18:07<12:00, 360.35s/it]

Epoch 12: loss =  1.56, train_acc = 0.66, val_acc = 0.76


 93%|█████████▎| 14/15 [1:24:07<06:00, 360.23s/it]

Epoch 13: loss =  1.08, train_acc = 0.66, val_acc = 0.76


100%|██████████| 15/15 [1:30:06<00:00, 360.44s/it]


Epoch 14: loss =  1.29, train_acc = 0.67, val_acc = 0.76
Finished training on task 3... val_accuracy on task 0 = 0.797


  7%|▋         | 1/15 [06:16<1:27:49, 376.43s/it]

Epoch 0: loss =  1.55, train_acc = 0.58, val_acc = 0.69


 13%|█▎        | 2/15 [12:32<1:21:31, 376.26s/it]

Epoch 1: loss =  1.62, train_acc = 0.61, val_acc = 0.70


 20%|██        | 3/15 [18:48<1:15:15, 376.32s/it]

Epoch 2: loss =  0.98, train_acc = 0.62, val_acc = 0.71


 27%|██▋       | 4/15 [25:05<1:08:58, 376.24s/it]

Epoch 3: loss =  1.25, train_acc = 0.62, val_acc = 0.72


 33%|███▎      | 5/15 [31:21<1:02:43, 376.31s/it]

Epoch 4: loss =  2.20, train_acc = 0.63, val_acc = 0.72


 40%|████      | 6/15 [37:38<56:27, 376.39s/it]  

Epoch 5: loss =  1.39, train_acc = 0.63, val_acc = 0.73


 47%|████▋     | 7/15 [43:53<50:09, 376.17s/it]

Epoch 6: loss =  1.63, train_acc = 0.65, val_acc = 0.74


 53%|█████▎    | 8/15 [50:09<43:53, 376.16s/it]

Epoch 7: loss =  1.25, train_acc = 0.65, val_acc = 0.73


 60%|██████    | 9/15 [56:26<37:37, 376.23s/it]

Epoch 8: loss =  2.00, train_acc = 0.65, val_acc = 0.73


 67%|██████▋   | 10/15 [1:02:44<31:23, 376.74s/it]

Epoch 9: loss =  1.32, train_acc = 0.66, val_acc = 0.74


 73%|███████▎  | 11/15 [1:09:02<25:09, 377.32s/it]

Epoch 10: loss =  2.19, train_acc = 0.66, val_acc = 0.75


 80%|████████  | 12/15 [1:15:22<18:53, 377.92s/it]

Epoch 11: loss =  1.75, train_acc = 0.67, val_acc = 0.74


 87%|████████▋ | 13/15 [1:21:41<12:36, 378.38s/it]

Epoch 12: loss =  2.19, train_acc = 0.67, val_acc = 0.75


 93%|█████████▎| 14/15 [1:28:00<06:18, 378.59s/it]

Epoch 13: loss =  1.97, train_acc = 0.68, val_acc = 0.75


100%|██████████| 15/15 [1:34:19<00:00, 377.31s/it]


Epoch 14: loss =  2.07, train_acc = 0.68, val_acc = 0.75
Finished training on task 4... val_accuracy on task 0 = 0.801


  7%|▋         | 1/15 [06:36<1:32:25, 396.08s/it]

Epoch 0: loss =  1.52, train_acc = 0.62, val_acc = 0.71


 13%|█▎        | 2/15 [13:12<1:25:49, 396.14s/it]

Epoch 1: loss =  1.61, train_acc = 0.63, val_acc = 0.72


 20%|██        | 3/15 [19:48<1:19:11, 395.98s/it]

Epoch 2: loss =  1.14, train_acc = 0.64, val_acc = 0.71


 27%|██▋       | 4/15 [26:23<1:12:35, 395.96s/it]

Epoch 3: loss =  1.14, train_acc = 0.65, val_acc = 0.73


 33%|███▎      | 5/15 [33:00<1:06:00, 396.04s/it]

Epoch 4: loss =  1.06, train_acc = 0.66, val_acc = 0.73


 40%|████      | 6/15 [39:36<59:24, 396.09s/it]  

Epoch 5: loss =  2.08, train_acc = 0.66, val_acc = 0.74


 47%|████▋     | 7/15 [46:12<52:47, 395.98s/it]

Epoch 6: loss =  1.74, train_acc = 0.66, val_acc = 0.74


 53%|█████▎    | 8/15 [52:48<46:12, 396.00s/it]

Epoch 7: loss =  2.45, train_acc = 0.67, val_acc = 0.74


 60%|██████    | 9/15 [59:24<39:36, 396.00s/it]

Epoch 8: loss =  0.87, train_acc = 0.68, val_acc = 0.74


 67%|██████▋   | 10/15 [1:05:59<32:59, 395.94s/it]

Epoch 9: loss =  1.23, train_acc = 0.68, val_acc = 0.73


 73%|███████▎  | 11/15 [1:12:35<26:23, 395.88s/it]

Epoch 10: loss =  1.09, train_acc = 0.69, val_acc = 0.75


 80%|████████  | 12/15 [1:19:11<19:47, 395.94s/it]

Epoch 11: loss =  1.10, train_acc = 0.69, val_acc = 0.75


 87%|████████▋ | 13/15 [1:25:47<13:11, 395.99s/it]

Epoch 12: loss =  1.18, train_acc = 0.69, val_acc = 0.75


 93%|█████████▎| 14/15 [1:32:23<06:35, 395.94s/it]

Epoch 13: loss =  0.76, train_acc = 0.69, val_acc = 0.76


100%|██████████| 15/15 [1:38:59<00:00, 395.97s/it]


Epoch 14: loss =  2.13, train_acc = 0.70, val_acc = 0.75
Finished training on task 5... val_accuracy on task 0 = 0.796


  7%|▋         | 1/15 [06:52<1:36:16, 412.61s/it]

Epoch 0: loss =  1.60, train_acc = 0.64, val_acc = 0.70


 13%|█▎        | 2/15 [13:45<1:29:24, 412.63s/it]

Epoch 1: loss =  1.26, train_acc = 0.66, val_acc = 0.71


 20%|██        | 3/15 [20:37<1:22:29, 412.46s/it]

Epoch 2: loss =  1.48, train_acc = 0.66, val_acc = 0.72


 27%|██▋       | 4/15 [27:30<1:15:37, 412.51s/it]

Epoch 3: loss =  1.22, train_acc = 0.67, val_acc = 0.73


 33%|███▎      | 5/15 [34:22<1:08:45, 412.55s/it]

Epoch 4: loss =  1.16, train_acc = 0.68, val_acc = 0.73


 40%|████      | 6/15 [41:14<1:01:51, 412.41s/it]

Epoch 5: loss =  1.34, train_acc = 0.68, val_acc = 0.73


 47%|████▋     | 7/15 [48:07<54:59, 412.40s/it]  

Epoch 6: loss =  0.80, train_acc = 0.68, val_acc = 0.73


 53%|█████▎    | 8/15 [54:59<48:07, 412.48s/it]

Epoch 7: loss =  1.68, train_acc = 0.69, val_acc = 0.74


 60%|██████    | 9/15 [1:01:52<41:14, 412.44s/it]

Epoch 8: loss =  1.56, train_acc = 0.70, val_acc = 0.74


 67%|██████▋   | 10/15 [1:08:42<34:18, 411.71s/it]

Epoch 9: loss =  1.62, train_acc = 0.71, val_acc = 0.74


 73%|███████▎  | 11/15 [1:15:29<27:21, 410.31s/it]

Epoch 10: loss =  1.22, train_acc = 0.72, val_acc = 0.75


 80%|████████  | 12/15 [1:22:16<20:28, 409.35s/it]

Epoch 11: loss =  1.25, train_acc = 0.73, val_acc = 0.75


 87%|████████▋ | 13/15 [1:29:03<13:37, 408.55s/it]

Epoch 12: loss =  0.96, train_acc = 0.73, val_acc = 0.75


 93%|█████████▎| 14/15 [1:35:50<06:48, 408.05s/it]

Epoch 13: loss =  0.82, train_acc = 0.74, val_acc = 0.76


100%|██████████| 15/15 [1:42:37<00:00, 410.48s/it]


Epoch 14: loss =  1.04, train_acc = 0.74, val_acc = 0.76
Finished training on task 6... val_accuracy on task 0 = 0.782


  7%|▋         | 1/15 [07:03<1:38:47, 423.40s/it]

Epoch 0: loss =  1.80, train_acc = 0.66, val_acc = 0.69


 13%|█▎        | 2/15 [14:07<1:31:46, 423.61s/it]

Epoch 1: loss =  1.60, train_acc = 0.68, val_acc = 0.71


 20%|██        | 3/15 [21:11<1:24:45, 423.79s/it]

Epoch 2: loss =  1.45, train_acc = 0.69, val_acc = 0.72


In [None]:
path = "../results/notebook/replay_vit/"
torch.save(loss, f"{path}/loss.pt")
torch.save(acc_train, f"{path}/acc_train.pt")
torch.save(acc_val, f"{path}/acc_val.pt")