In [1]:
from torchvision.datasets import MNIST, FashionMNIST, CIFAR10
import torchvision
import numpy as np
import random

import torch
import torch.nn.functional as F
import cl_gym as cl

import sys
import os

seed = 0

np.random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
torch.cuda.manual_seed_all(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.enabled = False
torch.set_num_threads(8)

def make_params() -> dict:
    import os
    from pathlib import Path
    import uuid

    params = {
            # dataset
            'dataset': "Drug",
            'fairness_agg': 'mean',
            # 'model': 'MLP',

            # benchmark
            'seed': seed,
            'num_tasks': 3,
            'epochs_per_task': 25,
            'per_task_examples': np.inf,
            # 'per_task_examples': 10000,
            'per_task_memory_examples': 64,
            'batch_size_train': 64,
            'batch_size_memory': 64,
            'batch_size_validation': 256,
            'tau': 1.0,

            # algorithm
            'optimizer': 'sgd',
            'learning_rate': 0.1,
            'momentum': 0.9,
            'learning_rate_decay': 1.0,
            # 'criterion': torch.nn.CrossEntropyLoss(),
            'criterion': torch.nn.BCEWithLogitsLoss(),

            'device': torch.device('cuda:7' if torch.cuda.is_available() else 'cpu'),
             
            # sample selection
            'alpha': 0.00,
            'metric' : "DP",
            'lambda': 0.0,
            'lambda_old': 0.0,

            # postprocessing
              }
    

#     trial_id = str(uuid.uuid4())
    trial_id = f"demo/dataset={params['dataset']}/seed={params['seed']}_epoch={params['epochs_per_task']}_lr={params['learning_rate']}_tau={params['tau']}_alpha={params['alpha']}"
    if params['lambda'] != 0:
        trial_id+=f"_lmbd_{params['lambda']}_lmbdold_{params['lambda_old']}"
    params['trial_id'] = trial_id
    params['output_dir'] = os.path.join("./outputs/{}".format(trial_id))
    print(f"output_dir={params['output_dir']}")
    Path(params['output_dir']).mkdir(parents=True, exist_ok=True)

    return params

params = make_params()

output_dir=./outputs/demo/dataset=Drug/seed=0_epoch=25_lr=0.1_tau=1.0_alpha=0.0


In [2]:
from datasets import Drug

if params['dataset'] in ["Drug"]:
    benchmark = Drug(num_tasks=params['num_tasks'],
                        per_task_memory_examples=params['per_task_memory_examples'],
                        per_task_examples = params['per_task_examples'],
                        random_class_idx = False)
    input_dim = (12)
    class_idx = benchmark.class_idx
    num_classes = len(class_idx)

In [3]:
from trainers import FairContinualTrainer
from trainers.fair_trainer import FairContinualTrainer2
from metrics import FairMetricCollector
from metrics import MetricCollector2

from algorithms import Heuristic3
from algorithms.fairl import FaIRL
from algorithms.icarl import iCaRL
from backbones import MLP2Layers2

backbone = MLP2Layers2(
    input_dim=input_dim, 
    hidden_dim_1=256, 
    hidden_dim_2=256, 
    output_dim=num_classes,
    class_idx=class_idx,
    config=params
    ).to(params['device'])
# backbone = ResNet18Small2(
#         input_dim=input_dim, 
#         output_dim=num_classes,
#         class_idx=class_idx,
#         config=params
#     ).to(params['device'])
algorithm = iCaRL(backbone, benchmark, params, requires_memory=True)
# algorithm = Heuristic3(backbone, benchmark, params, requires_memory=True)

metric_manager_callback = FairMetricCollector(num_tasks=params['num_tasks'],
                                                        eval_interval='epoch',
                                                        epochs_per_task=params['epochs_per_task'])
# metric_manager_callback = MetricCollector2(num_tasks=params['num_tasks'],
#                                                         eval_interval='epoch',
#                                                         epochs_per_task=params['epochs_per_task'])
# from trainers.baselines import BaseMemoryContinualTrainer as ContinualTrainer
from trainers.baselines import BaseContinualTrainer as ContinualTrainer

trainer = ContinualTrainer(algorithm, params, callbacks=[metric_manager_callback])
# 
# trainer = FairContinualTrainer2(algorithm, params, callbacks=[metric_manager_callback])


iCaRL


In [5]:
if params['fairness_agg'] == "mean":
    agg = np.mean
elif params['fairness_agg'] == "max":
    agg = np.max
else:
    raise NotImplementedError

fairness_metrics = ["std", "EER", "EO", "DP"]
for metric in metric_manager_callback.meters:
    if metric in fairness_metrics:
        metric_manager_callback.meters[metric].agg = agg


In [6]:
trainer.run()
print("final avg-acc", metric_manager_callback.meters['accuracy'].compute_final())
print("final avg-forget", metric_manager_callback.meters['forgetting'].compute_final())

---------------------------- Task 1 -----------------------
[1] Eval metrics for task 1 >> {'accuracy': 0.618000687049124, 'loss': 0.0, 'std': 0.04053589831672966, 'EER': -1, 'EO': [0.20526315789473681, 0.3593004769475357], 'DP': -1, 'accuracy_s0': 0.6323529411764706, 'accuracy_s1': 0.5553342816500711, 'classwise_accuracy': {0: array([ 81, 123]), 1: array([41, 71])}, 'DP_ingredients': {'class_pred_count_s0': {1: 40, 0: 22}, 'class_pred_count_s1': {1: 43, 0: 89}, 'class_pred_count': {1: 83, 0: 111}, 'count_s0': 62, 'count_s1': 132, 'count': 194}}
[2] Eval metrics for task 1 >> {'accuracy': 0.607981220657277, 'loss': 0.0, 'std': 0.05868544600938963, 'EER': -1, 'EO': [0.16954887218045112, 0.41335453100158975], 'DP': -1, 'accuracy_s0': 0.6502100840336134, 'accuracy_s1': 0.5283072546230441, 'classwise_accuracy': {0: array([ 82, 123]), 1: array([39, 71])}, 'DP_ingredients': {'class_pred_count_s0': {1: 39, 0: 23}, 'class_pred_count_s1': {0: 91, 1: 41}, 'class_pred_count': {0: 114, 1: 80}, 'co

In [8]:
metric_manager_callback.meters['accuracy'].get_data()

array([[0.614, 0.   , 0.   ],
       [0.461, 0.442, 0.   ],
       [0.448, 0.213, 0.367]])

In [9]:
np.mean(metric_manager_callback.meters['accuracy'].compute_overall())

0.4692979562547006

In [10]:
[np.round(x, 3) for x in metric_manager_callback.meters['EO'].compute_overall()]

[0.128, 0.116, 0.182]

In [11]:
np.mean(metric_manager_callback.meters['EO'].compute_overall())

0.1421684214233889

In [12]:
[np.round(x, 3) for x in metric_manager_callback.meters['DP'].compute_overall()]

[0.085, 0.046, 0.066]

In [13]:
np.mean(metric_manager_callback.meters['DP'].compute_overall())

0.0656115636874489