In [9]:
import numpy as np
import onnxruntime as rt
import polars as pl

In [48]:
columns = [f"ps1_mag_{b}" for b in "grizy"]
n_rows = 1_000
df = pl.read_parquet("data/ps1_stars.parquet", columns=columns, n_rows=n_rows)

In [None]:
session = rt.InferenceSession("models/phot-transformation/DES_r-PS1_g--r--i--z--y.onnx", providers=rt.get_available_providers())

def model(x):
    return session.run([session.get_outputs()[0].name], {session.get_inputs()[0].name: x})[0].squeeze()

In [53]:
def cov(i, j):
    n_samples = 1_000_000
    delta_scale = 1e-3

    rng = np.random.default_rng(0)
    deltas_i = rng.normal(loc=0, scale=delta_scale, size=(n_samples, len(columns)))
    # should j be the same as i?
    # deltas_j = rng.normal(loc=0, scale=delta_scale, size=(n_samples, len(columns)))
    deltas_j = deltas_i

    X_i = df[i].to_numpy() + deltas_i
    X_j = df[j].to_numpy() + deltas_j

    y_i = model(X_i.astype(np.float32))
    y_j = model(X_j.astype(np.float32))

    return np.mean((y_i - np.mean(y_i)) * (y_j - np.mean(y_j))) / (np.std(y_i) * np.std(y_j))

m = np.zeros((10, 10))
for i in range(m.shape[0]):
    for j in range(m.shape[1]):
        m[i, j] = cov(i, j)

m

array([[0.99999988, 0.99140847, 0.99009228, 0.99774969, 0.9921217 ,
        0.99108124, 0.9992404 , 0.99072438, 0.987037  , 0.99329388],
       [0.99140847, 1.        , 0.99981236, 0.98922229, 0.99826777,
        0.99794787, 0.99416304, 0.99741453, 0.99585289, 0.99953651],
       [0.99009228, 0.99981236, 1.        , 0.98775053, 0.99867648,
        0.99824804, 0.99316853, 0.9965958 , 0.99662626, 0.99920946],
       [0.99774969, 0.98922229, 0.98775053, 1.00000012, 0.98977983,
        0.98626614, 0.9957372 , 0.99227834, 0.98336244, 0.99239725],
       [0.9921217 , 0.99826777, 0.99867648, 0.98977983, 0.99999994,
        0.99911493, 0.99492407, 0.99430639, 0.99877745, 0.99776417],
       [0.99108124, 0.99794787, 0.99824804, 0.98626614, 0.99911493,
        1.00000012, 0.99458432, 0.99251068, 0.99864388, 0.99649894],
       [0.9992404 , 0.99416304, 0.99316853, 0.9957372 , 0.99492407,
        0.99458432, 1.00000012, 0.99130762, 0.99149454, 0.99510759],
       [0.99072438, 0.99741453, 0.9965958