In [20]:
import yaml
import os

dualview_path = os.path.abspath(os.path.join(os.getcwd(), os.pardir))

In [21]:
base_config={
                   'base_epoch': 0,
                   'batch_size': 64,
                   'epochs': 200,
                   'save_each': 10,
                   'num_batches_eval': 200,
                   'validation_size': 2000
}

In [22]:
def create_config_local(config, config_name):
    config['device'] = 'cpu'
    config['data_root'] = f'{dualview_path}/src/datasets'
    config['save_dir'] = f'{dualview_path}/checkpoints/{config['dataset_name']}/{config['dataset_type']}/{config['model_name']}_{config['dataset_type']}'
    config['epochs'] = 2
    config['save_each'] = 1
    path = f'local/train/{config['dataset_name']}'
    os.makedirs(path, exist_ok=True)
    
    with open(f'{path}/{config_name}.yaml', 'w') as outfile:
        yaml.dump(config, outfile, default_flow_style=False)
        
def create_config_cluster(config, config_name):
    config['device']='cuda'
    config['data_root'] = '/mnt/dataset/'
    config['save_dir'] = '/mnt/outputs/'
    config['epochs'] = 200
    config['save_each'] = 10
    
    path = f'cluster/train/{config['dataset_name']}'
    os.makedirs(path, exist_ok=True)
    
    with open(f'{path}/{config_name}.yaml', 'w') as outfile:
        yaml.dump(config, outfile, default_flow_style=False)

In [23]:
dsname_list = ['MNIST', 'CIFAR', 'AWA']
dstype_list = ['std'] #['std', 'group', 'corrupt', 'mark']
lr_list = [1e-5, 1e-4] #[1e-5, 5e-5, 1e-4, 5e-4]
momentum_list = [0, 0.9] #[0, 0.9, 0.99]
weight_decay_list = [0, 1e-4, 1e-2] #[0, 1e-5, 1e-4, 1e-3, 1e-2]
optimizer_list = ['sgd', 'adam', 'rmsprop']
scheduler_list = ['constant'] #['constant', 'step', 'annealing']
loss_list = ['cross_entropy', 'hinge']
augmentation_list = [(a + b + c + d)[:-1] for a in ['crop_', ''] for b in ['flip_', ''] for c in ['eq_', ''] for d in ['rotate_', '']] #'cifar', 'imagenet'

model_dict = {'MNIST': 'basic_conv', 'CIFAR': 'resnet18', 'AWA': 'resnet50'} #remove this
num_classes_dict = {'MNIST': 10, 'CIFAR': 10, 'AWA': 50}

In [24]:
for dsname in dsname_list:
    base_config['dataset_name'] = dsname
    base_config['model_name'] = model_dict[dsname]
    base_config['num_classes'] = num_classes_dict[dsname]
    base_config['class_groups'] = [[2*i,2*i+1] for i in range(base_config['num_classes'] // 2)]
    print(f'Creating config files for dataset {dsname}...')

    for dstype in dstype_list:
        base_config['dataset_type'] = dstype

        for lr in lr_list:
            base_config['lr'] = lr

            for momentum in momentum_list:
                base_config['momentum'] = momentum

                for weight_decay in weight_decay_list:
                    base_config['weight_decay'] = weight_decay

                    for optimizer in optimizer_list:
                        base_config['optimizer'] = optimizer

                        for scheduler in scheduler_list:
                            base_config['scheduler'] = scheduler

                            for loss in loss_list:
                                base_config['loss'] = loss

                                for augmentation in augmentation_list:
                                    if augmentation != '':
                                        base_config['augmentation'] = augmentation

                                    config_filename = f'{dsname}_{dstype}_{lr}_{momentum}_{weight_decay}_{optimizer}_{scheduler}_{loss}_{augmentation}'
                                    create_config_cluster(base_config, config_filename)
                                    create_config_local(base_config, config_filename)

Creating config files for dataset MNIST...
Creating config files for dataset CIFAR...
Creating config files for dataset AWA...
