In [1]:
import os

In [2]:
# param_path = '/pscratch/sd/j/joeschm/NSXTU_discharges/132588/r_0.736_q4_MTM_mode/convergence_check/nz0_hpyz_edgeopt_scans/parameters'
param_path = '/global/homes/j/joeschm/tools/GENE_sim_tools/GENE_sim_writer/tests/parameters'

# Param IO

In [None]:
def read_file(file_name):
    file_path = os.path.join(param_path, file_name)
    with open(file_path, 'r') as file:
        file_content = file.read()
    return file_content



def write_file(file_content, i:int=0):
    dir_name = f"param_batch/test_{i}"
    if not os.path.exists(dir_name):
        os.makedirs(dir_name)
    file_path = os.path.join(dir_name, 'parameters')
    with open(file_path, 'w') as file:
        file.write(file_content)


# Modify param files

In [130]:

# Clean up param file functions

def remove_scan_line(file_content):
    processed_lines = []

    lines = file_content.split('\n')
    for line in lines:
        if '!scan' in line:
            index = line.index('!scan')
            processed_line = line[:index]
            processed_lines.append(processed_line.rstrip())
        else:
            processed_lines.append(line.rstrip())
    return '\n'.join(processed_lines)




# Modify param values functions


def convert_type(var_value):
    try:
        return int(var_value)
    except ValueError:
        try:
            return float(var_value)
        except ValueError:
            return var_value


def change_var_value(file_content, input_string, value):
    lines = file_content.split('\n')
    for i, line in enumerate(lines):
        if ('=' in line) and (not line.strip().startswith('!')):
            line_parts = line.split('!')
            var_name_line = line_parts[0]

            var_name, var_value = var_name_line.split('=')
            var_value = convert_type(var_value.strip())
            
            # print(var_value, type(var_value))
            if (var_name.strip() == input_string):
                if (type(value) == type(var_value)):
                    lines[i] = var_name + ' = ' + str(value)
                else:
                    override_input = input(f"Value {value} is not of the same type as the original variable {var_name}. Would you like to override the type check? (y/n): ")
                    if override_input.lower() == 'y':
                        lines[i] = var_name + ' = ' + str(value)
                    else:
                        raise ValueError("Type mismatch. Exiting program.")

    return '\n'.join(lines)







# Scan line modification functions


def append_scan_line(file_content, input_string, values, scan_type='scanlist'):
    scan_added = False
    processed_lines = []

    if not scan_type.startswith('!'):
        scan_type = '!' + scan_type


    lines = file_content.split('\n')
    for line in lines:
        if '=' in line and line.split('=')[0].strip() == input_string:
            line_parts = line.split('!')
            if len(line_parts) > 1:
                # If there's a comment, insert the scan before it
                processed_line = line_parts[0].rstrip() + ' ' + scan_type + ': ' + ', '.join(map(str, values)) + ' !' + '!'.join(line_parts[1:])
            else:
                # If there's no comment, just append the scan at the end
                processed_line = line.rstrip() + ' ' + scan_type + ': ' + ', '.join(map(str, values))
            processed_lines.append(processed_line)
            scan_added = True
        else:
            processed_lines.append(line.rstrip())
    return '\n'.join(processed_lines), scan_added










