# Verification mixed dataset

### This notebook thest the mixed-dataset dataloader to see if it gives corresponding patches

In [15]:
# imports

import torch
import zarr
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import sys
project_root = Path().resolve().parent
sys.path.append(str(project_root))

import zarr  # Optional, only used to debug if needed
from supertrab.sr_dataset_utils import create_triplet_dataloader

In [4]:
def test_patch_alignment_grid(zarr_path, group_name, num_samples=50, seed=42):
    """
    Saves a single PNG showing a grid of QCT, HR-pQCT, and LR triplets side-by-side.
    """
    zarr_path = Path(zarr_path)
    z = zarr.open(str(zarr_path), mode="r")[group_name]

    qct = z["qct"]
    hrpqct = z["hrpqct"]
    lr = z["lr"]

    num_patches = len(qct)
    assert len(hrpqct) == num_patches and len(lr) == num_patches, "Patch count mismatch!"

    np.random.seed(seed)
    indices = np.random.choice(range(num_patches), size=min(num_samples, num_patches), replace=False)

    save_dir = Path("patch_outputs")
    save_dir.mkdir(parents=True, exist_ok=True)
    save_path = save_dir / f"{group_name}_grid.png"

    fig, axes = plt.subplots(nrows=num_samples, ncols=3, figsize=(10, 3 * num_samples))

    for row_idx, idx in enumerate(indices):
        patch_qct = torch.tensor(qct[idx]).squeeze(0)
        patch_hr  = torch.tensor(hrpqct[idx]).squeeze(0)
        patch_lr  = torch.tensor(lr[idx]).squeeze(0)

        for col_idx, (img, title) in enumerate(zip(
            [patch_qct, patch_hr, patch_lr], 
            ["QCT", "HR-pQCT", "LR"]
        )):
            ax = axes[row_idx, col_idx] if num_samples > 1 else axes[col_idx]
            ax.imshow(img, cmap="gray")
            if row_idx == 0:
                ax.set_title(title, fontsize=12)
            ax.axis("off")

    plt.tight_layout()
    plt.savefig(save_path, dpi=150)
    plt.close(fig)
    print(f"Saved grid image to: {save_path}")


In [5]:
zarr_path = "/usr/terminus/data-xrm-01/stamplab/external/tacosound/HR-pQCT_II/zarr_data/paired_patch_dataset.zarr"
group_name = "1996_R"
test_patch_alignment_grid(zarr_path, group_name, num_samples=50)

Saved grid image to: patch_outputs/1996_R_grid.png


# Test correspondence when loading from dataloder - several groups

In [28]:
def sample_random_triplets_from_dataloader(zarr_path, num_samples=10, batch_size=1, seed=42, save_path="patch_outputs/mixed_groups_grid.png"):
    """
    Sample random triplets (QCT, HR, LR) from a multi-group dataset and save a grid PNG.
    """
    np.random.seed(seed)
    torch.manual_seed(seed)

    # Load dataset & dataloader
    conditioning_mode = "mix"  #"qct"  or "mix"
    groups = ["1955_L", "1956_L", "1996_R", "2005_L"]
    dataloader = create_triplet_dataloader(zarr_path, groups, conditioning_mode=conditioning_mode, patch_size=(1, 256, 256), batch_size=4)


    save_path = Path(save_path)
    save_path.parent.mkdir(parents=True, exist_ok=True)

    collected = []
    for batch in dataloader:
        for i in range(batch_size):
            collected.append({
                "qct": batch["qct"][i].squeeze(0),
                "hr": batch["hr_image"][i].squeeze(0),
                "lr": batch["lr"][i].squeeze(0),
                "group": batch["group"][i],
                "index": batch["index"][i].item(),
            })
            if len(collected) >= num_samples:
                break
        if len(collected) >= num_samples:
            break

    # Plot grid
    fig, axes = plt.subplots(nrows=num_samples, ncols=3, figsize=(10, 3 * num_samples))
    group_str = []
    index_str = []

    for row_idx, sample in enumerate(collected):
        for col_idx, (img, title) in enumerate(zip(
            [sample["qct"], sample["hr"], sample["lr"]],
            ["QCT", "HR-pQCT", "LR"]
        )):
            ax = axes[row_idx, col_idx] if num_samples > 1 else axes[col_idx]
            ax.imshow(img, cmap="gray")
            if row_idx == 0:
                ax.set_title(title, fontsize=12)
            ax.axis("off")

        group_str.append(sample["group"])
        index_str.append(sample["index"])
        # axes[row_idx, 0].set_ylabel(f"{group_str}\n#{index_str}", fontsize=10)

    plt.tight_layout()
    # plt.show()
    plt.savefig(save_path, dpi=150)
    plt.close(fig)
    print(f"✅ Saved multi-group patch grid to: {save_path}")
    print(group_str)
    print(index_str)


