In [8]:
import json
import numpy as np

In [9]:
def generate_parameter_list(inputs, intermediate_results=[]):
    """Generate list of parameters with inputs as dict with keys as parameter names,
    values are list of possible parameter values, and output as list of all possible dict
    with keys are parameter names, values are parameter values.

    Args:
        inputs (dict<str, list>): Dict with keys as parameter names, values are list of possible parameter values
        intermediate_results (list, optional): Intermediate results of parameter list. Defaults to [].

    Returns:
        list(dict<str, any>): List of all possible dict with keys are parameter anems, values are parameter values.
    """
    if len(inputs) == 0: #base case
        return intermediate_results
    if len(inputs) > 0: # recursive case
        new_intermediate_results = []
        key, vals = inputs.popitem()
        for val in vals:
            if len(intermediate_results) == 0: # edge case when intermediate result is empty
                new_intermediate_results.append({key: val})
            else:
                for res in intermediate_results:
                    new_intermediate_results.append({**res, key: val})
        return generate_parameter_list(inputs, intermediate_results=new_intermediate_results)

In [10]:
parameter_inputs = {
    "task": ["CXT", "2AF", "DMS"],
    "test_data_folder": ["/mnt/home/hlethi/ceph/data/feature_learning_RNN"],
    "save_folder": ["/mnt/home/hlethi/ceph/feature_learning_RNN/feature_learning_RNN_20240915"],
    "seed": [int(i) for i in np.random.choice(100, size=50, replace=False)],
}

In [11]:
# Generate parameter list
parameter_list = generate_parameter_list(parameter_inputs)
print(len(parameter_list))

150


In [12]:
# Save parameter list into json file
filename = "params/params_20240915.json"
json.dump(parameter_list, open(filename, "w"))
print(f"Save parameter list to json file at {filename}")

Save parameter list to json file at params/params_20240915.json


In [13]:
n_thread = 5
num_jobs = len(parameter_list)
job_array_length = (num_jobs // n_thread) - 1 if num_jobs % n_thread == 0 else num_jobs // n_thread
print(f"Job array length: 0-{job_array_length}")

Job array length: 0-29
