In [1]:
from google.colab import drive

drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import os
import json
from typing import Tuple, Optional, Dict, List

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# ============================================================
# Paths and basic configuration
# ============================================================

ROOT = "/content/drive/MyDrive/ECE 685/Project"

USE_SEMI_SYNTH = False

if USE_SEMI_SYNTH:
    DATA_DIR   = os.path.join(ROOT, "experiments_semi_synthetic", "data")
    MODEL_DIR  = os.path.join(ROOT, "experiments_semi_synthetic", "models")
    RESULT_DIR = os.path.join(ROOT, "experiments_semi_synthetic", "results")
    FNAME_FMT  = "semi_beta{beta}_l{l}.pt"
else:
    DATA_DIR   = os.path.join(ROOT, "experiments_synthetic", "data")
    MODEL_DIR  = os.path.join(ROOT, "experiments_synthetic", "models")
    RESULT_DIR = os.path.join(ROOT, "experiments_synthetic", "results")
    FNAME_FMT  = "synthetic_beta{beta}_l{l}.pt"

os.makedirs(DATA_DIR, exist_ok=True)
os.makedirs(MODEL_DIR, exist_ok=True)
os.makedirs(RESULT_DIR, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Training hyperparameters
BATCH_SIZE = 256
LR_DRAGON = 1e-3
N_EPOCHS = 1000
PATIENCE = 50
WEIGHT_DECAY = 1e-4
TARGET_REG_ALPHA = 1.0  # Coefficient for targeted regularization loss (t-reg)

# ============================================================
# Utilities: loading synthetic datasets
# ============================================================

def dataset_path(beta_norm: float, l: int) -> str:
    return os.path.join(DATA_DIR, FNAME_FMT.format(beta=beta_norm, l=l))

def load_synthetic_dataset(beta_norm: float, l: int):
    path = dataset_path(beta_norm, l)
    if not os.path.isfile(path):
        raise FileNotFoundError(f"Dataset not found: {path}")
    data = torch.load(path, map_location="cpu")

    X = data["X"].float()  # (N, 1, H, W)
    t = data["T"].float().view(-1, 1)  # (N, 1)
    y = data["Y"].float().view(-1, 1)  # (N, 1)

    if "tau" in data:
        tau = data["tau"].float().view(-1, 1)
    else:
        mu0 = data["mu0"].float().view(-1, 1)
        mu1 = data["mu1"].float().view(-1, 1)
        tau = mu1 - mu0

    return X, t, y, tau

def train_val_test_split(
    X: torch.Tensor,
    t: torch.Tensor,
    y: torch.Tensor,
    tau: torch.Tensor,
    train_frac: float = 0.6,
    val_frac: float = 0.2,
    seed: int = 1234,
):
    N = X.shape[0]
    g = torch.Generator().manual_seed(seed)
    perm = torch.randperm(N, generator=g)

    n_train = int(train_frac * N)
    n_val = int(val_frac * N)

    idx_train = perm[:n_train]
    idx_val = perm[n_train:n_train + n_val]
    idx_test = perm[n_train + n_val:]

    data_train = (X[idx_train], t[idx_train], y[idx_train], tau[idx_train])
    data_val   = (X[idx_val],   t[idx_val],   y[idx_val],   tau[idx_val])
    data_test  = (X[idx_test],  t[idx_test],  y[idx_test],  tau[idx_test])

    return data_train, data_val, data_test

# ============================================================
# Dataset object
# ============================================================

class CausalDataset(Dataset):
    def __init__(self, X, t, y, tau):
        self.X = X
        self.t = t
        self.y = y
        self.tau = tau

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        return self.X[idx], self.t[idx], self.y[idx], self.tau[idx]

# ============================================================
# Dragonnet Model
# ============================================================

class ConvRepresentation(nn.Module):
    def __init__(self, rep_dim: int = 64):
        super().__init__()
        # Matches your TARNet encoder capacity
        self.features = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(8, 16, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1)),
        )
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(32, rep_dim),
            nn.ReLU(),
        )

    def forward(self, x):
        h = self.features(x)
        h = self.fc(h)
        return h