In [25]:
zarr_path = "/usr/terminus/data-xrm-01/stamplab/external/tacosound/HR-pQCT_II/zarr_data/paired_patch_dataset.zarr"

sample_random_triplets_from_dataloader(
    zarr_path=zarr_path,
    num_samples=50,  # number of rows in the image
    seed=52,
    batch_size=2,
    save_path="patch_outputs/mixed_groups_grid_2.png"
)


✅ Saved multi-group patch grid to: patch_outputs/mixed_groups_grid_2.png
['1996_R', '1956_L', '2005_L', '1996_R', '1996_R', '1955_L', '1956_L', '2005_L', '1955_L', '1996_R', '1996_R', '1955_L', '1996_R', '1955_L', '1955_L', '1955_L', '1955_L', '2005_L', '1996_R', '1955_L', '1996_R', '1996_R', '2005_L', '2005_L', '1955_L', '2005_L', '1955_L', '1955_L', '1955_L', '1955_L', '2005_L', '1996_R', '2005_L', '1956_L', '1955_L', '2005_L', '1955_L', '2005_L', '1956_L', '2005_L', '2005_L', '1955_L', '1955_L', '1956_L', '1955_L', '1956_L', '1955_L', '2005_L', '1955_L', '1955_L']
[5679, 20799, 7917, 16720, 8247, 20125, 12008, 3772, 599, 13185, 11319, 50859, 15747, 33156, 45750, 19569, 49665, 22093, 9227, 37733, 3955, 14298, 7848, 9682, 15450, 15874, 34806, 4250, 44540, 6176, 12743, 3729, 8510, 11317, 22661, 13615, 1263, 11686, 16176, 15882, 8034, 7624, 3848, 21529, 39475, 17881, 12431, 19194, 11234, 27041]


# Generate as in training loop

In [31]:

def plot_hr_conditioning_pairs_from_dataloader(
    dataloader,
    conditioning_mode="mix",
    num_pairs=50,
    save_path="patch_outputs/hr_conditioning_grid.png",
    seed=50
):
    """
    Sample HR and conditioning image pairs from a dataloader and save as a grid image.
    No air filtering — assumes it was already applied during dataset creation.
    """
    def scale(x):
        x_min = x.amin(dim=(-2, -1), keepdim=True)
        x_max = x.amax(dim=(-2, -1), keepdim=True)
        return (x - x_min) / (x_max - x_min + 1e-8)

    np.random.seed(seed)
    torch.manual_seed(seed)

    collected = []
    for batch in dataloader:
        clean_images = batch["hr_image"]

        if conditioning_mode == "qct":
            conditioning = batch["qct"]
        elif conditioning_mode == "lr":
            conditioning = batch["lr"]
        elif conditioning_mode == "mix":
            rand_mask = torch.rand(clean_images.size(0), device=clean_images.device) < 0.5
            conditioning = torch.where(
                rand_mask[:, None, None, None],
                batch["qct"],
                batch["lr"]
            )
        else:
            raise ValueError(f"Unsupported conditioning_mode: {conditioning_mode}")

        # Normalize for visualization
        conditioning = scale(conditioning)
        clean_images = scale(clean_images)

        for i in range(clean_images.shape[0]):
            collected.append({
                "hr": clean_images[i].squeeze(0).cpu(),
                "cond": conditioning[i].squeeze(0).cpu()
            })
            if len(collected) >= num_pairs:
                break
        if len(collected) >= num_pairs:
            break

    # Plot grid
    save_path = Path(save_path)
    save_path.parent.mkdir(parents=True, exist_ok=True)

    fig, axes = plt.subplots(nrows=num_pairs, ncols=2, figsize=(6, 2.5 * num_pairs))
    for i, pair in enumerate(collected):
        axes[i, 0].imshow(pair["hr"], cmap="gray")
        axes[i, 0].set_title("HR-pQCT")
        axes[i, 0].axis("off")

        axes[i, 1].imshow(pair["cond"], cmap="gray")
        axes[i, 1].set_title("Conditioning")
        axes[i, 1].axis("off")

    plt.tight_layout()
    plt.savefig(save_path, dpi=150)
    plt.close(fig)
    print(f"✅ Saved HR-conditioning pair grid to: {save_path}")

In [33]:
zarr_path = "/usr/terminus/data-xrm-01/stamplab/external/tacosound/HR-pQCT_II/zarr_data/paired_patch_dataset.zarr"
group_names = ["1955_L", "1956_L", "1996_R", "2005_L"]
conditioning_mode = "mix"

