In [1]:
import torch
from torch import optim, nn
from functools import partial
from initialize import *
from classes import *
from train import *
from data_proc import *
from tqdm import tqdm
from torchinfo import summary
from torch.utils.data import Dataset, DataLoader
import wandb 


if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = "cpu"

In [2]:
loaders = create_loaders()

100%|██████████| 10/10 [00:08<00:00,  1.22it/s]
100%|██████████| 10/10 [00:07<00:00,  1.35it/s]
100%|██████████| 10/10 [00:01<00:00,  6.16it/s]


In [15]:
from collections import defaultdict

counts = defaultdict(lambda: -1, {})
for data, label in loaders["val"][5]:
    label = label 
    for l in label:
        counts[int(l)] += 1

counts

defaultdict(<function __main__.<lambda>()>,
            {50: 49,
             54: 58,
             53: 54,
             55: 38,
             58: 50,
             57: 47,
             56: 44,
             51: 50,
             59: 45,
             52: 43})

In [72]:
data_path = "/nfs/scistore23/chlgrp/avolkova/rotation1/pipeline/data/"

train_data = datasets.CIFAR100(
    root=data_path,
    train=False,
    download=False,
    transform=nn.Sequential(T.ToImage())
    )


label_to_indices = defaultdict(lambda: torch.empty(0, dtype=int), {})
for idx, (_, label) in enumerate(train_data):
    label_to_indices[label] = torch.cat((label_to_indices[label], torch.tensor([idx], dtype=int)))


val_per_label = len(train_data) // 10 // 100
label_to_indx_val = defaultdict(lambda: torch.empty(0, dtype=int), {})
label_to_indx_train = defaultdict(lambda: torch.empty(0, dtype=int), {})

for label, indx in label_to_indices.items():
    perm = torch.randperm(len(indx))
    label_to_indx_val[label] = torch.cat((label_to_indx_val[label], indx[perm[:val_per_label]]))
    label_to_indx_train[label] = torch.cat((label_to_indx_train[label], indx[perm[val_per_label:]]))

label_to_indx_val[0].shape, label_to_indx_train[0].shape

task_to_indx = defaultdict(lambda: torch.empty(0, dtype=int), {})

for label, ind in label_to_indx_val.items():
    task_id = label // 10
    task_to_indx[task_id] = torch.cat((task_to_indx[task_id], ind))

task_to_indx

defaultdict(<function __main__.<lambda>()>,
            {4: tensor([5695, 7217, 6072, 5192, 8202, 9786, 8807, 3267, 4160,  990,   14, 9569,
                      302, 6323, 7612, 8558, 3967,  794, 3850, 3794, 6448, 1433,   66, 4711,
                     9314, 8905, 4266, 7263, 8545, 5979, 1210, 2126, 6743, 5899, 4047, 7058,
                     8371, 2142, 4903, 9776, 6401, 4677, 4278, 3346, 5472, 8774, 1324, 1843,
                     2570, 6061, 1608, 5620, 2722, 9320, 6708, 5919, 1497, 4089, 8408, 6174,
                     2705,  346, 9408, 4852, 5903, 9576, 6058, 1848, 5602, 2755, 8577, 1079,
                     5546, 8237, 3619, 2605, 7732, 9875, 5382, 8485, 5759, 3883, 3047, 7094,
                     3505, 2852, 1772, 5589, 5330, 1697, 9654, 2033, 3240, 3054, 1404, 4028,
                     6954, 6382, 3732, 1023]),
             3: tensor([3944, 6444, 1817, 6308, 4712,  863, 5052, 9702, 1306,  617, 9421, 8453,
                     3693, 5577, 2023, 5505, 4809, 5740, 2562, 621

In [33]:
label_to_indices[0]

[9,
 113,
 226,
 235,
 377,
 469,
 484,
 614,
 623,
 655,
 715,
 779,
 1014,
 1027,
 1136,
 1304,
 1308,
 1373,
 1738,
 1925,
 2077,
 2239,
 2271,
 2375,
 2426,
 2488,
 2512,
 2648,
 2865,
 2960,
 2988,
 3005,
 3066,
 3413,
 3492,
 3510,
 3518,
 3697,
 3712,
 3725,
 3751,
 3752,
 3775,
 4100,
 4144,
 4289,
 4340,
 4346,
 4433,
 4640,
 4666,
 4676,
 4760,
 4923,
 4965,
 5170,
 5306,
 5370,
 5496,
 5675,
 5679,
 5731,
 5807,
 5885,
 5900,
 5913,
 5998,
 6099,
 6212,
 6253,
 6415,
 6806,
 6827,
 6872,
 6922,
 7219,
 7365,
 7638,
 7655,
 7981,
 8019,
 8045,
 8131,
 8211,
 8250,
 8302,
 8314,
 8327,
 8341,
 8537,
 8578,
 8767,
 9082,
 9097,
 9200,
 9221,
 9387,
 9849,
 9872,
 9904]

In [36]:
perm = torch.randperm(len(label_to_indices[0]))
idx = perm[:5]
samples = label_to_indices[0][idx]

TypeError: only integer tensors of a single element can be converted to an index

In [21]:
train_data[0][1]

49

In [3]:
resnet = create_model("resnet").to(device)

In [6]:
optimizer = setup_optimizer(resnet.parameters(), lr=1e-3)
scheduler = setup_scheduler(optimizer)

trainer = ExperimentTrainer(loaders,
                            resnet, 
                            optimizer,
                            scheduler,
                            nn.NLLLoss(),
                            device,
                            1
                               )

In [7]:
wandb.init(project="test")
trainer.train([0], 1)

0,1
loss,▁
train_acc_0,▁
val_acc_0,▁

0,1
loss,2.66437
train_acc_0,0.10352
val_acc_0,0.09854


100%|██████████| 1/1 [00:05<00:00,  5.77s/it]

Epoch 0: loss =  2.37, train_acc = 0.14, val_acc = 0.15
torch.Size([0, 1]) torch.Size([1, 1])
Finished training on task 0... val_accuracy on task 0 = 0.150





(tensor([2.3750], device='cuda:0'),
 tensor([[0.1424]], device='cuda:0'),
 tensor([[0.1496]], device='cuda:0'))