In [None]:
import mcstasscript as ms
import make_powder_instrument
import quizlib

%matplotlib widget

import numpy as np
import mcstastox
import scipp as sc

from scippneutron.conversion.graph.beamline import beamline
from scippneutron.conversion.graph.tof import elastic

def read_event_union(data, source_name="Source", sample_name="sample_position"):
    with mcstastox.Read(data[0].original_data_location) as file:
        source_position, source_rotation = file.get_component_placement(source_name)
        sample_position, sample_rotation = file.get_component_placement(sample_name)
        
        detector_comps = file.get_components_with_ids()
        global_positions_all = None    
    
        for detector in detector_comps:
            mon = ms.name_search(detector, data)
    
            y = mon.get_data_column("y")
            positions = np.column_stack((np.zeros_like(y), y, np.zeros_like(y)))
            global_positions = file.transform(positions, detector)
    
            if global_positions_all is None:
                global_positions_all = global_positions
                event_weights = mon.get_data_column("p")
                event_times = mon.get_data_column("t")
            else:
                global_positions_all = np.vstack((global_positions_all, global_positions))
                event_weights = np.concatenate((event_weights, mon.get_data_column("p")))
                event_times = np.concatenate((event_times, mon.get_data_column("t")))
    
        events = sc.DataArray(
                data=sc.array(dims=['events'], unit=sc.units.counts, values=event_weights),
                coords={
                    'position' : sc.vectors(dims=['events'], values=global_positions_all, unit='m'),
                    't': sc.array(dims=['events'], unit='s', values=event_times),
                    'source_position': sc.vector(source_position, unit='m'),
                    'sample_position': sc.vector(sample_position, unit='m'),                
                })

        return events

def read_event_classic(data, source_name="Source", sample_name="sample_position"):
    with mcstastox.Read(data[0].original_data_location) as file:
        events = file.export_scipp_simple(source_name, sample_name)
    
        detector_name_1 = "Banana_large"
        detector_name_2 = "Banana_small"
        variables_1 = file.get_component_variables(detector_name_1)
        variables_2 = file.get_component_variables(detector_name_2)


        all_metadata_1 = all(item in variables_1 for item in ["L", "v", "U1", "U2"])
        all_metadata_2 = all(item in variables_2 for item in ["L", "v", "U1", "U2"])
        all_metadata = all_metadata_1 and all_metadata_2
        
        if all_metadata:
            raw_event_data_1 = file.get_event_data(variables=["id", "v", "L", "U1", "U2"], component_name=detector_name_1)
            raw_event_data_2 = file.get_event_data(variables=["id", "v", "L", "U1", "U2"], component_name=detector_name_2)
            
            full_L = np.concatenate((raw_event_data_1["L"], raw_event_data_2["L"]))
            events.coords["sim_wavelength"] = sc.array(dims=["events"], values=full_L, unit="Å")
            
            full_source_time = np.concatenate((raw_event_data_1["U1"], raw_event_data_2["U1"]))
            events.coords["sim_source_time"] = sc.array(dims=["events"], values=full_source_time, unit="s")
            
            full_scattering_order = np.concatenate((raw_event_data_1["U2"], raw_event_data_2["U2"]))
            events.coords["sim_scattering_order"] = sc.array(dims=["events"], values=full_scattering_order)
            
            full_speed = np.concatenate((raw_event_data_1["v"], raw_event_data_2["v"]))
            events.coords["sim_speed"] = sc.array(dims=["events"], values=full_speed, unit="m/s")
            
    events.coords["x"] = sc.array(dims=["events"], values=events.coords["position"].fields.x.values, unit="m")
    events.coords["y"] = sc.array(dims=["events"], values=events.coords["position"].fields.y.values, unit="m")
    events.coords["z"] = sc.array(dims=["events"], values=events.coords["position"].fields.z.values, unit="m")

    return events