def create_new_scan_line(file_content, input_string, values, scan_type='!scanlist', section_ind_dict=None):
    file_lines = []
    sections = []
    current_section = None

    if scan_type not in ['!scanlist', '!scan']:
        raise ValueError("Invalid scan type. Choose either !scanlist or !scan")


    lines = file_content.split('\n')
    for line in lines:
        file_lines.append(line)
        if line.strip().startswith('&'):
            current_section = line.strip()
            sections.append(current_section)
            
    
    if section_ind_dict is None:
        section_indices = [f"{index}: {section}" for index, section in enumerate(sections)]
        print("Sections:")
        print("\n".join(section_indices))
        
        section = input(f'Enter the index of the section where you want to add the {input_string} scan: ')
        section_ind_dict = {input_string: section}
    else:
        section = section_ind_dict.get(input_string)
        if section is None:
            section_indices = [f"{index}: {section}" for index, section in enumerate(sections)]
            print("Sections:")
            print("\n".join(section_indices))
            
            section = input(f'Enter the index of the section where you want to add the {input_string} scan: ')
            section_ind_dict[input_string] = section
        


    try:
        section_index = int(section)
        if section_index < 0 or section_index >= len(sections):
            raise ValueError("Invalid section index")
        else:
            # Find the index of the section in the processed_lines list
            section_line_index = file_lines.index(sections[section_index])
            # Insert the key with a default value of 0 right after the section divider
            processed_line = f'{input_string} = 0' + ' ' + scan_type + ': ' + ', '.join(map(str, values))
            file_lines.insert(section_line_index + 1, processed_line)
    except ValueError:
        raise ValueError("Invalid section index")

    return '\n'.join(file_lines), section_ind_dict

In [131]:

def get_parallel_sim(parallel_sim, scan_type, key, value):
    if scan_type == 'scanlist':
        parallel_sim *= len(value['values'])
    elif scan_type == 'scan':

        if parallel_sim == 1:                
            parallel_sim = len(value['values'])
        elif parallel_sim != len(value['values']):
            raise ValueError(f"Check '{key}' and ensure all scan values have the same length for 'scan'.")
        else:
            parallel_sim = len(value['values'])

    return parallel_sim






import itertools



def create_param_scan_matrix(scan_dict, scan_type, auto_name_diagdir_path=None, debug:bool=False):
    ind_param_names = []
    ind_param_values = []

    param_scan_dict = {}
    parallel_sim = 1

    for key, value in scan_dict.items():
        split_param_bool = value.get('split_param', False)

        if split_param_bool:
            ind_param_names.append(key)
            ind_param_values.append(value['values'])
        else:
            param_scan_dict[key] = value['values']
            parallel_sim = get_parallel_sim(parallel_sim, scan_type, key, value)
            


    
    # Create a list of dictionaries, where each dictionary represents a point in the matrix
    # The keys of the dictionary are the variable names, and the values are the coordinates
    product = list(itertools.product(*ind_param_values))
    param_file_values_list = [dict(zip(ind_param_names, coordinates)) for coordinates in product]


    # Add additional parameters to files for running
    for param_file_values in param_file_values_list:
        param_file_values['n_parallel_sims'] = parallel_sim

        ignore_names = ['n_parallel_sims', 'geomfile']
        if auto_name_diagdir_path:
            auto_name_dir = '_'.join([f"{key}-{value}" for key, value in param_file_values.items() if key not in ignore_names])
            diagdir_path = os.path.join(auto_name_diagdir_path, auto_name_dir)
            param_file_values['diagdir'] = diagdir_path



    if param_file_values_list[0] == {}:
        param_file_values_list = []

    if debug:
        print(len(param_file_values_list))
        print("Adding scan parameters for:", '\n',  param_scan_dict, '\n')

        for dict_param in param_file_values_list:
            print("Changing param values:", dict_param)
            


    return param_file_values_list, param_scan_dict


In [132]:

scan_type = 'scanlist'

scan_dict = {
    'nz0': {'values': [1,2,3], 'split_param': True},
    'nv0': {'values': [3,1,2, 4], 'split_param': False},
    'kymin': {'values': [20,30,40,4]},
    'theta0': {'values': [10,20,34]},
    'geomfile': {'values': ['g1','g2','g3'], 'split_param':True}
}


# scan_dict = {
#     'nz0': {'values': [1,2], 'split_param': True},
#     'nv0': {'values': [1,2,3], 'split_param': True}
# }

param_file_values_list, param_scan_dict = create_param_scan_matrix(scan_dict, scan_type, '/dsa')

print(len(param_file_values_list))
print("Adding scan parameters for each unique parameters file:", '\n',  param_scan_dict, '\n')

for dict_param in param_file_values_list:
    print("Changing parameter values:", dict_param)



