In [1]:
import torch
import numpy as np

import random
import copy
import os
import pickle

import cl_gym as cl
from metrics import MetricCollector2, FairMetricCollector
from configs import make_params

EER_dataset = ["MNIST", "FashionMNIST", "CIFAR10", "CIFAR100"]

In [74]:
log_path = f"scripts_output"
out_path = f"outputs"

def is_number(value):
    flag=True
    try:
        num=float(value)
        flag = num == num
    except ValueError:
        flag=False
    return flag

def load(dataset, seed, epoch, lr, tau, alpha, lmbd, method, fair_metric, randinx = False, verbose=2):
    path = f"dataset={dataset}"
    if randinx:
        path+="_randidx"
    if method != "FSW":
        fair_metric = "no_metrics"
    path = os.path.join(path, method)
    path = os.path.join(path, fair_metric)

    runs = list()
    for d in os.listdir(os.path.join(log_path, path)):
        info_dict = dict()
        info_dict['path'] = d
        for elem in d.split("_"):
            k, v = elem.split("=")
            info_dict[k] = int(v) if v.isdigit() else float(v) if is_number(v) else v
            info_dict['path'] = d
        runs.append(info_dict)

    avails = list()
    for run in runs:
        seed_flag = seed == run.get("seed", None)
        epoch_flag = epoch == run.get("epoch", None)
        lr_flag = lr == run.get("lr", None)
        tau_flag = tau == run.get("tau", 0)
        alpha_flag = alpha == run.get("alpha", 0)
        lmbd_flag = lmbd == run.get("lmbd", 0)
        flag = seed_flag & epoch_flag & lr_flag & tau_flag & alpha_flag & lmbd_flag
        if flag:
            avails.append(run)
    
    if len(avails) == 0:
        if verbose > 2:
            print(runs)
        return False
    
    path = os.path.join(path, run['path'])

    out = os.path.join(os.path.join(log_path, path), "log.out")
    err = os.path.join(os.path.join(log_path, path), "log.err")
    if os.path.exists(err):
        with open(err, "r") as f:
            lines = f.readlines()
        if len(lines):
            if verbose:
                print(f"Error in {err} - error during running")
                for line in lines:
                    print(line)
            return False
    else:
        if verbose:
            print(f"error in {os.path.join(log_path, path)} - not exists")
        return False
    if os.path.exists(out):
        with open(out, "r") as f:
            lines = f.readlines()
        if len(lines) == 0:
            if verbose:
                print(f"{os.path.join(log_path, path)} - currently running")
            return False
    
    return path

def print_log(path, option='err'):
    out = os.path.join(os.path.join(log_path, path), "log.out")
    err = os.path.join(os.path.join(log_path, path), "log.err")    
    if option == "err":
        if os.path.exists(err):
            with open(out, "r") as f:
                lines = f.readlines()
                print(lines)
    elif option == "out":
        if os.path.exists(out):
            with open(out, "r") as f:
                lines = f.readlines()
                print(lines)

def load_metrics(path, verbose=0):
    metrics_dir = os.path.join(os.path.join(out_path, path), "metrics/metrics.pickle")
    with open(metrics_dir, "rb") as f:
        metric_manager_callback = pickle.load(f)

    if verbose>0:
        print(f"{metric_manager_callback.meters.keys()}")
    return metric_manager_callback

