# Visualize conditional dependence results
Load the saved toy dataset and plot how the contamination becomes visible only after conditioning on the context variable Z.

In [None]:
from pathlib import Path
import sys
import numpy as np
import matplotlib.pyplot as plt

# Locate repo root (contains readme.md and experiments/)
REPO_ROOT = None
for candidate in [Path.cwd(), *Path.cwd().parents]:
    if (candidate / 'readme.md').exists() and (candidate / 'experiments').exists():
        REPO_ROOT = candidate
        break
if REPO_ROOT is None:
    REPO_ROOT = Path.cwd()

if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))

# Default dataset location under the repo root (../data from experiments/)
dataset_path = REPO_ROOT / 'data' / 'toy_data.npz'
data = np.load(dataset_path)
x, y, z = data['X'], data['Y'], data['Z']

print(f"Loaded dataset from {dataset_path} with shapes: X={x.shape}, Y={y.shape}, Z={z.shape}")

## Global view: X and Y appear independent
A quick scatter plot and Pearson correlation across the full dataset should look nearly independent when we ignore the contextual variable Z.

In [None]:
plt.figure(figsize=(6, 4))
plt.scatter(x, y, s=4, alpha=0.25, color="#4e79a7")
plt.xlabel("Candidate systematic X")
plt.ylabel("Shear Y")
plt.title("Global view: near-independence without conditioning")
plt.tight_layout()
plt.show()

corr = np.corrcoef(x, y)[0, 1]
print(f"Global Pearson r ≈ {corr:.3f}")


## Conditional slice: dependence activates when Z > 0
When we focus on positive contexts, the contamination term turns on and X→Y dependence becomes visible.

In [None]:
mask_positive = z > 0
plt.figure(figsize=(6, 4))
plt.scatter(x[mask_positive], y[mask_positive], s=6, alpha=0.3, color="#f28e2b", label="Z > 0")
plt.scatter(x[~mask_positive], y[~mask_positive], s=4, alpha=0.1, color="#9ea3a6", label="Z ≤ 0")
plt.xlabel("Candidate systematic X")
plt.ylabel("Shear Y")
plt.title("Conditional dependence emerges for Z > 0")
plt.legend(frameon=False)
plt.tight_layout()
plt.show()

corr_positive = np.corrcoef(x[mask_positive], y[mask_positive])[0, 1]
corr_negative = np.corrcoef(x[~mask_positive], y[~mask_positive])[0, 1]
print(f"Pearson r | Z>0 ≈ {corr_positive:.3f}")
print(f"Pearson r | Z≤0 ≈ {corr_negative:.3f}")


## Context vs. shear: contamination gating
Plotting Y directly against the context Z highlights the activation of contamination as Z increases.

In [None]:
plt.figure(figsize=(6, 4))
plt.scatter(z, y, s=4, alpha=0.3, color="#59a14f")
plt.xlabel("Context Z")
plt.ylabel("Shear Y")
plt.title("Conditional activation of contamination by context")
plt.tight_layout()
plt.show()


## Conditional distribution colored by learned $T$
A learned contextual embedding $T(Z)$ should separate regions where $X$ and $Y$ become dependent. In this toy collider, $Z$ already cleanly exposes the effect, so the model can learn $T$ very easily, but in real use we want $T$ to uncover more complex, nontrivial structure in $X/Y$.

In [None]:

from pathlib import Path
import torch
from models.t_network import build_t_network

weights_path = Path("conditional_results.pt")
t_values = None
if weights_path.exists():
    state = torch.load(weights_path, map_location="cpu")
    t_net = build_t_network(z_dim=z.shape[1], t_dim=2)
    t_net.load_state_dict(state["t_net"])
    t_net.eval()
    with torch.no_grad():
        t_values = t_net(torch.from_numpy(z).float()).numpy()
    print(f"Loaded trained T-network from {weights_path} and computed T with shape {t_values.shape}.")
else:
    # Fallback: reuse Z as a simple proxy when trained weights are not present.
    t_values = np.repeat(z, 2, axis=1)
    print("No trained T-network found; using Z as a simple proxy for T.")

plt.figure(figsize=(6, 4))
scatter = plt.scatter(x, y, c=t_values[:, 0], cmap="viridis", s=6, alpha=0.35)
plt.xlabel("Candidate systematic X")
plt.ylabel("Shear Y")
plt.title("Samples colored by learned context T(Z)")
cbar = plt.colorbar(scatter)
cbar.set_label("T component (dim 0)")
plt.tight_layout()
plt.show()