9
Adding scan parameters for each unique parameters file: 
 {'nv0': [3, 1, 2, 4], 'kymin': [20, 30, 40, 4], 'theta0': [10, 20, 34]} 

Changing parameter values: {'nz0': 1, 'geomfile': 'g1', 'n_parallel_sims': 48, 'diagdir': '/dsa/nz0-1'}
Changing parameter values: {'nz0': 1, 'geomfile': 'g2', 'n_parallel_sims': 48, 'diagdir': '/dsa/nz0-1'}
Changing parameter values: {'nz0': 1, 'geomfile': 'g3', 'n_parallel_sims': 48, 'diagdir': '/dsa/nz0-1'}
Changing parameter values: {'nz0': 2, 'geomfile': 'g1', 'n_parallel_sims': 48, 'diagdir': '/dsa/nz0-2'}
Changing parameter values: {'nz0': 2, 'geomfile': 'g2', 'n_parallel_sims': 48, 'diagdir': '/dsa/nz0-2'}
Changing parameter values: {'nz0': 2, 'geomfile': 'g3', 'n_parallel_sims': 48, 'diagdir': '/dsa/nz0-2'}
Changing parameter values: {'nz0': 3, 'geomfile': 'g1', 'n_parallel_sims': 48, 'diagdir': '/dsa/nz0-3'}
Changing parameter values: {'nz0': 3, 'geomfile': 'g2', 'n_parallel_sims': 48, 'diagdir': '/dsa/nz0-3'}
Changing parameter values: {'nz0':

In [136]:

def change_param_values(param_file_content, dict_param):
    for key, value in dict_param.items():
        if isinstance(value, str):
            value = f"'{value}'"
        param_file_content = change_var_value(param_file_content, key, value)
    return param_file_content





def add_scan_lines_to_param_file(param_file_content, param_scan_dict, scan_type='!scanlist',section_ind_dict=None):

    for param_var, param_values in param_scan_dict.items():
        
        param_file_content, scan_line_added = append_scan_line(param_file_content, param_var, param_values, scan_type=scan_type)

        if not scan_line_added:
            param_file_content, section_ind_dict = create_new_scan_line(param_file_content, param_var, param_values, scan_type=scan_type, section_ind_dict=section_ind_dict)

    return param_file_content, section_ind_dict
    


In [137]:


def create_param_files(param_path, scan_dict, scan_type='scanlist'):

    file_content = read_file(param_path)
    file_mod_content = remove_scan_line(file_content)

    param_file_values_list, param_scan_dict = create_param_scan_matrix(scan_dict, scan_type)
    section_ind_dict = None

    if len(param_file_values_list) == 0:
        new_param_file = add_scan_lines_to_param_file(file_mod_content, param_scan_dict)
        write_file(0, new_param_file)
    else:
        
        for param_ind, dict_param in enumerate(param_file_values_list):

            new_param_file = change_param_values(file_mod_content, dict_param)

            if param_scan_dict != {}:
                section_ind_dict = section_ind_dict if section_ind_dict is not None else None
                new_param_file, section_ind_dict = add_scan_lines_to_param_file(new_param_file, param_scan_dict, section_ind_dict=section_ind_dict)
        
            write_file(new_param_file, i=param_ind)


In [138]:
scan_dict = {
    'nz0': {'values': [1,2,3], 'split_param': True},
    'nv0': {'values': [3,1,2, 4], 'split_param': False},
    'kymin': {'values': [20,30,40,4]},
    'theta0': {'values': [10,20,34]},
    'geomfile': {'values': ['g1','g2','g3'], 'split_param':True}
}

create_param_files(param_path, scan_dict, scan_type='scanlist')

1
'g1'
48
Sections:
0: &parallelization
1: &box
2: &in_out
3: &general
4: &geometry
5: &species
6: &species
7: &species
8: &units
9: &scan
1
'g2'
48
1
'g3'
48
2
'g1'
48
2
'g2'
48
2
'g3'
48
3
'g1'
48
3
'g2'
48
3
'g3'
48
