In [1]:
import yaml
import os

dualda_path = os.path.abspath(os.path.join(os.getcwd(), os.pardir))

In [None]:
base_config={
                    "validation_size": 2000,
                    "accuracy": True,
                    "num_batches_per_file":1,
                    "start_file": 0,
                    "num_files":1
                    "hf_id":"MoritzWeckbecker/gpt2-large_ag-news_full"
}

In [8]:
def create_config_local(config, config_name):
    if "dualda" in config["xai_method"]:
        config['start_file']=0
        config['num_files']=38 if config["dataset_name"]=="ag_news" else 10
    config['device'] = "cuda"
    config['data_root'] = f"{dualda_path}/../Datasets"
    config['model_path']=f"{dualda_path}/checkpoints/{config['dataset_name']}/{config['dataset_type']}/{config['dataset_name']}_{config['model_name']}_best"
    config['grad_dir'] = f"{dualda_path}/cache/{config['dataset_name']}/{config['dataset_type']}/tracin/best"
    config['features_dir'] = f"{dualda_path}/cache/{config['dataset_name']}/{config['dataset_type']}/features"
    dir_id=config['xai_method']
    if config['C'] is not None:
        if config["xai_method"] != "representer":
            dir_id=f"{dir_id}_{config['C']}"
    save_dir_id=dir_id
    if config['xai_method']=="representer":
        save_dir_id=f"{dir_id}_{config['C']}"
    if config['xai_method'] == "graddot":
        dir_id="tracin"
        save_dir_id="tracin"
    config['cache_dir'] = f"{dualda_path}/cache/{config['dataset_name']}/{config['dataset_type']}/{dir_id}"
    config['save_dir'] = f"{dualda_path}/explanations/{config['dataset_name']}/{config['dataset_type']}/{save_dir_id}"
    save_config(config, config_name, "local")

def save_config(config, config_name, config_type):
    config = {k: v for k, v in config.items() if v is not None}
    
    path = f"{config_type}/explain/{config['dataset_name']}"
    os.makedirs(path, exist_ok=True)

    with open(f"{path}/{config_name}.yaml", "w") as outfile:
        yaml.dump(config, outfile, default_flow_style=False)
        
def create_config_cluster(config, config_name):
    config['device']="cuda"
    config['data_root'] = "/mnt/dataset/"
    config['save_dir'] = "/mnt/outputs/"
    config['num_files'] = 100
    dir_id=config['xai_method']
    if config['C'] is not None:
        dir_id=f"{dir_id}_{config['C']}"
    if config['xai_method'] == "graddot":
        dir_id="tracin"
        
    config['cache_dir'] = f"/mnt/cache/{config['dataset_name']}/{config['dataset_type']}/{dir_id}"
    config['grad_dir'] = f"/mnt/cache/{config['dataset_name']}/{config['dataset_type']}/tracin/best"
    config['features_dir'] = f"/mnt/cache/{config['dataset_name']}/{config['dataset_type']}/features"
    config['model_path']=f"/mnt/checkpoints/{config['dataset_name']}/{config['dataset_type']}/{config['dataset_name']}_{config['model_name']}_best"
    save_config(config, config_name, "cluster")

def create_config_cluster_parallel(config, config_name):
    config['device']="cuda"
    config['data_root'] = "/mnt/dataset/"
    config['save_dir'] = "/mnt/outputs/"
    dir_id=config['xai_method']
    if config['C'] is not None:
        if config["xai_method"] != "representer":
            dir_id=f"{dir_id}_{config['C']}"

    if config['xai_method'] == "graddot":
        dir_id="tracin"
    config['cache_dir'] = f"/mnt/cache/{config['dataset_name']}/{config['dataset_type']}/{dir_id}"
    config['grad_dir'] = f"/mnt/cache/{config['dataset_name']}/{config['dataset_type']}/tracin/best"
    config['features_dir'] = f"/mnt/cache/{config['dataset_name']}/{config['dataset_type']}/features"
    config['model_path']=f"/mnt/checkpoints/{config['dataset_name']}/{config['dataset_type']}/{config['dataset_name']}_{config['model_name']}_best"

    for j in range(config['validation_size'] // (config['num_files']*config['num_batches_per_file']*config['batch_size'])):
        config['start_file'] = j
        save_config(config, f"{j:02d}_{config_name}", "cluster")

In [9]:
dsname_list = ["MNIST","CIFAR", "AWA",]
dsname_list = ["ag_news"]
dstype_list = ["std", "group", "mark"]#, "switched"]
dstype_list = ["std"]
xai_method_list = ["dualda", "kronfluence", "lissa"]
#xai_method_list = ['lissa']
#xai_method_list = ["graddot", "tracin"]
C_value_list = [1e-05,1e-03,1e-01]


model_dict = {"MNIST": "basic_conv", "CIFAR": "resnet18", "AWA": "resnet50", "tweet_sentiment_extraction":None,  "ag_news":None }
num_classes_dict = {"MNIST": 10, "CIFAR": 10, "AWA": 50, "tweet_sentiment_extraction":None, "ag_news":None}
batch_size_dict={"representer": 200,
                 "rp_similarity": 20,
                 "tracin": 20, 
                 "trak": 20,
                 "dualda": 200, 
                 "graddot": 20,
                 "lissa": 5
                 }

In [None]:
from copy import deepcopy
for dsname in dsname_list:
    print(f"Creating config files for dataset {dsname}...")
    for dstype in dstype_list:
        for xai_method in xai_method_list:
            config = deepcopy(base_config)
            config['dataset_name'] = dsname
            config['model_name'] = model_dict[dsname]
            config['num_classes'] = num_classes_dict[dsname]

            config['dataset_type'] = dstype
            if dstype in ["group", "groupk"]:
                config['class_groups'] = [[2*i, 2*i+1] for i in range(config['num_classes'] // 2)]
            else:
                config['class_groups'] = None

            config['xai_method'] = xai_method
            config['batch_size'] = batch_size_dict.get(xai_method, 20)

            if xai_method=="dualda":
                for C_value in C_value_list:
                    config['C'] = C_value
                    config_filename = f"{dsname}_{dstype}_{xai_method}_{C_value}"
                    #create_config_cluster(config, config_filename)
                    create_config_cluster_parallel(deepcopy(config), config_filename)
                    create_config_local(deepcopy(config), config_filename)
            elif xai_method=="representer":
                for sparsity_value in C_value_list:
                    config["C"]=sparsity_value
                    config_filename = f"{dsname}_{dstype}_{xai_method}"
                    if sparsity_value!=0.:
                       config_filename = f"{config_filename}_{sparsity_value}"
                    create_config_cluster_parallel(config, config_filename)
                    create_config_local(config, config_filename)
            else:
                config['C'] = None
                config_filename = f"{dsname}_{dstype}_{xai_method}"
                #create_config_cluster(config, config_filename)
                create_config_cluster_parallel(config, config_filename)
                create_config_local(config, config_filename)

Creating config files for dataset ag_news...
