In [None]:
import sys
sys.path.append("../../")
from src.models.components.partmae_v6 import PARTMaskedAutoEncoderViT
from src.data.components.transforms.multi_crop_v4 import ParametrizedMultiCropV4
from lightning import Fabric
from tqdm import tqdm
from torch import nn, Tensor
import torch
from torch.utils.data import  default_collate
from PIL import Image
import matplotlib.pyplot as plt
import torch.utils._pytree as pytree
# from src.utils.visualization.reconstruction_v5_anchor_reparam import reconstruction_lstsq_with_anchor_reparam
from src.utils.visualization.reconstruction_v6 import reconstruction_lstsq_with_anchor_reparam
from src.utils.visualization.reconstruction_v5_gt import reconstruction_gt
from omegaconf import OmegaConf
import hydra
from pathlib import Path
import itertools
import math

## Utils

In [None]:
def clean_model_io(batch: tuple, out: dict, device="cuda"):
    """
    Clean and organize model inputs and outputs for visualization and analysis.
    
    Args:
        batch: A tuple containing model inputs (global images, global params, local images, local params)
        out: Model output dictionary
        device: Device to move tensors to (default: "cuda")
        
    Returns:
        io: Dictionary containing organized model inputs and outputs
    """
    # Initialize output dictionary
    io = dict()
    
    # Extract shapes from model output
    io["x"] = [list(itertools.chain.from_iterable(items)) for items in zip(*batch[0])]
    io["params"] = [list(itertools.chain.from_iterable(items)) for items in zip(*batch[1])]
    io["canonical_params"] = [[param[0:4] for param in batch_params] for batch_params in io["params"]][0]
    io["crop_params"] = [[param[4:8] for param in batch_params] for batch_params in io["params"]]
    
    # Include all output values
    io.update({name: out[name] for name in out.keys()})
    
    # Move all tensors to the specified device
    io = pytree.tree_map_only(
        Tensor,
        lambda t: t.detach().to(device),
        io
    )
    return io


def make_plots(
    model,
    io,
    train_transform,
    original_img,
):
    
    gt_reconstruction = reconstruction_gt(
        x=io["x"][0],
        patch_positions_nopos=io["patch_positions_nopos"][0],
        num_tokens=model._Ms,
        crop_params=io["crop_params"][0],
        patch_size=model.patch_size,
        canonical_img_size=model.canonical_img_size,
    )
    pred_reconstruction, *_ = reconstruction_lstsq_with_anchor_reparam(
        x=io["x"][0],
        patch_positions_nopos=io["patch_positions_nopos"][0],
        num_tokens=model._Ms,
        crop_params=io["crop_params"][0],
        patch_size=model.patch_size,
        canonical_img_size=model.canonical_img_size,
        max_scale_ratio=model.max_scale_ratio,
        pred_dT=io["pred_dT"][0],
    )
    fig, axes = plt.subplots(1, 3)
    canonical_img = train_transform.recreate_canonical(
        original_img, io["canonical_params"][0]
    )
    axes[0].imshow(canonical_img)
    axes[0].set_title("Original")
    axes[0].axis("off")
    axes[1].imshow(gt_reconstruction.permute(1, 2, 0).cpu())
    axes[1].set_title("GT Reconstruction")
    axes[1].axis("off")
    axes[2].imshow(pred_reconstruction.permute(1, 2, 0).cpu())
    axes[2].set_title("Reconstruction")
    axes[2].axis("off")
    return fig, axes

## Reconstruction

In [None]:
# overfit to a few batches
torch.set_float32_matmul_precision("high")

In [None]:
ROOT = Path("../../")
if not OmegaConf.has_resolver("eval"):
    OmegaConf.register_new_resolver("eval", eval)

In [None]:
# FOLDER = ROOT / Path("outputs/2025-06-16/13-23-31")
FOLDER = ROOT / Path("outputs/2025-06-22/19-16-53")
cfg = OmegaConf.load(FOLDER / ".hydra/config.yaml")

In [None]:
cfg = OmegaConf.load(FOLDER / ".hydra/config.yaml")
if "predict_uncertainty" in cfg["model"]:
    predict_uncertainty = cfg["model"].pop("predict_uncertainty")
    if predict_uncertainty:
        cfg["model"]["uncertainty_mode"] = "additive"
    else:
        cfg["model"]["uncertainty_mode"] = "none"
elif "uncertainty_mode" in cfg["model"]:
    pass
else:
    raise ValueError("Uncertainty mode not specified in the config.")
print(cfg["model"]["uncertainty_mode"])
ckpt_path = FOLDER / "epoch_0199.ckpt"
ckpt = torch.load(ckpt_path, map_location="cuda")
state_dict = ckpt["model"]
if "pose_head.mu.weight" in state_dict:
    state_dict["pose_head.mu_proj.weight"] = state_dict.pop("pose_head.mu.weight")
if "pose_head.logvar.weight" in state_dict:
    state_dict["pose_head.disp_proj.weight"] = state_dict.pop("pose_head.logvar.weight")
if "pose_head.logvar.bias" in state_dict:
    state_dict["pose_head.disp_proj.bias"] = state_dict.pop("pose_head.logvar.bias")
if "pose_head.gate_proj.weight" in state_dict:
    if "gate_dim" not in cfg["model"]:
        cfg["model"]["gate_dim"] = state_dict["pose_head.gate_proj.weight"].shape[0]
        print(f"Gate dimension not specified in the config, inferring from state_dict: {cfg['model']['gate_dim']}")
    assert cfg["model"]["gate_dim"] == state_dict["pose_head.gate_proj.weight"].shape[0]
if "pose_head.gate_mult" not in state_dict:
    state_dict["pose_head.gate_mult"] = torch.zeros(1)

In [None]:
V = 2
gV = 2
lV = V - gV
if V == 12:
    model = hydra.utils.instantiate(
        cfg["model"],
        # gate_dim=cfg["model"].get("gate_dim", 16),
        _target_="src.models.components.partmae_v6.PARTMaskedAutoEncoderViT",
        num_views=V,
        # mask_ratio=0,
        mask_ratio=0.75,
        pos_mask_ratio=0.75,
        # sampler='ongrid_canonical'
    )
elif V == 2:
    model = hydra.utils.instantiate(
        cfg["model"],
        # gate_dim=cfg["model"].get("gate_dim", 16),
        _target_="src.models.components.partmae_v6.PARTMaskedAutoEncoderViT",
        num_views=V,
        mask_ratio=0,
        pos_mask_ratio=0.75,
        # sampler='ongrid_canonical'
    )
else:
    raise ValueError(f"Unsupported number of views: {V}")
model.load_state_dict(state_dict, strict=True)
print(ckpt["global_step"], ckpt["epoch"])

In [None]:
img = Image.open("../../artifacts/dog.jpg")
# .crop((0, 0, 1000, 1000))
train_transform = hydra.utils.instantiate(
    cfg["data"]["transform"], distort_color=False, n_local_crops=V - gV
)
batch = default_collate([train_transform(img), train_transform(img), train_transform(img), train_transform(img)])

In [None]:
pytree.tree_map_only(Tensor, lambda t: t.shape, batch)

In [None]:
with torch.no_grad():
    out = model(*batch)
io = clean_model_io(batch, out, 'cuda')
fig, axes = make_plots(
    model,
    io,
    train_transform,
    img,
)

In [None]:
from jaxtyping import Float
from torch import Tensor
import torch.nn.functional as F


def paste_patch(
    crop: Float[Tensor, "C h w"],
    pos: Float[Tensor, "2"],
    pos_canonical: Float[Tensor, "2"],
    patch_size_canonical: Float[Tensor, "2"],
    canvas: Float[Tensor, "C H W"],
    count_map: Float[Tensor, "1 H W"],
    patch_size: int,
    canonical_size: int,
    disp: Float[Tensor, "4"] = None,
):
    """
    Extract a patch from a crop at pos and paste it onto a canvas at pos_canonical with appropriate rescaling.

    Args:
        crop: Source image crop of shape [C, h, w]
        pos: Patch position in crop coordinates [y, x]
        pos_canonical: Target position in canonical coordinates [y, x]
        patch_size_canonical: Size of patch in canonical space [height, width]
        canvas: Target canvas to paste onto [C, H, W]
        count_map: Counter for averaging overlapping patches [1, H, W]
        patch_size: Size of patch in crop space
        canonical_size: Size of the canonical image
        disp: Per token dispersion (as in Laplace scale) for each transformation parameter.

        pos ~ Laplace(mu_yx, b_yx) 
    """
    crop_h, crop_w = crop.shape[1:3]

    # Convert to integer coordinates for the canonical position
    y_canonical, x_canonical = int(round(pos_canonical[0].item())), int(
        round(pos_canonical[1].item())
    )

    # Get integer patch size for the canonical space
    patch_h_canonical, patch_w_canonical = patch_size_canonical.round().int()

    # Ensure the patch fits within the canonical canvas
    y_canonical = max(0, min(canonical_size - patch_h_canonical, y_canonical))
    x_canonical = max(0, min(canonical_size - patch_w_canonical, x_canonical))

    # Get source patch coordinates, ensuring they're within the crop boundaries
    y_crop, x_crop = int(round(pos[0].item())), int(round(pos[1].item()))
    y_crop = max(0, min(crop_h - patch_size, y_crop))
    x_crop = max(0, min(crop_w - patch_size, x_crop))

    # Extract the patch from the source crop
    patch = crop[
        :, y_crop : y_crop + patch_size, x_crop : x_crop + patch_size
    ].unsqueeze(0)

    # Resize the patch to the canonical size
    patch_resized = F.interpolate(
        patch,
        size=(patch_h_canonical, patch_w_canonical),
        mode="bilinear",
        align_corners=False,
    ).squeeze(0)

    # Add the patch to the canvas and update the count map
    canvas[
        :,
        y_canonical : y_canonical + patch_h_canonical,
        x_canonical : x_canonical + patch_w_canonical,
    ] += patch_resized
    count_map[
        :,
        y_canonical : y_canonical + patch_h_canonical,
        x_canonical : x_canonical + patch_w_canonical,
    ] += 1


@torch.no_grad
def reconstruction_with_uncertainty_visualization(
    x: list[Float[Tensor, "C gH gW"] | Float[Tensor, "C lH lW"]],
    patch_positions_nopos: Float[Tensor, "M 2"],
    num_tokens: list[int],
    crop_params: list[Float[Tensor, "4"]],
    patch_size: int,
    canonical_img_size: int,
    max_scale_ratio: float,
    pred_dT: Float[Tensor, "M M 4"],
    disp_T: Float[Tensor, "M 4"],  # NOTE: This contains LOG-dispersions
    uncertainty_mode: Literal["none", "global_heatmap", "laplace_distributions"] = "none",
) -> tuple[
    Float[Tensor, "C canonical_img_size canonical_img_size"],  # reconstructed image
    Float[Tensor, "canonical_img_size canonical_img_size"] | None,  # uncertainty map
]:
    """
    Reconstruct image with optional uncertainty visualization.
    
    Args:
        disp_T: Per-token log-dispersions [M, 4] - NOTE: these are in log-space!
        uncertainty_mode: 
            - "none": No uncertainty visualization
            - "global_heatmap": Global uncertainty heatmap using dispersion norms
            - "laplace_distributions": Individual Laplace distribution heatmaps
    
    Returns:
        reconstructed_img: Reconstructed canonical image
        uncertainty_map: Uncertainty visualization (None if uncertainty_mode="none")
    """
    device = x[0].device
    C = x[0].shape[0]

    # Undo normalization
    dT = pred_dT[..., :2] * canonical_img_size
    dS = pred_dT[..., 2:] * math.log(max_scale_ratio)

    # Choose anchor
    T_anchor = (
        crop_params[0][:2]
        + (patch_positions_nopos[0] / x[0].shape[1]) * crop_params[0][2:4]
    )
    S_anchor = torch.log((patch_size * crop_params[0][2:4] / x[0].shape[1]))

    T_global = dT[:, 0] + T_anchor
    S_global = dS[:, 0] + S_anchor

    T_global_grouped = torch.split(T_global, num_tokens)
    S_global_grouped = torch.split(S_global, num_tokens)
    patch_positions_nopos_grouped = torch.split(patch_positions_nopos, num_tokens)
    disp_T_grouped = torch.split(disp_T, num_tokens)

    # Reconstruct the canonical image
    canvas = torch.zeros((C, canonical_img_size, canonical_img_size), device=device)
    count_map = torch.zeros((1, canonical_img_size, canonical_img_size), device=device)

    for crop, patch_positions, canonical_pos, log_size, disp in zip(
        x,
        patch_positions_nopos_grouped,
        T_global_grouped,
        S_global_grouped,
        disp_T_grouped,
    ):
        N = patch_positions.shape[0]
        for i in range(N):
            paste_patch(
                crop=crop,
                pos=patch_positions[i].float(),
                pos_canonical=canonical_pos[i],
                patch_size_canonical=torch.exp(log_size[i]),
                canvas=canvas,
                count_map=count_map,
                patch_size=patch_size,
                canonical_size=canonical_img_size,
                disp=disp[i]
            )

    count_map[count_map == 0] = 1
    reconstructed_img = canvas / count_map

    # Generate uncertainty visualization
    uncertainty_map = None
    if uncertainty_mode != "none":
        # Convert log-dispersions to actual dispersions and scale to pixel units
        # disp_T contains log-dispersions, so we need to exp() them first
        actual_dispersions = torch.exp(disp_T)  # Convert from log-space
        disp_T_pixels = actual_dispersions.clone()
        disp_T_pixels[:, :2] *= canonical_img_size  # dy, dx to pixels
        disp_T_pixels[:, 2:] *= math.log(max_scale_ratio)  # log-scale factors
        
        if uncertainty_mode == "global_heatmap":
            uncertainty_map = create_global_uncertainty_heatmap(
                T_global, disp_T_pixels, canonical_img_size, patch_size
            )
        elif uncertainty_mode == "laplace_distributions":
            uncertainty_map = create_laplace_distribution_heatmaps(
                T_global, disp_T_pixels, canonical_img_size, patch_size
            )

    return reconstructed_img, uncertainty_map


def plot_reconstruction_with_uncertainty(
    model,
    io,
    train_transform,
    original_img,
    uncertainty_mode: Literal["none", "global_heatmap", "laplace_distributions"] = "none",
):
    """
    Updated plotting function that supports uncertainty visualization.
    """
    # Generate GT reconstruction (unchanged)
    gt_reconstruction = reconstruction_gt(
        x=io["x"][0],
        patch_positions_nopos=io["patch_positions_nopos"][0],
        num_tokens=model._Ms,
        crop_params=io["crop_params"][0],
        patch_size=model.patch_size,
        canonical_img_size=model.canonical_img_size,
    )
    
    # Generate prediction with uncertainty
    pred_reconstruction, uncertainty_map = reconstruction_with_uncertainty_visualization(
        x=io["x"][0],
        patch_positions_nopos=io["patch_positions_nopos"][0],
        num_tokens=model._Ms,
        crop_params=io["crop_params"][0],
        patch_size=model.patch_size,
        canonical_img_size=model.canonical_img_size,
        max_scale_ratio=model.max_scale_ratio,
        pred_dT=io["pred_dT"][0],
        disp_T=io["disp_T"][0],
        uncertainty_mode=uncertainty_mode,
    )
    
    # Determine number of subplots
    n_plots = 4 if uncertainty_mode != "none" else 3
    fig, axes = plt.subplots(1, n_plots, figsize=(4*n_plots, 4))
    
    # Original image
    canonical_img = train_transform.recreate_canonical(
        original_img, io["canonical_params"][0]
    )
    axes[0].imshow(canonical_img)
    axes[0].set_title("Original")
    axes[0].axis("off")
    
    # GT reconstruction
    axes[1].imshow(gt_reconstruction.permute(1, 2, 0).cpu())
    axes[1].set_title("GT Reconstruction")
    axes[1].axis("off")
    
    # Predicted reconstruction
    axes[2].imshow(pred_reconstruction.permute(1, 2, 0).cpu())
    axes[2].set_title("Reconstruction")
    axes[2].axis("off")
    
    # Uncertainty visualization
    if uncertainty_mode != "none" and uncertainty_map is not None:
        im = axes[3].imshow(uncertainty_map.cpu(), cmap='hot', alpha=0.8)
        axes[3].set_title(f"Uncertainty ({uncertainty_mode})")
        axes[3].axis("off")
        plt.colorbar(im, ax=axes[3], fraction=0.046, pad=0.04)
    
    plt.tight_layout()
    return fig, axes