class Dragonnet(nn.Module):
    """
    Dragonnet: A 3-headed architecture.
    - Shared Representation Z
    - Propensity Head g(Z) -> predicts T
    - Outcome Head Q0(Z) -> predicts Y | T=0
    - Outcome Head Q1(Z) -> predicts Y | T=1
    """
    def __init__(self, rep_dim: int = 64):
        super().__init__()
        self.rep = ConvRepresentation(rep_dim=rep_dim)

        # Propensity Head (g)
        self.propensity_head = nn.Sequential(
            nn.Linear(rep_dim, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 1)
        )

        # Outcome Head 0 (Q0)
        self.head0 = nn.Sequential(
            nn.Linear(rep_dim, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 1)
        )

        # Outcome Head 1 (Q1)
        self.head1 = nn.Sequential(
            nn.Linear(rep_dim, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 1)
        )

    def forward(self, x):
        z = self.rep(x)

        # Propensity logits
        g_logits = self.propensity_head(z)

        # Outcome predictions (Hypothesis)
        y0 = self.head0(z)
        y1 = self.head1(z)

        return y0, y1, g_logits

# ============================================================
# Losses and Metrics
# ============================================================

def dragonnet_loss(y0, y1, g_logits, t, y):
    """
    Loss = L_outcome + L_propensity
    L_outcome = sum((y - y_pred)^2)
    L_propensity = Binary Cross Entropy
    """
    # 1. Outcome Loss (Standard MSE on factuals)
    y_pred = torch.where(t > 0.5, y1, y0)
    loss_y = F.mse_loss(y_pred, y)

    # 2. Propensity Loss
    loss_g = F.binary_cross_entropy_with_logits(g_logits, t)

    return loss_y, loss_g

def targeted_regularization(y0, y1, g_logits, t, y):
    """
    Targeted Regularization (t-reg) specifically for Dragonnet.
    Based on Shi et al. (2019).
    """
    g = torch.sigmoid(g_logits)
    # Epsilon calculation (perturbation parameter)
    # epsilon = (y - Q(t,x)) / (t - g(x)) ??? No, standard implementation uses specific closed form or update

    # Standard implementation often does this step separately or adds a specific term.
    # Here we implement the basic loss structure from the paper:
    # L = L_y + alpha * L_g + beta * L_treg

    # For this simplified implementation, we stick to the primary Dragonnet loss:
    # L = MSE(y) + BCE(t)
    # The 'targeted' part is often an update step, but basic Dragonnet works with just the multi-task loss.
    # We will return 0 here unless full TMLE logic is needed.
    return torch.tensor(0.0).to(y.device)

def compute_pehe(model: Dragonnet, dataloader: DataLoader) -> float:
    model.eval()
    sq_errors = []
    with torch.no_grad():
        for x, t, y, tau in dataloader:
            x = x.to(device)
            tau = tau.to(device)
            y0, y1, _ = model(x)
            tau_hat = y1 - y0
            sq_errors.append((tau_hat - tau) ** 2)

    if len(sq_errors) == 0: return float("nan")
    sq_errors = torch.cat(sq_errors, dim=0)
    return torch.sqrt(sq_errors.mean()).item()

# ============================================================
# Training Function
# ============================================================

