# Generate and store benchmark environments

In this notebook can be used to generate benchmark environments.

In [None]:
%load_ext autoreload
%autoreload 2
import os
import shutil

import torch.distributions as dist

from src.environments.generic_environments import *

Generate the environments.

In [None]:
# benchmark parameters
env_class = CRGraph
num_envs = 5
n_list = [5]
frac_non_intervenable_nodes= None
non_intervenable_nodes= set()
# non_intervenable_nodes = set(["X2"])
num_test_samples_per_intervention = 50
num_test_queries = 50
# interventional_queries = None
interventional_queries = [InterventionalDistributionsQuery(['X4'], {'X2':dist.Uniform(2., 5.)})]

descriptor = ''
# descriptor = '_X2'

# generation setup
env_dir = '../data/' + env_class.__name__ + descriptor + '/'  # dir where to store the generated envs
delete_existing = False  # delete existing benchmarks

# generating the benchmark envs from here on
i = 0
total_graphs = num_envs * len(n_list)
for num_nodes in n_list:
    # generate/empty folder for envs of same type
    n_dir = env_dir + f'{num_nodes}_nodes/'
    if os.path.isdir(n_dir):
        if not delete_existing:
            print('\nDirectory \'' + n_dir + '\' already exists, not generating benchmarks...')
            continue

        print('\nDirectory \'' + n_dir + '\' already exists, delete existing benchmarks...')
        for root, dirs, files in os.walk(n_dir):
            for file in files:
                os.remove(os.path.join(root, file))
            for folder in dirs:
                shutil.rmtree(os.path.join(root, folder))

    os.makedirs(n_dir, exist_ok=True)

    # generate benchmark envs
    for _ in range(num_envs):
        i = i + 1
        env = env_class(num_nodes=num_nodes,
                        frac_non_intervenable_nodes=frac_non_intervenable_nodes,
                        non_intervenable_nodes=non_intervenable_nodes,
                        num_test_samples_per_intervention=num_test_samples_per_intervention,
                        num_test_queries=num_test_queries,
                        interventional_queries=interventional_queries)
        env_path = n_dir + env.name + '.pth'
        env.save(env_path)
        print(f'\rGenerated {i}/{total_graphs} environments.', end='')


Take existing environments, restrict their set of intervenable nodes and store them seperately.

In [None]:
# benchmark parameters
env_class = CRGraph
n_list = [5]
# non_intervenable_nodes= set()
non_intervenable_nodes = set(["X2"])

# descriptor = ''
descriptor = '_X2'

# generation setup
source_env_dir = '../data/' + env_class.__name__ + '/'  # dir where origianl envs are stored
target_env_dir = '../data/' + env_class.__name__ + descriptor + '/'  # dir where to store the generated envs
delete_existing = False  # delete existing benchmarks

i = 0
for num_nodes in n_list:
    # check if source envs available
    source_n_dir = source_env_dir + f'{num_nodes}_nodes/'
    if not os.path.isdir(source_n_dir):
        print(f'Source directory {source_n_dir} does not exist!')
        continue

    # generate/empty folder for target envs
    target_n_dir = target_env_dir + f'{num_nodes}_nodes/'
    if os.path.isdir(target_n_dir):
        if not delete_existing:
            print('\nTarget directory \'' + target_n_dir + '\' already exists, not generating benchmarks...')
            continue

        print('\nTarget directory \'' + target_n_dir + '\' already exists, delete existing benchmarks...')
        for root, dirs, files in os.walk(target_n_dir):
            for file in files:
                os.remove(os.path.join(root, file))
            for folder in dirs:
                shutil.rmtree(os.path.join(root, folder))

    os.makedirs(target_n_dir, exist_ok=True)

    # load source envs
    env_files = [entry for entry in os.scandir(source_n_dir) if entry.is_file() and os.path.basename(entry)[-4:] == '.pth']
    for i, f in enumerate(env_files):
        env = env_class.load(os.path.abspath(f))
        env.non_intervenable_nodes = non_intervenable_nodes
        env.intervenable_nodes = set(env.node_labels) - env.non_intervenable_nodes
        env.save(target_n_dir + os.path.basename(f))

        print(f'\rProcessed {i+1}/{len(env_files)} environments in {source_n_dir}.', end='')
    print('')