In [None]:
# print(io["disp_T"].shape) # [1, 294, 4] # [B, M, 4] # 4 for dy, dx, dlogh, dlogw
print(io["patch_positions_nopos"].shape) # [1, 294, 2] # y, x coordinates, where 
print(io["gt_dT"].shape) # [1, 294, 4] # ground truth displacement
# each disp_T is associated with a patch position, where 
print(io["patch_positions_nopos"].max())
print(io["params"])
print(io["crop_params"])
# the way to get the coordinates of each patch in the canonical image is to first obtain an anchor
# that is, let's pick patch 0, obtain its position in the canonical image and then add it to gt_dT

"""
dT = pred_dT[..., :2] * canonical_img_size
dS = pred_dT[..., 2:] * math.log(max_scale_ratio)


# Choose the anchor: first patch of the first global crop (global index 0).
T_anchor = (
    crop_params[0][:2]
    + (patch_positions_nopos[0] / x[0].shape[1]) * crop_params[0][2:4]
)
S_anchor = torch.log((patch_size * crop_params[0][2:4] / x[0].shape[1]))

T_global = dT[:, 0] + T_anchor
S_global = dS[:, 0] + S_anchor
"""

batch_idx = 0
anchor_idx = 0
crop_size = io["x"][batch_idx][0].shape[-1]

anchor_params = io["crop_params"][batch_idx][anchor_idx]
anchor_local_pos = io["patch_positions_nopos"][batch_idx][anchor_idx]
anchor_global_pos = anchor_params[:2] + (anchor_local_pos / crop_size) * anchor_params[2:4]
anchor_global_scale = torch.log((model.patch_size * anchor_params[2:4] / crop_size))
print("Anchor global position:", anchor_global_pos)
print("Anchor global scale:", anchor_global_scale)

# Calculate global positions and scales for all patches
gt_dt = io["gt_dT"][batch_idx, :, :, :2] * model.canonical_img_size
gt_ds = io["gt_dT"][batch_idx, :, :, 2:] * math.log(model.max_scale_ratio)
pred_dt = io["pred_dT"][batch_idx, :, :, :2] * model.canonical_img_size
pred_ds = io["pred_dT"][batch_idx, :, :, 2:] * math.log(model.max_scale_ratio)

print(io["gt_dT"].shape)  # [1, 294, 294, 4]
print(gt_dt.shape)  # [294, 2]
print(anchor_global_pos.shape)  # [2]
# gt_T_global and gt_S_global are the global positions and scales for all patches
gt_T_global = gt_dt[:, 0] + anchor_global_pos
gt_S_global = gt_ds[:, 0] + anchor_global_scale
# pred_T_global and pred_S_global are the predicted means of the global positions and scales for all patches
pred_T_global = pred_dt[:, 0] + anchor_global_pos
pred_S_global = pred_ds[:, 0] + anchor_global_scale

print("Global positions:", gt_T_global.shape)
print("Global scales:", gt_S_global.shape)

# cool
# now, let's pick an arbitrary patch
# and plot a laplacian distribution of the displacements given the mean (pred_dT) and the variance/lap. scale (disp_dT)

patch_idx = 0
print(io["disp_dT"].shape)  # [1, 294, 294, 4]
disp_patch_to_all = io["disp_dT"][batch_idx, patch_idx]
# we can only visualize the y and x dispersions
disp_patch_to_all = disp_patch_to_all[:, :2]  # [294, 2]

print(pred_T_global.shape)
mu_patch = pred_T_global[patch_idx]  # [2]

# now, we display the canonical image
canonical_img = train_transform.recreate_canonical(
    img, io["canonical_params"][0]
)
# and add a bounding box around the chosen patch_idx
# and the laplacian distribution given the mean and the variance  (i assume as a heatmap)
print(disp_patch_to_all.median())  # [294, 2]
print(disp_patch_to_all.quantile(0.75) - disp_patch_to_all.quantile(0.25))  # [2]

In [None]:
# let's plot the per-token dispersions.
print(io["disp_T"].shape) # (B, V * num_tokens_per_view, 4) # one uncertainty per dimension (y, x, dlogh, dlogw)


In [None]:
import torch
import math
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle


def visualize_patch_distribution(
    io,
    model,
    train_transform,
    img,
    batch_idx: int,
    anchor_idx: int,
    patch_idx: int,
    figsize=(6,6),
    disp_scale_pixels: bool = True,
):
    """
    Visualize the predicted 2D Laplacian distribution from anchor -> patch,
    drawing ground-truth bounding boxes (top-left) for both anchor and target patches,
    and displaying their exact positions in the legend.

    Args:
        io: dict with keys 'pred_dT','disp_dT','gt_dT','crop_params',
            'patch_positions_nopos','canonical_params','x'
        model: object with .canonical_img_size, .max_scale_ratio, .patch_size
        train_transform: has .recreate_canonical(img, canonical_params)
        img: input image for reconstruction
        batch_idx: index into batch
        anchor_idx: index of anchor patch/crop
        patch_idx:  index of target patch
        figsize:    matplotlib figure size
        disp_scale_pixels: scale predicted dispersion to pixel units
    """
    # sizes
    anchor_view_idx = model.view_ids_M[anchor_idx]
    anchor_crop_size = model.Is[anchor_view_idx]
    canon_size = model.canonical_img_size
    p_size = model.patch_size

    # retrieve crop params and local positions
    cp = io['crop_params'][batch_idx][anchor_view_idx]         # [y0,x0,h',w']
    lp = io['patch_positions_nopos'][batch_idx][anchor_idx]  # [y_loc, x_loc]

    # anchor base top-left in global coords
    anchor_tl = cp[:2] + (lp / anchor_crop_size) * cp[2:4]
    ay_tl, ax_tl = anchor_tl.tolist()

    # patch size in global coords
    box_h, box_w = (p_size * cp[2:4] / anchor_crop_size).tolist()

    # current patch TL
    # py_tl = tls_y[patch_idx]
    # px_tl = tls_x[patch_idx]
    py_tl = anchor_tl[0].item() - io['gt_dT'][batch_idx, anchor_idx, patch_idx, 0].item() * canon_size
    px_tl = anchor_tl[1].item() - io['gt_dT'][batch_idx, anchor_idx, patch_idx, 1].item() * canon_size

    # predicted dispersion for heatmap
    disp = io['disp_dT'][batch_idx, anchor_idx, :, :2]  # [N,2]
    b_norm_y, b_norm_x = disp[patch_idx].cpu().tolist()
    if disp_scale_pixels:
        b_y = b_norm_y * canon_size
        b_x = b_norm_x * canon_size
    else:
        b_y, b_x = b_norm_y, b_norm_x

    # predicted mean offset for this patch
    pred_off = io['pred_dT'][batch_idx, anchor_idx, patch_idx, :2] * canon_size
    mu_y = (anchor_tl[0] - pred_off[0]).item()
    mu_x = (anchor_tl[1] - pred_off[1]).item()

    # reconstruct canonical image
    canon = train_transform.recreate_canonical(img, io['canonical_params'][batch_idx])
    if torch.is_tensor(canon):
        canon = canon.permute(1,2,0).cpu().numpy()
    H, W = canon.size[:2]

    # build heatmap at predicted mean
    ys = np.arange(H)[:,None]
    xs = np.arange(W)[None,:]
    Z = (1.0/(4*b_y*b_x)) * np.exp(-np.abs(ys-mu_y)/b_y - np.abs(xs-mu_x)/b_x)
    # Z = Z / Z.max()

    # plot
    fig, ax = plt.subplots(figsize=figsize)
    ax.imshow(canon, interpolation='nearest')

    # draw GT box for anchor (blue) at its top-left, with position in label
    rect_a = Rectangle(
        (ax_tl, ay_tl), box_w, box_h,
        edgecolor='blue', lw=2, facecolor='none',
        label=f'Anchor GT (TL): ({ay_tl:.1f}, {ax_tl:.1f})'
    )
    ax.add_patch(rect_a)

    # draw GT box for target patch (red) at its top-left, with position
    rect_p = Rectangle(
        (px_tl, py_tl), box_w, box_h,
        edgecolor='red', lw=2, facecolor='none',
        label=f'Patch GT (TL): ({py_tl:.1f}, {px_tl:.1f})'
    )
    ax.add_patch(rect_p)

    # overlay predicted Laplace heatmap
    ax.imshow(Z, cmap='hot', alpha=0.5, extent=(0,W,H,0))

    # legend and title
    ax.legend(loc='upper right')
    ax.set_title(
        f'anchor={anchor_idx} → patch={patch_idx}   '\
        f'Pred mean=(%.1f,%.1f)' % (mu_y, mu_x)
    )
    ax.axis('off')
    plt.tight_layout()
    plt.close(fig)
    return fig

In [None]:
import torch
import math
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle


def visualize_patch_distribution_per_token(
    io,
    model,
    train_transform,
    img,
    batch_idx: int,
    figsize=(6,6),
    disp_scale_pixels: bool = True,
):
    #### 1. COMPUTE AN ANCHOR'S CANONICAL POSITION AND SIZE
    # sizes
    anchor_idx = 0 # this is just to make the predictions absolute
    anchor_view_idx = model.view_ids_M[anchor_idx]
    anchor_crop_size = model.Is[anchor_view_idx]
    canon_size = model.canonical_img_size
    p_size = model.patch_size

    # retrieve crop params and local positions
    cp = io['crop_params'][batch_idx][anchor_view_idx]         # [y0,x0,h',w']
    lp = io['patch_positions_nopos'][batch_idx][anchor_idx]  # [y_loc, x_loc]

    # anchor base top-left in global coords
    anchor_tl = cp[:2] + (lp / anchor_crop_size) * cp[2:4]
    ay_tl, ax_tl = anchor_tl.tolist()

    # patch size in global coords
    box_h, box_w = (p_size * cp[2:4] / anchor_crop_size).tolist()

    #### 2. COMPUTE EACH ANCHOR'S PATCH POSITION BY SUBTRACTING THE ANCHOR'S GT DISPLACEMENT

    gt_dt = io["gt_dT"][batch_idx, :, :, :2] * model.canonical_img_size
    pred_dt = io["pred_dT"][batch_idx, :, :, :2] * model.canonical_img_size
    pred_ds = io["pred_dT"][batch_idx, :, :, 2:] * math.log(model.max_scale_ratio)

    # gt_T_global and gt_S_global are the global positions and scales for all patches
    gt_T_global = gt_dt[:, 0] + anchor_global_pos
    gt_S_global = gt_ds[:, 0] + anchor_global_scale
    # pred_T_global and pred_S_global are the predicted means of the global positions and scales for all patches
    pred_T_global = pred_dt[:, 0] + anchor_global_pos
    pred_S_global = pred_ds[:, 0] + anchor_global_scale

    py_tl = anchor_tl[0].item() - io['gt_dT'][batch_idx, anchor_idx, patch_idx, 0] * canon_size
    px_tl = anchor_tl[1].item() - io['gt_dT'][batch_idx, anchor_idx, patch_idx, 1] * canon_size



    # predicted dispersion for heatmap
    # disp = io['disp_dT'][batch_idx, anchor_idx, :, :2]  # [N,2]
    # b_norm_y, b_norm_x = disp[patch_idx].cpu().tolist()
    # if disp_scale_pixels:
    #     b_y = b_norm_y * canon_size
    #     b_x = b_norm_x * canon_size
    # else:
    #     b_y, b_x = b_norm_y, b_norm_x

    # predicted mean offset for this patch
    pred_off = io['pred_dT'][batch_idx, anchor_idx, patch_idx, :2] * canon_size
    mu_y = (anchor_tl[0] - pred_off[0]).item()
    mu_x = (anchor_tl[1] - pred_off[1]).item()

    # reconstruct canonical image
    canon = train_transform.recreate_canonical(img, io['canonical_params'][batch_idx])
    if torch.is_tensor(canon):
        canon = canon.permute(1,2,0).cpu().numpy()
    H, W = canon.size[:2]

    # build heatmap at predicted mean
    ys = np.arange(H)[:,None]
    xs = np.arange(W)[None,:]
    Z = (1.0/(4*b_y*b_x)) * np.exp(-np.abs(ys-mu_y)/b_y - np.abs(xs-mu_x)/b_x)
    # Z = Z / Z.max()

    # plot
    fig, ax = plt.subplots(figsize=figsize)
    ax.imshow(canon, interpolation='nearest')

    # draw GT box for anchor (blue) at its top-left, with position in label
    rect_a = Rectangle(
        (ax_tl, ay_tl), box_w, box_h,
        edgecolor='blue', lw=2, facecolor='none',
        label=f'Anchor GT (TL): ({ay_tl:.1f}, {ax_tl:.1f})'
    )
    ax.add_patch(rect_a)

    # draw GT box for target patch (red) at its top-left, with position
    rect_p = Rectangle(
        (px_tl, py_tl), box_w, box_h,
        edgecolor='red', lw=2, facecolor='none',
        label=f'Patch GT (TL): ({py_tl:.1f}, {px_tl:.1f})'
    )
    ax.add_patch(rect_p)

    # legend and title
    ax.legend(loc='upper right')
    ax.set_title(
        f'anchor={anchor_idx} → patch={patch_idx}   '\
        f'Pred mean=(%.1f,%.1f)' % (mu_y, mu_x)
    )
    ax.axis('off')
    plt.tight_layout()
    plt.close(fig)
    return fig

In [None]:
visualize_patch_distribution(
    io,
    model,
    train_transform,
    img,
    batch_idx=0,
    anchor_idx=0,
    patch_idx=16,
    # patch_idx=16,
    figsize=(6, 6)
)

In [None]:
import torch
import numpy as np
import math
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from matplotlib import animation


# ------------------------------------------------------------------
# 2.  ANIMATION ACROSS (anchor_idx, patch_idx) PAIRS
# ------------------------------------------------------------------

def animate_all_pairs(
    io,
    model,
    train_transform,
    img,
    batch_idx: int = 0,
    anchors: list | None = None,
    patch_order: str = 'sequential',  # or 'random'
    fps: int = 2,
    disp_scale_pixels: bool = True,
    figsize=(6,6),
):
    """
    Create a matplotlib.animation that iterates over every (anchor, patch) pair.

    Args:
        anchors: list of anchor indices to iterate. If None, uses all patches.
        patch_order: 'sequential' or 'random' iteration of target patches.
        fps: frames per second for the resulting animation.

    Returns:
        anim (FuncAnimation) – you can save with anim.save('out.mp4', fps=fps)
    """
    if anchors is None:
        anchors = list(range(io['pred_dT'].shape[2]))

    # prepare list of (anchor_idx, patch_idx)
    pairs = []
    for a in anchors:
        patches = list(range(io['pred_dT'].shape[2]))
        if patch_order == 'random':
            np.random.shuffle(patches)
        for p in patches:
            pairs.append((a, p))

    # set up matplotlib figure once
    fig, ax = plt.subplots(figsize=figsize)
    plt.axis('off')

    def init():
        ax.clear()
        ax.axis('off')
        return []

    def update(frame_idx):
        a_idx, p_idx = pairs[frame_idx]
        ax.clear()
        ax.axis('off')
        # generate the frame using the static function
        frame_fig = visualize_patch_distribution(
            io, model, train_transform, img,
            batch_idx=batch_idx,
            anchor_idx=a_idx,
            patch_idx=p_idx,
            figsize=figsize,
            disp_scale_pixels=disp_scale_pixels,
        )
        # extract the Axes image from the returned fig & draw onto our ax
        ax.imshow(frame_fig.axes[0].images[0].get_array(), interpolation='nearest')
        ax.imshow(frame_fig.axes[0].images[1].get_array(), cmap='hot', alpha=0.5,)
        for child in frame_fig.axes[0].get_children():
            if isinstance(child, Rectangle):
                ax.add_patch(Rectangle(child.get_xy(), child.get_width(), child.get_height(),
                                        edgecolor=child.get_edgecolor(), facecolor='none', lw=child.get_lw()))
        ax.set_title(f'anchor={a_idx} → patch={p_idx}')
        plt.close(frame_fig)
        return ax.patches  # need to return updated artists

    anim = animation.FuncAnimation(
        fig, update,
        init_func=init,
        frames=len(pairs),
        interval=1000//fps,
        blit=False,
        repeat=True,
    )
    return anim

