In [1]:
import torch
import random
import numpy as np
import torch.backends.cudnn as cudnn

import learners
import dataloaders
from dataloaders.utils import *

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
seed = 0
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
cudnn.deterministic = True

In [3]:
dataset = 'CIFAR100'

# prepare dataloader
if dataset == 'CIFAR10':
    Dataset = dataloaders.iCIFAR10
    num_classes = 10
elif dataset == 'CIFAR100':
    Dataset = dataloaders.iCIFAR100
    num_classes = 100
elif dataset == 'TinyIMNET':
    Dataset = dataloaders.iTinyIMNET
    num_classes = 200
else:
    Dataset = dataloaders.H5Dataset
    num_classes = 100

In [4]:
# load tasks
rand_split = True
class_order = np.arange(num_classes).tolist()
class_order_logits = np.arange(num_classes).tolist()
if seed > 0 and rand_split:
    random.shuffle(class_order)

In [5]:
tasks = []
tasks_logits = []
p = 0
first_split_size = 5
other_split_size = 5

while p < num_classes:
    inc = other_split_size if p > 0 else first_split_size
    tasks.append(class_order[p:p+inc])
    tasks_logits.append(class_order_logits[p:p+inc])
    p += inc
num_tasks = len(tasks)
task_names = [str(i+1) for i in range(num_tasks)]

In [6]:
k = 2 # Append transform image and buffer image
ky = 1 # Not append transform for memory buffer

# datasets and dataloaders
dataroot = 'data'
labeled_samples = 10000 # image per task of CIFAR dataset 
unlabeled_task_samples = -1
l_dist = 'super' # if l_dist is super, then resample task
ul_dist = None
validation = False
repeat = 1

train_aug = True
train_transform = dataloaders.utils.get_transform(dataset=dataset, phase='train', aug=train_aug)
train_transformb = dataloaders.utils.get_transform(dataset=dataset, phase='train', aug=train_aug, hard_aug=True)
test_transform  = dataloaders.utils.get_transform(dataset=dataset, phase='test', aug=train_aug)

In [7]:
train_dataset = Dataset(dataroot, dataset, labeled_samples, unlabeled_task_samples, train=True, lab = True,
                        download=True, transform=TransformK(train_transform, train_transform, ky), l_dist=l_dist, ul_dist=ul_dist,
                        tasks=tasks, seed=seed, rand_split=rand_split, validation=validation, kfolds=repeat)
train_dataset_ul = Dataset(dataroot, dataset, labeled_samples, unlabeled_task_samples, train=True, lab = False,
                        download=True, transform=TransformK(train_transform, train_transformb, k), l_dist=l_dist, ul_dist=ul_dist,
                        tasks=tasks, seed=seed, rand_split=rand_split, validation=validation, kfolds=repeat)
test_dataset  = Dataset(dataroot, dataset, train=False,
                        download=False, transform=test_transform, l_dist=l_dist, ul_dist=ul_dist,
                        tasks=tasks, seed=seed, rand_split=rand_split, validation=validation, kfolds=repeat)

Files already downloaded and verified
Files already downloaded and verified


In [8]:
print("lab : ", train_dataset.lab)
print("l_dist : ", train_dataset.l_dist)
print("ul_dist : ", train_dataset.ul_dist)
print("train : ", train_dataset.train)
print("course_targets : ", np.array(train_dataset.course_targets).shape)
print("targets : ", np.array(train_dataset.targets).shape)
print("num_classes : ", train_dataset.num_classes)
print("valid_ul : ",  np.array(train_dataset.valid_ul).shape)
print("task : ", np.array(train_dataset.tasks).shape)
print("num_sample_ul : ", train_dataset.num_sample_ul)
print("archive : ", np.array(train_dataset.archive).shape)
print("coreset : ", np.array(train_dataset.coreset).shape)
print("coreset : ", train_dataset.coreset)
print("class_mapping : ", len(train_dataset.class_mapping))

lab :  True
l_dist :  super
ul_dist :  super
train :  True
course_targets :  (50000,)
targets :  (50000,)
num_classes :  100
valid_ul :  (20,)
task :  (20, 5)
num_sample_ul :  -1
archive :  (20, 2, 500)
coreset :  (2, 0)
coreset :  (array([], dtype=uint8), array([], dtype=int64))
class_mapping :  101


  print("valid_ul : ",  np.array(train_dataset.valid_ul).shape)
  print("archive : ", np.array(train_dataset.archive).shape)


In [9]:
print("lab : ", test_dataset.lab)
print("l_dist : ", test_dataset.l_dist)
print("ul_dist : ", test_dataset.ul_dist)
print("train : ", test_dataset.train)
print("course_targets : ", np.array(test_dataset.course_targets).shape)
print("targets : ", np.array(test_dataset.targets).shape)
print("num_classes : ", test_dataset.num_classes)
print("valid_ul : ",  np.array(test_dataset.valid_ul).shape)
print("task : ", np.array(test_dataset.tasks).shape)
print("class_mapping : ", len(test_dataset.class_mapping))

