# SAMPIC Noise Filtering & Average Shape

Use this notebook to identify and remove purely noisy waveforms from SAMPIC data, inspect summary statistics, and visualise the aggregate shape of the remaining signals.

## Workflow

1. Load the unpacked ROOT file and required libraries.
2. Classify each hit waveform using configurable noise rejection criteria.
3. Report statistics (overall and per channel) on filtered vs retained hits.
4. Visualise channel-level filtering and build a 2D heatmap plus average waveform for the surviving signals.

Tune the thresholds in the **Filtering Parameters** cell to match your dataset.

In [None]:
import ROOT
import numpy as np
import matplotlib.pyplot as plt
import os
from collections import defaultdict

plt.rcParams['figure.figsize'] = (12, 6)
plt.rcParams['axes.grid'] = True
plt.rcParams['grid.alpha'] = 0.25
plt.rcParams['font.size'] = 12

In [None]:
DATA_FILE = os.getenv('SAMPIC_OUTPUT', '../../output.root')
LIB_DIR = '../../build/lib'
LIBS_TO_LOAD = [
    'libanalysis_pipeline_core.so',
    'libunpacker_data_products_core.so',
    'libunpacker_data_products_sampic.so',
]

for lib in LIBS_TO_LOAD:
    path = os.path.join(LIB_DIR, lib)
    if os.path.exists(path):
        ROOT.gSystem.Load(path)
    else:
        print(f'Warning: {lib} not found in {LIB_DIR}')

In [None]:
file = ROOT.TFile.Open(DATA_FILE)
if not file or file.IsZombie():
    raise RuntimeError(f'Unable to open ROOT file: {DATA_FILE}')

tree = file.Get('events')
if not tree:
    raise RuntimeError("Tree 'events' not found in file")

print(f'Total events available: {tree.GetEntries()}')
print('Branches:', [branch.GetName() for branch in tree.GetListOfBranches()])

## Filtering Parameters

Adjust these thresholds to control how aggressively noise is removed. By default we treat a waveform as noise when:

- Its dynamic range (max - min) stays below `AMP_THRESHOLD` **and** the stored peak is small.
- Its sample-to-sample variation is tiny (`STD_THRESHOLD`).
- Its time-over-threshold metric is essentially zero (`TOT_THRESHOLD`).

All surviving waveforms are baseline-subtracted before building the heatmap/averages.

In [None]:
AMP_THRESHOLD = 25.0       # ADC counts
PEAK_THRESHOLD = 25.0      # ADC counts
STD_THRESHOLD = 2.0        # ADC counts
TOT_THRESHOLD = 0.5        # ns
MAX_EVENTS = None          # Set to an integer to limit the scan (e.g. 5000)
BASELINE_SAMPLES = 8       # Samples used to estimate the baseline per waveform

print('Filtering thresholds set – adjust as needed and re-run downstream cells if you change them.')

In [None]:
def extract_waveform(hit):
    """Return waveform samples as a numpy array of floats."""
    wf = hit.corrected_waveform
    return np.array([wf[i] for i in range(len(wf))], dtype=np.float32)


def classify_hit(hit):
    waveform = extract_waveform(hit)
    if waveform.size == 0:
        return waveform, {
            'amplitude': 0.0,
            'peak': 0.0,
            'tot': getattr(hit, 'tot_value', 0.0),
            'std': 0.0,
            'is_noise': True,
        }

    baseline = waveform[:BASELINE_SAMPLES].mean() if waveform.size >= BASELINE_SAMPLES else waveform.mean()
    centered = waveform - baseline
    amplitude = float(centered.max() - centered.min())
    peak = float(getattr(hit, 'peak', centered.max()))
    tot = float(getattr(hit, 'tot_value', 0.0))
    std = float(centered.std())

    noise_like = ((amplitude < AMP_THRESHOLD and abs(peak) < PEAK_THRESHOLD) or
                  std < STD_THRESHOLD or
                  tot <= TOT_THRESHOLD)

    return centered, {
        'amplitude': amplitude,
        'peak': peak,
        'tot': tot,
        'std': std,
        'is_noise': noise_like,
    }

In [None]:
summary = {
    'events_scanned': 0,
    'total_hits': 0,
    'noise_hits': 0,
}
channel_stats = defaultdict(lambda: {'total': 0, 'noise': 0})
non_noise_waveforms = []
non_noise_channels = []
non_noise_amplitudes = []
non_noise_tot = []

entries = tree.GetEntries() if MAX_EVENTS is None else min(MAX_EVENTS, tree.GetEntries())