# ------------------------------------------------------------------
# Example usage (inside a notebook):
# anim = animate_all_pairs(io, model, train_transform, img, batch_idx=0, fps=2, anchors=[0])
# from IPython.display import HTML
# HTML(anim.to_jshtml())


In [None]:
# anim = animate_all_pairs(io, model, train_transform, img, batch_idx=0, fps=2, anchors=[0])
# anim.save("uncertainty.mp4", fps=2, extra_args=['-vcodec', 'libx264'])

In [None]:
import torch
import numpy as np
from sklearn.cluster import DBSCAN
import matplotlib.pyplot as plt


def cluster_and_plot_patches(
    io,
    model,
    train_transform,
    img,
    batch_idx=0,
    anchor_idx=0,
    sigma_unc=0.05,
    eps=0.1,
    min_samples=5,
    figsize=(8,8),
):
    """
    1) Extract (mu_x, mu_y, b_x, b_y) for each patch
    2) Normalize and cluster with DBSCAN
    3) Overlay clusters on the canonical image

    Args:
        io: dict containing model I/O tensors
        model: object with view_ids_M, Is, canonical_img_size, patch_size
        train_transform: must have recreate_canonical(img, params)
        img: input image batch element
        batch_idx: which batch to use
        anchor_idx: which patch to anchor on
        sigma_unc: scale for uncertainty normalization
        eps, min_samples: DBSCAN tuning
        figsize: plot size
    """
    # canonical image size
    canon_size = model.canonical_img_size

    # determine anchor's crop view
    view_idx = model.view_ids_M[anchor_idx]
    crop_size = model.Is[view_idx]

    # get crop params and local patch pos
    cp = io['crop_params'][batch_idx][view_idx]
    lp = io['patch_positions_nopos'][batch_idx][anchor_idx]
    # compute anchor top-left in canonical coords
    anchor_tl = cp[:2] + (lp / crop_size) * cp[2:4]

    # predicted offsets and dispersions
    pred_dt = io['pred_dT'][batch_idx, anchor_idx, :, :2] * canon_size  # [N,2]
    disp    = io['disp_dT'][batch_idx, anchor_idx, :, :2]               # [N,2]

    # reconstruct canonical image and get dims
    canon = train_transform.recreate_canonical(img, io['canonical_params'][batch_idx])
    if torch.is_tensor(canon):
        canon = canon.permute(1,2,0).cpu().numpy()
    H, W = canon.size[:2]

    # compute predicted centers
    # note: subtract offsets per your GT convention
    mus = anchor_tl.unsqueeze(0) - pred_dt        # [N,2] tensor
    bs  = disp                                    # [N,2]

    # to numpy
    mus_np = mus.cpu().numpy()
    bs_np  = bs.cpu().numpy()

    # mask points inside image
    y, x = mus_np[:,0], mus_np[:,1]
    in_bounds = (y >= 0) & (y <= H) & (x >= 0) & (x <= W)
    mus_np = mus_np[in_bounds]
    bs_np  = bs_np[in_bounds]

    # build features
    mus_norm = mus_np / canon_size
    bs_norm  = bs_np  / (canon_size * sigma_unc)
    features = np.hstack([mus_norm, bs_norm])     # [M,4]

    # DBSCAN
    db = DBSCAN(eps=eps, min_samples=min_samples).fit(features)
    labels = db.labels_                           # length M

    # plot
    plt.figure(figsize=figsize)
    plt.imshow(canon, interpolation='nearest')
    scatter = plt.scatter(
        mus_np[:,1], mus_np[:,0],
        c=labels, cmap='tab20', s=40, edgecolor='k'
    )
    n_clusters = len(set(labels)) - (1 if -1 in labels else 0)
    plt.title(f'DBSCAN Clustering: {n_clusters} clusters (σ_unc={sigma_unc}, eps={eps})')
    plt.axis('off')
    plt.show()

    return labels, features

# Example usage:
labels, features = cluster_and_plot_patches(io, model, train_transform, img,
                                  batch_idx=0, anchor_idx=0,
                                  sigma_unc=0.5, eps=0.06, min_samples=5)


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.neighbors import NearestNeighbors

# Helper to plot k-distance graph for eps tuning
def plot_k_distance(features, k=5):
    """
    Plot the sorted k-distance graph (distance to k-th nearest neighbor) to help choose eps.

    Args:
        features: numpy array of shape [N, D]
        k: number of neighbors
    """
    # Compute k-nearest distances
    nbrs = NearestNeighbors(n_neighbors=k+1).fit(features)
    distances, _ = nbrs.kneighbors(features)
    # distances[:,0] is zero (self), so take distances[:,k]
    k_distances = np.sort(distances[:, k])

    plt.figure(figsize=(6,4))
    plt.plot(k_distances)
    plt.xlabel(f'Samples sorted by distance to {k}th NN')
    plt.ylabel(f'{k}th NN distance')
    plt.title('k-distance graph for DBSCAN eps selection')
    plt.grid(True)
    plt.show()

# Example usage in your notebook:
plot_k_distance(features, k=5)

# After inspecting the elbow point on the graph, pick eps where the curve shows a knee.


In [None]:
import torch
import numpy as np
from sklearn.cluster import DBSCAN
import matplotlib.pyplot as plt


def cluster_and_plot_patches(
    io,
    model,
    train_transform,
    img,
    batch_idx=0,
    anchor_idx=0,
    sigma_unc=0.05,
    eps=0.1,
    min_samples=5,
    figsize=(8,8),
    tune_sigma: bool = False,
    sigma_range=(0.01, 0.2, 10),
):
    """
    1) Extract (mu_x, mu_y, b_x, b_y) for each patch
    2) Normalize and cluster with DBSCAN
    3) Overlay clusters on the canonical image

    Args:
        io: dict containing model I/O tensors
        model: object with view_ids_M, Is, canonical_img_size, patch_size
        train_transform: must have recreate_canonical(img, params)
        img: input image batch element
        batch_idx: which batch to use
        anchor_idx: which patch to anchor on
        sigma_unc: scale for uncertainty normalization
        eps, min_samples: DBSCAN tuning
        figsize: plot size
    """
    # canonical image size
    canon_size = model.canonical_img_size

    # determine anchor's crop view
    view_idx = model.view_ids_M[anchor_idx]
    crop_size = model.Is[view_idx]

    # get crop params and local patch pos
    cp = io['crop_params'][batch_idx][view_idx]
    lp = io['patch_positions_nopos'][batch_idx][anchor_idx]
    # compute anchor top-left in canonical coords
    anchor_tl = cp[:2] + (lp / crop_size) * cp[2:4]

    # predicted offsets and dispersions
    pred_dt = io['pred_dT'][batch_idx, anchor_idx, :, :2] * canon_size  # [N,2]
    disp    = io['disp_dT'][batch_idx, anchor_idx, :, :2]               # [N,2]

    # reconstruct canonical image and get dims
    canon = train_transform.recreate_canonical(img, io['canonical_params'][batch_idx])
    if torch.is_tensor(canon):
        canon = canon.permute(1,2,0).cpu().numpy()
    H, W = canon.size[:2]

    # compute predicted centers
    # note: subtract offsets per your GT convention
    mus = anchor_tl.unsqueeze(0) - pred_dt        # [N,2] tensor
    bs  = disp                                    # [N,2]

    # to numpy
    mus_np = mus.cpu().numpy()
    bs_np  = bs.cpu().numpy()

    # mask points inside image
    y, x = mus_np[:,0], mus_np[:,1]
    in_bounds = (y >= 0) & (y <= H) & (x >= 0) & (x <= W)
    mus_np = mus_np[in_bounds]
    bs_np  = bs_np[in_bounds]

    # build features
    mus_norm = mus_np / canon_size
    bs_norm  = bs_np  / (canon_size * sigma_unc)
    features = np.hstack([mus_norm, bs_norm])     # [M,4]

        # optionally plot sigma_unc sensitivity (tune_sigma)
    if tune_sigma:
        sigmas = np.linspace(*sigma_range)
        plt.figure(figsize=(12, 3))
        for i, s in enumerate(sigmas, 1):
            bs_n = bs_np / (canon_size * s)
            feat = np.hstack([mus_norm, bs_n])
            # k-distance curve (6th NN)
            from sklearn.neighbors import NearestNeighbors
            nbrs = NearestNeighbors(n_neighbors=6).fit(feat)
            dists, _ = nbrs.kneighbors(feat)
            kdist = np.sort(dists[:,5])
            ax = plt.subplot(1, len(sigmas), i)
            ax.plot(kdist)
            ax.set_title(f'sigma_unc={s:.3f}')
            ax.set_xlabel('Sample index')
            ax.set_ylabel('6th NN distance')
        plt.suptitle('k-distance vs. sigma_unc')
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        plt.show()

    # DBSCAN
    db = DBSCAN(eps=eps, min_samples=min_samples).fit(features)
    labels = db.labels_                           # length M

    # plot
    plt.figure(figsize=figsize)
    plt.imshow(canon, interpolation='nearest')
    scatter = plt.scatter(
        mus_np[:,1], mus_np[:,0],
        c=labels, cmap='tab20', s=40, edgecolor='k'
    )
    n_clusters = len(set(labels)) - (1 if -1 in labels else 0)
    plt.title(f'DBSCAN Clustering: {n_clusters} clusters (σ_unc={sigma_unc}, eps={eps})')
    plt.axis('off')
    plt.show()

    return labels

labels = cluster_and_plot_patches(
    io, model, train_transform, img,
    tune_sigma=True,
    sigma_range=(0.01, 0.2, 5),  # try 5 values between 0.01 and 0.2
    eps=0.06, min_samples=5, sigma_unc=0.001
)



In [None]:
import torch
import numpy as np
import networkx as nx
from community import community_louvain
import matplotlib.pyplot as plt


def graph_cluster_and_plot(
    io,
    model,
    train_transform,
    img,
    batch_idx=0,
    anchor_idx=0,
    sigma_pos=0.05,
    sigma_unc=0.05,
    weight_threshold=0.01,
    figsize=(8,8)
):
    """
    1) Compute predicted patch centers (mu) and dispersions (b) from anchor -> each patch
    2) Build a weighted graph where nodes are patches and edges weighted by
       exp(-||mu_i-mu_j||^2 / sigma_pos^2) * exp(-||b_i-b_j||^2 / sigma_unc^2)
    3) Prune edges below weight_threshold
    4) Run Louvain community detection
    5) Plot patches color-coded by community on canonical image

    Returns:
        partition: dict mapping patch index -> community label
    """
    # canonical size and patch scale
    canon_size = model.canonical_img_size

    # anchor view and crop size
    view_idx = model.view_ids_M[anchor_idx]
    crop_size = model.Is[view_idx]

    # get crop params and local pos
    cp = io['crop_params'][batch_idx][view_idx]
    lp = io['patch_positions_nopos'][batch_idx][anchor_idx]

    # compute anchor top-left in canonical coords
    anchor_tl = cp[:2] + (lp / crop_size) * cp[2:4]        # tensor[2]

    # predicted offsets and dispersions
    pred_dt = io['pred_dT'][batch_idx, anchor_idx, :, :2] * canon_size   # [N,2]
    disp    = io['disp_dT'][batch_idx, anchor_idx, :, :2]                 # [N,2]

    # compute predicted patch centers mu = anchor_tl - pred_dt
    mus = anchor_tl.unsqueeze(0) - pred_dt      # [N,2]
    bs  = disp                                 # [N,2]

    # to numpy arrays
    mus_np = mus.cpu().numpy()  # (N,2)
    bs_np  = bs.cpu().numpy()   # (N,2)
    N = mus_np.shape[0]

    # reconstruct canonical image
    canon = train_transform.recreate_canonical(img, io['canonical_params'][batch_idx])
    if torch.is_tensor(canon):
        canon = canon.permute(1,2,0).cpu().numpy()
    H, W = canon.size[:2]

    # build weighted graph
    G = nx.Graph()
    G.add_nodes_from(range(N))

    # precompute squared norms
    mu_sq = np.sum(mus_np**2, axis=1, keepdims=True)
    b_sq  = np.sum(bs_np**2, axis=1, keepdims=True)

    # pairwise differences
    # compute weight matrix efficiently
    diff_mu = mus_np[:,None,:] - mus_np[None,:,:]       # [N,N,2]
    diff_b  = bs_np[:,None,:] - bs_np[None,:,:]          # [N,N,2]
    dist_mu2 = np.sum(diff_mu**2, axis=2)                # [N,N]
    dist_b2  = np.sum(diff_b**2, axis=2)

    # affinity
    W = np.exp(-dist_mu2 / (sigma_pos**2 * canon_size**2)) * \
        np.exp(-dist_b2  / (sigma_unc**2  * canon_size**2))

    # add edges above threshold
    for i in range(N):
        for j in range(i+1, N):
            w = W[i,j]
            if w >= weight_threshold:
                G.add_edge(i, j, weight=w)

    # Louvain community detection
    partition = community_louvain.best_partition(G, weight='weight')

    # plot
    plt.figure(figsize=figsize)
    plt.imshow(canon, interpolation='nearest')
    # scatter with community colors
    labels = [partition[i] for i in range(N)]
    y = mus_np[:,0]
    x = mus_np[:,1]
    scatter = plt.scatter(
        x, y,
        c=labels,
        cmap='tab20',
        s=40,
        edgecolor='k'
    )
    # title
    num_com = len(set(labels))
    plt.title(f'Louvain Clustering: {num_com} communities')
    plt.axis('off')
    plt.tight_layout()
    plt.show()

    return partition

# Example usage:
partition = graph_cluster_and_plot(
    io, model, train_transform, img,
    batch_idx=0, anchor_idx=34,
    sigma_pos=0.7, sigma_unc=1,
    weight_threshold=0.9
)


In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from typing import Literal

def create_global_uncertainty_heatmap(
    patch_positions: Float[Tensor, "M 2"],
    dispersions: Float[Tensor, "M 4"],  # [dy, dx, dlogh, dlogw]
    canonical_size: int,
    patch_size: int,
) -> Float[Tensor, "canonical_size canonical_size"]:
    """
    Create a global uncertainty heatmap by taking the norm of position dispersions.
    
    Args:
        patch_positions: Predicted patch positions in canonical space [M, 2]
        dispersions: Predicted dispersions for each patch [M, 4] (in pixel units)
        canonical_size: Size of the canonical image
        patch_size: Size of the patches in the image
        
    Returns:
        uncertainty_map: Global uncertainty heatmap
    """
    device = patch_positions.device
    uncertainty_map = torch.zeros((canonical_size, canonical_size), device=device)
    
    # Take norm of position dispersions (dy, dx)
    pos_uncertainty = torch.norm(dispersions[:, :2], dim=1)  # [M]
    
    for i, (pos, unc) in enumerate(zip(patch_positions, pos_uncertainty)):
        y, x = pos.int()
        # Clamp to image bounds
        y = torch.clamp(y, 0, canonical_size - patch_size)
        x = torch.clamp(x, 0, canonical_size - patch_size)
        
        # Add uncertainty to the patch region
        uncertainty_map[y:y+patch_size, x:x+patch_size] += unc
        
    return uncertainty_map


