In [1]:
from google.colab import drive

drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import os
import json
from typing import Tuple, Optional, Dict, List

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

# ============================================================
# Paths and basic configuration
# ============================================================

ROOT = "/content/drive/MyDrive/Project"

USE_SEMI_SYNTH = True

if USE_SEMI_SYNTH:
    DATA_DIR   = os.path.join(ROOT, "experiments_semi_synthetic", "data")
    MODEL_DIR  = os.path.join(ROOT, "experiments_semi_synthetic", "models")
    RESULT_DIR = os.path.join(ROOT, "experiments_semi_synthetic", "results")
    PROP_DIR   = os.path.join(ROOT, "experiments_semi_synthetic", "propensity_models")
    FNAME_FMT  = "semi_beta{beta}_l{l}.pt"
else:
    DATA_DIR   = os.path.join(ROOT, "experiments_synthetic", "data")
    MODEL_DIR  = os.path.join(ROOT, "experiments_synthetic", "models")
    RESULT_DIR = os.path.join(ROOT, "experiments_synthetic", "results")
    PROP_DIR   = os.path.join(ROOT, "experiments_synthetic", "propensity_models")
    FNAME_FMT  = "synthetic_beta{beta}_l{l}.pt"

os.makedirs(DATA_DIR, exist_ok=True)
os.makedirs(MODEL_DIR, exist_ok=True)
os.makedirs(RESULT_DIR, exist_ok=True)
os.makedirs(PROP_DIR, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Training hyperparameters
BATCH_SIZE = 256
LR_TARNET = 5e-4
LR_PROP = 1e-3
N_EPOCHS_TARNET = 1000  # Set high, we rely on Early Stopping now
N_EPOCHS_PROP = 50
PATIENCE = 50           # Early stopping patience
LAMBDA_BAL = 10.0       # Increased from 1.0 to force balancing with strong bias


# ============================================================
# Utilities: loading synthetic datasets
# ============================================================

def dataset_path(beta_norm: float, l: int) -> str:
    return os.path.join(DATA_DIR, FNAME_FMT.format(beta=beta_norm, l=l))


def load_synthetic_dataset(beta_norm: float, l: int):
    path = dataset_path(beta_norm, l)
    if not os.path.isfile(path):
        raise FileNotFoundError(f"Dataset not found: {path}")
    data = torch.load(path, map_location="cpu")

    X = data["X"].float()  # (N, 1, H, W)
    t = data["T"].float().view(-1, 1)  # (N, 1)
    y = data["Y"].float().view(-1, 1)  # (N, 1)

    if "tau" in data:
        tau = data["tau"].float().view(-1, 1)
    else:
        mu0 = data["mu0"].float().view(-1, 1)
        mu1 = data["mu1"].float().view(-1, 1)
        tau = mu1 - mu0

    return X, t, y, tau


def train_val_test_split(
    X: torch.Tensor,
    t: torch.Tensor,
    y: torch.Tensor,
    tau: torch.Tensor,
    train_frac: float = 0.6,
    val_frac: float = 0.2,
    seed: int = 1234,
):
    """
    Simple random split into train/val/test with reproducible seed.
    """
    N = X.shape[0]
    assert N == t.shape[0] == y.shape[0] == tau.shape[0]

    g = torch.Generator().manual_seed(seed)
    perm = torch.randperm(N, generator=g)

    n_train = int(train_frac * N)
    n_val = int(val_frac * N)
    # n_test = N - n_train - n_val

    idx_train = perm[:n_train]
    idx_val = perm[n_train:n_train + n_val]
    idx_test = perm[n_train + n_val:]

    data_train = (X[idx_train], t[idx_train], y[idx_train], tau[idx_train])
    data_val   = (X[idx_val],   t[idx_val],   y[idx_val],   tau[idx_val])
    data_test  = (X[idx_test],  t[idx_test],  y[idx_test],  tau[idx_test])

    return data_train, data_val, data_test


# ============================================================
# Dataset object
# ============================================================

class CausalDataset(Dataset):
    def __init__(
        self,
        X: torch.Tensor,
        t: torch.Tensor,
        y: torch.Tensor,
        tau: torch.Tensor,
        e_hat: Optional[torch.Tensor] = None,
    ):
        self.X = X
        self.t = t
        self.y = y
        self.tau = tau
        self.e_hat = e_hat

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        x = self.X[idx]
        t = self.t[idx]
        y = self.y[idx]
        tau = self.tau[idx]
        if self.e_hat is None:
            e = torch.tensor([0.0], dtype=torch.float32)  # dummy
        else:
            e = self.e_hat[idx]
        return x, t, y, tau, e


# ============================================================
# Models: representation, TARNet, propensity model
# ============================================================

class ConvRepresentation(nn.Module):
    def __init__(self, rep_dim: int = 64):
        super().__init__()
        # === CHANGE 1: SIMPLER CNN ===
        # Reduced filters from [16, 32, 64] to [8, 16, 32] to reduce capacity
        self.features = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(8, 16, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1)),
        )
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(32, rep_dim), # Input matches last conv channel count
            nn.ReLU(),
        )

    def forward(self, x):
        h = self.features(x)
        h = self.fc(h)
        return h  # (N, rep_dim)


