In [None]:
from pathlib import Path
import h5py
import math
import xraylib
import numpy as np
import json
import sys
import hdf5plugin
from swmr_tools import DataSource, utils, KeyFollower
import blosc

from time import perf_counter,sleep

In [None]:
element_list: str
window_width: int
inpath: str
outpath: str

In [None]:
sys.path.append("/dls_sw/i18/software/daqmessenger/daq-messenger")
daq = None
try:
    from daqmessenger import DaqMessenger
    daq = DaqMessenger("i18-control")
    daq.connect()
except:
    print("no messenger")

In [None]:
requested = element_list.split(' ')
width = window_width
inpath = Path("/inputs/").joinpath(inpath)
outpath = Path("/outputs/").joinpath(outpath)

In [None]:
#Kev per channel
gain = 0.01
zero = 0

lines = {"Ka": (xraylib.KL3_LINE, "Ka"),
                     "Kb": (xraylib.KM3_LINE, "Kb"),
                     "La": (xraylib.L3M5_LINE, "La"),
                     "Lb": (xraylib.L2M4_LINE, "Lb"),
                     "M": (xraylib.M5N7_LINE, "M")}

line_out = []
for l in requested:
    el = l.split("-")
    z = xraylib.SymbolToAtomicNumber(el[0]);
    if el[1] in lines:
        symb = lines[el[1]]
        e = xraylib.LineEnergy(z, symb[0])
        channel = math.floor((e-zero)/gain)
        start = math.floor(channel - width/2)
        end = start + width
        line_out.append((el[0]+"-"+symb[1], start, end))

print("Windows used {}".format(line_out))

In [None]:
key_name = "/entry/diamond_scan/keys"
data = "/entry/Xspress3A/data"
finished = "/entry/diamond_scan/scan_finished"

utils.check_file_readable(inpath, [data])
start = perf_counter()

#lets give h5 some more cache
cache_space = 1024*1024*1024

with h5py.File(inpath, 'r', libver='latest', swmr=True) as fh, h5py.File(outpath, 'w', libver='latest', swmr=True, rdcc_nbytes=cache_space) as ofh:

    kf = KeyFollower(fh, [key_name], timeout = 60, finished_dataset = finished)
    kf.check_datasets()

    e = utils.create_nxentry(ofh,"processed")
    nxdm =  utils.create_nxdata(e,"mca",default=False)
    nxdm.attrs["signal"] = "data"

    nxdm.create_dataset("energy",data=np.arange(4096)*gain)
    nxdm.attrs["energy_indices"] = 2

    utils.copy_nexus_axes(fh["/entry/Xspress3A"],nxdm,kf.scan_rank,frame_axes=["energy"])

    nxdata = {}

    for l in line_out:
        nxdw =  utils.create_nxdata(e,l[0])
        utils.copy_nexus_axes(fh["/entry/Xspress3A"],nxdw,kf.scan_rank)
        nxdw.attrs["signal"] = "data"
        nxdata[l[0]] = nxdw

    mca_ds = None
    window = {}
    winnp = {}

    ds = fh[data]
    maxs = ds.maxshape
    imshape = (maxs[0], maxs[1])

    flush_timer = perf_counter()
    for k in kf:
        if not kf._prelim_finished_check:
            sleep(0.001)
        shape = ds.shape
        index = k

        try:
             # might fail if dataset is cached
             pos = utils.get_position(index, shape, kf.scan_rank)
        except ValueError:
             # refresh dataset and try again
             if hasattr(ds, "refresh"):
                 ds.refresh()

             shape = ds.shape
             pos = utils.get_position(index, shape, kf.scan_rank)

        rank = len(shape)
        slices = [slice(0, None, 1)] * rank

        for i in range(len(pos)):
            slices[i] = slice(pos[i], pos[i] + 1)
        slicemd = tuple(slices[: kf.scan_rank])

        chunk = [pos[0],pos[1],0,0]
        try:            out = ds.id.read_direct_chunk(chunk)
        except:
             sleep(1)
             if hasattr(ds, "refresh"):
                 ds.refresh()
             out = ds.id.read_direct_chunk(chunk)

        shape = [1, 1,8, 4096]
        decom = blosc.decompress(out[1])
        a = np.frombuffer(decom,dtype=np.float64,count=-1)
        frame = a.reshape(shape)
        frame = frame.squeeze()
        mca = frame.sum(axis=0)
        
        if mca_ds is None:
            mca_ds = utils.create_dataset(mca,kf.maxshape, nxdm, "data", chunks=(1,1,4096), compression="lzf")
            for l in line_out:
                w = mca[l[1]:l[2]].sum()
                winnpd = np.zeros(imshape)
                winnpd[pos] = w
                winnp[l[0]] = winnpd
                window_ds = nxdata[l[0]].create_dataset("data", data=winnpd, chunks = imshape)
                window[l[0]] = window_ds

            ofh.swmr_mode = True
            if daq:
                daq.send_start(outpath)
        else:
            utils.append_data(mca,slicemd,mca_ds)
            for l in line_out:
                w = mca[l[1]:l[2]].sum()
                utils.append_data(w,slicemd,winnp[l[0]])

        now = perf_counter()
        if (now-flush_timer) > 2:
            mca_ds.flush()
            for l in line_out:
                window[l[0]][...] = winnp[l[0]]
                window[l[0]].flush()
                flush_timer = now
            if daq:
                daq.send_update(outpath)

    for l in line_out:
        window[l[0]][...] = winnp[l[0]]
        window[l[0]].flush()

if daq:
    daq.send_finished(outpath)

print(f"Completed in {(perf_counter()-start):.1f} seconds")