def create_laplace_distribution_heatmaps(
    patch_positions: Float[Tensor, "M 2"],
    dispersions: Float[Tensor, "M 4"],  # [dy, dx, dlogh, dlogw]
    canonical_size: int,
    patch_size: int,
    alpha: float = 0.7,
) -> Float[Tensor, "canonical_size canonical_size"]:
    """
    Create overlaid Laplace distribution heatmaps at each predicted patch position.
    
    Args:
        patch_positions: Predicted patch positions in canonical space [M, 2]
        dispersions: Per-patch dispersions [M, 4] (in pixel units)
        canonical_size: Size of canonical image
        patch_size: Size of patches
        alpha: Blending factor for overlapping distributions
        
    Returns:
        combined_heatmap: Combined Laplace distributions heatmap
    """
    device = patch_positions.device
    combined_heatmap = torch.zeros((canonical_size, canonical_size), device=device)
    
    # Create coordinate grids
    y_coords = torch.arange(canonical_size, device=device).float()
    x_coords = torch.arange(canonical_size, device=device).float()
    Y, X = torch.meshgrid(y_coords, x_coords, indexing='ij')
    
    for pos, disp in zip(patch_positions, dispersions):
        mu_y, mu_x = pos[0], pos[1]
        b_y, b_x = disp[0], disp[1]  # Laplace scale parameters (already in pixel units)
        
        # Ensure reasonable scale parameters (avoid too small/large values)
        b_y = torch.clamp(b_y, min=0.5, max=canonical_size/2)
        b_x = torch.clamp(b_x, min=0.5, max=canonical_size/2)
        
        # Only compute distribution if patch is reasonably within extended bounds
        if (mu_y >= -2*patch_size and mu_y <= canonical_size + 2*patch_size and 
            mu_x >= -2*patch_size and mu_x <= canonical_size + 2*patch_size):
            
            # Compute Laplace distribution: (1/(4*b_y*b_x)) * exp(-|y-mu_y|/b_y - |x-mu_x|/b_x)
            laplace_dist = (1.0 / (4 * b_y * b_x)) * torch.exp(
                -torch.abs(Y - mu_y) / b_y - torch.abs(X - mu_x) / b_x
            )
            
            # Normalize to [0,1] to prevent any single distribution from dominating
            if laplace_dist.max() > 0:
                laplace_dist = laplace_dist / laplace_dist.max()
            
            # Add to combined heatmap with additive blending
            combined_heatmap += alpha * laplace_dist
    
    # Normalize the final combined heatmap
    if combined_heatmap.max() > 0:
        combined_heatmap = combined_heatmap / combined_heatmap.max()
    
    return combined_heatmap

In [None]:
import torch
import numpy as np
from sklearn.cluster import DBSCAN
import matplotlib.pyplot as plt


def cluster_and_plot_patches(
    io,
    model,
    train_transform,
    img,
    batch_idx=0,
    anchor_idx=0,
    sigma_unc=0.05,
    eps=0.1,
    min_samples=5,
    figsize=(8,8),
):
    """
    1) Extract (mu_x, mu_y, b_x, b_y) for each patch
    2) Normalize and cluster with DBSCAN
    3) Overlay clusters on the canonical image

    Args:
        io: dict containing model I/O tensors
        model: object with view_ids_M, Is, canonical_img_size, patch_size
        train_transform: must have recreate_canonical(img, params)
        img: input image batch element
        batch_idx: which batch to use
        anchor_idx: which patch to anchor on
        sigma_unc: scale for uncertainty normalization
        eps, min_samples: DBSCAN tuning
        figsize: plot size
    """
    # canonical image size
    canon_size = model.canonical_img_size

    # determine anchor's crop view
    view_idx = model.view_ids_M[anchor_idx]
    crop_size = model.Is[view_idx]

    # get crop params and local patch pos
    cp = io['crop_params'][batch_idx][view_idx]
    lp = io['patch_positions_nopos'][batch_idx][anchor_idx]
    # compute anchor top-left in canonical coords
    anchor_tl = cp[:2] + (lp / crop_size) * cp[2:4]

    # predicted offsets and dispersions
    pred_dt = io['pred_dT'][batch_idx, anchor_idx, :, :2] * canon_size  # [N,2]
    disp    = io['disp_dT'][batch_idx, anchor_idx, :, :2]               # [N,2]

    # reconstruct canonical image and get dims
    canon = train_transform.recreate_canonical(img, io['canonical_params'][batch_idx])
    if torch.is_tensor(canon):
        canon = canon.permute(1,2,0).cpu().numpy()
    H, W = canon.size[:2]

    # compute predicted centers
    # note: subtract offsets per your GT convention
    mus = anchor_tl.unsqueeze(0) - pred_dt        # [N,2] tensor
    bs  = disp                                    # [N,2]

    # to numpy
    mus_np = mus.cpu().numpy()
    bs_np  = bs.cpu().numpy()

    # mask points inside image
    y, x = mus_np[:,0], mus_np[:,1]
    in_bounds = (y >= 0) & (y <= H) & (x >= 0) & (x <= W)
    mus_np = mus_np[in_bounds]
    bs_np  = bs_np[in_bounds]

    # build features
    mus_norm = mus_np / canon_size
    bs_norm  = bs_np  / (canon_size * sigma_unc)
    features = np.hstack([mus_norm, bs_norm])     # [M,4]

    # DBSCAN
    db = DBSCAN(eps=eps, min_samples=min_samples).fit(features)
    labels = db.labels_                           # length M

    # plot
    plt.figure(figsize=figsize)
    plt.imshow(canon, interpolation='nearest')
    scatter = plt.scatter(
        mus_np[:,1], mus_np[:,0],
        c=labels, cmap='tab20', s=40, edgecolor='k'
    )
    n_clusters = len(set(labels)) - (1 if -1 in labels else 0)
    plt.title(f'DBSCAN Clustering: {n_clusters} clusters (σ_unc={sigma_unc}, eps={eps})')
    plt.axis('off')
    plt.show()

    return labels

# Example usage:
labels, features = cluster_and_plot_patches(io, model, train_transform, img,
                                  batch_idx=0, anchor_idx=0,
                                  sigma_unc=0.5, eps=0.06, min_samples=5)


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.neighbors import NearestNeighbors

# Helper to plot k-distance graph for eps tuning
def plot_k_distance(features, k=5):
    """
    Plot the sorted k-distance graph (distance to k-th nearest neighbor) to help choose eps.

    Args:
        features: numpy array of shape [N, D]
        k: number of neighbors
    """
    # Compute k-nearest distances
    nbrs = NearestNeighbors(n_neighbors=k+1).fit(features)
    distances, _ = nbrs.kneighbors(features)
    # distances[:,0] is zero (self), so take distances[:,k]
    k_distances = np.sort(distances[:, k])

    plt.figure(figsize=(6,4))
    plt.plot(k_distances)
    plt.xlabel(f'Samples sorted by distance to {k}th NN')
    plt.ylabel(f'{k}th NN distance')
    plt.title('k-distance graph for DBSCAN eps selection')
    plt.grid(True)
    plt.show()

# Example usage in your notebook:
plot_k_distance(features, k=5)

# After inspecting the elbow point on the graph, pick eps where the curve shows a knee.


In [None]:
import torch
import numpy as np
from sklearn.cluster import DBSCAN
import matplotlib.pyplot as plt


def cluster_and_plot_patches(
    io,
    model,
    train_transform,
    img,
    batch_idx=0,
    anchor_idx=0,
    sigma_unc=0.05,
    eps=0.1,
    min_samples=5,
    figsize=(8,8),
    tune_sigma: bool = False,
    sigma_range=(0.01, 0.2, 10),
):
    """
    1) Extract (mu_x, mu_y, b_x, b_y) for each patch
    2) Normalize and cluster with DBSCAN
    3) Overlay clusters on the canonical image

    Args:
        io: dict containing model I/O tensors
        model: object with view_ids_M, Is, canonical_img_size, patch_size
        train_transform: must have recreate_canonical(img, params)
        img: input image batch element
        batch_idx: which batch to use
        anchor_idx: which patch to anchor on
        sigma_unc: scale for uncertainty normalization
        eps, min_samples: DBSCAN tuning
        figsize: plot size
    """
    # canonical image size
    canon_size = model.canonical_img_size

    # determine anchor's crop view
    view_idx = model.view_ids_M[anchor_idx]
    crop_size = model.Is[view_idx]

    # get crop params and local patch pos
    cp = io['crop_params'][batch_idx][view_idx]
    lp = io['patch_positions_nopos'][batch_idx][anchor_idx]
    # compute anchor top-left in canonical coords
    anchor_tl = cp[:2] + (lp / crop_size) * cp[2:4]

    # predicted offsets and dispersions
    pred_dt = io['pred_dT'][batch_idx, anchor_idx, :, :2] * canon_size  # [N,2]
    disp    = io['disp_dT'][batch_idx, anchor_idx, :, :2]               # [N,2]

    # reconstruct canonical image and get dims
    canon = train_transform.recreate_canonical(img, io['canonical_params'][batch_idx])
    if torch.is_tensor(canon):
        canon = canon.permute(1,2,0).cpu().numpy()
    H, W = canon.size[:2]

    # compute predicted centers
    # note: subtract offsets per your GT convention
    mus = anchor_tl.unsqueeze(0) - pred_dt        # [N,2] tensor
    bs  = disp                                    # [N,2]

    # to numpy
    mus_np = mus.cpu().numpy()
    bs_np  = bs.cpu().numpy()

    # mask points inside image
    y, x = mus_np[:,0], mus_np[:,1]
    in_bounds = (y >= 0) & (y <= H) & (x >= 0) & (x <= W)
    mus_np = mus_np[in_bounds]
    bs_np  = bs_np[in_bounds]

    # build features
    mus_norm = mus_np / canon_size
    bs_norm  = bs_np  / (canon_size * sigma_unc)
    features = np.hstack([mus_norm, bs_norm])     # [M,4]

        # optionally plot sigma_unc sensitivity (tune_sigma)
    if tune_sigma:
        sigmas = np.linspace(*sigma_range)
        plt.figure(figsize=(12, 3))
        for i, s in enumerate(sigmas, 1):
            bs_n = bs_np / (canon_size * s)
            feat = np.hstack([mus_norm, bs_n])
            # k-distance curve (6th NN)
            from sklearn.neighbors import NearestNeighbors
            nbrs = NearestNeighbors(n_neighbors=6).fit(feat)
            dists, _ = nbrs.kneighbors(feat)
            kdist = np.sort(dists[:,5])
            ax = plt.subplot(1, len(sigmas), i)
            ax.plot(kdist)
            ax.set_title(f'sigma_unc={s:.3f}')
            ax.set_xlabel('Sample index')
            ax.set_ylabel('6th NN distance')
        plt.suptitle('k-distance vs. sigma_unc')
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        plt.show()

    # DBSCAN
    db = DBSCAN(eps=eps, min_samples=min_samples).fit(features)
    labels = db.labels_                           # length M

    # plot
    plt.figure(figsize=figsize)
    plt.imshow(canon, interpolation='nearest')
    scatter = plt.scatter(
        mus_np[:,1], mus_np[:,0],
        c=labels, cmap='tab20', s=40, edgecolor='k'
    )
    n_clusters = len(set(labels)) - (1 if -1 in labels else 0)
    plt.title(f'DBSCAN Clustering: {n_clusters} clusters (σ_unc={sigma_unc}, eps={eps})')
    plt.axis('off')
    plt.show()

    return labels

labels = cluster_and_plot_patches(
    io, model, train_transform, img,
    tune_sigma=True,
    sigma_range=(0.01, 0.2, 5),  # try 5 values between 0.01 and 0.2
    eps=0.06, min_samples=5, sigma_unc=0.001
)



In [None]:
import torch
import numpy as np
import networkx as nx
from community import community_louvain
import matplotlib.pyplot as plt


def graph_cluster_and_plot(
    io,
    model,
    train_transform,
    img,
    batch_idx=0,
    anchor_idx=0,
    sigma_pos=0.05,
    sigma_unc=0.05,
    weight_threshold=0.01,
    figsize=(8,8)
):
    """
    1) Compute predicted patch centers (mu) and dispersions (b) from anchor -> each patch
    2) Build a weighted graph where nodes are patches and edges weighted by
       exp(-||mu_i-mu_j||^2 / sigma_pos^2) * exp(-||b_i-b_j||^2 / sigma_unc^2)
    3) Prune edges below weight_threshold
    4) Run Louvain community detection
    5) Plot patches color-coded by community on canonical image

    Returns:
        partition: dict mapping patch index -> community label
    """
    # canonical size and patch scale
    canon_size = model.canonical_img_size

    # anchor view and crop size
    view_idx = model.view_ids_M[anchor_idx]
    crop_size = model.Is[view_idx]

    # get crop params and local pos
    cp = io['crop_params'][batch_idx][view_idx]
    lp = io['patch_positions_nopos'][batch_idx][anchor_idx]

    # compute anchor top-left in canonical coords
    anchor_tl = cp[:2] + (lp / crop_size) * cp[2:4]        # tensor[2]

    # predicted offsets and dispersions
    pred_dt = io['pred_dT'][batch_idx, anchor_idx, :, :2] * canon_size   # [N,2]
    disp    = io['disp_dT'][batch_idx, anchor_idx, :, :2]                 # [N,2]

    # compute predicted patch centers mu = anchor_tl - pred_dt
    mus = anchor_tl.unsqueeze(0) - pred_dt      # [N,2]
    bs  = disp                                 # [N,2]

    # to numpy arrays
    mus_np = mus.cpu().numpy()  # (N,2)
    bs_np  = bs.cpu().numpy()   # (N,2)
    N = mus_np.shape[0]

    # reconstruct canonical image
    canon = train_transform.recreate_canonical(img, io['canonical_params'][batch_idx])
    if torch.is_tensor(canon):
        canon = canon.permute(1,2,0).cpu().numpy()
    H, W = canon.size[:2]

    # build weighted graph
    G = nx.Graph()
    G.add_nodes_from(range(N))

    # precompute squared norms
    mu_sq = np.sum(mus_np**2, axis=1, keepdims=True)
    b_sq  = np.sum(bs_np**2, axis=1, keepdims=True)

    # pairwise differences
    # compute weight matrix efficiently
    diff_mu = mus_np[:,None,:] - mus_np[None,:,:]       # [N,N,2]
    diff_b  = bs_np[:,None,:] - bs_np[None,:,:]          # [N,N,2]
    dist_mu2 = np.sum(diff_mu**2, axis=2)                # [N,N]
    dist_b2  = np.sum(diff_b**2, axis=2)

    # affinity
    W = np.exp(-dist_mu2 / (sigma_pos**2 * canon_size**2)) * \
        np.exp(-dist_b2  / (sigma_unc**2  * canon_size**2))

    # add edges above threshold
    for i in range(N):
        for j in range(i+1, N):
            w = W[i,j]
            if w >= weight_threshold:
                G.add_edge(i, j, weight=w)

    # Louvain community detection
    partition = community_louvain.best_partition(G, weight='weight')

    # plot
    plt.figure(figsize=figsize)
    plt.imshow(canon, interpolation='nearest')
    # scatter with community colors
    labels = [partition[i] for i in range(N)]
    y = mus_np[:,0]
    x = mus_np[:,1]
    scatter = plt.scatter(
        x, y,
        c=labels,
        cmap='tab20',
        s=40,
        edgecolor='k'
    )
    # title
    num_com = len(set(labels))
    plt.title(f'Louvain Clustering: {num_com} communities')
    plt.axis('off')
    plt.tight_layout()
    plt.show()

    return partition

# Example usage:
partition = graph_cluster_and_plot(
    io, model, train_transform, img,
    batch_idx=0, anchor_idx=34,
    sigma_pos=0.7, sigma_unc=1,
    weight_threshold=0.9
)