class TARNet(nn.Module):
    def __init__(self, rep_dim: int = 64, use_propensity: bool = False):
        super().__init__()
        self.use_propensity = use_propensity
        self.rep = ConvRepresentation(rep_dim=rep_dim)

        in_dim_heads = rep_dim + 1 if use_propensity else rep_dim

        # === CHANGE 2: ADD DROPOUT ===
        # Added Dropout(0.3) to prevent overfitting in the heads
        self.head0 = nn.Sequential(
            nn.Linear(in_dim_heads, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 1),
        )
        self.head1 = nn.Sequential(
            nn.Linear(in_dim_heads, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 1),
        )

    def forward(self, x, e_hat=None):
        phi_x = self.rep(x)
        if self.use_propensity:
            if e_hat is None:
                raise ValueError("e_hat must be provided when use_propensity=True")
            inp = torch.cat([phi_x, e_hat], dim=1)
        else:
            inp = phi_x
        y0 = self.head0(inp)
        y1 = self.head1(inp)
        return y0, y1, phi_x


class PropensityMLP(nn.Module):
    """
    Simple propensity model using flattened images -> 2-layer MLP.
    """
    def __init__(self, input_dim: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
        )

    def forward(self, x):
        flat = x.view(x.size(0), -1)
        logits = self.net(flat)
        return logits


# ============================================================
# Losses and metrics
# ============================================================

def pdist(sample_1, sample_2, norm=2, eps=1e-5):
    """Compute pairwise euclidean distance matrix"""
    n_1, n_2 = sample_1.size(0), sample_2.size(0)
    norm = float(norm)
    expansion_1 = sample_1.unsqueeze(1).expand(n_1, n_2, -1)
    expansion_2 = sample_2.unsqueeze(0).expand(n_1, n_2, -1)
    return torch.norm(expansion_1 - expansion_2, p=norm, dim=2)

