In [None]:
import os
from collections import OrderedDict

from ray_tools.base.engine import MinMaxRayEngine, RayEngine
from ray_tools.base.backend import *
from ray_tools.base.transform import XYHistogram

from ray_tools.simulation.data_tools import EfficientRandomRayDatasetGenerator
from ray_tools.base.transform import *

%load_ext autoreload
%autoreload 2

In [None]:
file_root = ''

param_limit_dict: OrderedDict[str, tuple[float, float]] = OrderedDict([
    #('U41_318eV.numberRays', 100),

    ('U41_318eV.translationXerror', (-0.25, 0.25)),
    ('U41_318eV.translationYerror', (-0.25, 0.25)),
    ('U41_318eV.rotationXerror', (-0.05, 0.05)),
    ('U41_318eV.rotationYerror', (-0.05, 0.05)),

    ('ASBL.openingWidth', (1.9, 2.1)),
    ('ASBL.openingHeight', (0.9, 1.1)),
    ('ASBL.translationXerror', (-0.2, 0.2)),
    ('ASBL.translationYerror', (-0.2, 0.2)),

    ('M1_Cylinder.radius', (174.06, 174.36)),
    ('M1_Cylinder.rotationXerror', (-0.25, 0.25)),
    ('M1_Cylinder.rotationYerror', (-1.0, 1.0)),
    ('M1_Cylinder.rotationZerror', (-1.0, 1.0)),
    ('M1_Cylinder.translationXerror', (-1.0, 1.0)),
    ('M1_Cylinder.translationYerror', (-1.0, 1.0)),

    ('SphericalGrating.radius', (109741.0, 109841.0)),
    ('SphericalGrating.rotationYerror', (-1.0, 1.0)),
    ('SphericalGrating.rotationZerror', (-2.5, 2.5)),

    ('ExitSlit.openingHeight', (0.009, 0.011)),
    ('ExitSlit.translationZerror', (-150.0, 150.0)),
    ('ExitSlit.rotationZerror', (-0.3, 0.3)),

    ('E1.longHalfAxisA', (20600.0, 20900.0)),
    ('E1.shortHalfAxisB', (300.721702601, 304.721702601)),
    ('E1.rotationXerror', (-1.5, 1.5)),
    ('E1.rotationYerror', (-7.5, 7.5)),
    ('E1.rotationZerror', (7.0, 14.0)),
    ('E1.translationYerror', (-1.0, 1.0)),
    ('E1.translationZerror', (-1.0, 1.0)),

    ('E2.longHalfAxisA', (4325.0, 4425.0)),
    ('E2.shortHalfAxisB', (96.1560870104, 98.1560870104)),
    ('E2.rotationXerror', (-0.5, 0.5)),
    ('E2.rotationYerror', (-7.5, 7.5)),
    ('E2.rotationZerror', (22.0, 32.0)),
    ('E2.translationYerror', (-1.0, 1.0)),
    ('E2.translationZerror', (-1.0, 1.0)),

    ('ImagePlane.translationXerror', (-1.0, 1.0)),
    ('ImagePlane.translationYerror', (-1.0, 1.0)),
    ('ImagePlane.translationZerror', (-33.0, 33.0)),
])


backend = RayBackendLocalRayX(seed=123, verbose=False)#, RayBackendLocalRayUI(seed=123)]

outputs_dir = os.path.join(file_root, 'outputs/')
engine = MinMaxRayEngine(rml_basefile=os.path.join(file_root, 'tests/rml/METRIX_U41_G1_H1_318eV_PS_MLearn_1.15.rml'),
                                param_limit_dict=param_limit_dict,
                                exported_planes=["ImagePlane"],
                                ray_backend=backend,
                                num_workers=-1,
                                verbose=False,
                    manual_transform_plane='ImagePlane',
                    manual_transform_xyz=('ImagePlane.translationXerror','ImagePlane.translationYerror','ImagePlane.translationZerror'))

def sampler_func(batch_len):
    return torch.rand(batch_len, len(param_limit_dict))
            
gen = EfficientRandomRayDatasetGenerator(
    engine=engine,
    sampler=sampler_func,
    transform=XYHistogram(50, (-10,10), (-2,2)),
    h5_datadir=outputs_dir,
    param_limit_dict=param_limit_dict,
    exported_planes=["ImagePlane"],
    h5_basename="dataset5",
    h5_max_size=100,
    fixed_output_size=True,
)

gen.generate(h5_idx=0, batch_size=5)
