In [1]:
import importlib
import os
import time
from collections import defaultdict
from pathlib import Path
from typing import Optional
import sys
sys.path.append('./ml4gw')
sys.path.append('./gw-anomaly')


import h5py
import numpy as np
import torch
from ml4gw.transforms import ShiftedPearsonCorrelation, SpectralDensity, Whiten
from ml4gw.utils.slicing import unfold_windows


model_lib = importlib.import_module("scripts.models")
config = importlib.import_module("config")

DATA_DIR = Path("/n/home00/emoreno/gw-anomaly/output/O3av2/1243382418_1248652818/")
MODEL_DIR = Path("output/gwak-paper-final-models/trained/models/")
DEVICE = "cuda:0"
SAMPLE_RATE = config.SAMPLE_RATE
HIGHPASS = config.BANDPASS_LOW
WINDOW_LENGTH = config.SEG_NUM_TIMESTEPS / config.SAMPLE_RATE
STRIDE = (config.SEG_NUM_TIMESTEPS - config.SEGMENT_OVERLAP) / config.SAMPLE_RATE
PSD_LENGTH = 64
FDURATION = 2
BATCH_SIZE = 2048


def data_iterator(dataset: h5py.Dataset, shifts: list[float], batch_size: int = BATCH_SIZE):
    shift_sizes = [int(i * SAMPLE_RATE) for i in shifts]
    num_channels, size = dataset.shape

    size -= max(shift_sizes)
    stride_size = int(STRIDE * SAMPLE_RATE)
    num_updates = size // stride_size
    num_batches = int(num_updates // batch_size)

    update_size = stride_size * batch_size
    idx = np.arange(update_size)
    x = np.zeros((num_channels, update_size))
    for i in range(num_batches):
        for j in range(num_channels):
            start = i + update_size + shift_sizes[j]
            stop = start + update_size
            x[j] = dataset[j, start: stop]
        yield torch.Tensor(x)

ModuleNotFoundError: No module named 'gw-anomaly'

In [2]:
class BatchGenerator(torch.nn.Module):
    def __init__(
        self,
        num_channels: int,
        sample_rate: float,
        kernel_length: float,
        fftlength: float,
        fduration: float,
        psd_length: float,
        inference_sampling_rate: float,
        highpass: Optional[float] = None,
    ) -> None:
        super().__init__()
        self.spectral_density = SpectralDensity(
            sample_rate=SAMPLE_RATE,
            fftlength=fftlength,
            overlap=None,  # defaults to fftlength / 2
            average="median",
            fast=True  # not accurate for lowest 2 frequency bins, but we don't care about those
        )
        self.whitener = Whiten(
            fduration=fduration,
            sample_rate=sample_rate,
            highpass=highpass
        )

        self.step_size = int(sample_rate / inference_sampling_rate)
        self.kernel_size = int(kernel_length * sample_rate)
        self.fsize = int(fduration * sample_rate)
        self.psd_size = int(psd_length * sample_rate)
        self.num_channels = num_channels

    @property
    def state_size(self):
        return self.psd_size + self.kernel_size + self.fsize - self.step_size

    def get_initial_state(self):
        return torch.zeros((self.num_channels, self.state_size))

    def forward(self, X, state):
        state = torch.cat([state, X], dim=-1)
        split = [self.psd_size, state.size(-1) - self.psd_size]
        whiten_background, X = torch.split(state, split, dim=-1)

        # only use the PSD of the non-injected data for computing
        # our whitening to avoid biasing our PSD estimate
        psd = self.spectral_density(whiten_background.double())
        X = self.whitener(X, psd)
        X = unfold_windows(X, self.kernel_size, self.step_size)
        X = X.reshape(-1, self.num_channels, self.kernel_size)

        # divide by standard deviation along time axis
        X = X / X.std(axis=-1, keepdims=True)
        return X, state[:, -self.state_size :]

In [3]:
batcher = BatchGenerator(
    num_channels=2,
    sample_rate=SAMPLE_RATE,
    kernel_length=WINDOW_LENGTH,
    fftlength=2,
    fduration=FDURATION,
    psd_length=PSD_LENGTH,
    inference_sampling_rate=1 / STRIDE,
    highpass=HIGHPASS
)
batcher = batcher.to(DEVICE)

In [5]:
models = {}
for fname in MODEL_DIR.glob("*.pt"):
    try:
        model_type = config.MODEL[fname.stem]
    except KeyError:
        print(f"WARNING: no corresponding model type for weights {fname}")
        continue

    if model_type == "lstm":
        model = model_lib.LSTM_AE_SPLIT(
            num_ifos=config.NUM_IFOS,
            num_timesteps=config.SEG_NUM_TIMESTEPS,
            BOTTLENECK=config.BOTTLENECK[fname.stem]
        )
    elif model_type == "dense":
        model = model_lib.FAT(
            num_ifos=config.NUM_IFOS,
            num_timesteps=config.SEG_NUM_TIMESTEPS,
            BOTTLENECK=config.BOTTLENECK[fname.stem]
        )
    else:
        raise ValueError(model_type)

    model = model.to(DEVICE)
    model.load_state_dict(torch.load(fname, map_location=DEVICE))
    models[fname.stem] = model



In [18]:
# build a Pearson correlation that we can evaluate
# alongside the rest of our models
pearson = ShiftedPearsonCorrelation(int(0.01 * SAMPLE_RATE))
SHIFT = [0, 1]

# here's some gross logic for pretty printing, don't worry about this
columns = ["Segment", "Duration (s)", "Throughput (s' / s)"]
tabs = [3, 1, 0]
print(*[i + "\t" * j for i, j in zip(columns, tabs)])
print(*["-" * len(i) + "\t" * j for i, j in zip(columns, tabs)])

with torch.no_grad(), h5py.File("background.hdf5", "r") as f:
    # iterate through all the datasets in our archive
    for segment, dataset in f.items():
        start_time = time.time()

        # initialize a container for our predictions
        # and create an initial blank snapshot state.
        # The most efficient way to do this would be
        # to allocate all the memory up front since
        # we know how many steps to take, but this
        # will work for the time being.
        predictions = defaultdict(list)
        state = batcher.get_initial_state().to(DEVICE)
        num_preds = 0
        for x in data_iterator(dataset, SHIFT):
            # move the timeseries onto the GPU, then update our
            # state and perform preprocessing
            x = x.to(DEVICE)
            X, state = batcher(x, state)

            # feed the preprocessed data through each one of our models
            for name, model in models.items():
                y = model(X)
                loss = (y - X)**2
                loss = loss.mean(dim=[1, 2])
                predictions[model].append(loss.cpu().numpy())

            # compute the pearson correlation between
            # both interferometer channels
            corr = pearson(X[:, :1], X[:, 1:])
            corr = corr.max(dim=0).values[:, 0].cpu().numpy()
            predictions["pearson"].append(corr)

            num_preds += len(X)

        # concatenate everything back on the CPU then click our stopwatch
        predictions = {k: np.concatenate(v) for k, v in predictions.items()}
        end_time = time.time()
        duration = end_time - start_time
        throughput = num_preds * STRIDE / duration

        # more pretty printing, continue not worrying about it
        duration, throughput = [f"{i:0.2f}" for i in [duration, throughput]]
        tabs = [1, 2, 0]
        print(*[i + "\t" * j for i, j in zip([segment, duration, throughput], tabs)])

Segment	Duration	Throughput
-------	--------	----------
1243668368_1243671968	3.25	1083.09
1243891850_1243895450	3.19	1105.68
1244038423_1244042023	3.20	1102.42
1244094823_1244098423	3.23	1092.99
1244153226_1244156826	3.23	1092.62
1244160426_1244164026	3.23	1092.97
1244281375_1244284975	3.25	1085.78
1244531604_1244535204	3.31	1065.92
1244619591_1244623191	3.23	1091.32
1244726076_1244729676	3.38	1043.40
1244850228_1244853828	3.32	1061.06
1244895499_1244899099	3.24	1086.40
1244994812_1244998412	3.30	1068.22
1245080336_1245083936	3.20	1100.15
1245130314_1245133914	3.22	1095.05
1245137514_1245141114	3.62	975.00
1245480475_1245484075	3.27	1078.27
1245889008_1245892608	3.31	1066.00
1246014028_1246017628	3.21	1098.66
1246160199_1246163799	3.22	1095.12
1246338884_1246342484	3.17	1111.92
1246510744_1246514344	3.19	1105.39
1246576222_1246579822	3.18	1107.28
1246637934_1246641534	3.22	1094.81
1247366989_1247370589	3.25	1084.13
1247407441_1247411041	3.23	1092.94
1247442440_1247446040	3.20	1102.12
