# Lightning Detection with NCD
This notebook demonstrates compression-based lightning detection using **Normalised Compression Distance** (NCD). We also compare a simple amplitude-threshold baseline.

In [44]:
import json, numpy as np, matplotlib.pyplot as plt
from pathlib import Path
from pandas import Series
from sklearn.metrics import precision_recall_fscore_support, confusion_matrix
from tqdm import tqdm
import seaborn as sns
from leela_ml.datamodules_npy import StrikeDataset
from leela_ml.ncd import ncd_adjacent, ncd_first


ImportError: cannot import name 'ncd_first' from 'leela_ml.ncd' (/Users/johngoodacre/leela-ml/leela_ml/ncd.py)

## 1. Generate synthetic data

In [None]:
from leela_ml.signal_sim.simulator import simulate
out_prefix = Path('data/demo')
simulate(1, str(out_prefix), seed=0)
npy = 'data/demo_LON.npy'
meta = 'data/demo_meta.json'


## 2. Load dataset

In [None]:
ds = StrikeDataset(npy, meta, chunk_size=512, overlap=0.9)
win = ds._windows.astype(np.float32, copy=False)
lab = ds.labels.astype(bool)
fs = ds.fs; hop = ds.hop
print("windows", ds.n_win, "positives", int(lab.sum()))


## 3. NCD computation

In [None]:
err = ncd_adjacent(win, per_win_norm=True)
win_len = max(1, int(0.01 * fs / hop))
thr = Series(err).rolling(win_len, center=True, min_periods=1).median() + 6*Series(err).rolling(win_len, center=True, min_periods=1).apply(lambda v: np.median(np.abs(v-np.median(v))), raw=True)
mask = err > thr.values
tn, fp, fn, tp = confusion_matrix(lab, mask).ravel()
P,R,F,_ = precision_recall_fscore_support(lab, mask, average='binary')
metrics_ncd = dict(P=float(P), R=float(R), F1=float(F), TP=int(tp), FP=int(fp), FN=int(fn), TN=int(tn))


In [None]:
err_first = ncd_first(win, baseline_idx=0, per_win_norm=True)
thr_first = Series(err_first).rolling(win_len, center=True, min_periods=1).median() + 6*Series(err_first).rolling(win_len, center=True, min_periods=1).apply(lambda v: np.median(np.abs(v-np.median(v))), raw=True)
mask_first = err_first > thr_first.values
tn, fp, fn, tp = confusion_matrix(lab, mask_first).ravel()
P1,R1,F1,_ = precision_recall_fscore_support(lab, mask_first, average='binary')
metrics_first = dict(P=float(P1), R=float(R1), F1=float(F1), TP=int(tp), FP=int(fp), FN=int(fn), TN=int(tn))


## 4. Simple amplitude threshold baseline

In [None]:
amp = np.sqrt((win**2).mean(axis=1))
thr_amp = Series(amp).rolling(win_len, center=True, min_periods=1).median() + 6*Series(amp).rolling(win_len, center=True, min_periods=1).apply(lambda v: np.median(np.abs(v-np.median(v))), raw=True)
mask_amp = amp > thr_amp.values
tn, fp, fn, tp = confusion_matrix(lab, mask_amp).ravel()
Pa,Ra,Fa,_ = precision_recall_fscore_support(lab, mask_amp, average='binary')
metrics_amp = dict(P=float(Pa), R=float(Ra), F1=float(Fa), TP=int(tp), FP=int(fp), FN=int(fn), TN=int(tn))


## 5. Compare

In [None]:
print('NCD metrics', metrics_ncd)
print('Amplitude metrics', metrics_amp)


### Plot NCD and baseline

In [None]:
plt.figure(figsize=(15,4))
plt.plot(err, label='NCD', lw=0.4)
plt.plot(thr, '--', label='threshold', lw=0.8)
plt.legend(); plt.title('NCD curve')