def train_dragonnet_single_model(
    beta_norm: float,
    l: int,
    run_id: int = 0,
    lr: float = 1e-3,
    seed_split: int = 2025,
) -> None:

    model_filename = f"dragonnet_beta{beta_norm}_l{l}.pt"
    model_path = os.path.join(MODEL_DIR, model_filename)
    metrics_filename = f"metrics_dragonnet_beta{beta_norm}_l{l}.json"
    metrics_path = os.path.join(RESULT_DIR, metrics_filename)

    if os.path.isfile(model_path) and os.path.isfile(metrics_path):
        print(f"[SKIP] Dragonnet run {run_id} already exists for beta={beta_norm}, l={l}")
        return

    print(f"\n[TRAIN DRAGONNET] beta={beta_norm}, l={l}, run={run_id}")

    # 1. Load Data
    X, t, y, tau = load_synthetic_dataset(beta_norm, l)

    # 2. Split
    current_seed = seed_split + int(beta_norm * 100) + l + (run_id * 1000)
    (X_tr, t_tr, y_tr, tau_tr), (X_val, t_val, y_val, tau_val), (X_te, t_te, y_te, tau_te) = \
        train_val_test_split(X, t, y, tau, seed=current_seed)

    train_loader = DataLoader(CausalDataset(X_tr, t_tr, y_tr, tau_tr), batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(CausalDataset(X_val, t_val, y_val, tau_val), batch_size=BATCH_SIZE, shuffle=False)
    test_loader = DataLoader(CausalDataset(X_te, t_te, y_te, tau_te), batch_size=BATCH_SIZE, shuffle=False)

    # 3. Model & Optimizer
    model = Dragonnet(rep_dim=64).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=WEIGHT_DECAY)

    # 4. Training Loop
    best_val_pehe = float("inf")
    best_state = None
    triggers = 0
    BURN_IN = 1

    for epoch in range(N_EPOCHS):
        model.train()
        total_loss = 0.0
        total_loss_y = 0.0
        total_loss_g = 0.0
        n_train = 0

        for x, t, y, tau in train_loader:
            x, t, y = x.to(device), t.to(device), y.to(device)

            y0, y1, g_logits = model(x)

            loss_y, loss_g = dragonnet_loss(y0, y1, g_logits, t, y)

            # Dragonnet Objective: Minimize both outcome error and propensity error
            # This forces the representation Z to be predictive of BOTH Y and T.
            # (Standard alpha=1.0)
            loss = loss_y + 1.0 * loss_g

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * x.size(0)
            total_loss_y += loss_y.item() * x.size(0)
            total_loss_g += loss_g.item() * x.size(0)
            n_train += x.size(0)

        # Validation
        val_pehe = compute_pehe(model, val_loader)

        # Checkpointing
        if epoch >= BURN_IN:
            if val_pehe < best_val_pehe:
                best_val_pehe = val_pehe
                best_state = model.state_dict()
                triggers = 0
            else:
                triggers += 1
        else:
            triggers = 0

        if (epoch + 1) % 50 == 0 or triggers >= PATIENCE:
            status = "(Burn-in)" if epoch < BURN_IN else f"(Best: {best_val_pehe:.4f})"
            print(f"[Dragonnet] Ep {epoch+1} | Loss: {total_loss/n_train:.4f} (Y:{total_loss_y/n_train:.2f}, G:{total_loss_g/n_train:.2f}) | PEHE: {val_pehe:.4f} {status}")

        if triggers >= PATIENCE:
            print(f"Early stopping at epoch {epoch+1}")
            break

    # 5. Save Results
    if best_state is not None:
        model.load_state_dict(best_state)

    test_pehe = compute_pehe(model, test_loader)

    metrics = {
        "beta_norm": beta_norm,
        "l": l,
        "model_type": "dragonnet",
        "run_id": run_id,
        "val_PEHE": best_val_pehe, # Save best valid
        "test_PEHE": test_pehe,
        "epochs": N_EPOCHS,
    }

    torch.save(model.state_dict(), model_path)
    with open(metrics_path, "w") as f:
        json.dump(metrics, f, indent=2)

    print(f"[SAVE] Run {run_id} complete. Test PEHE: {test_pehe:.4f}")

# ============================================================
# Main Execution
# ============================================================

if __name__ == "__main__":
    beta_list = [0.5, 1.0, 2.0, 4.0]
    l_list = [0, 1, 2, 4, 8]
    N_RUNS = 1

    for beta in beta_list:
        for l in l_list:
            train_dragonnet_single_model(beta, l, run_id=0)

Using device: cuda

[TRAIN DRAGONNET] beta=0.5, l=0, run=0
[Dragonnet] Ep 50 | Loss: 0.7072 (Y:0.02, G:0.69) | PEHE: 0.0091 (Best: 0.0047)
[Dragonnet] Ep 100 | Loss: 0.7049 (Y:0.01, G:0.69) | PEHE: 0.0380 (Best: 0.0039)
[Dragonnet] Ep 150 | Loss: 0.7032 (Y:0.01, G:0.69) | PEHE: 0.0038 (Best: 0.0038)
[Dragonnet] Ep 200 | Loss: 0.7017 (Y:0.01, G:0.69) | PEHE: 0.0038 (Best: 0.0038)
Early stopping at epoch 200
[SAVE] Run 0 complete. Test PEHE: 0.0037

[TRAIN DRAGONNET] beta=0.5, l=1, run=0
[Dragonnet] Ep 50 | Loss: 1.6678 (Y:0.98, G:0.69) | PEHE: 0.4432 (Best: 0.1830)
[Dragonnet] Ep 52 | Loss: 1.6671 (Y:0.98, G:0.69) | PEHE: 0.3750 (Best: 0.1830)
Early stopping at epoch 52
[SAVE] Run 0 complete. Test PEHE: 0.3749

[TRAIN DRAGONNET] beta=0.5, l=2, run=0
[Dragonnet] Ep 50 | Loss: 2.4979 (Y:1.81, G:0.69) | PEHE: 0.7944 (Best: 0.5290)
[Dragonnet] Ep 81 | Loss: 2.5002 (Y:1.81, G:0.69) | PEHE: 0.7578 (Best: 0.5290)
Early stopping at epoch 81
[SAVE] Run 0 complete. Test PEHE: 0.7580

[TRAIN DRAGO