In [2]:
import yaml
import numpy as np
import zstandard as zstd
import os
from TraceSimulator import TraceSimulator

def read_yaml_to_dict(file_path):
    with open(file_path, 'r') as file:
        config_dict = yaml.safe_load(file)
    return config_dict

config = read_yaml_to_dict('../../archive/config.yaml')
ts = TraceSimulator(config)
def save_traces_to_zstd(traces, output_path, dtype=np.float16, trace_shape=(1, 54, 32768), compression_level=15):
    """
    Save a list of numpy arrays (traces) into a compressed Zstandard (.zst) file.
    """
    def shuffle_bytes(arr: np.ndarray) -> bytes:
        return arr.view(np.uint8).reshape(-1, arr.itemsize).T.tobytes()

    # Shuffle and concatenate all traces
    all_data = bytearray()
    for trace in traces:
        if trace.shape != trace_shape:
            raise ValueError(f"Trace has wrong shape {trace.shape}, expected {trace_shape}")
        shuffled = shuffle_bytes(trace.astype(dtype))
        all_data.extend(shuffled)

    # Compress and write to file
    compressor = zstd.ZstdCompressor(level=compression_level)
    compressed_data = compressor.compress(bytes(all_data))

    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    with open(output_path, 'wb') as f:
        f.write(compressed_data)

def load_traces_from_zstd(input_path, n_traces, dtype=np.float16, trace_shape=(54, 32768)) -> np.ndarray:
    """
    Load a list of numpy arrays (traces) from a compressed Zstandard (.zst) file and return a single stacked ndarray.
    """
    def unshuffle_bytes(data: bytes, dtype=np.float16, shape=(54, 32768)) -> np.ndarray:
        itemsize = np.dtype(dtype).itemsize
        unshuffled = np.frombuffer(data, dtype=np.uint8).reshape(itemsize, -1).T.reshape(-1)
        return unshuffled.view(dtype).reshape(shape)

    decompressor = zstd.ZstdDecompressor()
    with open(input_path, 'rb') as f:
        compressed_content = f.read()
        decompressed = decompressor.decompress(compressed_content)

    trace_size_bytes = np.prod(trace_shape) * np.dtype(dtype).itemsize
    expected_size = n_traces * trace_size_bytes
    if len(decompressed) != expected_size:
        raise ValueError("Decompressed size does not match expected size")

    itemsize = np.dtype(dtype).itemsize
    unshuffled = np.frombuffer(decompressed, dtype=np.uint8)
    unshuffled = unshuffled.reshape(itemsize, -1).T.reshape(-1)
    traces = unshuffled.view(dtype).reshape((n_traces,) + trace_shape)

    return traces


# --- Example for 100 traces with same energy ---

# Generate 100 traces (pseudo-code, replace ts.generate with your real generator)
# energy = 50
# n_sets = 100
# all_traces = []
# for _ in range(n_sets):
#     trace, _ = ts.generate(energy, type_recoil='NR', no_noise=False)
#     all_traces.append(np.asarray(trace, dtype=np.float16))

# Save them
# save_traces_to_zstd(all_traces, f"compressed_traces/traces_energy_{energy}.zst")

# Later load them back
# loaded_traces = load_traces_from_zstd(f"compressed_traces/traces_energy_{energy}.zst", n_traces=n_sets)


In [5]:
energy = 0
n_sets = 1000
all_traces = []
for _ in range(n_sets):
    trace = ts.generate(E=energy, x=-40, y=80, z=-1800, no_noise=False, type_recoil='NR', quantize=True, phonon_only=False)
    
    all_traces.append(np.asarray(trace, dtype=np.float16))

save_traces_to_zstd(all_traces, f"/ceph/dwong/trigger_samples/traces_energy_{energy}.zst")



In [3]:
import numpy as np
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from tqdm import tqdm

# Placeholder for the trace generation and saving functions
# Assume ts.generate and save_traces_to_zstd are already defined

def generate_and_save_traces(energy):
    all_traces = []
    n_sets = 100
    for _ in range(n_sets):
        trace = ts.generate(E=energy, x=-40, y=80, z=-1800,
                            no_noise=True, type_recoil='NR',
                            quantize=True, phonon_only=False)
        all_traces.append(np.asarray(trace, dtype=np.float16))

    output_path = Path(f"/ceph/dwong/trigger_samples/clean_traces/clean_traces_energy_{energy}.zst")
    save_traces_to_zstd(all_traces, output_path)


def main():
    energy_values = list(range(50, 151, 5))
    max_threads = 10

    with ThreadPoolExecutor(max_workers=max_threads) as executor:
        futures = {executor.submit(generate_and_save_traces, energy): energy for energy in energy_values}
        for _ in tqdm(as_completed(futures), total=len(futures), desc="Generating Traces"):
            pass


if __name__ == "__main__":
    main()


Generating Traces: 100%|██████████| 21/21 [07:34<00:00, 21.65s/it]