In [83]:
def get_best(dataset, seed_range, epoch, lr_range, tau_range, alpha_range, lambda_range, \
             method, randinx = False, verbose=2):
    randinx = False
    acc_list = list()
    fair_list = list()
    info_list = list()
    disp_metric = "EO"
    if dataset in EER_dataset:
        disp_metric = "EER"
    # MNIST
    metric = disp_metric
    if method in ['joint', 'vanilla', 'finetune']:
        metric = "no_metrics"
    for tau in tau_range:
        for lr in lr_range:
            for alpha in alpha_range:
                for lmbd in lambda_range:
                    cnt = 0
                    acc_sum = 0
                    fair_sum = 0
                    # print(f"{info=}")
                    avail_seed = copy.deepcopy(seed_range)
                    for seed in seed_range:
                        path = load(dataset, seed, epoch, lr, tau, alpha, lmbd, method, metric, randinx=randinx)
                        print(f"{path=}")
                        if not path:
                            if verbose>1:
                                print("Remove")
                                print(f"{dataset=}, {randinx=}, {method=}, {metric=}")
                                print(f"{seed=}, {epoch=}, {lr=}, {tau=}, {alpha=}, {lmbd=}")
                            avail_seed.remove(seed)
                            # print(seed)
                            # print(target_dir)
                            continue
                        mmc = load_metrics(path, verbose=verbose-1)
                        
                        acc = np.mean(mmc.meters['accuracy'].compute_overall())
                        fair = np.mean(mmc.meters[disp_metric].compute_overall())
                        acc_sum+=acc
                        fair_sum+=fair
                        cnt+=1
                    info = f"lr={lr}_tau={tau}_alpha_{alpha}({cnt=}, {avail_seed})"
                    if cnt == 0:
                        print(f"{info=}: check if boom?")
                        continue
                    if not cnt:
                        info_list.append(info)
                        acc_list.append(acc_sum/cnt)
                        fair_list.append(fair_sum/cnt)
                        out = f"{info}\n{acc_sum/cnt}\n{fair_sum/cnt}"
                        if verbose:
                            print(out)
    if verbose:
        print()

    # 대충 정한 measure
    integrated_score = [e - 2*fair_list[i] for i, e in enumerate(acc_list)]
    idx = integrated_score.index(max(integrated_score))
    print(f"{info_list[idx]}")
    accuracy = acc_list[idx]
    fairness = fair_list[idx]

    if accuracy > 1:
        accuracy /= 100
    print(f"acc:{accuracy}")
    print(f"fair:{fairness}")
    return info_list, acc_list, fair_list


In [84]:
import matplotlib.pyplot as plt

def plot(acc_list, fair_list, marker = "o"):
    plt.scatter(acc_list, fair_list, marker=marker, s = 10)
    plt.xlabel('acc')
    plt.ylabel('fairness')

In [85]:
def display(dataset, epoch, size = 20):
    print("joint")
    tau_range = [0.0]
    # lr_range = [0.01, 0.001]
    lr_range = [0.01]
    alpha_range = [0.0]
    lambda_range = [0.0]
    seed_range = [0, 1, 2, 3, 4]

    joint_info_list, joint_acc_list, joint_fair_list = get_best(dataset, seed_range, epoch, lr_range, tau_range, alpha_range, lambda_range, "joint")
    plt.scatter(joint_acc_list, joint_fair_list, marker='o', s = size)

    print("finetune")
    tau_range = [0.0]
    alpha_range = [0.0]
    lambda_range = [0.0]
    seed_range = [0, 1, 2, 3, 4]
    finetune_info_list, finetune_acc_list, finetune_fair_list = get_best(dataset, seed_range, epoch, lr_range, tau_range, alpha_range, lambda_range, "finetune")
    # plt.scatter(finetune_acc_list, finetune_fair_list, marker='x', s = size)


    print("vanilla")
    tau_range = [1.0, 5.0, 10.0]
    alpha_range = [0.0]
    lambda_range = [1.0, 5.0, 10.0]
    seed_range = [0, 1, 2, 3, 4]

    base_info_list, base_acc_list, base_fair_list = get_best(dataset, seed_range, epoch, lr_range, tau_range, alpha_range, lambda_range, "vanilla")
    plt.scatter(base_acc_list, base_fair_list, marker='v', s = size)

    print("FSW")
    tau_range = [1.0, 5.0, 10.0]
    alpha_range = [0.0005, 0.001, 0.002, 0.005, 0.01, 0.02]
    lambda_range = [1.0, 5.0, 10.0]
    seed_range = [0, 1, 2, 3, 4]

    fss_info_list, fss_acc_list, fss_fair_list = get_best(dataset, seed_range, epoch, lr_range, tau_range, alpha_range, lambda_range, "FSW")
    for i, e in enumerate(fss_info_list):
        print(e, fss_acc_list[i], fss_fair_list[i])
    plt.scatter(fss_acc_list, fss_fair_list, marker='x', s = size)

    target_baseline = [e - 2*base_fair_list[i] for i, e in enumerate(base_acc_list)]
    idx = target_baseline.index(max(target_baseline))
    base_fair_max = base_fair_list[idx]
    base_acc_min = base_acc_list[idx]
    plt.xlim([base_acc_min-5, 100])
    if dataset != "BiasedMNIST":
        plt.ylim([0, base_fair_max+5])
    else:
        plt.ylim([0, base_fair_max+0.15])

    plt.xlabel('acc')
    if dataset in EER_dataset:
        plt.ylabel('std')
    else:
        plt.ylabel('eo')