lab :  True
l_dist :  super
ul_dist :  super
train :  False
course_targets :  (10000,)
targets :  (10000,)
num_classes :  100
valid_ul :  (20,)
task :  (20, 5)
class_mapping :  101


  print("valid_ul : ",  np.array(test_dataset.valid_ul).shape)


In [10]:
# in case tasks reset
tasks = train_dataset.tasks

# Prepare the Learner (model)
workers = 8
batch_size = 64
ul_batch_size = 128
learner_config = {'num_classes': num_classes,
                    'lr': 0.1,
                    'ul_batch_size': 128,
                    'tpr': 0.05, # tpr for ood calibration of class network
                    'oodtpr': 0.05, # tpr for ood calibration of ood network
                    'momentum': 0.9,
                    'weight_decay': 5e-4,
                #   'schedule': [120, 160, 180, 200], # schedule and epoch(schedule[-1])
                    'schedule': [1, 2, 3, 4],
                    'schedule_type': 'decay',
                    'model_type': "resnet",
                    'model_name': "WideResNet_28_2_cifar",
                    'ood_model_name': 'WideResNet_DC_28_2_cifar',
                    'out_dim': 100,
                    'optimizer': 'SGD',
                    'gpuid': [0],
                    'pl_flag': True, # use pseudo-labeled ul data for DM -> ???
                    'fm_loss': True, # Use fix-match loss with classifier -> Consistency Regularization / eq.4 -> unsupervised loss
                    'weight_aux': 1.0,
                    'memory': 400,
                    'distill_loss': 'C',
                    'co': 1., # out-of-distribution confidence loss ratio
                    'FT': True, # finetune distillation -> 이거 필요한가???
                    'DW': True, # dataset balancing
                    'num_labeled_samples': labeled_samples,
                    'num_unlabeled_samples': unlabeled_task_samples,
                    'super_flag': l_dist == "super",
                    'no_unlabeled_data': False
                    }

In [11]:
learner = learners.distillmatch.DistillMatch(learner_config)

In [12]:
print("first_task : ", learner.first_task)
print("oodtpr : ", learner.oodtpr)
print("tpr : ", learner.tpr)
print("num_classes : ", learner.num_classes)
print("pl_flag : ", learner.pl_flag)
print("prob_threshold_class : ", learner.prob_threshold_class)
print("prob_threshold_ood : ", learner.prob_threshold_ood)
print("fm : ", learner.fm)
print("model.last.in_features : ", learner.model.last.in_features)
print("model.last : ", learner.model.last)
print("reset_optimizer : ", learner.reset_optimizer)
print("dw : ", learner.dw)
print("dw_thresh : ", learner.dw_thresh)
print("last_valid_out_dim : ", learner.last_valid_out_dim)
print("valid_out_dim : ", learner.valid_out_dim)
print("memory_size : ", learner.memory_size)
print("task_count : ", learner.task_count)
print("weight_aux : ", learner.weight_aux)
print("schedule_type : ", learner.schedule_type)
print("ft : ", learner.ft)
print("schedule : ", learner.schedule)
print("distf : ", learner.distf)
print("tasks : ", learner.tasks)
print("past_tasks : ", learner.past_tasks)
print("ood_holdout_ratio : ", learner.ood_holdout_ratio)
print("dc_eps_values : ", learner.dc_eps_values)
print("grad_clip : ", learner.grad_clip)
print("num_deltas : ", learner.num_deltas)
print("num_delta_loop : ", learner.num_delta_loop)

first_task :  True
oodtpr :  0.05
tpr :  0.05
num_classes :  100
pl_flag :  True
prob_threshold_class :  0.0
prob_threshold_ood :  0.0
fm :  {'thresh': 0.85}
model.last.in_features :  128
model.last :  Linear(in_features=128, out_features=100, bias=True)
reset_optimizer :  True
dw :  True
dw_thresh :  10.0
last_valid_out_dim :  0
valid_out_dim :  0
memory_size :  400
task_count :  0
weight_aux :  1.0
schedule_type :  decay
ft :  True
schedule :  [0, 2, 2]
distf :  KLDivLoss()
tasks :  0
past_tasks :  []
ood_holdout_ratio :  0.5
dc_eps_values :  [0.0025, 0.005, 0.001, 0.002, 0.004, 0.08]
grad_clip :  1
num_deltas :  100
num_delta_loop :  10


In [13]:
from collections import OrderedDict
from torch.utils.data import DataLoader

In [14]:
acc_table = OrderedDict()
acc_table_pt = OrderedDict()
run_ood = {}