In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from typing import Literal

def create_global_uncertainty_heatmap(
    patch_positions: Float[Tensor, "M 2"],
    dispersions: Float[Tensor, "M 4"],  # [dy, dx, dlogh, dlogw]
    canonical_size: int,
    patch_size: int,
) -> Float[Tensor, "canonical_size canonical_size"]:
    """
    Create a global uncertainty heatmap by taking the norm of position dispersions.
    
    Args:
        patch_positions: Predicted patch positions in canonical space [M, 2]
        dispersions: Predicted dispersions for each patch [M, 4] (in pixel units)
        canonical_size: Size of the canonical image
        patch_size: Size of the patches in the image
        
    Returns:
        uncertainty_map: Global uncertainty heatmap
    """
    device = patch_positions.device
    uncertainty_map = torch.zeros((canonical_size, canonical_size), device=device)
    
    # Take norm of position dispersions (dy, dx)
    pos_uncertainty = torch.norm(dispersions[:, :2], dim=1)  # [M]
    
    for i, (pos, unc) in enumerate(zip(patch_positions, pos_uncertainty)):
        y, x = pos.int()
        # Clamp to image bounds
        y = torch.clamp(y, 0, canonical_size - patch_size)
        x = torch.clamp(x, 0, canonical_size - patch_size)
        
        # Add uncertainty to the patch region
        uncertainty_map[y:y+patch_size, x:x+patch_size] += unc
        
    return uncertainty_map


def create_laplace_distribution_heatmaps(
    patch_positions: Float[Tensor, "M 2"],
    dispersions: Float[Tensor, "M 4"],  # [dy, dx, dlogh, dlogw]
    canonical_size: int,
    patch_size: int,
    alpha: float = 0.7,
) -> Float[Tensor, "canonical_size canonical_size"]:
    """
    Create overlaid Laplace distribution heatmaps at each predicted patch position.
    
    Args:
        patch_positions: Predicted patch positions in canonical space [M, 2]
        dispersions: Per-patch dispersions [M, 4] (in pixel units)
        canonical_size: Size of canonical image
        patch_size: Size of patches
        alpha: Blending factor for overlapping distributions
        
    Returns:
        combined_heatmap: Combined Laplace distributions heatmap
    """
    device = patch_positions.device
    combined_heatmap = torch.zeros((canonical_size, canonical_size), device=device)
    
    # Create coordinate grids
    y_coords = torch.arange(canonical_size, device=device).float()
    x_coords = torch.arange(canonical_size, device=device).float()
    Y, X = torch.meshgrid(y_coords, x_coords, indexing='ij')
    
    for pos, disp in zip(patch_positions, dispersions):
        mu_y, mu_x = pos[0], pos[1]
        b_y, b_x = disp[0], disp[1]  # Laplace scale parameters (already in pixel units)
        
        # Ensure reasonable scale parameters (avoid too small/large values)
        b_y = torch.clamp(b_y, min=0.5, max=canonical_size/2)
        b_x = torch.clamp(b_x, min=0.5, max=canonical_size/2)
        
        # Only compute distribution if patch is reasonably within extended bounds
        if (mu_y >= -2*patch_size and mu_y <= canonical_size + 2*patch_size and 
            mu_x >= -2*patch_size and mu_x <= canonical_size + 2*patch_size):
            
            # Compute Laplace distribution: (1/(4*b_y*b_x)) * exp(-|y-mu_y|/b_y - |x-mu_x|/b_x)
            laplace_dist = (1.0 / (4 * b_y * b_x)) * torch.exp(
                -torch.abs(Y - mu_y) / b_y - torch.abs(X - mu_x) / b_x
            )
            
            # Normalize to [0,1] to prevent any single distribution from dominating
            if laplace_dist.max() > 0:
                laplace_dist = laplace_dist / laplace_dist.max()
            
            # Add to combined heatmap with additive blending
            combined_heatmap += alpha * laplace_dist
    
    # Normalize the final combined heatmap
    if combined_heatmap.max() > 0:
        combined_heatmap = combined_heatmap / combined_heatmap.max()
    
    return combined_heatmap

In [None]:
# print(io["disp_T"].shape) # [1, 294, 4] # [B, M, 4] # 4 for dy, dx, dlogh, dlogw
print(io["patch_positions_nopos"].shape) # [1, 294, 2] # y, x coordinates, where 
print(io["gt_dT"].shape) # [1, 294, 4] # ground truth displacement
# each disp_T is associated with a patch position, where 
print(io["patch_positions_nopos"].max())
print(io["params"])
print(io["crop_params"])
# the way to get the coordinates of each patch in the canonical image is to first obtain an anchor
# that is, let's pick patch 0, obtain its position in the canonical image and then add it to gt_dT

"""
dT = pred_dT[..., :2] * canonical_img_size
dS = pred_dT[..., 2:] * math.log(max_scale_ratio)


# Choose the anchor: first patch of the first global crop (global index 0).
T_anchor = (
    crop_params[0][:2]
    + (patch_positions_nopos[0] / x[0].shape[1]) * crop_params[0][2:4]
)
S_anchor = torch.log((patch_size * crop_params[0][2:4] / x[0].shape[1]))

T_global = dT[:, 0] + T_anchor
S_global = dS[:, 0] + S_anchor
"""

batch_idx = 0
anchor_idx = 0
crop_size = io["x"][batch_idx][0].shape[-1]

anchor_params = io["crop_params"][batch_idx][anchor_idx]
anchor_local_pos = io["patch_positions_nopos"][batch_idx][anchor_idx]
anchor_global_pos = anchor_params[:2] + (anchor_local_pos / crop_size) * anchor_params[2:4]
anchor_global_scale = torch.log((model.patch_size * anchor_params[2:4] / crop_size))
print("Anchor global position:", anchor_global_pos)
print("Anchor global scale:", anchor_global_scale)

# Calculate global positions and scales for all patches
gt_dt = io["gt_dT"][batch_idx, :, :, :2] * model.canonical_img_size
gt_ds = io["gt_dT"][batch_idx, :, :, 2:] * math.log(model.max_scale_ratio)
pred_dt = io["pred_dT"][batch_idx, :, :, :2] * model.canonical_img_size
pred_ds = io["pred_dT"][batch_idx, :, :, 2:] * math.log(model.max_scale_ratio)

print(io["gt_dT"].shape)  # [1, 294, 294, 4]
print(gt_dt.shape)  # [294, 2]
print(anchor_global_pos.shape)  # [2]
# gt_T_global and gt_S_global are the global positions and scales for all patches
gt_T_global = gt_dt[:, 0] + anchor_global_pos
gt_S_global = gt_ds[:, 0] + anchor_global_scale
# pred_T_global and pred_S_global are the predicted means of the global positions and scales for all patches
pred_T_global = pred_dt[:, 0] + anchor_global_pos
pred_S_global = pred_ds[:, 0] + anchor_global_scale

print("Global positions:", gt_T_global.shape)
print("Global scales:", gt_S_global.shape)

# cool
# now, let's pick an arbitrary patch
# and plot a laplacian distribution of the displacements given the mean (pred_dT) and the variance/lap. scale (disp_dT)

patch_idx = 0
print(io["disp_dT"].shape)  # [1, 294, 294, 4]
disp_patch_to_all = io["disp_dT"][batch_idx, patch_idx]
# we can only visualize the y and x dispersions
disp_patch_to_all = disp_patch_to_all[:, :2]  # [294, 2]

print(pred_T_global.shape)
mu_patch = pred_T_global[patch_idx]  # [2]

# now, we display the canonical image
canonical_img = train_transform.recreate_canonical(
    img, io["canonical_params"][0]
)
# and add a bounding box around the chosen patch_idx
# and the laplacian distribution given the mean and the variance  (i assume as a heatmap)
print(disp_patch_to_all.median())  # [294, 2]
print(disp_patch_to_all.quantile(0.75) - disp_patch_to_all.quantile(0.25))  # [2]

In [None]:
# let's plot the per-token dispersions.
print(io["disp_T"].shape) # (B, V * num_tokens_per_view, 4) # one uncertainty per dimension (y, x, dlogh, dlogw)


In [None]:
import torch
import math
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle


def visualize_patch_distribution(
    io,
    model,
    train_transform,
    img,
    batch_idx: int,
    anchor_idx: int,
    patch_idx: int,
    figsize=(6,6),
    disp_scale_pixels: bool = True,
):
    """
    Visualize the predicted 2D Laplacian distribution from anchor -> patch,
    drawing ground-truth bounding boxes (top-left) for both anchor and target patches,
    and displaying their exact positions in the legend.

    Args:
        io: dict with keys 'pred_dT','disp_dT','gt_dT','crop_params',
            'patch_positions_nopos','canonical_params','x'
        model: object with .canonical_img_size, .max_scale_ratio, .patch_size
        train_transform: has .recreate_canonical(img, canonical_params)
        img: input image for reconstruction
        batch_idx: index into batch
        anchor_idx: index of anchor patch/crop
        patch_idx:  index of target patch
        figsize:    matplotlib figure size
        disp_scale_pixels: scale predicted dispersion to pixel units
    """
    # sizes
    anchor_view_idx = model.view_ids_M[anchor_idx]
    anchor_crop_size = model.Is[anchor_view_idx]
    canon_size = model.canonical_img_size
    p_size = model.patch_size

    # retrieve crop params and local positions
    cp = io['crop_params'][batch_idx][anchor_view_idx]         # [y0,x0,h',w']
    lp = io['patch_positions_nopos'][batch_idx][anchor_idx]  # [y_loc, x_loc]

    # anchor base top-left in global coords
    anchor_tl = cp[:2] + (lp / anchor_crop_size) * cp[2:4]
    ay_tl, ax_tl = anchor_tl.tolist()

    # patch size in global coords
    box_h, box_w = (p_size * cp[2:4] / anchor_crop_size).tolist()

    # current patch TL
    # py_tl = tls_y[patch_idx]
    # px_tl = tls_x[patch_idx]
    py_tl = anchor_tl[0].item() - io['gt_dT'][batch_idx, anchor_idx, patch_idx, 0].item() * canon_size
    px_tl = anchor_tl[1].item() - io['gt_dT'][batch_idx, anchor_idx, patch_idx, 1].item() * canon_size

    # predicted dispersion for heatmap
    disp = io['disp_dT'][batch_idx, anchor_idx, :, :2]  # [N,2]
    b_norm_y, b_norm_x = disp[patch_idx].cpu().tolist()
    if disp_scale_pixels:
        b_y = b_norm_y * canon_size
        b_x = b_norm_x * canon_size
    else:
        b_y, b_x = b_norm_y, b_norm_x

    # predicted mean offset for this patch
    pred_off = io['pred_dT'][batch_idx, anchor_idx, patch_idx, :2] * canon_size
    mu_y = (anchor_tl[0] - pred_off[0]).item()
    mu_x = (anchor_tl[1] - pred_off[1]).item()

    # reconstruct canonical image
    canon = train_transform.recreate_canonical(img, io['canonical_params'][batch_idx])
    if torch.is_tensor(canon):
        canon = canon.permute(1,2,0).cpu().numpy()
    H, W = canon.size[:2]

    # build heatmap at predicted mean
    ys = np.arange(H)[:,None]
    xs = np.arange(W)[None,:]
    Z = (1.0/(4*b_y*b_x)) * np.exp(-np.abs(ys-mu_y)/b_y - np.abs(xs-mu_x)/b_x)
    # Z = Z / Z.max()

    # plot
    fig, ax = plt.subplots(figsize=figsize)
    ax.imshow(canon, interpolation='nearest')

    # draw GT box for anchor (blue) at its top-left, with position in label
    rect_a = Rectangle(
        (ax_tl, ay_tl), box_w, box_h,
        edgecolor='blue', lw=2, facecolor='none',
        label=f'Anchor GT (TL): ({ay_tl:.1f}, {ax_tl:.1f})'
    )
    ax.add_patch(rect_a)

    # draw GT box for target patch (red) at its top-left, with position
    rect_p = Rectangle(
        (px_tl, py_tl), box_w, box_h,
        edgecolor='red', lw=2, facecolor='none',
        label=f'Patch GT (TL): ({py_tl:.1f}, {px_tl:.1f})'
    )
    ax.add_patch(rect_p)

    # overlay predicted Laplace heatmap
    ax.imshow(Z, cmap='hot', alpha=0.5, extent=(0,W,H,0))

    # legend and title
    ax.legend(loc='upper right')
    ax.set_title(
        f'anchor={anchor_idx} → patch={patch_idx}   '\
        f'Pred mean=(%.1f,%.1f)' % (mu_y, mu_x)
    )
    ax.axis('off')
    plt.tight_layout()
    plt.close(fig)
    return fig

In [None]:
import torch
import math
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle


def visualize_patch_distribution_per_token(
    io,
    model,
    train_transform,
    img,
    batch_idx: int,
    figsize=(6,6),
    disp_scale_pixels: bool = True,
):
    #### 1. COMPUTE AN ANCHOR'S CANONICAL POSITION AND SIZE
    # sizes
    anchor_idx = 0 # this is just to make the predictions absolute
    anchor_view_idx = model.view_ids_M[anchor_idx]
    anchor_crop_size = model.Is[anchor_view_idx]
    canon_size = model.canonical_img_size
    p_size = model.patch_size

    # retrieve crop params and local positions
    cp = io['crop_params'][batch_idx][anchor_view_idx]         # [y0,x0,h',w']
    lp = io['patch_positions_nopos'][batch_idx][anchor_idx]  # [y_loc, x_loc]

    # anchor base top-left in global coords
    anchor_tl = cp[:2] + (lp / anchor_crop_size) * cp[2:4]
    ay_tl, ax_tl = anchor_tl.tolist()

    # patch size in global coords
    box_h, box_w = (p_size * cp[2:4] / anchor_crop_size).tolist()

    #### 2. COMPUTE EACH ANCHOR'S PATCH POSITION BY SUBTRACTING THE ANCHOR'S GT DISPLACEMENT

    gt_dt = io["gt_dT"][batch_idx, :, :, :2] * model.canonical_img_size
    pred_dt = io["pred_dT"][batch_idx, :, :, :2] * model.canonical_img_size
    pred_ds = io["pred_dT"][batch_idx, :, :, 2:] * math.log(model.max_scale_ratio)

    # gt_T_global and gt_S_global are the global positions and scales for all patches
    gt_T_global = gt_dt[:, 0] + anchor_global_pos
    gt_S_global = gt_ds[:, 0] + anchor_global_scale
    # pred_T_global and pred_S_global are the predicted means of the global positions and scales for all patches
    pred_T_global = pred_dt[:, 0] + anchor_global_pos
    pred_S_global = pred_ds[:, 0] + anchor_global_scale

    py_tl = anchor_tl[0].item() - io['gt_dT'][batch_idx, anchor_idx, patch_idx, 0] * canon_size
    px_tl = anchor_tl[1].item() - io['gt_dT'][batch_idx, anchor_idx, patch_idx, 1] * canon_size



    # predicted dispersion for heatmap
    # disp = io['disp_dT'][batch_idx, anchor_idx, :, :2]  # [N,2]
    # b_norm_y, b_norm_x = disp[patch_idx].cpu().tolist()
    # if disp_scale_pixels:
    #     b_y = b_norm_y * canon_size
    #     b_x = b_norm_x * canon_size
    # else:
    #     b_y, b_x = b_norm_y, b_norm_x

    # predicted mean offset for this patch
    pred_off = io['pred_dT'][batch_idx, anchor_idx, patch_idx, :2] * canon_size
    mu_y = (anchor_tl[0] - pred_off[0]).item()
    mu_x = (anchor_tl[1] - pred_off[1]).item()

    # reconstruct canonical image
    canon = train_transform.recreate_canonical(img, io['canonical_params'][batch_idx])
    if torch.is_tensor(canon):
        canon = canon.permute(1,2,0).cpu().numpy()
    H, W = canon.size[:2]

    # build heatmap at predicted mean
    ys = np.arange(H)[:,None]
    xs = np.arange(W)[None,:]
    Z = (1.0/(4*b_y*b_x)) * np.exp(-np.abs(ys-mu_y)/b_y - np.abs(xs-mu_x)/b_x)
    # Z = Z / Z.max()

    # plot
    fig, ax = plt.subplots(figsize=figsize)
    ax.imshow(canon, interpolation='nearest')

    # draw GT box for anchor (blue) at its top-left, with position in label
    rect_a = Rectangle(
        (ax_tl, ay_tl), box_w, box_h,
        edgecolor='blue', lw=2, facecolor='none',
        label=f'Anchor GT (TL): ({ay_tl:.1f}, {ax_tl:.1f})'
    )
    ax.add_patch(rect_a)

    # draw GT box for target patch (red) at its top-left, with position
    rect_p = Rectangle(
        (px_tl, py_tl), box_w, box_h,
        edgecolor='red', lw=2, facecolor='none',
        label=f'Patch GT (TL): ({py_tl:.1f}, {px_tl:.1f})'
    )
    ax.add_patch(rect_p)

    # legend and title
    ax.legend(loc='upper right')
    ax.set_title(
        f'anchor={anchor_idx} → patch={patch_idx}   '\
        f'Pred mean=(%.1f,%.1f)' % (mu_y, mu_x)
    )
    ax.axis('off')
    plt.tight_layout()
    plt.close(fig)
    return fig

