In [None]:
def generate_toy_data(run, duration, directory='toy_data', record_length=hx.DEFAULT_RECORD_LENGTH,
                      sampling_dt=hx.DEFAULT_SAMPLING_DT, template_length=hx.DEFAULT_TEMPLATE_LENGTH,
                      channel_map=hx.DEFAULT_CHANNEL_MAP, noise_std=3, event_rate=1, overwrite=False, helix_data_dir='test_helix_data', baseline_step=0, traces_file='traces.csv'):
    """
    Generates and saves toy data with multiple channels of vacuum and submerged types, physics events consisting of UV
    and QP signals, as well as background lone hits and muon saturated events. CAUTION: it's slow!
    Noise is uncorrelated pink noise, with a correlated 5 kHz feature in all channels. Channels have different baselines

    :param run: run id
    :param duration: run duration in seconds. Caution: the function is slow, don't ask to generate days of data
    :param directory: output directory, where a directory with run_id name will be created
    :param record_length: length of records in each file in time samples
    :param sampling_dt: sampling time in ns
    :param template_length: length of UV and QP templates
    :param channel_map: dictionary of channel types and channel number ranges
    :param noise_std: standard deviation of the noise
    :param event_rate: rate of physics events in Hz
    :param overwrite: a boolean specifying whether the function should overwrite data, if it already exists. If False,
    a RuntimeError is raised when a directory with the same run id exists
    :param helix_data_dir: a directory to save the run metadata. Should be the same as the helix output directory.
    :param baseline_step: add a baseline to each channel, equal to baseline_step * channel_index
    """

    run_dir = os.path.join(directory, run)
    if os.path.exists(run_dir):
        if overwrite:
            shutil.rmtree(run_dir)
        else:
            raise RuntimeError(f'Directory {run_dir} already exists.')

    os.makedirs(run_dir)
    traces_array = load_traces_from_csv(traces_file)

    record_length_s = record_length * sampling_dt / units.s
    n_records = int(duration / record_length_s)
    channels = hx.Channels(channel_map)
    n_channels = len(channels)
    batch_size = 1  # number of records per batch
    sampling_frequency = 1 / (sampling_dt / units.s)

    _, psd = get_pink_psd(record_length * batch_size, sampling_dt, noise_std)

    baseline = baseline_step * np.arange(n_channels)    
    n_batches = 1
    batch_length_s = 31

    n_events = np.full(n_batches, int(round(event_rate * batch_length_s)), dtype=int)

    def fill_waveform_csv(waveform, traces_array, n_events, event_times, channels):
        for ich, ch_type in enumerate(channels.types):
            if 0 <= ich <= 14 or 35 <= ich <= 49:
                coverage = np.zeros(waveform.shape[1], dtype=bool)
                for j in range(min(100, n_events[i])):
                    trace_idx = (ich * 10 + j) % traces_array.shape[0]
                    selected_trace = traces_array[trace_idx]
                    start_idx = event_times[j]
                    end_idx = min(start_idx + hx.DEFAULT_TEMPLATE_LENGTH, waveform.shape[1])
                    waveform[ich, start_idx:end_idx] += selected_trace[:]
                    coverage[start_idx:end_idx] = True
                waveform[ich, ~coverage] += hx.DEFAULT_MMC_BASELINE_LIFT
        return waveform

    def fill_waveform_2_2_OF(waveform, traces_array, n_events, event_times, channels):
        for ich, ch_type in enumerate(channels.types):
            if 0 <= ich <= 14 or 35 <= ich <= 49:
                coverage = np.zeros(waveform.shape[1], dtype=bool)
                for j in range(min(100, n_events[i])):
                    trace_idx = (ich * 10 + j) % traces_array.shape[0]
                    selected_trace = traces_array[trace_idx]
                    start_idx = event_times[j]
                    end_idx = min(start_idx + hx.DEFAULT_TEMPLATE_LENGTH, waveform.shape[1])
                    
                    if ich in [0, 1]:  # Push normally for channels 0 and 1
                        waveform[ich, start_idx:end_idx] += selected_trace[:]
                    elif ich in [2, 3]:  # Out of phase for channels 2 and 3
                        segment_length = waveform.shape[1] // 11
                        segment_idx = j % 10  # Cycle through 10 segments
                        if ich == 2 and segment_idx % 2 == 0:
                            waveform[ich, start_idx:end_idx] += selected_trace[:]
                        elif ich == 3 and segment_idx % 2 == 1:
                            waveform[ich, start_idx:end_idx] += selected_trace[:]
                    coverage[start_idx:end_idx] = True
                waveform[ich, ~coverage] += hx.DEFAULT_MMC_BASELINE_LIFT
        return waveform


    for i in tqdm(range(n_batches)):
        waveform = generate_silent_traces(n_channels, psd, sampling_frequency)
        event_times = np.linspace(0, batch_size * record_length - template_length, num=n_events[i], dtype=int)
        waveform = fill_waveform_2_2_OF(waveform, traces_array, n_events, event_times, channels)
        
        for j in range(batch_size):
            i_record = i * batch_size + j
            if i_record == n_records:
                break
            fn = f'{directory}/{run}/{run}-{i_record:05d}'
            with open(fn, mode='wb') as f:
                data = np.ascontiguousarray(waveform[:, j * record_length:(j + 1) * record_length], dtype=np.int16)
                f.write(lz4.compress(data))

    if not os.path.exists(helix_data_dir):
        os.makedirs(helix_data_dir)

    metadata_path = os.path.join(helix_data_dir, f"{run}-metadata.json")
    start = datetime.now().replace(microsecond=0)
    end = start + timedelta(seconds=duration)
    metadata = {'start': start.isoformat(), 'end': end.isoformat()}
    
    with open(metadata_path, 'w') as f:
        json.dump(metadata, f)


In [None]:
import numpy as np
from bson import json_util
import helix as hx
from helix import units
import numpy as np
import strax as sx                    
from matplotlib import pyplot as plt
from glob import glob 
import os
import shutil

raw_data_dir = 'toy_data'  # to save the raw toy data
helix_data_dir = 'test_helix_data'  # to save the run metadata
run = 'run10' 
duration = 10  # seconds
baseline_step = 0  # add a baseline equal to baseline_step*channel_index to each channel 

In [None]:
for path in glob(f'{helix_data_dir}/*'):
    if os.path.isdir(path):
        shutil.rmtree(path)
    else:
        os.remove(path)
    
generate_toy_data(run, duration, raw_data_dir, helix_data_dir=helix_data_dir, overwrite=True, baseline_step=baseline_step)