In [1]:
import yaml
import os

dualview_path = os.path.abspath(os.path.join(os.getcwd(), os.pardir))

In [2]:
base_config={
                    "validation_size": 2000,
                    "accuracy": True,
                    "num_batches_per_file":1,
                    "start_file": 0,
                    "num_files":100
}

In [3]:
C_value_list = [1e-1]

def create_config_local(config, config_name):
    config['device'] = "cpu"
    config['data_root'] = f"{dualview_path}/src/datasets"
    config['model_path']=f"{dualview_path}/checkpoints/{config['dataset_name']}/{config['dataset_type']}/{config['model_name']}_{config['dataset_type']}/{config['dataset_name']}_{config['model_name']}"
    if config['xai_method'] == "dualview":
        for C_value in C_value_list:
            config['save_dir'] = f"{dualview_path}/explanations/{config['dataset_name']}/{config['dataset_type']}/{config['model_name']}_{config['dataset_type']}/{config['xai_method']}_{C_value}"
    else:
        config['save_dir'] = f"{dualview_path}/explanations/{config['dataset_name']}/{config['dataset_type']}/{config['model_name']}_{config['dataset_type']}/{config['xai_method']}"    

    config = {k: v for k, v in config.items() if v is not None}
    
    path = f"local/explain/{config['dataset_name']}"
    os.makedirs(path, exist_ok=True)

    with open(f"{path}/{config_name}.yaml", "w") as outfile:
        yaml.dump(config, outfile, default_flow_style=False)
        
def create_config_cluster(config, config_name):
    config['device']="cuda"
    config['data_root'] = "/mnt/dataset/"
    config['save_dir'] = "/mnt/outputs/"
    config['model_path']=f"/mnt/checkpoints/{config['dataset_name']}/{config['dataset_type']}/{config['model_name']}_{config['dataset_type']}/{config['dataset_name']}_{config['model_name']}"

    config = {k: v for k, v in config.items() if v is not None}

    path = f"cluster/explain/{config['dataset_name']}"
    os.makedirs(path, exist_ok=True)
    
    with open(f"{path}/{config_name}.yaml", "w") as outfile:
        yaml.dump(config, outfile, default_flow_style=False)

def create_config_cluster_parallel(config, config_name):
    config['device']="cuda"
    config['data_root'] = "/mnt/dataset/"
    config['save_dir'] = "/mnt/outputs/"
    config['model_path']=f"/mnt/checkpoints/{config['dataset_name']}/{config['dataset_type']}/{config['model_name']}_{config['dataset_type']}/{config['dataset_name']}_{config['model_name']}"

    config['num_files'] = 1

    config = {k: v for k, v in config.items() if v is not None}

    path = f"cluster/explain/{config['dataset_name']}"
    os.makedirs(path, exist_ok=True)

    for j in range(config['validation_size'] // config['batch_size']):
        config['start_file'] = j
    
        with open(f"{path}/{j:02d}_{config_name}.yaml", "w") as outfile:
            yaml.dump(config, outfile, default_flow_style=False)

In [4]:
dsname_list = ["MNIST", "CIFAR", "AWA"]
dstype_list = ["std", "group", "corrupt", "mark", "switched"]
xai_method_list = ["representer", "rp_similarity", "tracin", "trak", "dualview", "graddot", "influence"]
C_value_list = [1e-1]

model_dict = {"MNIST": "basic_conv", "CIFAR": "resnet18", "AWA": "resnet50"}
num_classes_dict = {"MNIST": 10, "CIFAR": 10, "AWA": 50}
batch_size_dict={"representer": 20,
                 "rp_similarity": 20,
                 "tracin": 20, 
                 "trak": 20,
                 "dualview": 20, 
                 "graddot": 20,
                 "influence": 20
                 }

In [5]:
for dsname in dsname_list:
    base_config['dataset_name'] = dsname
    base_config['model_name'] = model_dict[dsname]
    base_config['num_classes'] = num_classes_dict[dsname]
    print(f"Creating config files for dataset {dsname}...")

    for dstype in dstype_list:
        base_config['dataset_type'] = dstype
        if dstype in ["group", "groupk"]:
            base_config['class_groups'] = [[2*i,2*i+1] for i in range(base_config['num_classes'] // 2)]
        else:
            base_config['class_groups'] = None

        for xai_method in xai_method_list:
            base_config['xai_method'] = xai_method
            base_config['batch_size'] = batch_size_dict.get(xai_method, 20)

            if xai_method == "dualview":
                for C_value in C_value_list:
                    base_config['C'] = C_value

                    config_filename = f"{dsname}_{dstype}_{xai_method}_{C_value}"
                    #create_config_cluster(base_config, config_filename)
                    create_config_cluster_parallel(base_config, config_filename)
                    create_config_local(base_config, config_filename)

            else:
                base_config['C'] = None
                config_filename = f"{dsname}_{dstype}_{xai_method}"
                #create_config_cluster(base_config, config_filename)
                create_config_cluster_parallel(base_config, config_filename)
                create_config_local(base_config, config_filename)

Creating config files for dataset MNIST...
Creating config files for dataset CIFAR...
Creating config files for dataset AWA...
