In [None]:
import os

import hydra
import rootutils
import torch
from hydra import compose, initialize
from hydra.core.global_hydra import GlobalHydra
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig, open_dict

os.environ["IS_PROD"] = "1"
rootutils.setup_root("../", indicator=".project-root", pythonpath=True)

In [None]:
from datamodule import DataModule
from model import Model

In [None]:
ckpt_path = "/u2/l2hebert/MDT-2/src/logs/train/multiruns/2025-04-13_22-58-38/1/hpc_ckpt_1.ckpt"


def make_hydra_config(overrides: list[str]):
    with initialize(config_path="../configs/"):
        cfg = compose(
            config_name="train.yaml",
            overrides=overrides,
            return_hydra_config=True,
        )
    return cfg


cfg = make_hydra_config(
    [
        "experiment=giga_pretrain_siglip",
        "model.encoder.graph_stack_factory.num_layers=3",
        "model.encoder.graph_stack_factory.graph_tfmr_factory.use_rope=True",
        "model.encoder.graph_stack_factory.graph_tfmr_factory.differential_attention=True",
        "model.encoder.graph_stack_factory.graph_tfmr_factory.n_heads=16",
        "model.encoder.graph_stack_factory.graph_tfmr_factory.n_kv_heads=16",
        "model.encoder.graph_stack_factory.graph_tfmr_factory.head_dim=128",
        "paths.root_dir='../'",
        "dataset.test_batch_size=32",
        "dataset.train_batch_size=1",
        "dataset.group_size=1",
    ]
)
HydraConfig().set_config(cfg)

In [None]:
encoder: Model = hydra.utils.instantiate(cfg.model.encoder)
loss = hydra.utils.instantiate(cfg.model.loss)

model = Model.load_from_checkpoint(ckpt_path, encoder=encoder, loss=loss)

In [None]:
dataset = hydra.utils.instantiate(cfg.dataset)

In [None]:
dataset.prepare_data()

In [None]:
dataset.setup("inference")

In [None]:
dataset._val_sampler

In [None]:
test_loader = dataset.val_dataloader()

In [None]:
from tqdm import tqdm

In [None]:
def send_to_device(batch, device):
    for k, v in batch.items():
        if isinstance(v, torch.Tensor):
            batch[k] = v.to(device, non_blocking=True)
        if isinstance(v, dict):
            batch[k] = send_to_device(v, device)
    return batch


embeddings = []
labels = []
with torch.no_grad():
    for batch in tqdm(test_loader):
        x, y = batch["x"], batch["y"]
        x = send_to_device(x, "cuda:0")
        y = send_to_device(y, "cuda:0")
        _, graph_embeddings = model.forward(x)
        embeddings.append(graph_embeddings.cpu().numpy())
        labels.append(y["ys"].cpu().numpy())

In [None]:
import numpy as np

output_embeddings = np.concat(embeddings, axis=0)
output_labels = np.concatenate(labels, axis=0)

output_embeddings_mini = output_embeddings[output_labels != 1]
output_labels_mini = output_labels[output_labels != 1]

In [None]:
output_embeddings.shape

In [None]:
output_embeddings

In [None]:
output_labels

In [None]:
output_embeddings.shape

In [None]:
import numba

In [None]:
@numba.njit()
def sigmoid(x: np.ndarray, y: np.ndarray):
    def apply_sig(z):
        return 1 / (1 + np.exp(-z))

    def apply_norm(z):
        return z / np.linalg.norm(z, ord=2)

    x = apply_norm(x)
    y = apply_norm(y)
    return apply_sig(x.dot(y.T))

In [None]:
import umap

mapper = umap.UMAP(metric=sigmoid, n_neighbors=100).fit(output_embeddings)

In [None]:
output_embeddings = (
    output_embeddings
    / np.linalg.norm(output_embeddings, ord=2, axis=1)[:, None]
)

In [None]:
output_embeddings

In [None]:
sim = np.dot(output_embeddings, output_embeddings.T)

In [None]:
sim_scaled = sim * 5.12 - 2.02

In [None]:
np.fill_diagonal(sim, 0)

In [None]:
sim_scaled = 1 / (1 + np.exp(-sim_scaled))
np.fill_diagonal(sim_scaled, 0)
np.fill_diagonal(sim, 0)

In [None]:
output_embeddings.shape

In [None]:
top_10 = np.argsort(sim, axis=-1)[:, :10]

In [None]:
labels_flat = np.concatenate(labels, axis=0)

In [None]:
labels_flat

In [None]:
labels_flat[3]

In [None]:
labels_flat[top_10][3]

In [None]:
import pandas as pd

top_10_df = pd.DataFrame.from_dict(
    {
        "labels": labels_flat,
        "top_10": labels_flat[top_10],
    },
    orient="index",
).T

In [None]:
res = top_10_df["labels"].isin(top_10_df["top_10"])

In [None]:
labels_flat[top_10].shape

In [None]:
labels_flat.shape

In [None]:
labels_flat[3]

In [None]:
labels_flat[top_10][3]

In [None]:
vals = {
    "labels": labels_flat,
    "preds": labels_flat[sim.argmax(axis=1)],
    "sim_scaled": sim_scaled.max(axis=1),
    "sim_native": sim[np.arange(0, len(sim)), sim_scaled.argmax(axis=1)],
}
df_vals = pd.DataFrame(vals)

In [None]:
df_vals

In [None]:
sim_scaled.argmax(axis=1)

In [None]:
sim.shape

In [None]:
sim[np.arange(0, len(sim)), sim_scaled.argmax(axis=1)]

In [None]:
df_vals

In [None]:
(df_vals["labels"] == df_vals["preds"]).value_counts()

In [None]:
df_vals[df_vals["labels"] == df_vals["preds"]]["labels"].value_counts()

In [None]:
counts = (df_vals["labels"] == df_vals["preds"]).value_counts()

In [None]:
counts[True] / (counts[True] + counts[False]) * 100

In [None]:
df_vals["match"] = df_vals["labels"] == df_vals["preds"]

In [None]:
df_vals

In [None]:
np.unique(np.concatenate(labels, axis=0), return_counts=True)

In [None]:
import umap.plot

umap.plot.points(mapper, labels=output_labels)