In [86]:
dataset = "MNIST"
epoch = 5
display(dataset, epoch)

joint
[0, 1, 2, 3, 4]
path='dataset=MNIST/joint/no_metrics/seed=4_epoch=5_lr=0.01'
dict_keys(['accuracy', 'std', 'forgetting', 'loss'])


KeyError: 'EER'

In [None]:
dataset = "MNIST"
epoch = 5
display(dataset, epoch)

In [7]:
dataset = "MNIST"
epoch = 15
display(dataset, epoch)

joint
info='tau=0.0/lr_0.01/alpha_0.0(cnt=0, [])': check if boom?


ValueError: max() arg is an empty sequence

In [None]:
print("gss")
dataset = "MNIST"
epoch = 1
tau_range = [1.0, 5.0, 10.0]
alpha_range = [0.0]
lr_range = [0.01, 0.001]
lambda_range = [0.0]
seed_range = [0, 1, 2]

base_info_list, base_acc_list, base_fair_list = get_best(dataset, seed_range, epoch, lr_range, tau_range, alpha_range, lambda_range)
plt.scatter(base_acc_list, base_fair_list, marker='v', s = 10)


In [None]:
tau_range = [1.0, 5.0, 10.0]
lr_range = [0.01, 0.001]
alpha_range = [0.0005, 0.001, 0.002, 0.005, 0.01, 0.02]
lambda_range = [1.0, 5.0, 10.0]
seed_range = [0, 1, 2, 3, 4]


info_list, acc_list, fair_list = get_best(dataset, seed_range, epoch, lr_range, tau_range, alpha_range, lambda_range)
plot(acc_list, fair_list)

In [None]:
tau_range = [1.0, 5.0, 10.0]
lr_range = [0.01, 0.001]
alpha_range = [0.0]
lambda_range = [1.0, 5.0, 10.0]
seed_range = [0, 1, 2, 3, 4]

info_list, acc_list, fair_list = get_best(dataset, seed_range, epoch, lr_range, tau_range, alpha_range, lambda_range)
plot(acc_list, fair_list)

In [None]:
tau_range = [0.0]
lr_range = [0.01, 0.001]
alpha_range = [0.0]
seed_range = [0, 1, 2, 3, 4]

info_list, acc_list, fair_list = get_best(dataset, seed_range, epoch, lr_range, tau_range, alpha_range, lambda_range, joint=True)
plot(acc_list, fair_list)

In [None]:
tau_range = [1.0, 5.0, 10.0]
lr_range = [0.01, 0.001]
alpha_range = [0.0005, 0.001, 0.002, 0.005, 0.01, 0.02]
lambda_range = [1.0, 5.0, 10.0]
seed_range = [0, 1, 2, 3, 4]

dataset = "FashionMNIST"
epoch = 1
lmbd = 1.0

info_list, acc_list, fair_list = get_best(dataset, seed_range, epoch, lr_range, tau_range, alpha_range, lambda_range)

In [None]:
tau_range = [1.0, 5.0, 10.0]
lr_range = [0.01, 0.001]
alpha_range = [0.0]
lambda_range = [1.0, 5.0, 10.0]
seed_range = [0, 1, 2, 3, 4]

dataset = "FashionMNIST"
epoch = 1
lmbd = 1.0

info_list, acc_list, fair_list = get_best(dataset, seed_range, epoch, lr_range, tau_range, alpha_range, lambda_range)