In [6]:
import yaml
import os

dualview_path = os.path.abspath(os.path.join(os.getcwd(), os.pardir))

In [7]:
base_config={
    "validation_size": 2000,
}

training_dict_MNIST ={
    "epochs": 200,
    "loss": "cross_entropy",
    "lr": 5e-3,
    "momentum": 0.9,
    "optimizer": "sgd",
    "scheduler": "constant",
    "weight_decay": 0,
    "augmentation": "crop"
}

training_dict_CIFAR ={
    "epochs": 190,
    "loss": "cross_entropy",
    "lr": 1e-3,
    "momentum": 0.9,
    "optimizer": "sgd",
    "scheduler": "constant",
    "weight_decay": 0.01,
    "augmentation": "crop_flip"
}

training_dict_AWA ={
    "epochs": 50,
    "loss": "cross_entropy",
    "lr": 0.01,
    "momentum": 0.9,
    "optimizer": "sgd",
    "scheduler": "annealing",
    "weight_decay": 0,
    "augmentation": "flip"
}

training_dict_empty ={
    "epochs": None,
    "loss": None,
    "lr": None,
    "momentum": None,
    "optimizer": None,
    "scheduler": None,
    "weight_decay": None,
    "augmentation": None
}

In [8]:
dstype_dict = {"std": "std", "stdk": "std", "group": "group", "groupk": "group",
               "mark": "mark", "switched": "switched_one_file"}
C_value_list = [1e-1]

def create_config_local(config, config_name):
    config['device'] = "cpu"
    config['data_root'] = f"{dualview_path}/src/datasets"
    config['save_dir'] = f"{dualview_path}/test_output"
    if config['metric'] == 'mark':
        config['model_path']=f"{dualview_path}/checkpoints/{config['dataset_name']}/{config['dataset_type']}/{config['model_name']}_{config['dataset_type']}/{config['dataset_name']}_{config['model_name']}"
    else:
        config['model_path'] = None
    dstype = dstype_dict.get(config['metric'], "std")
    
    if config['xai_method'] == "dualview":
        for C_value in C_value_list:
            config['xpl_root'] = f"{dualview_path}/xpl/{config['dataset_name']}/{dstype}/{config['xai_method']}_{C_value}/"
    else:
        config['xpl_root'] = f"{dualview_path}/xpl/{config['dataset_name']}/{dstype}/{config['xai_method']}/"

    if config['xai_method'] in ['dualview", "representer']: #WHY COEF_ROOT? NEEDED?
        config['coef_root']=config['xpl_root']
    elif "coef_root" in config.keys():
        config.pop("coef_root")

    config = {k: v for k, v in config.items() if v is not None}
    
    path = f"local/evaluate/{config['dataset_name']}"
    os.makedirs(path, exist_ok=True)
    
    with open(f"{path}/{config_name}.yaml", "w") as outfile:
        yaml.dump(config, outfile, default_flow_style=False)

def create_config_cluster(config, config_name):
    config['device']="cuda"
    config['data_root']="/mnt/dataset"
    config['save_dir'] = "/mnt/outputs/"
    if config['metric'] == 'mark':
        config['model_path']=f"/mnt/checkpoints/{config['dataset_name']}/{config['dataset_type']}/{config['model_name']}_{config['metric']}/{config['dataset_name']}_{config['model_name']}"
    else:
        config['model_path'] = None
    dstype = dstype_dict.get(config['metric'], "std")

    if config['xai_method'] == "dualview":
        for C_value in C_value_list:
            config['xpl_root'] = f"/mnt/xpl/{config['dataset_name']}/{dstype}/{config['xai_method']}_{C_value}/"
    else:
        config['xpl_root'] = f"/mnt/xpl/{config['dataset_name']}/{dstype}/{config['xai_method']}/"

    if config['xai_method'] in ['dualview", "representer']:
        config['coef_root']=config['xpl_root']
    elif "coef_root" in config.keys():
        config.pop("coef_root")

    config = {k: v for k, v in config.items() if v is not None}
    
    path = f"cluster/evaluate/{config['dataset_name']}"
    os.makedirs(path, exist_ok=True)
    
    with open(f"{path}/{config_name}.yaml", "w") as outfile:
        yaml.dump(config, outfile, default_flow_style=False)

In [9]:
dsname_list = ["MNIST", "CIFAR", "AWA"]
metric_list = ["std", "group", "stdk", "groupk", "corrupt", "mark", "stdk", "switched", "add_batch_in", "add_batch_in_neg",
               "leave_out", "only_batch", "lds", "labelflip"]
xai_method_list = ["representer", "gradcos", "tracin", "trak", "dualview", "graddot", "influence"]
C_value_list = [1e-2, 1e-2, 1]

model_dict = {"MNIST": "basic_conv", "CIFAR": "resnet18", "AWA": "resnet50"}
num_classes_dict = {"MNIST": 10, "CIFAR": 10, "AWA": 50}
batch_size_dict={"similarity": 16, "influence": 16, "tracin":16, "mcsvm": 32, "representer": 128}
ds_type_dict={"group":"group", "groupk":"group", "corrupt": "corrupt", "mark": "mark", "switched": "switched"}

In [10]:
for dsname in dsname_list:
    base_config['dataset_name'] = dsname
    base_config['model_name'] = model_dict[dsname]
    base_config['num_classes'] = num_classes_dict[dsname]
    print(f"Creating config files for dataset {dsname}...")

    for metric in metric_list:
        base_config['metric'] = metric
        base_config['dataset_type'] = ds_type_dict.get(metric, 'std')
        if metric in ["group", "groupk"]:
            base_config['class_groups'] = [[2*i,2*i+1] for i in range(base_config['num_classes'] // 2)]
        else:
            base_config['class_groups'] = None
        if metric in ["add_batch_in", "add_batch_in_neg", "leave_out", "only_batch", "lds", "labelflip"]:
            base_config.update(globals()[f'training_dict_{dsname}'])
        else:
            base_config.update(training_dict_empty)

        for xai_method in xai_method_list:
            base_config['xai_method'] = xai_method
            base_config['batch_size'] = batch_size_dict.get(xai_method, 16)

            if xai_method == "dualview":
                for C_value in C_value_list:
                    base_config['C'] = C_value

                    config_filename = f"{dsname}_{metric}_{xai_method}_{C_value}"
                    create_config_cluster(base_config, config_filename)
                    create_config_local(base_config, config_filename)

            else:
                base_config['C'] = None
                config_filename = f"{dsname}_{metric}_{xai_method}"
                create_config_cluster(base_config, config_filename)
                create_config_local(base_config, config_filename)

Creating config files for dataset MNIST...
Creating config files for dataset CIFAR...
Creating config files for dataset AWA...
