# Scanning Astra evaluations

In [1]:
from astra import Astra
from astra.evaluate import evaluate_astra
from tempfile import mkdtemp
from concurrent.futures import ProcessPoolExecutor as Executor
from itertools import product
import matplotlib.pyplot as plt
import numpy as np
import os
import logging
import json

In [2]:
ASTRA_INPUT_FILE = 'astra.in'
ASTRA_INPUT_DEFAULT_PARAMETERS = Astra(ASTRA_INPUT_FILE).input
ASTRA_INPUT_VARIABLE_PARAMETERS = {
    'maxb(2)': [0., 0.1, 0.01],
    'maxe(4)': [-15., 0., 0.1],
}
SCRATCH_DIR = 'tmp700uv792'
LOG_FILE = os.path.join(SCRATCH_DIR, 'evaluation_logs.log')
CACHE_FILE = os.path.join(SCRATCH_DIR, 'evaluation_cache.json')

In [3]:
logging.basicConfig(level=logging.DEBUG, 
                    format='%(asctime)s - %(levelname)s - %(message)s',
                    handlers=[
                        logging.FileHandler(LOG_FILE),
                    ])
logger = logging.getLogger(__name__)

In [4]:
def load_cache(cache_file):
    if os.path.exists(cache_file):
        with open(cache_file, 'r') as f:
            evaluation_cache = json.load(f)
        logger.info("Cache loaded successfully.")
    else:
        evaluation_cache = {}
        logger.info("No cache file found. Starting with an empty cache.")
    return evaluation_cache

In [5]:
def save_cache(evaluation_cache, cache_file):
    with open(cache_file, 'w') as f:
        json.dump(evaluation_cache, f)
    logger.info("Cache saved successfully.")

In [6]:
def find_key_recursive(d, key):
  if key in d:
    return d[key]
  for k, v in d.items():
    if type(v) is dict:
      value = find_key_recursive(v, key)
      if value is not None:
        return value


def initialize_cache_from_existing_archives(evaluation_cache, scratch_dir):
    archives = [f for f in os.listdir(scratch_dir) if f.endswith('.h5')]
    for archive in archives:
        try:
            astra = Astra.from_archive(os.path.join(scratch_dir, archive))
            settings_key = str({key: find_key_recursive(astra.input, key) for key in ASTRA_INPUT_VARIABLE_PARAMETERS})
            if settings_key not in evaluation_cache:
                evaluation_cache[settings_key] = archive
                logger.info(f"Added {archive} to cache with key: {settings_key}")
        except Exception as e:
            logger.error(f"Error processing archive {archive}: {e}")
    return evaluation_cache

In [7]:
evaluation_cache = load_cache(CACHE_FILE)

In [8]:
# evaluation_cache = initialize_cache_from_existing_archives(load_cache(CACHE_FILE), SCRATCH_DIR)

In [9]:
save_cache(evaluation_cache, CACHE_FILE)

In [10]:
def evaluate(settings, evaluation_cache):
    logger.debug(f"Evaluating settings: {settings}")

    settings_key = str(settings)

    if settings_key in evaluation_cache:
        logger.info(f"Using cached result for settings: {settings}")
    else:    
        try:
            output = evaluate_astra(settings, astra_input_file=ASTRA_INPUT_FILE, archive_path=SCRATCH_DIR)
            logger.info(f"Evaluation successful for settings: {settings}, Archive: {output['archive']}")
            evaluation_cache[settings_key] = output['archive']
            save_cache(evaluation_cache, CACHE_FILE)
        except Exception as e:
            logger.error(f"Error evaluating settings: {settings}, Exception: {e}")

In [11]:
def generate_parameter_values(parameters):
    param_values = {}
    for param, bounds in parameters.items():
        param_values[param] = np.arange(bounds[0], bounds[1], bounds[2])
    return param_values

In [13]:
def generate_settings_combinations(param_values):
    keys = list(param_values.keys())
    values_combinations = list(product(*param_values.values()))
    
    settings_list = []
    for combination in values_combinations:
        settings = dict(zip(keys, combination))
        settings_list.append(settings)
    
    return settings_list

In [14]:
param_values = generate_parameter_values(ASTRA_INPUT_VARIABLE_PARAMETERS)
settings_combinations = generate_settings_combinations(param_values)

In [None]:
with Executor() as executor:
    outputs = list(executor.map(evaluate, settings_combinations, evaluation_cache))