def balancing_loss(phi_x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
    """
    RBF MMD Loss with Median Heuristic.
    """
    t = t.view(-1)
    treated = phi_x[t > 0.5]
    control = phi_x[t <= 0.5]

    if treated.size(0) == 0 or control.size(0) == 0:
        return torch.tensor(0.0, device=phi_x.device)

    # 1. Dynamic Sigma Calculation
    all_samples = torch.cat([treated, control], dim=0)
    dists = pdist(all_samples, all_samples)

    n = dists.shape[0]
    mask = ~torch.eye(n, dtype=torch.bool, device=phi_x.device)
    median_dist = dists[mask].median()

    sigma = median_dist.detach()
    if sigma == 0: sigma = torch.tensor(1.0, device=phi_x.device)

    # 2. Compute MMD
    D_tt = pdist(treated, treated)
    D_tc = pdist(treated, control)
    D_cc = pdist(control, control)

    K_tt = torch.exp(-D_tt**2 / (2 * sigma**2)).mean()
    K_tc = torch.exp(-D_tc**2 / (2 * sigma**2)).mean()
    K_cc = torch.exp(-D_cc**2 / (2 * sigma**2)).mean()

    return K_tt - 2 * K_tc + K_cc



def compute_pehe(model: TARNet, dataloader: DataLoader, use_propensity: bool) -> float:
    """
    Compute PEHE on a given dataloader.
    """
    model.eval()
    sq_errors = []

    with torch.no_grad():
        for x, t, y, tau, e_hat in dataloader:
            x = x.to(device)
            tau = tau.to(device)
            e_hat = e_hat.to(device)

            if use_propensity:
                y0_hat, y1_hat, _ = model(x, e_hat=e_hat)
            else:
                y0_hat, y1_hat, _ = model(x, e_hat=None)

            tau_hat = (y1_hat - y0_hat)
            sq_errors.append((tau_hat - tau) ** 2)

    if len(sq_errors) == 0:
        return float("nan")

    sq_errors = torch.cat(sq_errors, dim=0)
    pehe = torch.sqrt(sq_errors.mean()).item()
    return pehe


def compute_mse_factual(model: TARNet, dataloader: DataLoader, use_propensity: bool) -> float:
    model.eval()
    mse = nn.MSELoss(reduction="sum")
    total_loss = 0.0
    total_n = 0

    with torch.no_grad():
        for x, t, y, tau, e_hat in dataloader:
            x = x.to(device)
            t = t.to(device)
            y = y.to(device)
            e_hat = e_hat.to(device)

            if use_propensity:
                y0_hat, y1_hat, _ = model(x, e_hat=e_hat)
            else:
                y0_hat, y1_hat, _ = model(x, e_hat=None)

            y_hat = torch.where(t > 0.5, y1_hat, y0_hat)
            loss = mse(y_hat, y)
            total_loss += loss.item()
            total_n += y.shape[0]

    if total_n == 0:
        return float("nan")
    return total_loss / total_n


# ============================================================
# Training helpers
# ============================================================

def get_or_train_propensity_model(
    beta_norm: float,
    l: int,
    train_dataset: CausalDataset,
    val_dataset: CausalDataset,
    img_size: Tuple[int, int],
) -> PropensityMLP:
    """
    Train or load a propensity model.
    UPDATED: Uses Weight Decay, Lower LR, and Best Model Checkpointing.
    """
    H, W = img_size
    input_dim = 1 * H * W
    prop_path = os.path.join(PROP_DIR, f"propensity_beta{beta_norm}_l{l}.pt")

    prop_model = PropensityMLP(input_dim=input_dim).to(device)

    # if os.path.isfile(prop_path):
    #     print(f"[Propensity] Loading existing propensity model from {prop_path}")
    #     state = torch.load(prop_path, map_location=device)
    #     prop_model.load_state_dict(state)
    #     return prop_model

    print(f"[Propensity] Training new propensity model for beta={beta_norm}, l={l}")

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

    # === CHANGE 1 & 2: Lower LR and Add Weight Decay ===
    # Reduced LR to 1e-4 for stability
    # Added weight_decay=1e-4 for regularization
    optimizer = torch.optim.Adam(prop_model.parameters(), lr=1e-4, weight_decay=1e-4)
    bce = nn.BCEWithLogitsLoss()

    # === CHANGE 3: Track Best Model ===
    best_val_loss = float('inf')
    best_state = None

    for epoch in range(N_EPOCHS_PROP):
        prop_model.train()
        total_loss = 0.0
        n = 0

        for x, t, y, tau, e_hat in train_loader:
            x = x.to(device)
            t = t.to(device)

            logits = prop_model(x)
            loss = bce(logits, t)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * x.size(0)
            n += x.size(0)

        avg_train_loss = total_loss / n if n > 0 else 0.0

        # Validation Loop
        prop_model.eval()
        total_val_loss = 0.0
        n_val = 0
        with torch.no_grad():
            for x, t, y, tau, e_hat in val_loader:
                x = x.to(device)
                t = t.to(device)
                logits = prop_model(x)
                loss = bce(logits, t)
                total_val_loss += loss.item() * x.size(0)
                n_val += x.size(0)

        avg_val_loss = total_val_loss / n_val if n_val > 0 else 0.0

        # Save Best State
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            best_state = prop_model.state_dict()

        if (epoch + 1) % 10 == 0:
            print(
                f"[Propensity beta={beta_norm}, l={l}] "
                f"Epoch {epoch+1}/{N_EPOCHS_PROP}, "
                f"train_loss={avg_train_loss:.4f}, val_loss={avg_val_loss:.4f} "
                f"(Best: {best_val_loss:.4f})"
            )

    # === CRITICAL: Restore Best State ===
    if best_state is not None:
        prop_model.load_state_dict(best_state)
        print(f"[Propensity] Restored best model with val_loss={best_val_loss:.4f}")

    torch.save(prop_model.state_dict(), prop_path)
    print(f"[Propensity] Saved propensity model to {prop_path}")
    return prop_model

def train_tarnet_single_model(
    beta_norm: float,
    l: int,
    model_type: str,
    run_id: int = 0,         # Added for Multi-Seed Support
    lr: float = 1e-3,        # Added to allow custom LR per model if needed
    seed_split: int = 2025,
) -> None:
    """
    Train a single TARNet model with robust checkpointing and regularization.
    """
    assert model_type in ["baseline", "baseline_bal", "propensity", "propensity_bal"]

    # Unique filename for this specific run (includes run_id)
    model_filename = f"tarnet_{model_type}_beta{beta_norm}_l{l}.pt"
    model_path = os.path.join(MODEL_DIR, model_filename)

    metrics_filename = f"metrics_{model_type}_beta{beta_norm}_l{l}.json"
    metrics_path = os.path.join(RESULT_DIR, metrics_filename)

    # if os.path.isfile(model_path) and os.path.isfile(metrics_path):
    #     print(f"[SKIP] Run {run_id} already exists for {model_type}, beta={beta_norm}, l={l}")
    #     return

    print(f"\n[TRAIN] beta={beta_norm}, l={l}, type={model_type}, run={run_id}")

    # 1. Load Data
    X, t, y, tau = load_synthetic_dataset(beta_norm, l)
    N, C, H, W = X.shape

    # 2. Deterministic Split (Unique per run_id)
    # Using run_id in the seed ensures different splits for different runs
    current_seed = seed_split + int(beta_norm * 100) + l + (run_id * 1000)

    (X_tr, t_tr, y_tr, tau_tr), \
    (X_val, t_val, y_val, tau_val), \
    (X_te, t_te, y_te, tau_te) = train_val_test_split(
        X, t, y, tau, seed=current_seed
    )

    # 3. Handle Propensity
    use_propensity = model_type.startswith("propensity")
    if use_propensity:
        tmp_train_ds = CausalDataset(X_tr, t_tr, y_tr, tau_tr, e_hat=None)
        tmp_val_ds = CausalDataset(X_val, t_val, y_val, tau_val, e_hat=None)

        prop_model = get_or_train_propensity_model(
            beta_norm=beta_norm,
            l=l,
            train_dataset=tmp_train_ds,
            val_dataset=tmp_val_ds,
            img_size=(H, W),
        ).to(device)

        prop_model.eval()
        with torch.no_grad():
            e_tr = torch.sigmoid(prop_model(X_tr.to(device))).cpu()
            e_val = torch.sigmoid(prop_model(X_val.to(device))).cpu()
            e_te = torch.sigmoid(prop_model(X_te.to(device))).cpu()

            # --- FIX: Clip Propensity Scores ---
            # Prevents instability when score is exactly 0.0 or 1.0
            e_tr = torch.clamp(e_tr, 0.05, 0.95)
            e_val = torch.clamp(e_val, 0.05, 0.95)
            e_te = torch.clamp(e_te, 0.05, 0.95)
    else:
        e_tr = e_val = e_te = None

    # 4. Data Loaders
    train_ds = CausalDataset(X_tr, t_tr, y_tr, tau_tr, e_hat=e_tr)
    val_ds   = CausalDataset(X_val, t_val, y_val, tau_val, e_hat=e_val)
    test_ds  = CausalDataset(X_te, t_te, y_te, tau_te, e_hat=e_te)

    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
    val_loader   = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)
    test_loader  = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False)

    # 5. Model Setup
    model = TARNet(rep_dim=64, use_propensity=use_propensity).to(device)

    # --- FIX: Add Weight Decay (L2 Regularization) ---
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
    mse = nn.MSELoss()

    use_balancing = model_type.endswith("_bal")

    # 6. Training Loop with Burn-in & Early Stopping
    best_val_pehe = float("inf")
    best_state = None
    triggers = 0

    # Burn-in: Ignore validation metrics for first 20 epochs to avoid "lucky" initialization checkpoints
    BURN_IN_EPOCHS = 0

    for epoch in range(N_EPOCHS_TARNET):
        model.train()
        total_train_loss = 0.0
        n_train = 0

        for x, t_batch, y_batch, tau_batch, e_hat_batch in train_loader:
            x = x.to(device)
            t_batch = t_batch.to(device)
            y_batch = y_batch.to(device)
            e_hat_batch = e_hat_batch.to(device)

            if use_propensity:
                y0_hat, y1_hat, phi_x = model(x, e_hat=e_hat_batch)
            else:
                y0_hat, y1_hat, phi_x = model(x, e_hat=None)

            y_pred = torch.where(t_batch > 0.5, y1_hat, y0_hat)
            factual_loss = mse(y_pred, y_batch)

            if use_balancing:
                bal_loss = balancing_loss(phi_x, t_batch)
                loss = factual_loss + LAMBDA_BAL * bal_loss
            else:
                loss = factual_loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_train_loss += loss.item() * x.size(0)
            n_train += x.size(0)

        avg_train_loss = total_train_loss / n_train if n_train > 0 else 0.0

        # --- Validation & Checkpointing ---
        val_pehe = compute_pehe(model, val_loader, use_propensity=use_propensity)
        val_mse_f = compute_mse_factual(model, val_loader, use_propensity=use_propensity)

        # Only check for best model AFTER burn-in period
        if epoch >= BURN_IN_EPOCHS:
            if val_pehe < best_val_pehe:
                best_val_pehe = val_pehe
                best_state = model.state_dict()
                triggers = 0
            else:
                triggers += 1
        else:
            # During burn-in, reset triggers (don't early stop)
            triggers = 0

        # Logging
        if (epoch + 1) % 50 == 0 or triggers >= PATIENCE:
            status = "(Burn-in)" if epoch < BURN_IN_EPOCHS else f"(Best: {best_val_pehe:.4f})"
            print(
                f"[{model_type} run={run_id}] Ep {epoch+1} | "
                f"Loss: {avg_train_loss:.4f} | MSE: {val_mse_f:.4f} | "
                f"PEHE: {val_pehe:.4f} {status}"
            )

        if triggers >= PATIENCE:
            print(f"Early stopping at epoch {epoch+1}")
            break

    # 7. Final Evaluation (Load Best State)
    if best_state is not None:
        model.load_state_dict(best_state)
    else:
        print("[WARNING] No best state found (did not pass burn-in?). Using last state.")

    train_pehe = compute_pehe(model, train_loader, use_propensity=use_propensity)
    val_pehe_final = compute_pehe(model, val_loader, use_propensity=use_propensity)
    test_pehe = compute_pehe(model, test_loader, use_propensity=use_propensity)

    metrics = {
        "beta_norm": beta_norm,
        "l": l,
        "model_type": model_type,
        "run_id": run_id,
        "train_PEHE": train_pehe,
        "val_PEHE": val_pehe_final,
        "best_val_PEHE": best_val_pehe, # Explicitly save the tracked minimum
        "test_PEHE": test_pehe,
        "epochs": N_EPOCHS_TARNET,
        "lambda_bal": LAMBDA_BAL,
        "lr": lr
    }

    torch.save(model.state_dict(), model_path)
    with open(metrics_path, "w") as f:
        json.dump(metrics, f, indent=2)

    print(f"[SAVE] Run {run_id} complete. Test PEHE: {test_pehe:.4f}")