In [None]:
visualize_patch_distribution(
    io,
    model,
    train_transform,
    img,
    batch_idx=0,
    anchor_idx=0,
    patch_idx=16,
    # patch_idx=16,
    figsize=(6, 6)
)

In [None]:
import torch
import numpy as np
import math
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from matplotlib import animation


# ------------------------------------------------------------------
# 2.  ANIMATION ACROSS (anchor_idx, patch_idx) PAIRS
# ------------------------------------------------------------------

def animate_all_pairs(
    io,
    model,
    train_transform,
    img,
    batch_idx: int = 0,
    anchors: list | None = None,
    patch_order: str = 'sequential',  # or 'random'
    fps: int = 2,
    disp_scale_pixels: bool = True,
    figsize=(6,6),
):
    """
    Create a matplotlib.animation that iterates over every (anchor, patch) pair.

    Args:
        anchors: list of anchor indices to iterate. If None, uses all patches.
        patch_order: 'sequential' or 'random' iteration of target patches.
        fps: frames per second for the resulting animation.

    Returns:
        anim (FuncAnimation) – you can save with anim.save('out.mp4', fps=fps)
    """
    if anchors is None:
        anchors = list(range(io['pred_dT'].shape[2]))

    # prepare list of (anchor_idx, patch_idx)
    pairs = []
    for a in anchors:
        patches = list(range(io['pred_dT'].shape[2]))
        if patch_order == 'random':
            np.random.shuffle(patches)
        for p in patches:
            pairs.append((a, p))

    # set up matplotlib figure once
    fig, ax = plt.subplots(figsize=figsize)
    plt.axis('off')

    def init():
        ax.clear()
        ax.axis('off')
        return []

    def update(frame_idx):
        a_idx, p_idx = pairs[frame_idx]
        ax.clear()
        ax.axis('off')
        # generate the frame using the static function
        frame_fig = visualize_patch_distribution(
            io, model, train_transform, img,
            batch_idx=batch_idx,
            anchor_idx=a_idx,
            patch_idx=p_idx,
            figsize=figsize,
            disp_scale_pixels=disp_scale_pixels,
        )
        # extract the Axes image from the returned fig & draw onto our ax
        ax.imshow(frame_fig.axes[0].images[0].get_array(), interpolation='nearest')
        ax.imshow(frame_fig.axes[0].images[1].get_array(), cmap='hot', alpha=0.5,)
        for child in frame_fig.axes[0].get_children():
            if isinstance(child, Rectangle):
                ax.add_patch(Rectangle(child.get_xy(), child.get_width(), child.get_height(),
                                        edgecolor=child.get_edgecolor(), facecolor='none', lw=child.get_lw()))
        ax.set_title(f'anchor={a_idx} → patch={p_idx}')
        plt.close(frame_fig)
        return ax.patches  # need to return updated artists

    anim = animation.FuncAnimation(
        fig, update,
        init_func=init,
        frames=len(pairs),
        interval=1000//fps,
        blit=False,
        repeat=True,
    )
    return anim

# ------------------------------------------------------------------
# Example usage (inside a notebook):
# anim = animate_all_pairs(io, model, train_transform, img, batch_idx=0, fps=2, anchors=[0])
# from IPython.display import HTML
# HTML(anim.to_jshtml())


In [None]:
# anim = animate_all_pairs(io, model, train_transform, img, batch_idx=0, fps=2, anchors=[0])
# anim.save("uncertainty.mp4", fps=2, extra_args=['-vcodec', 'libx264'])

In [None]:
import torch
import numpy as np
from sklearn.cluster import DBSCAN
import matplotlib.pyplot as plt


def cluster_and_plot_patches(
    io,
    model,
    train_transform,
    img,
    batch_idx=0,
    anchor_idx=0,
    sigma_unc=0.05,
    eps=0.1,
    min_samples=5,
    figsize=(8,8),
):
    """
    1) Extract (mu_x, mu_y, b_x, b_y) for each patch
    2) Normalize and cluster with DBSCAN
    3) Overlay clusters on the canonical image

    Args:
        io: dict containing model I/O tensors
        model: object with view_ids_M, Is, canonical_img_size, patch_size
        train_transform: must have recreate_canonical(img, params)
        img: input image batch element
        batch_idx: which batch to use
        anchor_idx: which patch to anchor on
        sigma_unc: scale for uncertainty normalization
        eps, min_samples: DBSCAN tuning
        figsize: plot size
    """
    # canonical image size
    canon_size = model.canonical_img_size

    # determine anchor's crop view
    view_idx = model.view_ids_M[anchor_idx]
    crop_size = model.Is[view_idx]

    # get crop params and local patch pos
    cp = io['crop_params'][batch_idx][view_idx]
    lp = io['patch_positions_nopos'][batch_idx][anchor_idx]
    # compute anchor top-left in canonical coords
    anchor_tl = cp[:2] + (lp / crop_size) * cp[2:4]

    # predicted offsets and dispersions
    pred_dt = io['pred_dT'][batch_idx, anchor_idx, :, :2] * canon_size  # [N,2]
    disp    = io['disp_dT'][batch_idx, anchor_idx, :, :2]               # [N,2]

    # reconstruct canonical image and get dims
    canon = train_transform.recreate_canonical(img, io['canonical_params'][batch_idx])
    if torch.is_tensor(canon):
        canon = canon.permute(1,2,0).cpu().numpy()
    H, W = canon.size[:2]

    # compute predicted centers
    # note: subtract offsets per your GT convention
    mus = anchor_tl.unsqueeze(0) - pred_dt        # [N,2] tensor
    bs  = disp                                    # [N,2]

    # to numpy
    mus_np = mus.cpu().numpy()
    bs_np  = bs.cpu().numpy()

    # mask points inside image
    y, x = mus_np[:,0], mus_np[:,1]
    in_bounds = (y >= 0) & (y <= H) & (x >= 0) & (x <= W)
    mus_np = mus_np[in_bounds]
    bs_np  = bs_np[in_bounds]

    # build features
    mus_norm = mus_np / canon_size
    bs_norm  = bs_np  / (canon_size * sigma_unc)
    features = np.hstack([mus_norm, bs_norm])     # [M,4]

    # DBSCAN
    db = DBSCAN(eps=eps, min_samples=min_samples).fit(features)
    labels = db.labels_                           # length M

    # plot
    plt.figure(figsize=figsize)
    plt.imshow(canon, interpolation='nearest')
    scatter = plt.scatter(
        mus_np[:,1], mus_np[:,0],
        c=labels, cmap='tab20', s=40, edgecolor='k'
    )
    n_clusters = len(set(labels)) - (1 if -1 in labels else 0)
    plt.title(f'DBSCAN Clustering: {n_clusters} clusters (σ_unc={sigma_unc}, eps={eps})')
    plt.axis('off')
    plt.show()

    return labels, features

# Example usage:
labels, features = cluster_and_plot_patches(io, model, train_transform, img,
                                  batch_idx=0, anchor_idx=0,
                                  sigma_unc=0.5, eps=0.06, min_samples=5)


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.neighbors import NearestNeighbors

# Helper to plot k-distance graph for eps tuning
def plot_k_distance(features, k=5):
    """
    Plot the sorted k-distance graph (distance to k-th nearest neighbor) to help choose eps.

    Args:
        features: numpy array of shape [N, D]
        k: number of neighbors
    """
    # Compute k-nearest distances
    nbrs = NearestNeighbors(n_neighbors=k+1).fit(features)
    distances, _ = nbrs.kneighbors(features)
    # distances[:,0] is zero (self), so take distances[:,k]
    k_distances = np.sort(distances[:, k])

    plt.figure(figsize=(6,4))
    plt.plot(k_distances)
    plt.xlabel(f'Samples sorted by distance to {k}th NN')
    plt.ylabel(f'{k}th NN distance')
    plt.title('k-distance graph for DBSCAN eps selection')
    plt.grid(True)
    plt.show()

# Example usage in your notebook:
plot_k_distance(features, k=5)

# After inspecting the elbow point on the graph, pick eps where the curve shows a knee.


In [None]:
import torch
import numpy as np
from sklearn.cluster import DBSCAN
import matplotlib.pyplot as plt


def cluster_and_plot_patches(
    io,
    model,
    train_transform,
    img,
    batch_idx=0,
    anchor_idx=0,
    sigma_unc=0.05,
    eps=0.1,
    min_samples=5,
    figsize=(8,8),
    tune_sigma: bool = False,
    sigma_range=(0.01, 0.2, 10),
):
    """
    1) Extract (mu_x, mu_y, b_x, b_y) for each patch
    2) Normalize and cluster with DBSCAN
    3) Overlay clusters on the canonical image

    Args:
        io: dict containing model I/O tensors
        model: object with view_ids_M, Is, canonical_img_size, patch_size
        train_transform: must have recreate_canonical(img, params)
        img: input image batch element
        batch_idx: which batch to use
        anchor_idx: which patch to anchor on
        sigma_unc: scale for uncertainty normalization
        eps, min_samples: DBSCAN tuning
        figsize: plot size
    """
    # canonical image size
    canon_size = model.canonical_img_size

    # determine anchor's crop view
    view_idx = model.view_ids_M[anchor_idx]
    crop_size = model.Is[view_idx]

    # get crop params and local patch pos
    cp = io['crop_params'][batch_idx][view_idx]
    lp = io['patch_positions_nopos'][batch_idx][anchor_idx]
    # compute anchor top-left in canonical coords
    anchor_tl = cp[:2] + (lp / crop_size) * cp[2:4]

    # predicted offsets and dispersions
    pred_dt = io['pred_dT'][batch_idx, anchor_idx, :, :2] * canon_size  # [N,2]
    disp    = io['disp_dT'][batch_idx, anchor_idx, :, :2]               # [N,2]

    # reconstruct canonical image and get dims
    canon = train_transform.recreate_canonical(img, io['canonical_params'][batch_idx])
    if torch.is_tensor(canon):
        canon = canon.permute(1,2,0).cpu().numpy()
    H, W = canon.size[:2]

    # compute predicted centers
    # note: subtract offsets per your GT convention
    mus = anchor_tl.unsqueeze(0) - pred_dt        # [N,2] tensor
    bs  = disp                                    # [N,2]

    # to numpy
    mus_np = mus.cpu().numpy()
    bs_np  = bs.cpu().numpy()

    # mask points inside image
    y, x = mus_np[:,0], mus_np[:,1]
    in_bounds = (y >= 0) & (y <= H) & (x >= 0) & (x <= W)
    mus_np = mus_np[in_bounds]
    bs_np  = bs_np[in_bounds]

    # build features
    mus_norm = mus_np / canon_size
    bs_norm  = bs_np  / (canon_size * sigma_unc)
    features = np.hstack([mus_norm, bs_norm])     # [M,4]

        # optionally plot sigma_unc sensitivity (tune_sigma)
    if tune_sigma:
        sigmas = np.linspace(*sigma_range)
        plt.figure(figsize=(12, 3))
        for i, s in enumerate(sigmas, 1):
            bs_n = bs_np / (canon_size * s)
            feat = np.hstack([mus_norm, bs_n])
            # k-distance curve (6th NN)
            from sklearn.neighbors import NearestNeighbors
            nbrs = NearestNeighbors(n_neighbors=6).fit(feat)
            dists, _ = nbrs.kneighbors(feat)
            kdist = np.sort(dists[:,5])
            ax = plt.subplot(1, len(sigmas), i)
            ax.plot(kdist)
            ax.set_title(f'sigma_unc={s:.3f}')
            ax.set_xlabel('Sample index')
            ax.set_ylabel('6th NN distance')
        plt.suptitle('k-distance vs. sigma_unc')
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        plt.show()

    # DBSCAN
    db = DBSCAN(eps=eps, min_samples=min_samples).fit(features)
    labels = db.labels_                           # length M

    # plot
    plt.figure(figsize=figsize)
    plt.imshow(canon, interpolation='nearest')
    scatter = plt.scatter(
        mus_np[:,1], mus_np[:,0],
        c=labels, cmap='tab20', s=40, edgecolor='k'
    )
    n_clusters = len(set(labels)) - (1 if -1 in labels else 0)
    plt.title(f'DBSCAN Clustering: {n_clusters} clusters (σ_unc={sigma_unc}, eps={eps})')
    plt.axis('off')
    plt.show()

    return labels

labels = cluster_and_plot_patches(
    io, model, train_transform, img,
    tune_sigma=True,
    sigma_range=(0.01, 0.2, 5),  # try 5 values between 0.01 and 0.2
    eps=0.06, min_samples=5, sigma_unc=0.001
)



In [None]:
with torch.no_grad():
    out = model(*batch)
io = clean_model_io(batch, out, 'cuda')
fig, axes = plot_reconstruction_with_uncertainty(
    model,
    io,
    train_transform,
    img,
    uncertainty_mode="global_heatmap"  # or "global_heatmap", "none"
)

In [None]:
# Test both uncertainty visualization modes
fig1, axes1 = plot_reconstruction_with_uncertainty(
    model,
    io,
    train_transform,
    img,
    uncertainty_mode="global_heatmap"
)
plt.show()

fig2, axes2 = plot_reconstruction_with_uncertainty(
    model,
    io,
    train_transform,
    img,
    uncertainty_mode="laplace_distributions"
)
plt.show()

# Also check the statistics of the dispersions to understand the scale
print("Dispersion statistics:")
print(f"disp_T shape: {io['disp_T'].shape}")
print(f"disp_T min: {io['disp_T'].min():.6f}")
print(f"disp_T max: {io['disp_T'].max():.6f}")
print(f"disp_T mean: {io['disp_T'].mean():.6f}")
print(f"Position dispersions (first 2 dims) - min: {io['disp_T'][0, :, :2].min():.6f}, max: {io['disp_T'][0, :, :2].max():.6f}")

In [None]:
#!/usr/bin/env python
# coding: utf-8

from pathlib import Path
ROOT = Path("../../")
# In[1]:

from torch import Tensor
import torch
from torch.utils.data import  default_collate
from PIL import Image
import matplotlib.pyplot as plt
import torch.utils._pytree as pytree
# from src.utils.visualization.reconstruction_v5_anchor_reparam import reconstruction_lstsq_with_anchor_reparam
from src.utils.visualization.reconstruction_v6 import reconstruction_lstsq_with_anchor_reparam
from src.utils.visualization.reconstruction_v5_gt import reconstruction_gt
from omegaconf import OmegaConf
import hydra
import itertools
from jaxtyping import Float, Int
from typing import Literal
import math
import torch.nn.functional as F


# ## Utils

# In[2]:


