In [1]:
import yaml
import os

dualview_path = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
data_path="~/Documents/Code/Datasets"

In [2]:
base_config={
    "validation_size": 2000,
}

training_dict_MNIST ={
    "epochs": 200,
    "loss": "cross_entropy",
    "lr": 5e-3/3.,
    "momentum": 0.9,
    "optimizer": "sgd",
    "scheduler": "constant",
    "weight_decay": 0,
    "augmentation": "crop"
}

training_dict_CIFAR ={
    "epochs": 190,
    "loss": "cross_entropy",
    "lr": 1e-3/3.,
    "momentum": 0.9,
    "optimizer": "sgd",
    "scheduler": "constant",
    "weight_decay": 0.01,
    "augmentation": "crop_flip"
}

training_dict_AWA ={
    "epochs": 30,
    "loss": "cross_entropy",
    "lr": 0.001,
    "momentum": 0.9,
    "optimizer": "sgd",
    "scheduler": "constant",
    "weight_decay": 0,
    "augmentation": "flip"
}

training_dict_empty ={
    "epochs": None,
    "loss": None,
    "lr": None,
    "momentum": None,
    "optimizer": None,
    "scheduler": None,
    "weight_decay": None,
    "augmentation": None
}

In [3]:
dstype_dict = {"std": "std", "stdk": "std", "group": "group", "groupk": "group",
               "mark": "mark", "switched": "switched_one_file"}


def save_config(config, config_name, config_type):
    config = {k: v for k, v in config.items() if v is not None}
    path = f"{config_type}/evaluate/{config['dataset_name']}"
    os.makedirs(path, exist_ok=True)
    with open(f"{path}/{config_name}.yaml", "w") as outfile:
        yaml.dump(config, outfile, default_flow_style=False)

def create_config_local(config, config_name):
    config['device'] = "cpu"
    config['data_root'] = data_path
    config['save_dir'] = f"{dualview_path}/test_output"
    config['grad_dir'] = f"{dualview_path}/cache/{config['dataset_name']}/{config['dataset_type']}/tracin/best"
    config['features_dir'] = f"{dualview_path}/cache/{config['dataset_name']}/{config['dataset_type']}/features"
    dir_id=config['xai_method']
    if config['C'] is not None:
        dir_id=f"{dir_id}_{str(config['C']).replace('-','_')}"
    if config['xai_method'] == "graddot":
        dir_id="tracin"
    config['lds_cache_dir'] = f"{dualview_path}/cache/{config['dataset_name']}/{config['metric']}"
    config['cache_dir'] = f"{dualview_path}/cache/{config['dataset_name']}/{config['dataset_type']}/{dir_id}"
    config['model_path']=f"{dualview_path}/checkpoints/{config['dataset_name']}/{config['dataset_type']}/{config['dataset_name']}_{config['model_name']}_best"
    dstype = dstype_dict.get(config['metric'], "std")
    xpl_root_id=config['xai_method'] if config["xai_method"]!="dualview" else dir_id
    config['xpl_root'] = f"{dualview_path}/explanations/{config['dataset_name']}/{dstype}/{xpl_root_id}/"

    save_config(config, config_name,"local")



def create_config_cluster(config, config_name):
    config['device']="cuda"
    config['data_root']="/mnt/dataset"
    config['save_dir'] = "/mnt/outputs/"
    config['model_path']=f"/mnt/checkpoints/{config['dataset_name']}/{config['dataset_type']}/{config['dataset_name']}_{config['model_name']}_best"
    dstype = dstype_dict.get(config['metric'], "std")
    dir_id=config['xai_method']
    if config['C'] is not None:
        dir_id=f"{dir_id}_{str(config['C']).replace('-','_')}"
    if config['xai_method'] == "graddot":
        dir_id="tracin"
    config['cache_dir'] = f"/mnt/cache/{config['dataset_name']}/{config['dataset_type']}/{dir_id}"
    config['lds_cache_dir'] = f"/mnt/cache/{config['dataset_name']}/{config['metric']}"
    config['grad_dir'] = f"/mnt/cache/{config['dataset_name']}/{config['dataset_type']}/tracin/best"
    config['features_dir'] = f"/mnt/cache/{config['dataset_name']}/{config['dataset_type']}/features"
    xpl_root_id=config['xai_method'] if config["xai_method"]!="dualview" else dir_id
    config['xpl_root'] = f"/mnt/explanations/{config['dataset_name']}/{dstype}/{xpl_root_id}/"
    save_config(config, config_name,"cluster")

In [4]:
dsname_list = ["MNIST", "CIFAR", "AWA"]
metric_list = ["std", "group", "stdk", "groupk", "corrupt", "mark", "add_batch_in", "lindatmod", "add_batch_in_neg",]#,  "leave_out", "labelflip"]
xai_method_list = ["kronfluence"]#["representer", "gradcos", "tracin", "trak", "dualview", "graddot", "lissa", "arnoldi"]
C_value_list = [1e-5, 1e-3, 1e-1]
sparsity_list = [0.85, 0.90, 0.95, 0.97, 0.99]
print([str(c).replace("-","_") for c in C_value_list])
model_dict = {"MNIST": "basic_conv", "CIFAR": "resnet18", "AWA": "resnet50"}
num_classes_dict = {"MNIST": 10, "CIFAR": 10, "AWA": 50}
batch_size_dict={"similarity": 16, "influence": 16, "tracin":16, "mcsvm": 32, "representer": 128, "kronfluence":32}
ds_type_dict={"group":"group", "groupk":"group", "corrupt": "corrupt", "mark": "mark", "switched": "switched"}

['1e_05', '0.001', '0.1']


In [None]:
from copy import deepcopy

for dsname in dsname_list:
    print(f"Creating config files for dataset {dsname}...")
    for metric in metric_list:
        for xai_method in xai_method_list:
            config = deepcopy(base_config)
            config['dataset_type']=dstype_dict.get(metric, "std")
            config['dataset_name'] = dsname
            config['model_name'] = model_dict[dsname]
            config['num_classes'] = num_classes_dict[dsname]
            config['metric'] = metric

            dstype = ds_type_dict.get(metric, 'std')
            config['dataset_type'] = dstype
            if dstype in ["group", "groupk"]:
                config['class_groups'] = [[2*i, 2*i+1] for i in range(config['num_classes'] // 2)]
            else:
                config['class_groups'] = None

            config['xai_method'] = xai_method
            config['batch_size'] = batch_size_dict.get(xai_method, 16)
            if metric in ["add_batch_in", "add_batch_in_neg", "leave_out", "only_batch", "lds", "labelflip"]:
                config.update(globals()[f'training_dict_{dsname}'])
                config["batch_size"]=8
            else:
                config.update(training_dict_empty)
            if xai_method == "dualview":
                for C_value in C_value_list:
                    config['C'] = C_value
                    config_filename = f"{dsname}_{metric}_{xai_method}_{C_value}"
                    create_config_cluster(config, config_filename)
                    create_config_local(config, config_filename)
            else:
                config['C'] = None
                config_filename = f"{dsname}_{metric}_{xai_method}"
                create_config_cluster(config, config_filename)
                create_config_local(config, config_filename)

Creating config files for dataset MNIST...
Creating config files for dataset CIFAR...
Creating config files for dataset AWA...