# ============================================================
# Main loops
# ============================================================

if __name__ == "__main__":
    beta_list = [0.5, 1.0, 2.0, 4.0]
    l_list = [0, 1, 2, 4, 8]
    model_types = ["baseline", "baseline_bal", "propensity"]  #"propensity_bal"]

    '''
    for l in l_list:
        train_tarnet_single_model(beta_norm=4.0, l=l, model_type="baseline")
        train_tarnet_single_model(beta_norm=4.0, l=l, model_type="baseline_bal")
        train_tarnet_single_model(beta_norm=4.0, l=l, model_type="propensity")
    '''


    for beta_norm in beta_list:
        for l in l_list:
            for mtype in model_types:
                train_tarnet_single_model(beta_norm=beta_norm, l=l, model_type=mtype)


Using device: cuda

[TRAIN] beta=0.5, l=0, type=baseline, run=0
[baseline run=0] Ep 50 | Loss: 0.0124 | MSE: 0.0101 | PEHE: 0.0220 (Best: 0.0217)
[baseline run=0] Ep 100 | Loss: 0.0115 | MSE: 0.0102 | PEHE: 0.0307 (Best: 0.0215)
[baseline run=0] Ep 150 | Loss: 0.0111 | MSE: 0.0101 | PEHE: 0.0215 (Best: 0.0214)
[baseline run=0] Ep 168 | Loss: 0.0110 | MSE: 0.0100 | PEHE: 0.0216 (Best: 0.0214)
Early stopping at epoch 168
[SAVE] Run 0 complete. Test PEHE: 0.0228

[TRAIN] beta=0.5, l=0, type=baseline_bal, run=0
[baseline_bal run=0] Ep 50 | Loss: 0.1298 | MSE: 0.0101 | PEHE: 0.0270 (Best: 0.0215)
[baseline_bal run=0] Ep 75 | Loss: 0.1605 | MSE: 0.0101 | PEHE: 0.0311 (Best: 0.0215)
Early stopping at epoch 75
[SAVE] Run 0 complete. Test PEHE: 0.0324

[TRAIN] beta=0.5, l=0, type=propensity, run=0
[Propensity] Training new propensity model for beta=0.5, l=0
[Propensity beta=0.5, l=0] Epoch 10/50, train_loss=0.4606, val_loss=0.4443 (Best: 0.4439)
[Propensity beta=0.5, l=0] Epoch 20/50, train_los