In [1]:
import torch
import random
import numpy as np
import torch.backends.cudnn as cudnn

import dataloaders
from dataloaders.utils import *

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
seed = 0
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
cudnn.deterministic = True

In [3]:
dataset = 'CIFAR100'

# prepare dataloader
if dataset == 'CIFAR10':
    Dataset = dataloaders.iCIFAR10
    num_classes = 10
elif dataset == 'CIFAR100':
    Dataset = dataloaders.iCIFAR100
    num_classes = 100
elif dataset == 'TinyIMNET':
    Dataset = dataloaders.iTinyIMNET
    num_classes = 200
else:
    Dataset = dataloaders.H5Dataset
    num_classes = 100

In [4]:
# load tasks
rand_split = True
class_order = np.arange(num_classes).tolist()
class_order_logits = np.arange(num_classes).tolist()
if seed > 0 and rand_split:
    random.shuffle(class_order)

In [5]:
tasks = []
tasks_logits = []
p = 0
first_split_size = 5
other_split_size = 5

while p < num_classes:
    inc = other_split_size if p > 0 else first_split_size
    tasks.append(class_order[p:p+inc])
    tasks_logits.append(class_order_logits[p:p+inc])
    p += inc
num_tasks = len(tasks)
task_names = [str(i+1) for i in range(num_tasks)]

In [6]:
k = 2 # Append transform image and buffer image
ky = 1 # Not append transform for memory buffer

# datasets and dataloaders
dataroot = 'data'
labeled_samples = 10000 # image per task of CIFAR dataset 
unlabeled_task_samples = -1
l_dist = 'super' # if l_dist is super, then resample task
ul_dist = None
validation = False
repeat = 1

train_aug = True
train_transform = dataloaders.utils.get_transform(dataset=dataset, phase='train', aug=train_aug)
train_transformb = dataloaders.utils.get_transform(dataset=dataset, phase='train', aug=train_aug, hard_aug=True)
test_transform  = dataloaders.utils.get_transform(dataset=dataset, phase='test', aug=train_aug)

In [7]:
train_dataset = Dataset(dataroot, dataset, labeled_samples, unlabeled_task_samples, train=True, lab = True,
                        download=True, transform=TransformK(train_transform, train_transform, ky), l_dist=l_dist, ul_dist=ul_dist,
                        tasks=tasks, seed=seed, rand_split=rand_split, validation=validation, kfolds=repeat)
train_dataset_ul = Dataset(dataroot, dataset, labeled_samples, unlabeled_task_samples, train=True, lab = False,
                        download=True, transform=TransformK(train_transform, train_transformb, k), l_dist=l_dist, ul_dist=ul_dist,
                        tasks=tasks, seed=seed, rand_split=rand_split, validation=validation, kfolds=repeat)
test_dataset  = Dataset(dataroot, dataset, train=False,
                        download=False, transform=test_transform, l_dist=l_dist, ul_dist=ul_dist,
                        tasks=tasks, seed=seed, rand_split=rand_split, validation=validation, kfolds=repeat)

Files already downloaded and verified
not rand
Files already downloaded and verified
not rand
not rand


In [9]:
for t in range(20):
    valid_ul = np.arange(100)
    print(valid_ul)

[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
 96 97 98 99]
[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
 96 97 98 99]
[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
 96 97 98 99]
[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 24 25 26

In [17]:
print(train_dataset.lab)
print(train_dataset.ul_dist)
print(train_dataset.train)
print(np.array(train_dataset.course_targets).shape)
print(np.array(train_dataset.targets).shape)
print(train_dataset.num_classes)
print(np.array(train_dataset.valid_ul).shape)


(20,)
[2, 11, 35, 46, 98]
[54, 62, 70, 82, 92, 0, 51, 53, 57, 83, 47, 52, 56, 59, 96]
[[12, 17, 37, 68, 76, 23, 33, 49, 60, 71], [8, 13, 48, 58, 90, 41, 69, 81, 85, 89], [3, 42, 43, 88, 97, 15, 19, 21, 31, 38, 34, 63, 64, 66, 75, 27, 29, 44, 78, 93, 36, 50, 65, 74, 80], [2, 11, 35, 46, 98], [4, 30, 55, 72, 95, 1, 32, 67, 73, 91], [54, 62, 70, 82, 92, 0, 51, 53, 57, 83, 47, 52, 56, 59, 96], [3, 42, 43, 88, 97, 15, 19, 21, 31, 38, 34, 63, 64, 66, 75, 27, 29, 44, 78, 93, 36, 50, 65, 74, 80], [54, 62, 70, 82, 92, 0, 51, 53, 57, 83, 47, 52, 56, 59, 96], [9, 10, 16, 28, 61, 22, 39, 40, 86, 87, 5, 20, 25, 84, 94], [12, 17, 37, 68, 76, 23, 33, 49, 60, 71], [9, 10, 16, 28, 61, 22, 39, 40, 86, 87, 5, 20, 25, 84, 94], [6, 7, 14, 18, 24, 26, 45, 77, 79, 99], [54, 62, 70, 82, 92, 0, 51, 53, 57, 83, 47, 52, 56, 59, 96], [8, 13, 48, 58, 90, 41, 69, 81, 85, 89], [9, 10, 16, 28, 61, 22, 39, 40, 86, 87, 5, 20, 25, 84, 94], [3, 42, 43, 88, 97, 15, 19, 21, 31, 38, 34, 63, 64, 66, 75, 27, 29, 44, 78, 93, 3

  print(np.array(train_dataset.valid_ul).shape)