log_dir = "outputs/CIFAR100-10k/realistic/dm"
save_table = []
save_table_pc = -1 * np.ones((num_tasks,num_tasks))
pl_table = [[],[],[],[]]
temp_dir = log_dir + '/temp'
if not os.path.exists(temp_dir): os.makedirs(temp_dir)

# Training
max_task = -1
if max_task > 0:
    max_task = min(max_task, len(task_names))
else:
    max_task = len(task_names)

for i in range(max_task):
    # set seeds
    random.seed(i)
    np.random.seed(i)
    torch.manual_seed(i)
    torch.cuda.manual_seed(i)

    train_name = task_names[i]
    print('======================', train_name, '=======================')

    # load dataset for task
    task = tasks_logits[i]
    prev = sorted(set([k for task in tasks_logits[:i] for k in task]))

    train_dataset.load_dataset(prev, i, train=True)
    train_dataset_ul.load_dataset(prev, i, train=True)
    out_dim_add = len(task)

    # load dataset with memory
    train_dataset.append_coreset(only=False)

    # load dataloader
    train_loader_l = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=False, num_workers=int(workers / 2))
    train_loader_ul = DataLoader(train_dataset_ul, batch_size=ul_batch_size, shuffle=True, drop_last=False, num_workers=int(workers / 2))
    train_loader_ul_task = DataLoader(train_dataset_ul, batch_size=batch_size, shuffle=True, drop_last=False, num_workers=int(workers / 2))
    train_loader = dataloaders.SSLDataLoader(train_loader_l, train_loader_ul) # return labeled data, unlabeled data

    # add valid class to classifier
    learner.add_valid_output_dim(out_dim_add) # return number of classes learned to the current task

    # Learn
    test_dataset.load_dataset(prev, i, train=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=workers)

    model_save_dir = log_dir + '/models/repeat-'+str(seed+1)+'/task-'+task_names[i]+'/'
    if not os.path.exists(model_save_dir): os.makedirs(model_save_dir)

    learner.learn_batch(train_loader, train_dataset, train_dataset_ul, model_save_dir, test_loader)

Incremental class: Old valid output dimension: 0
Incremental class: New Valid output dimension: 5
Optimizer is reset!
*************************
num seen:[100. 100. 100. 100. 100.]
*************************
 * Val Acc 20.000, Total time 0.60
=> Load Done
Searching the best perturbation magnitude on in-domain data. Magnitude: [0.0025, 0.005, 0.001, 0.002, 0.004, 0.08]
Magnitude: 0.0025 loss: 0.13288981628417967
Magnitude: 0.005 loss: 0.13331765747070312
Magnitude: 0.001 loss: 0.1322540740966797
Magnitude: 0.002 loss: 0.13312376403808593
Magnitude: 0.004 loss: 0.1339483642578125
Magnitude: 0.08 loss: 0.14960356140136719
New Threshold OOD: 0.1405347958 | TPR: 0.0500
New Threshold Class: 0.1405347958 | TPR: 0.0500
k :  4
data shape :  (1, 80, 32, 32, 3)
k :  3
data shape :  (2, 80, 32, 32, 3)
k :  2
data shape :  (3, 80, 32, 32, 3)
k :  1
data shape :  (4, 80, 32, 32, 3)
k :  0
data shape :  (5, 80, 32, 32, 3)
Incremental class: Old valid output dimension: 5
Incremental class: New Valid out

  print('data shape : ', np.array(data).shape)


Incremental class: Old valid output dimension: 15
Incremental class: New Valid output dimension: 20
Optimizer is reset!
*************************
num seen:[130. 130. 130. 130. 130. 135. 135. 135. 135. 135. 135. 135. 135. 135.
 135. 500. 500. 500. 500. 500.]
*************************
 * Val Acc 15.500, Total time 0.80
=> Load Done
Searching the best perturbation magnitude on in-domain data. Magnitude: [0.0025, 0.005, 0.001, 0.002, 0.004, 0.08]
Magnitude: 0.0025 loss: 0.9298486328125
Magnitude: 0.005 loss: 0.9412433539496528
Magnitude: 0.001 loss: 1.013521728515625
Magnitude: 0.002 loss: 0.8575259060329861
Magnitude: 0.004 loss: 0.8704815673828125
Magnitude: 0.08 loss: 1.5289657931857639
New Threshold OOD: 3.3981684446 | TPR: 0.0501
New Threshold Class: 3.3981684446 | TPR: 0.0501
k :  19
data shape :  (1, 20, 32, 32, 3)
k :  18
data shape :  (2, 20, 32, 32, 3)
k :  17
data shape :  (3, 20, 32, 32, 3)
k :  16
data shape :  (4, 20, 32, 32, 3)
k :  15
data shape :  (5, 20, 32, 32, 3)
k :  1

KeyboardInterrupt: 