train_dataloader = create_triplet_dataloader(
    zarr_path=zarr_path,
    group_names=group_names,
    conditioning_mode=conditioning_mode,
    patch_size=(1, 256, 256),
    batch_size=8
)

plot_hr_conditioning_pairs_from_dataloader(
    dataloader=train_dataloader,
    conditioning_mode=conditioning_mode,
    num_pairs=30,
    save_path="patch_outputs/hr_conditioning_grid.png"
)


✅ Saved HR-conditioning pair grid to: patch_outputs/hr_conditioning_grid.png


# use dataloaders mix function

In [42]:

def plot_hr_conditioning_pairs_from_dataloader_dataloadermix(
    dataloader,
    num_pairs=50,
    save_path="patch_outputs/hr_conditioning_grid.png",
):
    """
    Sample HR and conditioning image pairs from a dataloader and save as a grid image.
    No air filtering — assumes it was already applied during dataset creation.
    """

    collected = []
    for batch in dataloader:
        clean_images = batch["hr_image"]
        conditioning = batch["conditioning"]

        for i in range(clean_images.shape[0]):
            cond_source = "Unknown"
            if conditioning_mode == "mix":
                if torch.allclose(conditioning[i], batch["qct"][i]):
                    cond_source = "QCT"
                elif torch.allclose(conditioning[i], batch["lr"][i]):
                    cond_source = "LR"
                print(f"[{batch['group'][i]} #{batch['index'][i]}] Used {cond_source} as conditioning")

            collected.append({
                "hr": clean_images[i].squeeze(0).cpu(),
                "cond": conditioning[i].squeeze(0).cpu(),
            })

            if len(collected) >= num_pairs:
                break
        if len(collected) >= num_pairs:
            break

    # Plot grid
    save_path = Path(save_path)
    save_path.parent.mkdir(parents=True, exist_ok=True)

    fig, axes = plt.subplots(nrows=num_pairs, ncols=2, figsize=(6, 2.5 * num_pairs))
    for i, pair in enumerate(collected):
        axes[i, 0].imshow(pair["hr"], cmap="gray")
        axes[i, 0].set_title("HR-pQCT")
        axes[i, 0].axis("off")

        axes[i, 1].imshow(pair["cond"], cmap="gray")
        axes[i, 1].set_title("Conditioning")
        axes[i, 1].axis("off")

    plt.tight_layout()
    plt.savefig(save_path, dpi=150)
    plt.close(fig)
    print(f"✅ Saved HR-conditioning pair grid to: {save_path}")

In [43]:
zarr_path = "/usr/terminus/data-xrm-01/stamplab/external/tacosound/HR-pQCT_II/zarr_data/paired_patch_dataset.zarr"
group_names = ["1955_L", "1956_L", "1996_R", "2005_L"]
conditioning_mode = "mix"

train_dataloader = create_triplet_dataloader(
    zarr_path=zarr_path,
    group_names=group_names,
    conditioning_mode=conditioning_mode,
    patch_size=(1, 256, 256),
    batch_size=8, 
    num_workers=0
)

plot_hr_conditioning_pairs_from_dataloader_dataloadermix(
    dataloader=train_dataloader,
    num_pairs=30,
    save_path="patch_outputs/hr_conditioning_grid_dataloader_mix.png"
)


[1956_L #21320] Used QCT as conditioning
[1956_L #10277] Used QCT as conditioning
[1996_R #5627] Used LR as conditioning
[1996_R #3405] Used QCT as conditioning
[1996_R #16110] Used QCT as conditioning
[1955_L #27451] Used QCT as conditioning
[1955_L #13900] Used QCT as conditioning
[1955_L #42294] Used QCT as conditioning
[1955_L #30380] Used LR as conditioning
[1996_R #15155] Used QCT as conditioning
[1955_L #9807] Used LR as conditioning
[1955_L #36390] Used QCT as conditioning
[1955_L #11780] Used LR as conditioning
[2005_L #9917] Used LR as conditioning
[1955_L #1660] Used LR as conditioning
[2005_L #11508] Used LR as conditioning
[1955_L #40883] Used LR as conditioning
[1955_L #15650] Used LR as conditioning
[1955_L #20628] Used QCT as conditioning
[1996_R #3565] Used LR as conditioning
[1996_R #14847] Used LR as conditioning
[2005_L #21558] Used LR as conditioning
[1955_L #5368] Used QCT as conditioning
[1956_L #18716] Used QCT as conditioning
[1996_R #16981] Used LR as conditio