def clean_model_io(batch: tuple, out: dict, device="cuda"):
    """
    Clean and organize model inputs and outputs for visualization and analysis.
    
    Args:
        batch: A tuple containing model inputs (global images, global params, local images, local params)
        out: Model output dictionary
        device: Device to move tensors to (default: "cuda")
        
    Returns:
        io: Dictionary containing organized model inputs and outputs
    """
    # Initialize output dictionary
    io = dict()
    
    # Extract shapes from model output
    io["x"] = [list(itertools.chain.from_iterable(items)) for items in zip(*batch[0])]
    io["params"] = [list(itertools.chain.from_iterable(items)) for items in zip(*batch[1])]
    io["canonical_params"] = [[param[0:4] for param in batch_params] for batch_params in io["params"]][0]
    io["crop_params"] = [[param[4:8] for param in batch_params] for batch_params in io["params"]]
    
    # Include all output values
    io.update({name: out[name] for name in out.keys()})
    
    # Move all tensors to the specified device
    io = pytree.tree_map_only(
        Tensor,
        lambda t: t.detach().to(device),
        io
    )
    return io


def make_plots(
    model,
    io,
    train_transform,
    original_img,
):
    
    gt_reconstruction = reconstruction_gt(
        x=io["x"][0],
        patch_positions_nopos=io["patch_positions_nopos"][0],
        num_tokens=model._Ms,
        crop_params=io["crop_params"][0],
        patch_size=model.patch_size,
        canonical_img_size=model.canonical_img_size,
    )
    pred_reconstruction, *_ = reconstruction_lstsq_with_anchor_reparam(
        x=io["x"][0],
        patch_positions_nopos=io["patch_positions_nopos"][0],
        num_tokens=model._Ms,
        crop_params=io["crop_params"][0],
        patch_size=model.patch_size,
        canonical_img_size=model.canonical_img_size,
        max_scale_ratio=model.max_scale_ratio,
        pred_dT=io["pred_dT"][0],
    )
    fig, axes = plt.subplots(1, 3)
    canonical_img = train_transform.recreate_canonical(
        original_img, io["canonical_params"][0]
    )
    axes[0].imshow(canonical_img)
    axes[0].set_title("Original")
    axes[0].axis("off")
    axes[1].imshow(gt_reconstruction.permute(1, 2, 0).cpu())
    axes[1].set_title("GT Reconstruction")
    axes[1].axis("off")
    axes[2].imshow(pred_reconstruction.permute(1, 2, 0).cpu())
    axes[2].set_title("Reconstruction")
    axes[2].axis("off")
    return fig, axes


# ## Reconstruction

# In[3]:


# overfit to a few batches
torch.set_float32_matmul_precision("high")


# In[4]:


if not OmegaConf.has_resolver("eval"):
    OmegaConf.register_new_resolver("eval", eval)


# In[5]:


# FOLDER = ROOT / Path("outputs/2025-06-16/13-23-31")
FOLDER = ROOT / Path("outputs/2025-06-22/19-16-53")
cfg = OmegaConf.load(FOLDER / ".hydra/config.yaml")


# In[6]:


cfg = OmegaConf.load(FOLDER / ".hydra/config.yaml")
if "predict_uncertainty" in cfg["model"]:
    predict_uncertainty = cfg["model"].pop("predict_uncertainty")
    if predict_uncertainty:
        cfg["model"]["uncertainty_mode"] = "additive"
    else:
        cfg["model"]["uncertainty_mode"] = "none"
elif "uncertainty_mode" in cfg["model"]:
    pass
else:
    raise ValueError("Uncertainty mode not specified in the config.")
print(cfg["model"]["uncertainty_mode"])
ckpt_path = FOLDER / "epoch_0199.ckpt"
ckpt = torch.load(ckpt_path, map_location="cuda")
state_dict = ckpt["model"]
if "pose_head.mu.weight" in state_dict:
    state_dict["pose_head.mu_proj.weight"] = state_dict.pop("pose_head.mu.weight")
if "pose_head.logvar.weight" in state_dict:
    state_dict["pose_head.disp_proj.weight"] = state_dict.pop("pose_head.logvar.weight")
if "pose_head.logvar.bias" in state_dict:
    state_dict["pose_head.disp_proj.bias"] = state_dict.pop("pose_head.logvar.bias")
if "pose_head.gate_proj.weight" in state_dict:
    if "gate_dim" not in cfg["model"]:
        cfg["model"]["gate_dim"] = state_dict["pose_head.gate_proj.weight"].shape[0]
        print(f"Gate dimension not specified in the config, inferring from state_dict: {cfg['model']['gate_dim']}")
    assert cfg["model"]["gate_dim"] == state_dict["pose_head.gate_proj.weight"].shape[0]
if "pose_head.gate_mult" not in state_dict:
    state_dict["pose_head.gate_mult"] = torch.zeros(1)


# In[7]:


V = 2
gV = 2
lV = V - gV
if V == 12:
    model = hydra.utils.instantiate(
        cfg["model"],
        # gate_dim=cfg["model"].get("gate_dim", 16),
        _target_="src.models.components.partmae_v6.PARTMaskedAutoEncoderViT",
        num_views=V,
        # mask_ratio=0,
        mask_ratio=0.75,
        pos_mask_ratio=0.75,
        # sampler='ongrid_canonical'
    )
elif V == 2:
    model = hydra.utils.instantiate(
        cfg["model"],
        # gate_dim=cfg["model"].get("gate_dim", 16),
        _target_="src.models.components.partmae_v6.PARTMaskedAutoEncoderViT",
        num_views=V,
        mask_ratio=0,
        pos_mask_ratio=0.75,
        # sampler='ongrid_canonical'
    )
else:
    raise ValueError(f"Unsupported number of views: {V}")
model.load_state_dict(state_dict, strict=True)
print(ckpt["global_step"], ckpt["epoch"])


# In[8]:


img = Image.open(ROOT / "artifacts/samoyed.jpg")
# .crop((0, 0, 1000, 1000))
train_transform = hydra.utils.instantiate(
    cfg["data"]["transform"], distort_color=False, n_local_crops=V - gV
)
batch = default_collate([train_transform(img), train_transform(img), train_transform(img), train_transform(img)])


# In[10]:


with torch.no_grad():
    out = model(*batch)
io = clean_model_io(batch, out, 'cuda')
fig, axes = make_plots(
    model,
    io,
    train_transform,
    img,
)




# In[ ]:

def create_global_confidence_heatmap(
    patch_positions: Float[Tensor, "M 2"],
    dispersions: Float[Tensor, "M 4"],  # [dy, dx, dlogh, dlogw]
    canonical_size: int,
    patch_size: int,
) -> Float[Tensor, "canonical_size canonical_size"]:
    """
    Create a global confidence heatmap using inverse of position dispersions.
    Higher values = more confidence (lower uncertainty).
    
    Args:
        patch_positions: Predicted patch positions in canonical space [M, 2]
        dispersions: Predicted dispersions for each patch [M, 4] (in pixel units)
        canonical_size: Size of the canonical image
        patch_size: Size of the patches in the image
        
    Returns:
        confidence_map: Global confidence heatmap (higher = more confident)
    """
    device = patch_positions.device
    confidence_map = torch.zeros((canonical_size, canonical_size), device=device)
    
    # Take norm of position dispersions (dy, dx) and invert for confidence
    pos_uncertainty = torch.norm(dispersions[:, :2], dim=1)  # [M]
    # Convert uncertainty to confidence: higher uncertainty -> lower confidence
    pos_confidence = 1.0 / (1.0 + pos_uncertainty)  # [M]
    
    for i, (pos, conf) in enumerate(zip(patch_positions, pos_confidence)):
        y, x = pos.int()
        # Clamp to image bounds
        y = torch.clamp(y, 0, canonical_size - patch_size)
        x = torch.clamp(x, 0, canonical_size - patch_size)
        
        # Add confidence to the patch region
        confidence_map[y:y+patch_size, x:x+patch_size] += conf
        
    return confidence_map


def create_global_uncertainty_heatmap(
    patch_positions: Float[Tensor, "M 2"],
    dispersions: Float[Tensor, "M 4"],  # [dy, dx, dlogh, dlogw]
    canonical_size: int,
    patch_size: int,
) -> Float[Tensor, "canonical_size canonical_size"]:
    """
    Create a global uncertainty heatmap by taking the norm of position dispersions.
    Higher values = more uncertainty (lower confidence).
    
    Args:
        patch_positions: Predicted patch positions in canonical space [M, 2]
        dispersions: Predicted dispersions for each patch [M, 4] (in pixel units)
        canonical_size: Size of the canonical image
        patch_size: Size of the patches in the image
        
    Returns:
        uncertainty_map: Global uncertainty heatmap (higher = more uncertain)
    """
    device = patch_positions.device
    uncertainty_map = torch.zeros((canonical_size, canonical_size), device=device)
    
    # Take norm of position dispersions (dy, dx) directly for uncertainty
    pos_uncertainty = torch.norm(dispersions[:, :2], dim=1)  # [M]
    
    for i, (pos, unc) in enumerate(zip(patch_positions, pos_uncertainty)):
        y, x = pos.int()
        # Clamp to image bounds
        y = torch.clamp(y, 0, canonical_size - patch_size)
        x = torch.clamp(x, 0, canonical_size - patch_size)
        
        # Add uncertainty to the patch region
        uncertainty_map[y:y+patch_size, x:x+patch_size] += unc
        
    return uncertainty_map


def create_laplace_confidence_heatmaps(
    patch_positions: Float[Tensor, "M 2"],
    dispersions: Float[Tensor, "M 4"],  # [dy, dx, dlogh, dlogw]
    canonical_size: int,
    patch_size: int,
    alpha: float = 0.7,
    confidence_transform: str = "inverse_width",  # "inverse_width" or "peak_height"
) -> Float[Tensor, "canonical_size canonical_size"]:
    """
    Create overlaid Laplace distribution heatmaps showing confidence.
    
    Args:
        patch_positions: Predicted patch positions in canonical space [M, 2]
        dispersions: Per-patch dispersions [M, 4] (in pixel units)
        canonical_size: Size of canonical image
        patch_size: Size of patches
        alpha: Blending factor for overlapping distributions
        confidence_transform: How to convert uncertainty to confidence visualization
        
    Returns:
        combined_heatmap: Combined confidence distributions heatmap
    """
    device = patch_positions.device
    combined_heatmap = torch.zeros((canonical_size, canonical_size), device=device)
    
    # Create coordinate grids
    y_coords = torch.arange(canonical_size, device=device).float()
    x_coords = torch.arange(canonical_size, device=device).float()
    Y, X = torch.meshgrid(y_coords, x_coords, indexing='ij')
    
    for pos, disp in zip(patch_positions, dispersions):
        mu_y, mu_x = pos[0], pos[1]
        b_y, b_x = disp[0], disp[1]  # Laplace scale parameters (already in pixel units)
        
        # Ensure reasonable scale parameters
        b_y = torch.clamp(b_y, min=0.5, max=canonical_size/2)
        b_x = torch.clamp(b_x, min=0.5, max=canonical_size/2)
        
        # Only compute distribution if patch is reasonably within extended bounds
        if (mu_y >= -2*patch_size and mu_y <= canonical_size + 2*patch_size and 
            mu_x >= -2*patch_size and mu_x <= canonical_size + 2*patch_size):
            
            # Compute Laplace distribution
            laplace_dist = (1.0 / (4 * b_y * b_x)) * torch.exp(
                -torch.abs(Y - mu_y) / b_y - torch.abs(X - mu_x) / b_x
            )
            
            if confidence_transform == "inverse_width":
                # Scale by inverse of width for confidence: narrower = more confident
                width_factor = 1.0 / (b_y * b_x)
                confidence_factor = width_factor / (1.0 + width_factor)
                laplace_dist = laplace_dist * confidence_factor
            
            # Normalize to [0,1]
            if laplace_dist.max() > 0:
                laplace_dist = laplace_dist / laplace_dist.max()
            
            # Add to combined heatmap
            combined_heatmap += alpha * laplace_dist
    
    # Normalize the final combined heatmap
    if combined_heatmap.max() > 0:
        combined_heatmap = combined_heatmap / combined_heatmap.max()
    
    return combined_heatmap


def create_laplace_distribution_heatmaps(
    patch_positions: Float[Tensor, "M 2"],
    dispersions: Float[Tensor, "M 4"],  # [dy, dx, dlogh, dlogw]
    canonical_size: int,
    patch_size: int,
    alpha: float = 0.7,
) -> Float[Tensor, "canonical_size canonical_size"]:
    """
    Create overlaid Laplace distribution heatmaps at each predicted patch position.
    This shows the raw uncertainty distributions without confidence transformation.
    
    Args:
        patch_positions: Predicted patch positions in canonical space [M, 2]
        dispersions: Per-patch dispersions [M, 4] (in pixel units)
        canonical_size: Size of canonical image
        patch_size: Size of patches
        alpha: Blending factor for overlapping distributions
        
    Returns:
        combined_heatmap: Combined Laplace distributions heatmap
    """
    device = patch_positions.device
    combined_heatmap = torch.zeros((canonical_size, canonical_size), device=device)
    
    # Create coordinate grids
    y_coords = torch.arange(canonical_size, device=device).float()
    x_coords = torch.arange(canonical_size, device=device).float()
    Y, X = torch.meshgrid(y_coords, x_coords, indexing='ij')
    
    for pos, disp in zip(patch_positions, dispersions):
        mu_y, mu_x = pos[0], pos[1]
        b_y, b_x = disp[0], disp[1]  # Laplace scale parameters (already in pixel units)
        
        # Ensure reasonable scale parameters
        b_y = torch.clamp(b_y, min=0.5, max=canonical_size/2)
        b_x = torch.clamp(b_x, min=0.5, max=canonical_size/2)
        
        # Only compute distribution if patch is reasonably within extended bounds
        if (mu_y >= -2*patch_size and mu_y <= canonical_size + 2*patch_size and 
            mu_x >= -2*patch_size and mu_x <= canonical_size + 2*patch_size):
            
            # Compute Laplace distribution: (1/(4*b_y*b_x)) * exp(-|y-mu_y|/b_y - |x-mu_x|/b_x)
            laplace_dist = (1.0 / (4 * b_y * b_x)) * torch.exp(
                -torch.abs(Y - mu_y) / b_y - torch.abs(X - mu_x) / b_x
            )
            
            # Normalize to [0,1] to prevent any single distribution from dominating
            if laplace_dist.max() > 0:
                laplace_dist = laplace_dist / laplace_dist.max()
            
            # Add to combined heatmap with additive blending
            combined_heatmap += alpha * laplace_dist
    
    # Normalize the final combined heatmap
    if combined_heatmap.max() > 0:
        combined_heatmap = combined_heatmap / combined_heatmap.max()
    
    return combined_heatmap