for idx in range(entries):
    tree.GetEntry(idx)
    summary['events_scanned'] += 1
    event = tree.sampic_event
    for hit in event.hits:
        waveform, stats = classify_hit(hit)
        ch = int(hit.channel)
        summary['total_hits'] += 1
        channel_stats[ch]['total'] += 1

        if stats['is_noise']:
            summary['noise_hits'] += 1
            channel_stats[ch]['noise'] += 1
            continue

        non_noise_waveforms.append(waveform)
        non_noise_channels.append(ch)
        non_noise_amplitudes.append(stats['amplitude'])
        non_noise_tot.append(stats['tot'])

print(f"Processed {summary['events_scanned']} events")
print(f"Total hits inspected: {summary['total_hits']}")
print(f"Noise-like hits filtered: {summary['noise_hits']} ({summary['noise_hits']/max(summary['total_hits'],1):.2%})")

In [None]:
channels = sorted(channel_stats.keys())
print(f"Channels seen: {channels}")
print('\nPer-channel summary:')
print(f"{'Ch':>3} | {'Hits':>6} | {'Filtered':>8} | {'Retained':>8} | {'Filtered %':>10}")
print('-' * 48)
for ch in channels:
    totals = channel_stats[ch]
    filtered = totals['noise']
    retained = totals['total'] - filtered
    frac = (filtered / totals['total']) if totals['total'] else 0.0
    print(f"{ch:>3d} | {totals['total']:>6d} | {filtered:>8d} | {retained:>8d} | {frac:>9.1%}")

In [None]:
if channels:
    filtered_counts = [channel_stats[ch]['noise'] for ch in channels]
    retained_counts = [channel_stats[ch]['total'] - channel_stats[ch]['noise'] for ch in channels]

    x = np.arange(len(channels))
    width = 0.35

    fig, ax = plt.subplots(figsize=(14, 5))
    ax.bar(x - width/2, retained_counts, width, label='Retained')
    ax.bar(x + width/2, filtered_counts, width, label='Filtered (noise)')
    ax.set_xticks(x)
    ax.set_xticklabels(channels)
    ax.set_xlabel('Channel')
    ax.set_ylabel('Hit count')
    ax.set_title('Per-channel retained vs filtered hits')
    ax.legend()
    plt.show()
else:
    print('No channels found – check filtering thresholds or input file.')

In [None]:
if non_noise_waveforms:
    lengths = [len(wf) for wf in non_noise_waveforms]
    min_len = min(lengths)
    if min_len == 0:
        raise RuntimeError('Encountered zero-length waveform; cannot build heatmap.')

    trimmed = np.stack([wf[:min_len] for wf in non_noise_waveforms])

    sample_indices = np.tile(np.arange(min_len), trimmed.shape[0])
    values = trimmed.reshape(-1)

    v_min, v_max = np.percentile(values, [1, 99])
    bins = [min_len, 120]
    hist, xedges, yedges = np.histogram2d(sample_indices, values, bins=bins,
                                          range=[[0, min_len], [v_min, v_max]])

    fig, ax = plt.subplots(figsize=(14, 6))
    im = ax.imshow(hist.T, origin='lower', aspect='auto',
                   extent=[0, min_len, yedges[0], yedges[-1]], cmap='magma')
    ax.set_xlabel('Sample index')
    ax.set_ylabel('Amplitude [ADC] (baseline-subtracted)')
    ax.set_title('Heatmap of retained waveform shapes')
    fig.colorbar(im, ax=ax, label='Counts per bin')
    plt.show()

    mean_wave = trimmed.mean(axis=0)
    std_wave = trimmed.std(axis=0)

    fig, ax = plt.subplots(figsize=(14, 4))
    x = np.arange(min_len)
    ax.plot(x, mean_wave, color='tab:blue', label='Mean waveform')
    ax.fill_between(x, mean_wave - std_wave, mean_wave + std_wave,
                    color='tab:blue', alpha=0.25, label='±1σ envelope')
    ax.set_xlabel('Sample index')
    ax.set_ylabel('Amplitude [ADC]')
    ax.set_title('Average retained waveform (baseline-subtracted)')
    ax.legend()
    plt.show()
else:
    print('No non-noise waveforms collected – relax thresholds or check data file.')

## Next Steps

- Fine tune `AMP_THRESHOLD`, `PEAK_THRESHOLD`, and `STD_THRESHOLD` to capture your notion of noise.
- Export the retained waveforms for downstream analysis (e.g., using `np.save`).
- Compare heatmaps generated with different run conditions to spot drifts or hardware issues.