In [4]:
from gwpy.timeseries import TimeSeries
from PIL import Image
import numpy as np
import onnx
import onnxruntime as ort
import torch
from torchvision.transforms import v2

In [5]:
# This part uses the example timeseries from s11.2--LS220_0.1kpc_sim100_SNR29.21.txt
# to produce the strain
import os
from glob import glob

# Initialize an empty list to store the numbers
strain = []
double_column = True
# Open the file in read mode
with open('s11.2--LS220_0.1kpc_sim100_SNR29.21.txt', "r") as file:
    # Iterate through each line in the file
    for line in file:
        # Convert each line to a float and append it to the list
        try:
            if double_column:
                parts = line.strip().split()
                number = float(parts[1]); # use only the second column for strain
            else:
                number = float(line.strip())
            strain.append(number)
        except ValueError:
            # Handle any invalid values or exceptions
            print(f"Skipping invalid value: {line.strip()}")


In [6]:
def get_qgram(strain):
    """
    Compute the Q-transform of the input strain and return the computed spectrogram as an image.
    Alternatively you could also save the image locally as a .png

    Parameters:
    - strain: a raw Python list contianing the strain timeseries

    Returns:
    - An image of type PIL.Image.Image
    """
    
    # build the gwpy timeseries, change to your own sampling rate
    hdata = TimeSeries(strain, sample_rate=2**14)
    # band pass filtering the signal
    hdata_filtered = hdata.bandpass(100,2000) 
    # q-transform
    hq = hdata_filtered.q_transform(qrange=(100,200), frange=(0, 1600),
                                    tres=(2-0)/128, fres=(1600-0)/128, norm='Median', whiten=False)  
    # flip the array up and down so the frequency axis goes up
    hq_array = np.flipud(np.transpose(hq.value))
    # convert to uint8
    hq_rescale = 255 * (hq_array - np.min(hq_array))/(np.max(hq_array) - np.min(hq_array)) # rescale to 0 and 255
    # convert to image
    hq_image = Image.fromarray(hq_rescale.astype(np.uint8))
    
    # could save the image locally here if wish
    
    return hq_image

In [7]:
def prediction_to_category(prediction_array, threshold=0.4):
    """
    Converts the prediction array to categories 'noise' or 'signal' using a threshold.

    Parameters:
    - prediction_array (np.ndarray): The array of probabilities, they sum up to 1
    - threshold (float): The threshold to classify as 'signal'. Default is 0.4.

    Returns:
    - List of categories: 'noise' or 'signal' for each prediction.
    """
    
    is_signal = prediction_array >= threshold

    categories = np.where(is_signal, 'signal', 'noise')
    
    return categories


In [8]:
def preprocess_image(image: Image.Image):
    """
    Preprocesses a PIL image for testing according to the specified transformations.

    Parameters:
    - image (PIL.Image.Image): Input image to be processed.

    Returns:
    - np.ndarray: Processed image ready for ONNX inference.
    """
    # Define the transformation pipeline for testing
    transform = v2.Compose([
        v2.Resize((224, 224)),  # Resize image to 224x224
        v2.ToImage(),
        v2.ToDtype(torch.float32, scale=True),  # Convert to tensor, scales between [0, 1]
        v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize with ResNet18 parameters
    ])

    # Apply transformation to the image
    image_tensor = transform(image)

    # Convert torch.Tensor to numpy array with the correct shape for ONNX (NCHW format)
    image_np = image_tensor.unsqueeze(0).numpy()  # Add batch dimension (1, 3, 224, 224)

    return image_np

In [9]:
# load the onnx model
onnx_model = onnx.load("CCSNe_QNet_24Jan.onnx")
#onnx.checker.check_model(onnx_model)

# create inference session
session = ort.InferenceSession("CCSNe_QNet_24Jan.onnx")

# get q-transform output
qgram = get_qgram(strain)
# pre-process the image
input_img = preprocess_image(qgram)

# CNN prediction
input_name = session.get_inputs()[0].name
result = session.run(None, {input_name: input_img})

# Output the prediction
category = 'signal' if result[0][0][1] > 0.4 else 'noise'

print("Prediction Probability for 'signal and 'noise':", result[0][0])
print("Model Prediction:", category)

Prediction Probability for 'signal and 'noise': [0.00414355 0.99585646]
Model Prediction: signal