def analyze(data, union_detectors=False):

    if union_detectors:
        events = read_event_union(data)
    else:
        events = read_event_classic(data)
    
    # McStas provides absolute time, not time of flight
    events.coords["tof"] = events.coords["t"] - sc.to_unit(sc.scalar(2.86/2, unit="ms"), "s")
    
    graph = {**beamline(scatter=True), **elastic("tof")}
    events = events.transform_coords("dspacing", graph=graph)
    events = events.transform_coords("wavelength", graph=graph)

    return events

def run(par_dict=None, settings_dict=None, union_detectors=False):

    if par_dict is None:
        par_dict = {}

    if settings_dict is None:
        settings_dict = {}

    if "NeXus" not in settings_dict:
        settings_dict["NeXus"] = True

    instrument = make_powder_instrument.make(union_detectors=False)
    instrument.set_parameters(**par_dict)
    instrument.settings(**settings_dict)
    data = instrument.backengine()

    events = analyze(data)

    return events

def plot_with_dspacings(events, filename, d_min=0):
    d_hist = events.hist(dspacing=sc.linspace("dspacing", 0.6, 3.0, num=800, unit="Å", endpoint=False))
    max_val = max(d_hist.values)

    laz = np.loadtxt(filename)
    d_spacings = laz[:, 5]
    intensity = laz[:, 11]

    intensity_normalized = intensity/intensity.max()

    fig =  d_hist.plot()

    for d_space, intensity_norm in zip(d_spacings, intensity_normalized):
        if d_space > d_min:
            fig.ax.plot([d_space, d_space], [0, max_val], "-k", alpha=intensity_norm)
    
    return fig

def run_and_plot(par_dict=None, settings_dict=None, union_detectors=False):
    events = run(par_dict, settings_dict, union_detectors)

    if "reflections" in par_dict:
        filename = par_dict["reflections"].strip(('"'))
    else:
        filename = "Na2Ca3Al2F14.laz"
    
    return plot_with_dspacings(events, filename, d_min=0.5)

In [None]:
instrument = make_powder_instrument.make(union_detectors=False)
instrument.show_parameters()

### Quick run for testing

In [None]:
pars = dict(reflections='"Fe.laz"', frequency_multiplier=3, guide_curve_deg=0, detector_height=1.5)
settings = dict(suppress_output=True, ncount=1E8, mpi=10)

fig = run_and_plot(pars, settings)
fig

### Run and work with scipp object directly

In [None]:
settings["ncount"]=5E7
events = run(pars, settings)

In [None]:
graph = {**beamline(scatter=True), **elastic("tof")}
events = events.transform_coords("wavelength", graph=graph)

events

In [None]:
import plopp as pp

pp.scatter3d(events[0::3], pos='position', size=0.02, cbar=True, norm="linear")

In [None]:
events_binned = events.bin(wavelength=120, two_theta=800)
pp.slicer(events_binned.hist(), keep=["wavelength", "two_theta"], norm="log", vmax=1E6, vmin=1E-3)

In [None]:
events_binned = events.bin(sim_wavelength=120, two_theta=800)
pp.slicer(events_binned.hist(), keep=["sim_wavelength", "two_theta"], norm="log", vmax=1E6, vmin=1E-3)

In [None]:
events.hist(t=200).plot()

### Run instrument directly without the help functions

In [None]:
instrument = make_powder_instrument.make(union_detectors=False)

In [None]:
instrument.show_parameters()

In [None]:
instrument.set_parameters(frequency_multiplier=3, reflections='"Fe.laz"')
instrument.settings(ncount=3e7, mpi=10, suppress_output=False, NeXus=True, output_path="first_run", custom_flags="--bufsiz=100000")

In [None]:
data = instrument.backengine()

In [None]:
data

In [None]:
ms.make_sub_plot(data, log=True, orders_of_mag=5)

In [None]:
ms.make_sub_plot([data[0].make_2d("th", "y"), data[1].make_2d("th", "y")], log=True, orders_of_mag=5)

### Visualize instrument (opens in new window)

In [None]:
#instrument.show_instrument(format="window")

In [None]:
instrument.show_instrument_file()

In [None]:
events = analyze(data)
events