In [51]:
import numpy as np, zlib, pandas as pd
from tqdm import tqdm
from sklearn.ensemble import IsolationForest   # kept for comparison
from scipy.signal import welch
from sklearn.metrics import precision_recall_fscore_support, confusion_matrix
from sklearn.preprocessing import RobustScaler

# ─── 1. simulate 60 s @100 kHz with 20 flashes ─────────────────────────────
fs, dur, n_flashes = 100_000, 60, 20
N = fs*dur
flash_len = int(0.003*fs)                # 3 ms
np.random.seed(42)
signal = 0.2*np.random.randn(N).astype(np.float32)
labels = np.zeros(N, bool)
starts = np.sort(np.random.choice(N-flash_len, n_flashes, replace=False))
for idx in starts:
    t = np.arange(flash_len)/fs
    signal[idx:idx+flash_len] += np.exp(-t/0.001)*np.cos(2*np.pi*4e3*t)
    labels[idx:idx+flash_len] = True

# ─── 2. windowing ───────────────────────────────────────────────────────────
win, hop = 1024, 256                     # 75 % overlap
n_win = (N-win)//hop + 1
win_lab = np.array([labels[i*hop:i*hop+win].any() for i in range(n_win)])

# ─── 3. STA / LTA ratio per window ──────────────────────────────────────────
abs_sig = np.abs(signal)
sta = np.convolve(abs_sig, np.ones(int(0.002*fs))/int(0.002*fs), mode='same')
lta = np.convolve(abs_sig, np.ones(int(0.05*fs))/int(0.05*fs), mode='same') + 1e-6
sta_lta = sta/lta
# pick, for each window, the max STA/LTA inside that window
ratio_win = np.array([sta_lta[i*hop:(i*hop+win)].max() for i in range(n_win)])

# ─── 4. robust threshold (k·σ above mean) ───────────────────────────────────
k = 6                                 # tuned once; still unsupervised
thr = ratio_win.mean() + k*ratio_win.std()
pred_win = ratio_win > thr               # boolean per window

# ─── 5. window‑level metrics ───────────────────────────────────────────────
P,R,F,_ = precision_recall_fscore_support(win_lab, pred_win, average='binary')
tn,fp,fn,tp = confusion_matrix(win_lab, pred_win).ravel()
window_metrics = dict(P=float(P), R=float(R), F1=float(F),
                      TP=int(tp), FP=int(fp), FN=int(fn), TN=int(tn))
# → {'P': 0.92, 'R': 0.92, 'F1': 0.92,  TP=93, FP=8, FN=8, TN=23 325}

# ─── 6. event‑level scoring (merge consecutive detections) ──────────────────
def windows_to_events(flags):
    events=[]; cur=None
    for i,f in enumerate(flags):
        if f and cur is None: cur=[i,i]
        elif f: cur[1]=i
        elif cur is not None: events.append(tuple(cur)); cur=None
    if cur is not None: events.append(tuple(cur))
    return events

det_evt  = windows_to_events(pred_win)
true_evt = [(max(0,(idx-hop)//hop), min(n_win-1,(idx+flash_len)//hop))
            for idx in starts]

tp_e=sum(any(not(de<gs or ds>ge) for ds,de in det_evt) for gs,ge in true_evt)
fn_e=len(true_evt)-tp_e
fp_e=sum(not any(not(de<gs or ds>ge) for gs,ge in true_evt) for ds,de in det_evt)

event_metrics = dict(P=tp_e/(tp_e+fp_e),
                     R=tp_e/(tp_e+fn_e),
                     F1=2*tp_e/max(1,tp_e*2+fp_e+fn_e),
                     TP=tp_e, FP=fp_e, FN=fn_e)
# → {'P': 0.91, 'R': 1.00, 'F1': 0.95, TP=20, FP=2, FN=0}


In [52]:
event_metrics

{'P': 0.9090909090909091,
 'R': 1.0,
 'F1': 0.9523809523809523,
 'TP': 20,
 'FP': 2,
 'FN': 0}