def paste_patch(
    crop: Float[Tensor, "C h w"],
    pos: Float[Tensor, "2"],
    pos_canonical: Float[Tensor, "2"],
    patch_size_canonical: Float[Tensor, "2"],
    canvas: Float[Tensor, "C H W"],
    count_map: Float[Tensor, "1 H W"],
    patch_size: int,
    canonical_size: int,
    disp: Float[Tensor, "4"] = None,
):
    """
    Extract a patch from a crop at pos and paste it onto a canvas at pos_canonical with appropriate rescaling.

    Args:
        crop: Source image crop of shape [C, h, w]
        pos: Patch position in crop coordinates [y, x]
        pos_canonical: Target position in canonical coordinates [y, x]
        patch_size_canonical: Size of patch in canonical space [height, width]
        canvas: Target canvas to paste onto [C, H, W]
        count_map: Counter for averaging overlapping patches [1, H, W]
        patch_size: Size of patch in crop space
        canonical_size: Size of the canonical image
        disp: Per token dispersion (as in Laplace scale) for each transformation parameter.

        pos ~ Laplace(mu_yx, b_yx) 
    """
    crop_h, crop_w = crop.shape[1:3]

    # Convert to integer coordinates for the canonical position
    y_canonical, x_canonical = int(round(pos_canonical[0].item())), int(
        round(pos_canonical[1].item())
    )

    # Get integer patch size for the canonical space
    patch_h_canonical, patch_w_canonical = patch_size_canonical.round().int()

    # Ensure the patch fits within the canonical canvas
    y_canonical = max(0, min(canonical_size - patch_h_canonical, y_canonical))
    x_canonical = max(0, min(canonical_size - patch_w_canonical, x_canonical))

    # Get source patch coordinates, ensuring they're within the crop boundaries
    y_crop, x_crop = int(round(pos[0].item())), int(round(pos[1].item()))
    y_crop = max(0, min(crop_h - patch_size, y_crop))
    x_crop = max(0, min(crop_w - patch_size, x_crop))

    # Extract the patch from the source crop
    patch = crop[
        :, y_crop : y_crop + patch_size, x_crop : x_crop + patch_size
    ].unsqueeze(0)

    # Resize the patch to the canonical size
    patch_resized = F.interpolate(
        patch,
        size=(patch_h_canonical, patch_w_canonical),
        mode="bilinear",
        align_corners=False,
    ).squeeze(0)

    # Add the patch to the canvas and update the count map
    canvas[
        :,
        y_canonical : y_canonical + patch_h_canonical,
        x_canonical : x_canonical + patch_w_canonical,
    ] += patch_resized
    count_map[
        :,
        y_canonical : y_canonical + patch_h_canonical,
        x_canonical : x_canonical + patch_w_canonical,
    ] += 1


@torch.no_grad
def reconstruction_with_uncertainty_visualization(
    x: list[Float[Tensor, "C gH gW"] | Float[Tensor, "C lH lW"]],
    patch_positions_nopos: Float[Tensor, "M 2"],
    num_tokens: list[int],
    crop_params: list[Float[Tensor, "4"]],
    patch_size: int,
    canonical_img_size: int,
    max_scale_ratio: float,
    pred_dT: Float[Tensor, "M M 4"],
    disp_T: Float[Tensor, "M 4"],  # NOTE: This contains LOG-dispersions
    uncertainty_mode: Literal["none", "confidence_heatmap", "confidence_distributions", "uncertainty_heatmap", "uncertainty_distributions"] = "none",
) -> tuple[
    Float[Tensor, "C canonical_img_size canonical_img_size"],  # reconstructed image
    Float[Tensor, "canonical_img_size canonical_img_size"] | None,  # uncertainty/confidence map
]:
    """
    Reconstruct image with optional uncertainty/confidence visualization.
    
    Args:
        disp_T: Per-token log-dispersions [M, 4] - NOTE: these are in log-space!
        uncertainty_mode: 
            - "none": No visualization
            - "confidence_heatmap": Global confidence heatmap (bright = confident)
            - "confidence_distributions": Individual Laplace confidence distributions
            - "uncertainty_heatmap": Global uncertainty heatmap (bright = uncertain)
            - "uncertainty_distributions": Individual Laplace uncertainty distributions
    
    Returns:
        reconstructed_img: Reconstructed canonical image
        uncertainty_map: Uncertainty/confidence visualization (None if uncertainty_mode="none")
    """
    device = x[0].device
    C = x[0].shape[0]

    # Undo normalization
    dT = pred_dT[..., :2] * canonical_img_size
    dS = pred_dT[..., 2:] * math.log(max_scale_ratio)

    # Choose anchor
    T_anchor = (
        crop_params[0][:2]
        + (patch_positions_nopos[0] / x[0].shape[1]) * crop_params[0][2:4]
    )
    S_anchor = torch.log((patch_size * crop_params[0][2:4] / x[0].shape[1]))

    T_global = dT[:, 0] + T_anchor
    S_global = dS[:, 0] + S_anchor

    T_global_grouped = torch.split(T_global, num_tokens)
    S_global_grouped = torch.split(S_global, num_tokens)
    patch_positions_nopos_grouped = torch.split(patch_positions_nopos, num_tokens)
    disp_T_grouped = torch.split(disp_T, num_tokens)

    # Reconstruct the canonical image
    canvas = torch.zeros((C, canonical_img_size, canonical_img_size), device=device)
    count_map = torch.zeros((1, canonical_img_size, canonical_img_size), device=device)

    for crop, patch_positions, canonical_pos, log_size, disp in zip(
        x,
        patch_positions_nopos_grouped,
        T_global_grouped,
        S_global_grouped,
        disp_T_grouped,
    ):
        N = patch_positions.shape[0]
        for i in range(N):
            paste_patch(
                crop=crop,
                pos=patch_positions[i].float(),
                pos_canonical=canonical_pos[i],
                patch_size_canonical=torch.exp(log_size[i]),
                canvas=canvas,
                count_map=count_map,
                patch_size=patch_size,
                canonical_size=canonical_img_size,
                disp=disp[i]
            )

    count_map[count_map == 0] = 1
    reconstructed_img = canvas / count_map

    # Generate uncertainty/confidence visualization
    viz_map = None
    if uncertainty_mode != "none":
        # Convert log-dispersions to actual dispersions and scale to pixel units
        actual_dispersions = torch.exp(disp_T)  # Convert from log-space
        disp_T_pixels = actual_dispersions.clone()
        disp_T_pixels[:, :2] *= canonical_img_size  # dy, dx to pixels
        disp_T_pixels[:, 2:] *= math.log(max_scale_ratio)  # log-scale factors
        
        if uncertainty_mode == "confidence_heatmap":
            viz_map = create_global_confidence_heatmap(
                T_global, disp_T_pixels, canonical_img_size, patch_size
            )
        elif uncertainty_mode == "confidence_distributions":
            viz_map = create_laplace_confidence_heatmaps(
                T_global, disp_T_pixels, canonical_img_size, patch_size
            )
        elif uncertainty_mode == "uncertainty_heatmap":
            viz_map = create_global_uncertainty_heatmap(
                T_global, disp_T_pixels, canonical_img_size, patch_size
            )
        elif uncertainty_mode == "uncertainty_distributions":
            viz_map = create_laplace_distribution_heatmaps(
                T_global, disp_T_pixels, canonical_img_size, patch_size
            )

    return reconstructed_img, viz_map


def plot_reconstruction_with_uncertainty(
    model,
    io,
    train_transform,
    original_img,
    uncertainty_mode: Literal["none", "confidence_heatmap", "confidence_distributions", "uncertainty_heatmap", "uncertainty_distributions"] = "none",
):
    """
    Updated plotting function that supports uncertainty/confidence visualization.
    """
    # Generate GT reconstruction (unchanged)
    gt_reconstruction = reconstruction_gt(
        x=io["x"][0],
        patch_positions_nopos=io["patch_positions_nopos"][0],
        num_tokens=model._Ms,
        crop_params=io["crop_params"][0],
        patch_size=model.patch_size,
        canonical_img_size=model.canonical_img_size,
    )
    
    # Generate prediction with uncertainty
    pred_reconstruction, viz_map = reconstruction_with_uncertainty_visualization(
        x=io["x"][0],
        patch_positions_nopos=io["patch_positions_nopos"][0],
        num_tokens=model._Ms,
        crop_params=io["crop_params"][0],
        patch_size=model.patch_size,
        canonical_img_size=model.canonical_img_size,
        max_scale_ratio=model.max_scale_ratio,
        pred_dT=io["pred_dT"][0],
        disp_T=io["disp_T"][0],
        uncertainty_mode=uncertainty_mode,
    )
    
    # Determine number of subplots
    n_plots = 4 if uncertainty_mode != "none" else 3
    fig, axes = plt.subplots(1, n_plots, figsize=(4*n_plots, 4))
    
    # Original image
    canonical_img = train_transform.recreate_canonical(
        original_img, io["canonical_params"][0]
    )
    axes[0].imshow(canonical_img)
    axes[0].set_title("Original")
    axes[0].axis("off")
    
    # GT reconstruction
    axes[1].imshow(gt_reconstruction.permute(1, 2, 0).cpu())
    axes[1].set_title("GT Reconstruction")
    axes[1].axis("off")
    
    # Predicted reconstruction
    axes[2].imshow(pred_reconstruction.permute(1, 2, 0).cpu())
    axes[2].set_title("Reconstruction")
    axes[2].axis("off")
    
    # Uncertainty/confidence visualization
    if uncertainty_mode != "none" and viz_map is not None:
        # Choose appropriate colormap and title
        if "confidence" in uncertainty_mode:
            cmap = 'hot'  # bright = confident
            title_prefix = "Confidence"
        else:
            cmap = 'hot_r'  # bright = uncertain (inverted hot)
            title_prefix = "Uncertainty"
            
        mode_name = uncertainty_mode.split('_')[1]  # "heatmap" or "distributions"
        
        im = axes[3].imshow(viz_map.cpu(), cmap=cmap, alpha=0.8)
        axes[3].set_title(f"{title_prefix} ({mode_name})")
        axes[3].axis("off")
        plt.colorbar(im, ax=axes[3], fraction=0.046, pad=0.04)
    
    plt.tight_layout()
    return fig, axes


# In[9]:

with torch.no_grad():
    out = model(*batch)
io = clean_model_io(batch, out, 'cuda')
fig, axes = plot_reconstruction_with_uncertainty(
    model,
    io,
    train_transform,
    img,
    uncertainty_mode="global_heatmap"  # or "global_heatmap", "none"
)


# Test both confidence and uncertainty visualization modes
print("Testing confidence heatmap visualization...")
fig1, axes1 = plot_reconstruction_with_uncertainty(
    model,
    io,
    train_transform,
    img,
    uncertainty_mode="confidence_heatmap"
)
plt.show()

print("Testing confidence distributions visualization...")
fig2, axes2 = plot_reconstruction_with_uncertainty(
    model,
    io,
    train_transform,
    img,
    uncertainty_mode="confidence_distributions"
)
plt.show()

print("Testing uncertainty heatmap visualization...")
fig3, axes3 = plot_reconstruction_with_uncertainty(
    model,
    io,
    train_transform,
    img,
    uncertainty_mode="uncertainty_heatmap"
)
plt.show()

print("Testing uncertainty distributions visualization...")
fig4, axes4 = plot_reconstruction_with_uncertainty(
    model,
    io,
    train_transform,
    img,
    uncertainty_mode="uncertainty_distributions"
)
plt.show()

# Also check the statistics of the dispersions to understand the scale
print("Dispersion statistics:")
print(f"disp_T shape: {io['disp_T'].shape}")
print(f"disp_T (log-space) min: {io['disp_T'].min():.6f}")
print(f"disp_T (log-space) max: {io['disp_T'].max():.6f}")
print(f"disp_T (log-space) mean: {io['disp_T'].mean():.6f}")

# Convert to actual dispersions for better understanding
actual_disp = torch.exp(io['disp_T'])
print(f"Actual dispersions min: {actual_disp.min():.6f}")
print(f"Actual dispersions max: {actual_disp.max():.6f}")
print(f"Actual dispersions mean: {actual_disp.mean():.6f}")
print(f"Position dispersions (normalized) - min: {actual_disp[0, :, :2].min():.6f}, max: {actual_disp[0, :, :2].max():.6f}")

# Show what they become in pixel units
pixel_disp = actual_disp.clone()
pixel_disp[:, :, :2] *= model.canonical_img_size
print(f"Position dispersions (pixels) - min: {pixel_disp[0, :, :2].min():.1f}, max: {pixel_disp[0, :, :2].max():.1f}")

# Debug: Let's check what patches have high vs low uncertainty
print("\nDebugging uncertainty interpretation:")
actual_dispersions = torch.exp(io['disp_T'][0])  # [M, 4] for first batch
pos_uncertainties = torch.norm(actual_dispersions[:, :2], dim=1)  # [M]
print(f"Min position uncertainty: {pos_uncertainties.min():.6f}")
print(f"Max position uncertainty: {pos_uncertainties.max():.6f}")
print(f"Mean position uncertainty: {pos_uncertainties.mean():.6f}")

# Check if smaller dispersions (more confident) are in the center
sorted_indices = torch.argsort(pos_uncertainties)
most_confident_patches = sorted_indices[:10]  # 10 most confident (smallest dispersion)
least_confident_patches = sorted_indices[-10:]  # 10 least confident (largest dispersion)

print(f"\nMost confident patch positions (smallest dispersions):")
for i in most_confident_patches:
    pos = io['patch_positions_nopos'][0][i]
    print(f"  Patch {i}: pos={pos.cpu().numpy()}, uncertainty={pos_uncertainties[i]:.6f}")

print(f"\nLeast confident patch positions (largest dispersions):")
for i in least_confident_patches:
    pos = io['patch_positions_nopos'][0][i]
    print(f"  Patch {i}: pos={pos.cpu().numpy()}, uncertainty={pos_uncertainties[i]:.6f}")

# Enhanced debugging to understand what's happening
print("\n" + "="*60)
print("DETAILED ANALYSIS: Understanding Confidence vs Uncertainty")
print("="*60)

actual_dispersions = torch.exp(io['disp_T'][0])  # [M, 4] for first batch
pos_uncertainties = torch.norm(actual_dispersions[:, :2], dim=1)  # [M]
pos_confidences = 1.0 / (1.0 + pos_uncertainties)

print(f"Uncertainty range: {pos_uncertainties.min():.6f} to {pos_uncertainties.max():.6f}")
print(f"Confidence range: {pos_confidences.min():.6f} to {pos_confidences.max():.6f}")

# Analyze patch distribution by location
patch_positions = io['patch_positions_nopos'][0]  # [M, 2]
crop_height, crop_width = io['x'][0][0].shape[1:3]

# Classify patches by their position in the crop
center_y, center_x = crop_height // 2, crop_width // 2
patch_distances_from_center = torch.norm(patch_positions.float() - torch.tensor([center_y, center_x], device=patch_positions.device), dim=1)

# Find patches in different regions
center_patches = patch_distances_from_center < crop_height * 0.3
edge_patches = patch_distances_from_center > crop_height * 0.7

print(f"\nCenter patches (distance < 30% of crop): {center_patches.sum()} patches")
if center_patches.sum() > 0:
    print(f"  - Mean uncertainty: {pos_uncertainties[center_patches].mean():.6f}")
    print(f"  - Mean confidence: {pos_confidences[center_patches].mean():.6f}")

print(f"\nEdge patches (distance > 70% of crop): {edge_patches.sum()} patches")
if edge_patches.sum() > 0:
    print(f"  - Mean uncertainty: {pos_uncertainties[edge_patches].mean():.6f}")
    print(f"  - Mean confidence: {pos_confidences[edge_patches].mean():.6f}")
else:
    print("  - No patches found in edge region")

# Let's also look at medium distance patches
medium_patches = (patch_distances_from_center >= crop_height * 0.3) & (patch_distances_from_center <= crop_height * 0.7)
print(f"\nMedium distance patches (30%-70% of crop): {medium_patches.sum()} patches")
if medium_patches.sum() > 0:
    print(f"  - Mean uncertainty: {pos_uncertainties[medium_patches].mean():.6f}")
    print(f"  - Mean confidence: {pos_confidences[medium_patches].mean():.6f}")

print(f"\nKey Findings:")
print(f"- Position dispersions range from {pixel_disp[0, :, :2].min():.1f} to {pixel_disp[0, :, :2].max():.1f} pixels")
print(f"- Most confident patches have ~{pos_uncertainties.min():.3f} uncertainty (very precise)")
print(f"- Least confident patches have ~{pos_uncertainties.max():.3f} uncertainty (quite uncertain)")
print(f"- Center patches are more confident ({pos_confidences[center_patches].mean():.3f}) than overall mean ({pos_confidences.mean():.3f})")

print(f"\nVisualization Guide:")
print(f"- CONFIDENCE maps: Bright center = model is confident about dog/main subject")
print(f"- UNCERTAINTY maps: Dark center = model is less uncertain about dog/main subject")
print(f"- This pattern makes perfect sense: clear features → confident predictions")