#Exp. 1:
Political Campaigning - Estimate how regional campaign spending affects subregional election outcomes when effectiveness depends on subregional wealth levels. Only aggregated regional outcomes are observed, and the goal is to recover the heterogeneous local causal effects.

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn

In [None]:
data_dir = 'data_exp1'
num_regions = 100
subregions_per_region = 100
SPATIAL_VARIANCE = 5
noise_variance = 0.01
np.random.seed(42)

os.makedirs(data_dir, exist_ok=True)

# 1) Intervention matrix: 100x100, each row all 0s or all 1s with 50% probability
interventions_region_flags = np.random.choice([0, 1], size=num_regions, p=[0.5, 0.5])
interventions_sub = np.repeat(interventions_region_flags[:, np.newaxis], subregions_per_region, axis=1)
pd.DataFrame(interventions_sub).to_csv(os.path.join(data_dir, 'interventions_subregion.csv'), index=False, header=False)

# Region-level interventions: average per row (which is 0 or 1 since all values same)
interventions_reg = np.mean(interventions_sub, axis=1)
pd.DataFrame(interventions_reg).to_csv(os.path.join(data_dir, 'interventions_region.csv'), index=False, header=False)

# 2) Context matrix: 100x100, wealth levels 1,2,3
context_sub = np.zeros((num_regions, subregions_per_region), dtype=int)
for i in range(num_regions):
    if i == 0:
        # First row: all 2 (middle class)
        context_sub[i, :] = 2
    elif i == 1:
        # Second row: exactly half 1 (poor) and half 3 (rich)
        half = subregions_per_region // 2
        arr = np.concatenate([np.ones(half, dtype=int), np.full(subregions_per_region - half, 3, dtype=int)])
        np.random.shuffle(arr)
        context_sub[i, :] = arr
    else:
        # For other rows: choose x3 middle class between 20 and 100
        x3 = np.random.randint(20, 101)
        remaining = subregions_per_region - x3
        # Split remaining roughly 50:50 into poor (1) and rich (3), with wiggle room
        deviation = np.random.randint(-SPATIAL_VARIANCE, SPATIAL_VARIANCE + 1)
        num_poor = max(0, min(remaining, remaining // 2 + deviation))
        num_rich = remaining - num_poor
        # Create and shuffle array
        arr = np.concatenate([np.ones(num_poor, dtype=int), np.full(x3, 2, dtype=int), np.full(num_rich, 3, dtype=int)])
        np.random.shuffle(arr)
        context_sub[i, :] = arr
pd.DataFrame(context_sub).to_csv(os.path.join(data_dir, 'context_subregion.csv'), index=False, header=False)

# Region-level context: mean per row
context_reg = np.mean(context_sub, axis=1)
pd.DataFrame(context_reg).to_csv(os.path.join(data_dir, 'context_region.csv'), index=False, header=False)

# 3) Noise matrix: 100x100, iid normal with given variance
noise_sub = np.random.normal(0, np.sqrt(noise_variance), size=(num_regions, subregions_per_region))
pd.DataFrame(noise_sub).to_csv(os.path.join(data_dir, 'noise_subregion.csv'), index=False, header=False)

# 4) Outcome matrix: based on intervention, context, and noise
delta = np.zeros((num_regions, subregions_per_region), dtype=float)
delta[context_sub == 1] = -0.1
delta[context_sub == 2] = 0.0
delta[context_sub == 3] = 0.3
# Outcome = noise + intervention * delta
outcome_sub = noise_sub + interventions_sub * delta
pd.DataFrame(outcome_sub).to_csv(os.path.join(data_dir, 'outcome_subregion.csv'), index=False, header=False)

# Region-level outcome: mean per row
outcome_reg = np.mean(outcome_sub, axis=1)
pd.DataFrame(outcome_reg).to_csv(os.path.join(data_dir, 'outcome_region.csv'), index=False, header=False)

# 5) Causal effect matrix
causal_effect_sub = delta
pd.DataFrame(causal_effect_sub).to_csv(os.path.join(data_dir, 'causal_effect_subregion.csv'), index=False, header=False)

In [None]:
data_dir = 'data_exp1'
region_grid_size = 10
sub_grid_size = 100
sub_per_region_side = 10

colors = sns.color_palette("deep")
red_cmap = sns.light_palette(colors[3], as_cmap=True)  # Red for interventions
blue_cmap = sns.light_palette(colors[0], as_cmap=True)  # Blue for outcomes
green_cmap = sns.light_palette(colors[2], as_cmap=True)  # Green for context
purple_cmap = sns.light_palette(colors[4], as_cmap=True)  # Purple for causal effects
orange_cmap = sns.light_palette(colors[1], as_cmap=True)  # Orange for noise

# Function to reshape subregion data (100x100) to 100x100 grid
def reshape_to_subgrid(data):
    grid = np.zeros((sub_grid_size, sub_grid_size))
    for ri in range(region_grid_size):
        for rj in range(region_grid_size):
            reg_idx = ri * region_grid_size + rj
            sub_data = data[reg_idx].reshape(sub_per_region_side, sub_per_region_side)
            grid[ri*sub_per_region_side:(ri+1)*sub_per_region_side, rj*sub_per_region_side:(rj+1)*sub_per_region_side] = sub_data
    return grid

# Function to reshape region data (100,) to 10x10 grid
def reshape_to_reggrid(data):
    return data.reshape(region_grid_size, region_grid_size)

# Independent plotting functions

def plot_interventions_subregion():
    file_path = os.path.join(data_dir, 'interventions_subregion.csv')
    data = pd.read_csv(file_path, header=None).values
    grid = reshape_to_subgrid(data)
    fig, ax = plt.subplots(figsize=(10, 10))
    sns.heatmap(grid, ax=ax, cmap=red_cmap, square=True, linewidths=0, cbar=True, cbar_kws={'location': 'right'})
    ax.set_title('Interventions Subregion', fontsize=20)
    ax.set_xticks([])
    ax.set_yticks([])
    cbar = ax.collections[0].colorbar
    cbar.ax.tick_params(labelsize=14)
    fig.savefig(os.path.join(data_dir, 'interventions_subregion.jpg'), dpi=300, bbox_inches='tight')
    fig.savefig(os.path.join(data_dir, 'interventions_subregion.pdf'), bbox_inches='tight')
    plt.show()

def plot_interventions_region():
    file_path = os.path.join(data_dir, 'interventions_region.csv')
    data = pd.read_csv(file_path, header=None).values.squeeze()
    grid = reshape_to_reggrid(data)
    fig, ax = plt.subplots(figsize=(5, 5))
    sns.heatmap(grid, ax=ax, cmap=red_cmap, square=True, linewidths=1, cbar=True, cbar_kws={'location': 'right'})
    ax.set_title('Interventions Region', fontsize=20)
    ax.set_xticks([])
    ax.set_yticks([])
    cbar = ax.collections[0].colorbar
    cbar.ax.tick_params(labelsize=14)
    fig.savefig(os.path.join(data_dir, 'interventions_region.jpg'), dpi=300, bbox_inches='tight')
    fig.savefig(os.path.join(data_dir, 'interventions_region.pdf'), bbox_inches='tight')
    plt.show()

def plot_context_subregion():
    file_path = os.path.join(data_dir, 'context_subregion.csv')
    data = pd.read_csv(file_path, header=None).values
    grid = reshape_to_subgrid(data)
    fig, ax = plt.subplots(figsize=(10, 10))
    sns.heatmap(grid, ax=ax, cmap=green_cmap, square=True, linewidths=0, cbar=True, cbar_kws={'location': 'right'}, vmin=1, vmax=3)
    ax.set_title('Context Subregion', fontsize=20)
    ax.set_xticks([])
    ax.set_yticks([])
    cbar = ax.collections[0].colorbar
    cbar.ax.tick_params(labelsize=14)
    fig.savefig(os.path.join(data_dir, 'context_subregion.jpg'), dpi=300, bbox_inches='tight')
    fig.savefig(os.path.join(data_dir, 'context_subregion.pdf'), bbox_inches='tight')
    plt.show()

def plot_context_region():
    file_path = os.path.join(data_dir, 'context_region.csv')
    data = pd.read_csv(file_path, header=None).values.squeeze()
    grid = reshape_to_reggrid(data)
    fig, ax = plt.subplots(figsize=(5, 5))
    sns.heatmap(grid, ax=ax, cmap=green_cmap, square=True, linewidths=1, cbar=True, cbar_kws={'location': 'right'}, vmin=1.8, vmax=2.2)
    ax.set_title('Context Region', fontsize=20)
    ax.set_xticks([])
    ax.set_yticks([])
    cbar = ax.collections[0].colorbar
    cbar.ax.tick_params(labelsize=14)
    fig.savefig(os.path.join(data_dir, 'context_region.jpg'), dpi=300, bbox_inches='tight')
    fig.savefig(os.path.join(data_dir, 'context_region.pdf'), bbox_inches='tight')
    plt.show()

def plot_noise_subregion():
    file_path = os.path.join(data_dir, 'noise_subregion.csv')
    data = pd.read_csv(file_path, header=None).values
    grid = reshape_to_subgrid(data)
    fig, ax = plt.subplots(figsize=(10, 10))
    sns.heatmap(grid, ax=ax, cmap=orange_cmap, square=True, linewidths=0, cbar=True, cbar_kws={'location': 'right'})
    ax.set_title('Noise Subregion', fontsize=20)
    ax.set_xticks([])
    ax.set_yticks([])
    cbar = ax.collections[0].colorbar
    cbar.ax.tick_params(labelsize=14)
    fig.savefig(os.path.join(data_dir, 'noise_subregion.jpg'), dpi=300, bbox_inches='tight')
    fig.savefig(os.path.join(data_dir, 'noise_subregion.pdf'), bbox_inches='tight')
    plt.show()

def plot_outcome_subregion():
    file_path = os.path.join(data_dir, 'outcome_subregion.csv')
    data = pd.read_csv(file_path, header=None).values
    grid = reshape_to_subgrid(data)
    fig, ax = plt.subplots(figsize=(10, 10))
    sns.heatmap(grid, ax=ax, cmap=blue_cmap, square=True, linewidths=0, cbar=True, cbar_kws={'location': 'right'}, vmin=-0.1, vmax=0.5)
    ax.set_title('Outcome Subregion', fontsize=20)
    ax.set_xticks([])
    ax.set_yticks([])
    cbar = ax.collections[0].colorbar
    cbar.ax.tick_params(labelsize=14)
    fig.savefig(os.path.join(data_dir, 'outcome_subregion.jpg'), dpi=300, bbox_inches='tight')
    fig.savefig(os.path.join(data_dir, 'outcome_subregion.pdf'), bbox_inches='tight')
    plt.show()

def plot_outcome_region():
    file_path = os.path.join(data_dir, 'outcome_region.csv')
    data = pd.read_csv(file_path, header=None).values.squeeze()
    grid = reshape_to_reggrid(data)
    fig, ax = plt.subplots(figsize=(5, 5))
    sns.heatmap(grid, ax=ax, cmap=blue_cmap, square=True, linewidths=1, cbar=True, cbar_kws={'location': 'right'}, vmin=-0.05, vmax=0.1)
    ax.set_title('Outcome Region', fontsize=20)
    ax.set_xticks([])
    ax.set_yticks([])
    cbar = ax.collections[0].colorbar
    cbar.ax.tick_params(labelsize=14)
    fig.savefig(os.path.join(data_dir, 'outcome_region.jpg'), dpi=300, bbox_inches='tight')
    fig.savefig(os.path.join(data_dir, 'outcome_region.pdf'), bbox_inches='tight')
    plt.show()

def plot_causal_effect_subregion():
    file_path = os.path.join(data_dir, 'causal_effect_subregion.csv')
    data = pd.read_csv(file_path, header=None).values
    grid = reshape_to_subgrid(data)
    fig, ax = plt.subplots(figsize=(10, 10))
    sns.heatmap(grid, ax=ax, cmap=purple_cmap, square=True, linewidths=0, cbar=True, cbar_kws={'location': 'right'}, vmin=-0.1, vmax=0.3)
    ax.set_title('Causal Effect Subregion', fontsize=20)
    ax.set_xticks([])
    ax.set_yticks([])
    cbar = ax.collections[0].colorbar
    cbar.ax.tick_params(labelsize=14)
    fig.savefig(os.path.join(data_dir, 'causal_effect_subregion.jpg'), dpi=300, bbox_inches='tight')
    fig.savefig(os.path.join(data_dir, 'causal_effect_subregion.pdf'), bbox_inches='tight')
    plt.show()

# Call all plotting functions to generate and display all plots
plot_interventions_subregion()
plot_interventions_region()
plot_context_subregion()
plot_context_region()
plot_noise_subregion()
plot_outcome_subregion()
plot_outcome_region()
plot_causal_effect_subregion()

In [None]:
# Define parameters
data_dir = 'data_exp1'
region_grid_size = 10
sub_grid_size = 100
sub_per_region_side = 10

# Color definitions
colors = sns.color_palette("deep")
red_cmap = sns.light_palette(colors[3], as_cmap=True)  # Red for interventions
blue_cmap = sns.light_palette(colors[0], as_cmap=True)  # Blue for outcomes
green_cmap = sns.light_palette(colors[2], as_cmap=True)  # Green for context

# Function to reshape region data (100,) to 10x10 grid
def reshape_to_reggrid(data):
    return data.reshape(region_grid_size, region_grid_size)

# Function to reshape subregion data (100x100) to 100x100 grid
def reshape_to_subgrid(data):
    grid = np.zeros((sub_grid_size, sub_grid_size))
    for ri in range(region_grid_size):
        for rj in range(region_grid_size):
            reg_idx = ri * region_grid_size + rj
            sub_data = data[reg_idx].reshape(sub_per_region_side, sub_per_region_side)
            grid[ri*sub_per_region_side:(ri+1)*sub_per_region_side, rj*sub_per_region_side:(rj+1)*sub_per_region_side] = sub_data
    return grid

# Load data
interv_data = pd.read_csv(os.path.join(data_dir, 'interventions_region.csv'), header=None).values.squeeze()
outcome_data = pd.read_csv(os.path.join(data_dir, 'outcome_region.csv'), header=None).values.squeeze()
context_data = pd.read_csv(os.path.join(data_dir, 'context_subregion.csv'), header=None).values
context_data_norm = (context_data - 1) / 2.0  # Normalize 1-3 to 0-1

# Parameter to control space between subplots
space_between = 0.05

# Create a single figure with three subplots side by side
fig, axs = plt.subplots(1, 3, figsize=(15, 5))

# Left: Interventions in red
sns.heatmap(reshape_to_reggrid(interv_data), ax=axs[0], cmap=red_cmap, cbar=False, square=True, linewidths=2, vmin=0.0, vmax=1.0)
axs[0].set_xticks([])
axs[0].set_yticks([])

# Center: Outcomes in blue
sns.heatmap(reshape_to_reggrid(outcome_data), ax=axs[1], cmap=blue_cmap, cbar=False, square=True, linewidths=2, vmin=-0.05, vmax=0.1)
axs[1].set_xticks([])
axs[1].set_yticks([])

# Right: Context in green
sns.heatmap(reshape_to_subgrid(context_data_norm), ax=axs[2], cmap=green_cmap, cbar=False, square=True, linewidths=0, vmin=0.0, vmax=1.0)
axs[2].set_xticks([])
axs[2].set_yticks([])

# Adjust the space between subplots
fig.subplots_adjust(wspace=space_between)

# Save and show
fig.savefig(os.path.join(data_dir, 'combined_heatmaps.jpg'), dpi=300, bbox_inches='tight', pad_inches=0)
fig.savefig(os.path.join(data_dir, 'combined_heatmaps.pdf'), bbox_inches='tight', pad_inches=0)
plt.show()

In [None]:
# Define parameters
data_dir = 'data_exp1'

# Color definitions
colors = sns.color_palette("deep")
blue_color = colors[0]  # Blue
red_color = colors[3]   # Red

# Font size parameters
label_fontsize = 20
tick_fontsize = 18
legend_fontsize = 18

# Load region-level data
context_reg = pd.read_csv(os.path.join(data_dir, 'context_region.csv'), header=None).values.squeeze()
outcome_reg = pd.read_csv(os.path.join(data_dir, 'outcome_region.csv'), header=None).values.squeeze()
interv_reg = pd.read_csv(os.path.join(data_dir, 'interventions_region.csv'), header=None).values.squeeze()

# 1) Scatter plot: aggregated wealth vs. outcome, colored by intervention with different markers
fig1, ax1 = plt.subplots(figsize=(8, 6))
sns.scatterplot(x=context_reg[interv_reg == 0], y=outcome_reg[interv_reg == 0], color=blue_color, alpha=0.6, s=150, label='Control', marker='o', ax=ax1)
sns.scatterplot(x=context_reg[interv_reg == 1], y=outcome_reg[interv_reg == 1], color=red_color, alpha=0.6, s=150, label='Intervened', marker='^', ax=ax1)
ax1.set_xlabel('Regional Wealth', fontsize=label_fontsize)
ax1.set_ylabel('Outcome', fontsize=label_fontsize)
ax1.tick_params(labelsize=tick_fontsize, length=0)
ax1.legend(loc='upper left', bbox_to_anchor=(1, 1), frameon=False, fontsize=legend_fontsize)
plt.tight_layout()
fig1.savefig(os.path.join(data_dir, 'scatter_wealth_outcome.jpg'), dpi=300, bbox_inches='tight')
fig1.savefig(os.path.join(data_dir, 'scatter_wealth_outcome.pdf'), bbox_inches='tight')
plt.show()

# Load subregion-level data
context_sub = pd.read_csv(os.path.join(data_dir, 'context_subregion.csv'), header=None).values.flatten()
outcome_sub = pd.read_csv(os.path.join(data_dir, 'outcome_subregion.csv'), header=None).values.flatten()
interv_sub = pd.read_csv(os.path.join(data_dir, 'interventions_subregion.csv'), header=None).values.flatten()

# Create DataFrame for violin plot
df = pd.DataFrame({
    'Wealth': context_sub,
    'Intervention': interv_sub,
    'Outcome': outcome_sub
})

# 2) Violin plot: distribution of outcome by wealth and intervention
fig2, ax2 = plt.subplots(figsize=(8, 6))
sns.violinplot(data=df, x='Wealth', y='Outcome', hue='Intervention', palette={0: blue_color, 1: red_color}, split=False, inner='quartile', ax=ax2)
ax2.set_xlabel('Subregional Wealth', fontsize=label_fontsize)
ax2.set_ylabel('Outcome', fontsize=label_fontsize)
ax2.tick_params(labelsize=tick_fontsize, length=0)
handles, labels = ax2.get_legend_handles_labels()
ax2.legend(handles, ['Control', 'Intervened'], loc='upper left', bbox_to_anchor=(1, 1), frameon=False, fontsize=legend_fontsize)
plt.tight_layout()
fig2.savefig(os.path.join(data_dir, 'violin_wealth_intervention_outcome.jpg'), dpi=300, bbox_inches='tight')
fig2.savefig(os.path.join(data_dir, 'violin_wealth_intervention_outcome.pdf'), bbox_inches='tight')
plt.show()

In [None]:
def reshape_to_subgrid(data):
    grid = np.zeros((sub_grid_size, sub_grid_size))
    for ri in range(region_grid_size):
        for rj in range(region_grid_size):
            reg_idx = ri * region_grid_size + rj
            sub_data = data[reg_idx].reshape(sub_per_region_side, sub_per_region_side)
            grid[ri * sub_per_region_side:(ri + 1) * sub_per_region_side, rj * sub_per_region_side:(rj + 1) * sub_per_region_side] = sub_data
    return grid

def estimate_current_subregionaloutcome(model, context_sub_t, interventions_reg_t):
    _, outcome_sub_est_t = model(interventions_reg_t, context_sub_t)
    return outcome_sub_est_t

def estimate_current_causaleffectmatrix(model, context_sub_t):
    causal_effect_sub_t = torch.zeros_like(context_sub_t, dtype=torch.float32)
    causal_effect_sub_t[context_sub_t == 1] = model.theta[3] - model.theta[0]
    causal_effect_sub_t[context_sub_t == 2] = model.theta[4] - model.theta[1]
    causal_effect_sub_t[context_sub_t == 3] = model.theta[5] - model.theta[2]
    return causal_effect_sub_t

# Define important parameters
data_dir = 'data_exp1'
num_regions = 100
subregions_per_region = 100
n_epochs = 1000
lr = 0.1
seed = 42
loglog_flag = True
device = 'cpu' if not torch.cuda.is_available() else 'cuda'

# Set seeds for reproducibility
np.random.seed(seed)
torch.manual_seed(seed)

# Load data
interventions_reg = pd.read_csv(os.path.join(data_dir, 'interventions_region.csv'), header=None).values.squeeze() # (100,)
context_sub = pd.read_csv(os.path.join(data_dir, 'context_subregion.csv'), header=None).values # (100, 100), wealth 1,2,3
outcome_reg = pd.read_csv(os.path.join(data_dir, 'outcome_region.csv'), header=None).values.squeeze() # (100,)

# Load ground truth causal effect
causal_effect_true = pd.read_csv(os.path.join(data_dir, 'causal_effect_subregion.csv'), header=None).values  # (100, 100)

# Convert to tensors
interventions_reg_t = torch.tensor(interventions_reg, dtype=torch.long, device=device) # (100,), 0 or 1
context_sub_t = torch.tensor(context_sub, dtype=torch.long, device=device) # (100, 100), 1,2,3
outcome_reg_t = torch.tensor(outcome_reg, dtype=torch.float32, device=device) # (100,)

# Define PyTorch module
class OutcomeModel(nn.Module):
    def __init__(self):
        super(OutcomeModel, self).__init__()
        self.theta = nn.Parameter(torch.ones(6))

    def forward(self, interventions_reg_t, context_sub_t):
        # Explicitly compute f_theta(t_{i,j}, c_{i,j})
        pred_sub = torch.zeros_like(context_sub_t, dtype=torch.float32)
        theta = self.theta
        # if not self.eval():
        # theta = theta + torch.randn_like(self.theta) * 0.01 # Add noise to theta
        # Masks for intervention
        mask_t0 = (interventions_reg_t[:, None] == 0)
        mask_t1 = (interventions_reg_t[:, None] == 1)
        # For t=0
        pred_sub[mask_t0 & (context_sub_t == 1)] = theta[0] # theta1: t=0, c=1 (poor)
        pred_sub[mask_t0 & (context_sub_t == 2)] = theta[1] # theta2: t=0, c=2 (middle)
        pred_sub[mask_t0 & (context_sub_t == 3)] = theta[2] # theta3: t=0, c=3 (rich)
        # For t=1
        pred_sub[mask_t1 & (context_sub_t == 1)] = theta[3] # theta4: t=1, c=1 (poor)
        pred_sub[mask_t1 & (context_sub_t == 2)] = theta[4] # theta5: t=1, c=2 (middle)
        pred_sub[mask_t1 & (context_sub_t == 3)] = theta[5] # theta6: t=1, c=3 (rich)
        # Aggregate mean per region: (100,)
        pred_reg = pred_sub.mean(dim=1)
        return pred_reg, pred_sub # Return pred_sub for later use

# Instantiate model and optimizer
model = OutcomeModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

# Loss history
loss_hist = []
ce_loss_hist = []

# Training loop
for epoch in range(n_epochs):
    optimizer.zero_grad()
    # Forward pass
    pred_reg, _ = model(interventions_reg_t, context_sub_t)
    # MSE loss
    loss = torch.mean((pred_reg - outcome_reg_t) ** 2)
    # Backward and optimize
    loss.backward()
    optimizer.step()
    loss_hist.append(loss.item())

    # Compute MSE for causal effect
    causal_effect_est_t = estimate_current_causaleffectmatrix(model, context_sub_t)
    ce_mse = np.mean((causal_effect_est_t.detach().cpu().numpy() - causal_effect_true) ** 2)
    ce_loss_hist.append(ce_mse)

model.eval()

# Save and plot loss curves
pd.DataFrame({'loss': loss_hist, 'ce_loss': ce_loss_hist}).to_csv(os.path.join(data_dir, 'estimation_loss.csv'), index=False)

fig, axs = plt.subplots(1, 2, figsize=(16, 6))
epochs = np.arange(1, len(loss_hist) + 1)

if loglog_flag:
    axs[0].loglog(epochs, loss_hist)
    axs[1].loglog(epochs, ce_loss_hist)
else:
    axs[0].plot(epochs, loss_hist)
    axs[1].plot(epochs, ce_loss_hist)

axs[0].set_xlabel('Epoch')
axs[0].set_ylabel('MSE Loss (Region Outcomes)')
axs[0].set_title('Training Loss Curve')

axs[1].set_xlabel('Epoch')
axs[1].set_ylabel('MSE Loss (Causal Effects)')
axs[1].set_title('Causal Effect Estimation Error')

plt.tight_layout()
fig.savefig(os.path.join(data_dir, 'estimation_losses.jpg'), dpi=300)
fig.savefig(os.path.join(data_dir, 'estimation_losses.pdf'))
plt.show()

# Print final theta estimates
print("Final theta estimates:")
print("theta1 (t=0, poor): ", model.theta[0].detach().cpu().numpy())
print("theta2 (t=0, middle class): ", model.theta[1].detach().cpu().numpy())
print("theta3 (t=0, rich): ", model.theta[2].detach().cpu().numpy())
print("theta4 (t=1, poor): ", model.theta[3].detach().cpu().numpy())
print("theta5 (t=1, middle class): ", model.theta[4].detach().cpu().numpy())
print("theta6 (t=1, rich): ", model.theta[5].detach().cpu().numpy())

# Save theta to text file
with open(os.path.join(data_dir, 'parameter_estimate.txt'), 'w') as f:
    f.write("theta1,theta2,theta3,theta4,theta5,theta6\n")
    f.write(','.join(map(str, model.theta.detach().cpu().numpy().flatten())) + '\n')

# Compute deterministic subregion outcomes (no noise)
with torch.no_grad():
    outcome_sub_est_t = estimate_current_subregionaloutcome(model, context_sub_t, interventions_reg_t)
outcome_sub_est = outcome_sub_est_t.cpu().numpy()
pd.DataFrame(outcome_sub_est).to_csv(os.path.join(data_dir, 'outcome_subregion_estimated.csv'), index=False, header=False)

# Visualize estimated subregion outcome
region_grid_size = 10
sub_grid_size = 100
sub_per_region_side = 10

# Color definitions
colors = sns.color_palette("deep")
blue_cmap = sns.light_palette(colors[0], as_cmap=True) # Blue for outcomes

grid = reshape_to_subgrid(outcome_sub_est)
fig_vis, ax_vis = plt.subplots(figsize=(10, 10))
sns.heatmap(grid, ax=ax_vis, cmap=blue_cmap, square=True, linewidths=0, cbar=True, cbar_kws={'location': 'right'}, vmin=-0.1, vmax=0.5)
ax_vis.set_title('Estimated Outcome Subregion', fontsize=20)
ax_vis.set_xticks([])
ax_vis.set_yticks([])
cbar = ax_vis.collections[0].colorbar
cbar.ax.tick_params(labelsize=14)
fig_vis.savefig(os.path.join(data_dir, 'outcome_subregion_estimated.jpg'), dpi=300, bbox_inches='tight')
fig_vis.savefig(os.path.join(data_dir, 'outcome_subregion_estimated.pdf'), bbox_inches='tight')
plt.show()

# Compute causal effect subregion: f(1, c_ij) - f(0, c_ij)
causal_effect_sub_t = estimate_current_causaleffectmatrix(model, context_sub_t)
causal_effect_sub = causal_effect_sub_t.detach().cpu().numpy()
pd.DataFrame(causal_effect_sub).to_csv(os.path.join(data_dir, 'causal_effect_subregion_estimated.csv'), index=False, header=False)

# Visualize estimated causal effect subregion
purple_cmap = sns.light_palette(colors[4], as_cmap=True) # Purple for causal effects
grid_ce = reshape_to_subgrid(causal_effect_sub)
fig_ce, ax_ce = plt.subplots(figsize=(10, 10))
sns.heatmap(grid_ce, ax=ax_ce, cmap=purple_cmap, square=True, linewidths=0, cbar=True, cbar_kws={'location': 'right'}, vmin=-0.1, vmax=0.3)
ax_ce.set_title('Estimated Causal Effect Subregion', fontsize=20)
ax_ce.set_xticks([])
ax_ce.set_yticks([])
cbar_ce = ax_ce.collections[0].colorbar
cbar_ce.ax.tick_params(labelsize=14)
fig_ce.savefig(os.path.join(data_dir, 'causal_effect_subregion_estimated.jpg'), dpi=300, bbox_inches='tight')
fig_ce.savefig(os.path.join(data_dir, 'causal_effect_subregion_estimated.pdf'), bbox_inches='tight')
plt.show()


In [None]:
# Read and plot loss curve
data_dir = 'data_exp1'
loglog_flag = True

# Load loss history from CSV
loss_data = pd.read_csv(os.path.join(data_dir, 'estimation_loss.csv'))
loss_hist = loss_data['loss'].values
ce_loss_hist = loss_data['ce_loss'].values
epochs = np.arange(1, len(loss_hist) + 1)

# Plotting
fig, ax1 = plt.subplots(figsize=(8, 6))
colors = sns.color_palette("deep")

ax1.set_xlabel('Epoch', fontsize=20)
ax1.set_ylabel('Training Loss', fontsize=20, color=colors[0])
if loglog_flag:
    ax1.loglog(epochs, loss_hist, color=colors[0], alpha=0.9, lw=3)
else:
    ax1.plot(epochs, loss_hist, color=colors[0], alpha=0.7, lw=2)
ax1.tick_params(axis='both', which='both', length=0, labelsize=18, colors=colors[0])
ax1.set_title('MSE Curves', fontsize=22)
ax1.spines['top'].set_visible(False)

ax2 = ax1.twinx()
ax2.set_ylabel('MSE Causal Effect', fontsize=20, color=colors[3])
if loglog_flag:
    ax2.loglog(epochs, ce_loss_hist, color=colors[3], alpha=0.9, ls='--', lw=5)
else:
    ax2.plot(epochs, ce_loss_hist, color=colors[3], alpha=0.7, ls='--', lw=2)
ax2.tick_params(axis='y', which='both', length=0, labelsize=18, colors=colors[3])
ax2.spines['top'].set_visible(False)

plt.tight_layout()
fig.savefig(os.path.join(data_dir, 'estimation_loss.jpg'), dpi=300)
fig.savefig(os.path.join(data_dir, 'estimation_loss.pdf'))
plt.show()

In [None]:
# Define important parameters
data_dir = 'data_exp1'
num_regions = 100
region_grid_size = 10

# Load original region-level interventions and invert second half
interventions_reg = pd.read_csv(os.path.join(data_dir, 'interventions_region.csv'), header=None).values.squeeze()  # (100,)
interventions_reg_cf = interventions_reg.copy()
interventions_reg_cf[num_regions // 2:] = 1 - interventions_reg_cf[num_regions // 2:]
pd.DataFrame(interventions_reg_cf).to_csv(os.path.join(data_dir, 'interventions_region_cf.csv'), index=False, header=False)

# Load context, noise
context_sub = pd.read_csv(os.path.join(data_dir, 'context_subregion.csv'), header=None).values  # (100, 100)
noise_sub = pd.read_csv(os.path.join(data_dir, 'noise_subregion.csv'), header=None).values  # (100, 100)

# Create subregion-level counterfactual interventions (all 0s or 1s per row)
interventions_sub_cf = np.repeat(interventions_reg_cf[:, np.newaxis], context_sub.shape[1], axis=1)  # (100, 100)

# Ground truth delta based on context: -0.1 for 1, 0 for 2, 0.3 for 3
delta = np.zeros_like(context_sub, dtype=float)
delta[context_sub == 1] = -0.1
delta[context_sub == 2] = 0.0
delta[context_sub == 3] = 0.3

# Compute counterfactual outcome subregion
outcome_sub_cf = noise_sub + interventions_sub_cf * delta
pd.DataFrame(outcome_sub_cf).to_csv(os.path.join(data_dir, 'outcome_subregion_cf.csv'), index=False, header=False)

# Aggregate mean per row for region-level
outcome_reg_cf = np.mean(outcome_sub_cf, axis=1)
pd.DataFrame(outcome_reg_cf).to_csv(os.path.join(data_dir, 'outcome_region_cf.csv'), index=False, header=False)

# Visualize region-level counterfactual outcome
colors = sns.color_palette("deep")
blue_cmap = sns.light_palette(colors[0], as_cmap=True)

def reshape_to_reggrid(data):
    return data.reshape(region_grid_size, region_grid_size)

grid_cf = reshape_to_reggrid(outcome_reg_cf)
fig_cf, ax_cf = plt.subplots(figsize=(5, 5))
sns.heatmap(grid_cf, ax=ax_cf, cmap=blue_cmap, square=True, linewidths=1, cbar=True, cbar_kws={'location': 'right'}, vmin=-0.05, vmax=0.1)
ax_cf.set_title('CF Outcome GT', fontsize=20)
ax_cf.set_xticks([])
ax_cf.set_yticks([])
cbar = ax_cf.collections[0].colorbar
cbar.ax.tick_params(labelsize=14)
fig_cf.savefig(os.path.join(data_dir, 'outcome_region_cf.jpg'), dpi=300, bbox_inches='tight')
fig_cf.savefig(os.path.join(data_dir, 'outcome_region_cf.pdf'), bbox_inches='tight')
plt.show()

In [None]:
def plot_interventions_region_cf():
    file_path = os.path.join(data_dir, 'interventions_region_cf.csv')
    data = pd.read_csv(file_path, header=None).values.squeeze()
    grid = reshape_to_reggrid(data)
    fig, ax = plt.subplots(figsize=(5, 5))
    sns.heatmap(grid, ax=ax, cmap=red_cmap, square=True, linewidths=1, cbar=True, cbar_kws={'location': 'right'})
    ax.set_title('CF Interventions', fontsize=20)
    ax.set_xticks([])
    ax.set_yticks([])
    cbar = ax.collections[0].colorbar
    cbar.ax.tick_params(labelsize=14)
    fig.savefig(os.path.join(data_dir, 'interventions_region_cf.jpg'), dpi=300, bbox_inches='tight')
    fig.savefig(os.path.join(data_dir, 'interventions_region_cf.pdf'), bbox_inches='tight')
    plt.show()

plot_interventions_region_cf()

In [None]:
# Define important parameters
data_dir = 'data_exp1'
num_regions = 100
subregions_per_region = 100
region_grid_size = 10
sub_grid_size = 100
sub_per_region_side = 10
device = 'cpu'

# Load data
interventions_reg_cf = pd.read_csv(os.path.join(data_dir, 'interventions_region_cf.csv'), header=None).values.squeeze()  # (100,)
interventions_reg = pd.read_csv(os.path.join(data_dir, 'interventions_region.csv'), header=None).values.squeeze()  # (100,)
context_sub = pd.read_csv(os.path.join(data_dir, 'context_subregion.csv'), header=None).values  # (100, 100)
outcome_reg = pd.read_csv(os.path.join(data_dir, 'outcome_region.csv'), header=None).values.squeeze()  # (100,)

# Load inferred theta
with open(os.path.join(data_dir, 'parameter_estimate.txt'), 'r') as f:
    lines = f.readlines()
    theta_values = list(map(float, lines[1].strip().split(',')))
theta = np.array(theta_values)  # [theta1, theta2, theta3, theta4, theta5, theta6]

# Convert to tensors
interventions_reg_t = torch.tensor(interventions_reg, dtype=torch.long, device=device)
interventions_reg_cf_t = torch.tensor(interventions_reg_cf, dtype=torch.long, device=device)
context_sub_t = torch.tensor(context_sub, dtype=torch.long, device=device)

# Function to compute subregional predictions with theta
def compute_pred_sub(interventions_t, context_t, theta):
    pred_sub = torch.zeros_like(context_t, dtype=torch.float32)
    mask_t0 = (interventions_t[:, None] == 0)
    mask_t1 = (interventions_t[:, None] == 1)
    pred_sub[mask_t0 & (context_t == 1)] = theta[0]
    pred_sub[mask_t0 & (context_t == 2)] = theta[1]
    pred_sub[mask_t0 & (context_t == 3)] = theta[2]
    pred_sub[mask_t1 & (context_t == 1)] = theta[3]
    pred_sub[mask_t1 & (context_t == 2)] = theta[4]
    pred_sub[mask_t1 & (context_t == 3)] = theta[5]
    return pred_sub

# 1) Estimate regional noise: residual = actual_reg - pred_reg_original (deterministic mean)
pred_sub_original = compute_pred_sub(interventions_reg_t, context_sub_t, theta)
pred_reg_original = pred_sub_original.mean(dim=1).numpy()
regional_noise = outcome_reg - pred_reg_original  # (100,)

# 2) Compute CF subregional deterministic
pred_sub_cf = compute_pred_sub(interventions_reg_cf_t, context_sub_t, theta)
pd.DataFrame(pred_sub_cf.numpy()).to_csv(os.path.join(data_dir, 'outcome_subregion_cf_estimated.csv'), index=False, header=False)

# Aggregate to regional deterministic, add noise
pred_reg_cf_deterministic = pred_sub_cf.mean(dim=1).numpy()
outcome_reg_cf = pred_reg_cf_deterministic + regional_noise
pd.DataFrame(outcome_reg_cf).to_csv(os.path.join(data_dir, 'outcome_region_cf_estimated.csv'), index=False, header=False)

# 3) Visualize CF regional outcome
colors = sns.color_palette("deep")
blue_cmap = sns.light_palette(colors[0], as_cmap=True)

def reshape_to_reggrid(data):
    return data.reshape(region_grid_size, region_grid_size)

grid_cf = reshape_to_reggrid(outcome_reg_cf)
fig_cf, ax_cf = plt.subplots(figsize=(5, 5))
sns.heatmap(grid_cf, ax=ax_cf, cmap=blue_cmap, square=True, linewidths=1, cbar=True, cbar_kws={'location': 'right'}, vmin=-0.05, vmax=0.1)
ax_cf.set_title('CF Outcome Estimate', fontsize=20)
ax_cf.set_xticks([])
ax_cf.set_yticks([])
cbar = ax_cf.collections[0].colorbar
cbar.ax.tick_params(labelsize=14)
fig_cf.savefig(os.path.join(data_dir, 'outcome_region_cf_estimated.jpg'), dpi=300, bbox_inches='tight')
fig_cf.savefig(os.path.join(data_dir, 'outcome_region_cf_estimated.pdf'), bbox_inches='tight')
plt.show()

# Ablation for Exp 1.
Testing how regional (mean) context variation and subregional context variation after recovery of the underlying causal mechanism.

In [None]:
import numpy as np
import pandas as pd
import os

In [None]:
# ==== CONFIG ====
GRID = 10              # number of regions per side
SUB = 10               # subregions per region per side
NUM_REGIONS = GRID * GRID
NUM_SUBREGIONS = SUB * SUB
OUTPUT_DIR = "data_ablation_exp1"

os.makedirs(OUTPUT_DIR, exist_ok=True)

# --- Full Simulation Generator ---
def generate_full_simulation(regional_std, inter_std, seed=42, outfolder="data_ablation_exp1"):
    rng = np.random.default_rng(seed)

    # Generate a unique folder for each setting
    setting_folder = os.path.join(outfolder, f"inter{inter_std:.2f}_regional{regional_std:.2f}_sample{seed}")
    os.makedirs(setting_folder, exist_ok=True)

    # --- Fixed parameters ---
    GRID_SIZE, SUBREGION_SIZE = 10, 10
    RICH_VOTE, POOR_VOTE, INTER_VOTE, BASELINE_VOTE = 0.80, 0.40, 0.50, 0.50

    # Map std hyperparams to data gen control
    spatial_variance = int(np.ceil(regional_std * 10))   # controls wiggle in counts
    outcome_noise = inter_std                            # subregion-level noise

    # --- 1) interventions ---
    interventions = rng.integers(0, 2, size=(GRID_SIZE, GRID_SIZE))
    interventions[0, 0] = 1
    interventions[0, 1] = 1
    pd.DataFrame(interventions).to_csv(os.path.join(setting_folder, "interventions.csv"), header=False, index=False)

    # --- 2) wealth grid generation ---
    wealth_hi = np.zeros((GRID_SIZE*SUBREGION_SIZE, GRID_SIZE*SUBREGION_SIZE), dtype=int)
    for i in range(GRID_SIZE):
        for j in range(GRID_SIZE):
            n_inter = rng.integers(20, 101)
            remaining = 100 - n_inter
            wiggle = min(spatial_variance, remaining // 2)
            n_rich = np.clip(remaining // 2 + rng.integers(-wiggle, wiggle+1), 0, remaining)
            n_poor = remaining - n_rich

            if (i, j) == (0, 0):
                block_vals = np.full(100, 2)  # all intermediate
            elif (i, j) == (0, 1):
                half = 50
                block_vals = np.concatenate([np.ones(half, int), np.full(half, 3, int)])
                rng.shuffle(block_vals)
            else:
                block_vals = np.concatenate([
                    np.ones(n_poor, int),
                    np.full(n_inter, 2, int),
                    np.full(n_rich, 3, int)
                ])
                rng.shuffle(block_vals)

            r0, c0 = i*SUBREGION_SIZE, j*SUBREGION_SIZE
            wealth_hi[r0:r0+SUBREGION_SIZE, c0:c0+SUBREGION_SIZE] = \
                block_vals.reshape(SUBREGION_SIZE, SUBREGION_SIZE)

    pd.DataFrame(wealth_hi).to_csv(os.path.join(setting_folder, "wealth_high_res.csv"), header=False, index=False)

    # --- 3) outcome generation ---
    subregion_noise = np.zeros_like(wealth_hi, dtype=float)
    outcome = np.zeros_like(interventions, dtype=float)

    def vote(w_class, treated, rr, cc):
        noise = rng.normal(0, outcome_noise)
        subregion_noise[rr, cc] = noise
        if not treated:
            return np.clip(BASELINE_VOTE + noise, 0, 1)
        base = RICH_VOTE if w_class == 3 else POOR_VOTE if w_class == 1 else INTER_VOTE
        return np.clip(base + noise, 0, 1)

    for i in range(GRID_SIZE):
        for j in range(GRID_SIZE):
            r0, c0 = i*SUBREGION_SIZE, j*SUBREGION_SIZE
            block_wealth = wealth_hi[r0:r0+SUBREGION_SIZE, c0:c0+SUBREGION_SIZE]
            treated = bool(interventions[i, j])
            votes = [[vote(block_wealth[r, c], treated, r0 + r, c0 + c)
                      for c in range(SUBREGION_SIZE)]
                      for r in range(SUBREGION_SIZE)]
            outcome[i, j] = np.mean(votes) * 100

    pd.DataFrame(np.round(subregion_noise, 4)).to_csv(os.path.join(setting_folder, "subregion_noise_gt.csv"), header=False, index=False)
    region_noise = subregion_noise.reshape(GRID_SIZE, SUBREGION_SIZE, GRID_SIZE, SUBREGION_SIZE).sum(axis=(1, 3))
    pd.DataFrame(np.round(region_noise, 4)).to_csv(os.path.join(setting_folder, "region_noise_gt.csv"), header=False, index=False)
    pd.DataFrame(np.round(outcome, 2)).to_csv(os.path.join(setting_folder, "outcome.csv"), header=False, index=False)

    # --- 4) counterfactual ---
    interventions_cf = interventions.copy()
    interventions_cf[:GRID_SIZE//2, :] = 1 - interventions_cf[:GRID_SIZE//2, :]
    pd.DataFrame(interventions_cf).to_csv(os.path.join(setting_folder, "interventions_cf_gt.csv"), header=False, index=False)

    outcome_cf = np.zeros_like(interventions_cf, dtype=float)
    for i in range(GRID_SIZE):
        for j in range(GRID_SIZE):
            r0, c0 = i*SUBREGION_SIZE, j*SUBREGION_SIZE
            block_wealth = wealth_hi[r0:r0+SUBREGION_SIZE, c0:c0+SUBREGION_SIZE]
            treated = bool(interventions_cf[i, j])
            votes = [[
                np.clip(
                    (RICH_VOTE if w == 3 else POOR_VOTE if w == 1 else INTER_VOTE) + subregion_noise[r0 + r, c0 + c]
                    if treated else BASELINE_VOTE + subregion_noise[r0 + r, c0 + c],
                    0, 1
                )
                for c, w in enumerate(block_wealth[r])]
                for r in range(SUBREGION_SIZE)]
            outcome_cf[i, j] = np.mean(votes) * 100

    pd.DataFrame(np.round(outcome_cf, 2)).to_csv(os.path.join(setting_folder, "outcome_cf_gt.csv"), header=False, index=False)

    print(f"✓ Simulation saved for inter_std={inter_std}, regional_std={regional_std}, sample={seed}")


In [None]:
output_dir = "data_ablation_exp1"
os.makedirs(output_dir, exist_ok=True)

inter_std_values = [0.01, 0.1, 0.3, 0.5, 0.7]
regional_std_values = [0.1, 0.3, 0.5, 0.7, 1.0]
num_samples = 5
base_seed = 42

for inter_std in inter_std_values:
    for regional_std in regional_std_values:
        for sample_idx in range(num_samples):
            seed = base_seed + sample_idx
            generate_full_simulation(
                inter_std=inter_std,
                regional_std=regional_std,
                seed=seed
            )

            print(f"✓ Simulation saved for inter_std={inter_std:.2f}, "
                  f"regional_std={regional_std:.2f}, sample={sample_idx:02d}")

In [None]:
import os
import numpy as np
import pandas as pd
import torch
import pathlib
import matplotlib.pyplot as plt

DATA_ROOT = "data_ablation_exp1"
OUTFOLDER = "exp1_ablation_results"
GRID, SUB = 10, 10
N_EPOCHS, LR = 2000, 0.1
DEVICE = "cpu"

TRUE_PARAMS = np.array([
    0.50, 0.50, 0.50,
    0.40, 0.50, 0.80
])

# Ablation ranges
inter_std_values = [0.01, 0.1, 0.3, 0.5, 0.7]
regional_std_values = [0.1, 0.3, 0.5, 0.7, 1.0]
num_samples = 5
base_seed = 42

pathlib.Path(OUTFOLDER).mkdir(exist_ok=True)

for inter_std in inter_std_values:
    for regional_std in regional_std_values:
        for sample_idx in range(num_samples):

            seed = base_seed + sample_idx
            setting_dir = os.path.join(
                DATA_ROOT, f"inter{inter_std:.2f}_regional{regional_std:.2f}_sample{seed}"
            )

            if not os.path.exists(setting_dir):
                print(f" Missing data for {setting_dir}, skipping.")
                continue

            print(f"--- Training for inter_std={inter_std:.2f}, "
                  f"regional_std={regional_std:.2f}, sample={sample_idx:02d} ---")


            wealth_hi = pd.read_csv(os.path.join(setting_dir, "wealth_high_res.csv"), header=None).values
            interv_np = pd.read_csv(os.path.join(setting_dir, "interventions.csv"), header=None).values
            outcome_np = pd.read_csv(os.path.join(setting_dir, "outcome.csv"), header=None).values


            counts_np = np.zeros((GRID, GRID, 3), dtype=int)
            for i in range(GRID):
                for j in range(GRID):
                    blk = wealth_hi[i*SUB:(i+1)*SUB, j*SUB:(j+1)*SUB]
                    counts_np[i, j, 0] = (blk == 1).sum()
                    counts_np[i, j, 1] = (blk == 2).sum()
                    counts_np[i, j, 2] = (blk == 3).sum()

            interv_t  = torch.tensor(interv_np,  dtype=torch.float32, device=DEVICE)
            counts_t  = torch.tensor(counts_np,  dtype=torch.float32, device=DEVICE)
            outcome_t = torch.tensor(outcome_np, dtype=torch.float32, device=DEVICE)
            treated_mask = interv_t == 1


            params = torch.nn.Parameter(torch.tensor(
                [0.5, 0.5, 0.5, 0.5, 0.5, 0.5], dtype=torch.float32, device=DEVICE
            ))
            opt = torch.optim.Adam([params], lr=LR)

            loss_hist = []
            param_mse_hist = []

            for epoch in range(N_EPOCHS):
                opt.zero_grad()

                mean_ctrl  = (counts_t[:, :, 0]*params[0] +
                              counts_t[:, :, 1]*params[1] +
                              counts_t[:, :, 2]*params[2]) / 100.0
                mean_trt   = (counts_t[:, :, 0]*params[3] +
                              counts_t[:, :, 1]*params[4] +
                              counts_t[:, :, 2]*params[5]) / 100.0

                mean_pred  = torch.where(treated_mask, mean_trt, mean_ctrl)
                loss = torch.mean((mean_pred*100 - outcome_t) ** 2)

                param_mse = torch.mean((params - torch.tensor(TRUE_PARAMS, device=DEVICE))**2)

                loss.backward()
                opt.step()

                loss_hist.append(loss.item())
                param_mse_hist.append(param_mse.item())

            loss_df = pd.DataFrame({
                "mse_loss": loss_hist,
                "param_mse": param_mse_hist
            })
            loss_filename = f"loss_inter{inter_std:.2f}_regional{regional_std:.2f}_sample{sample_idx:02d}.csv"
            loss_df.to_csv(os.path.join(OUTFOLDER, loss_filename), index_label="epoch")

            param_names = ["mu_poor_0","mu_inter_0","mu_rich_0",
                           "mu_poor_1","mu_inter_1","mu_rich_1"]
            results_df = pd.DataFrame({"parameter": param_names,
                                       "value": params.detach().cpu().numpy()})
            results_filename = f"results_inter{inter_std:.2f}_regional{regional_std:.2f}_sample{sample_idx:02d}.csv"
            results_df.to_csv(os.path.join(OUTFOLDER, results_filename), index=False)

            with torch.no_grad():
                mean_ctrl_final  = (counts_t[:, :, 0]*params[0] +
                                    counts_t[:, :, 1]*params[1] +
                                    counts_t[:, :, 2]*params[2]) / 100.0
                mean_trt_final   = (counts_t[:, :, 0]*params[3] +
                                    counts_t[:, :, 1]*params[4] +
                                    counts_t[:, :, 2]*params[5]) / 100.0
                mean_factual = torch.where(treated_mask, mean_trt_final, mean_ctrl_final) * 100

            region_noise = (outcome_t - mean_factual).cpu().numpy()
            noise_filename = f"region_noise_inter{inter_std:.2f}_regional{regional_std:.2f}_sample{sample_idx:02d}.csv"
            pd.DataFrame(np.round(region_noise,3)).to_csv(os.path.join(OUTFOLDER, noise_filename),
                                                          header=False, index=False)

print("✓ Ablation study training complete.")


In [None]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

OUTFOLDER = "exp1_ablation_results"
inter_std_values = [0.01, 0.1, 0.3, 0.5, 0.7]
regional_std_values = [0.1, 0.3, 0.5, 0.7, 1.0]
num_samples = 5

sns.set_context("paper", font_scale=1.4)
plt.rcParams.update({
    "axes.edgecolor": "black",
    "axes.linewidth": 1.5,
    "axes.spines.top": False,
    "axes.spines.right": False,
    "xtick.direction": "out",
    "ytick.direction": "out",
    "xtick.major.size": 5,
    "ytick.major.size": 5,
    "xtick.labelsize": 13,
    "ytick.labelsize": 13,
    "legend.frameon": False,
})
colors = sns.color_palette("colorblind")

loss_records = []

for inter_std in inter_std_values:
    for regional_std in regional_std_values:
        for sample_idx in range(num_samples):
            fname = f"loss_inter{inter_std:.2f}_regional{regional_std:.2f}_sample{sample_idx:02d}.csv"
            fpath = os.path.join(OUTFOLDER, fname)
            if not os.path.exists(fpath):
                print(f"[Warning] Missing: {fname}")
                continue
            df = pd.read_csv(fpath)
            loss_records.append({
                "inter_std": inter_std,
                "regional_std": regional_std,
                "sample": sample_idx,
                "final_mse_loss": df["mse_loss"].iloc[-1],
                "final_param_mse": df["param_mse"].iloc[-1]
            })

loss_df = pd.DataFrame(loss_records)

loss_summary = (
    loss_df.groupby(["inter_std", "regional_std"])
    .agg(final_mse_loss_mean=("final_mse_loss", "mean"),
         final_mse_loss_std=("final_mse_loss", "std"),
         final_param_mse_mean=("final_param_mse", "mean"),
         final_param_mse_std=("final_param_mse", "std"))
    .reset_index()
)


fig, ax = plt.subplots(figsize=(7, 5))
for i, reg_std in enumerate(regional_std_values):
    subset = loss_summary[loss_summary["regional_std"] == reg_std]
    ax.errorbar(
        subset["inter_std"], subset["final_mse_loss_mean"],
        yerr=subset["final_mse_loss_std"],
        label=f"{reg_std:.2f}",
        marker="o", markersize=6,
        lw=2.5, capsize=4,
        color=colors[i % len(colors)]
    )

ax.set_title("Final Region-Level MSE", fontsize=16, pad=12)
ax.set_xlabel("Within-Region Standard Deviation", fontsize=14)
ax.set_ylabel("Region-Level MSE", fontsize=14)
ax.legend(title="Across-Region Std", fontsize=11, title_fontsize=12)
ax.grid(False)
ax.tick_params(axis='both', which='both', length=0) # <--- REMOVE TICKS
plt.tight_layout()
fig.savefig(os.path.join(OUTFOLDER, "final_mse_loss_with_uncertainty.jpg"), dpi=300)
fig.savefig(os.path.join(OUTFOLDER, "final_mse_loss_with_uncertainty.pdf"))
plt.show()


fig, ax = plt.subplots(figsize=(7, 5))
for i, reg_std in enumerate(regional_std_values):
    subset = loss_summary[loss_summary["regional_std"] == reg_std]
    ax.errorbar(
        subset["inter_std"], subset["final_param_mse_mean"],
        yerr=subset["final_param_mse_std"],
        label=f"{reg_std:.2f}",
        marker="o", markersize=6,
        lw=2.5, capsize=4,
        color=colors[i % len(colors)]
    )

ax.set_title("Final Parameter MSE", fontsize=16, pad=12)
ax.set_xlabel("Within-Region Standard Deviation", fontsize=14)
ax.set_ylabel("Final Parameter-Level MSE", fontsize=14)
ax.legend(title="Across-Region Std", fontsize=11, title_fontsize=12)
ax.grid(False)
ax.tick_params(axis='both', which='both', length=0) # <--- REMOVE TICKS
plt.tight_layout()
fig.savefig(os.path.join(OUTFOLDER, "final_param_mse_with_uncertainty.jpg"), dpi=300)
fig.savefig(os.path.join(OUTFOLDER, "final_param_mse_with_uncertainty.pdf"))
plt.show()


pivot_loss = loss_summary.pivot(index="inter_std", columns="regional_std", values="final_mse_loss_mean")
fig, ax = plt.subplots(figsize=(7.5, 5.5))
sns.heatmap(
    pivot_loss, annot=True, fmt=".3f", cmap="viridis", linewidths=0.5,
    cbar_kws={'label': 'Final Region MSE'}, ax=ax
)

ax.set_title("Mean Final Region MSE (Heatmap)", fontsize=16, pad=12)
ax.set_xlabel("Across-Region Std", fontsize=14)
ax.set_ylabel("Within-Region Std", fontsize=14)
plt.tight_layout()
fig.savefig(os.path.join(OUTFOLDER, "final_mse_loss_heatmap_mean.jpg"), dpi=300)
fig.savefig(os.path.join(OUTFOLDER, "final_mse_loss_heatmap_mean.pdf"))
plt.show()

#Exp. 2
Unknown Intervention Locations : Examine how regional public school funding affects educational outcomes when the exact subregion receiving the funding is unknown. Each region contains multiple school districts, but only one district receives additional resources, and at the regional level only the total spending is observed. The goal is to identify the hidden intervention location within each region and recover the corresponding local causal effects.

In [None]:
import numpy as np
import pandas as pd
import os

# Define parameters
num_regions = 100
num_subregions = 4
shift_param = 1.23
scale_param = 4.2
noise_var = 0.005
data_dir = 'data_exp2'
os.makedirs(data_dir, exist_ok=True)
np.random.seed(42)

# 1) Interventions subregion
interventions_sub = np.zeros((num_regions, num_subregions))
for i in range(1, num_regions):  # Skip row 0
    col_idx = np.random.randint(0, num_subregions)
    interventions_sub[i, col_idx] = np.random.uniform(0.2, 1)
pd.DataFrame(interventions_sub).to_csv(os.path.join(data_dir, 'interventions_subregion.csv'), index=False, header=False)

# Interventions region: sum rows
interventions_reg = np.sum(interventions_sub, axis=1)
pd.DataFrame(interventions_reg).to_csv(os.path.join(data_dir, 'interventions_region.csv'), index=False, header=False)

# 2) Context (wealth) subregion
context_sub = np.random.uniform(0, 1, (num_regions, num_subregions))
pd.DataFrame(context_sub).to_csv(os.path.join(data_dir, 'context_subregion.csv'), index=False, header=False)

# 3) Noise subregion
noise_sub = np.random.normal(0, np.sqrt(noise_var), (num_regions, num_subregions))
pd.DataFrame(noise_sub).to_csv(os.path.join(data_dir, 'noise_subregion.csv'), index=False, header=False)

# 4) Outcome subregion
outcome_sub = (shift_param - context_sub) * scale_param * interventions_sub + noise_sub
pd.DataFrame(outcome_sub).to_csv(os.path.join(data_dir, 'outcome_subregion.csv'), index=False, header=False)

# Outcome region: average rows
outcome_reg = np.mean(outcome_sub, axis=1)
pd.DataFrame(outcome_reg).to_csv(os.path.join(data_dir, 'outcome_region.csv'), index=False, header=False)

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os

# Define parameters
data_dir = 'data_exp2'
num_regions = 100
num_subregions = 4
grid_size_reg = 10  # 10x10 for regions
grid_size_sub = 20  # 20x20 for subregions (2x2 per region)
sns.set_context('paper', font_scale=1.5)
reds = sns.color_palette("deep")[3]
blues = sns.color_palette("deep")[0]

def visualize_subregion_csv(file_name, cmap, title, is_wealth=False):
    data = pd.read_csv(os.path.join(data_dir, file_name), header=None).values
    # Reshape to 20x20: each region row becomes 2x2 block
    grid = np.zeros((grid_size_sub, grid_size_sub))
    for i in range(grid_size_reg):
        for j in range(grid_size_reg):
            region_idx = i * grid_size_reg + j
            sub_data = data[region_idx, :].reshape(2, 2)
            grid[2*i:2*(i+1), 2*j:2*(j+1)] = sub_data
    fig, ax = plt.subplots(figsize=(8, 8))
    if is_wealth:
        heatmap_cmap = sns.color_palette("Spectral", as_cmap=True)
    else:
        if isinstance(cmap, tuple):
            heatmap_cmap = sns.light_palette(cmap, as_cmap=True)
        else:
            heatmap_cmap = cmap
    sns.heatmap(grid, ax=ax, cmap=heatmap_cmap, square=True, cbar=True, cbar_kws={'label': 'Value', 'location': 'right', 'pad': 0.1})
    ax.set_title(title, fontsize=18)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.grid(False)
    plt.legend([], [], frameon=False)  # No legend box
    plt.tight_layout()
    base_name = file_name.replace('.csv', '')
    plt.savefig(os.path.join(data_dir, f'{base_name}.jpg'), dpi=300)
    plt.savefig(os.path.join(data_dir, f'{base_name}.pdf'))
    plt.show()
    plt.close()

def visualize_region_csv(file_name, cmap, title, is_wealth=False):
    data = pd.read_csv(os.path.join(data_dir, file_name), header=None).values.flatten()
    # Reshape to 10x10
    grid = data.reshape(grid_size_reg, grid_size_reg)
    fig, ax = plt.subplots(figsize=(8, 8))
    if is_wealth:
        heatmap_cmap = sns.color_palette("Spectral", as_cmap=True)
    else:
        if isinstance(cmap, tuple):
            heatmap_cmap = sns.light_palette(cmap, as_cmap=True)
        else:
            heatmap_cmap = cmap
    sns.heatmap(grid, ax=ax, cmap=heatmap_cmap, square=True, cbar=True, cbar_kws={'label': 'Value', 'location': 'right', 'pad': 0.1})
    ax.set_title(title, fontsize=18)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.grid(False)
    plt.legend([], [], frameon=False)
    plt.tight_layout()
    base_name = file_name.replace('.csv', '')
    plt.savefig(os.path.join(data_dir, f'{base_name}.jpg'), dpi=300)
    plt.savefig(os.path.join(data_dir, f'{base_name}.pdf'))
    plt.show()
    plt.close()

# Visualize each
visualize_subregion_csv('interventions_subregion.csv', reds, 'Interventions Subregion', is_wealth=False)
visualize_region_csv('interventions_region.csv', reds, 'Interventions Region', is_wealth=False)
visualize_subregion_csv('context_subregion.csv', None, 'Wealth Subregion', is_wealth=True)
visualize_subregion_csv('noise_subregion.csv', 'coolwarm', 'Noise Subregion', is_wealth=False)  # Using coolwarm for noise as not specified
visualize_subregion_csv('outcome_subregion.csv', blues, 'Outcome Subregion', is_wealth=False)
visualize_region_csv('outcome_region.csv', blues, 'Outcome Region', is_wealth=False)

# Wealth low-res: average context_sub to get wealth_reg
context_sub = pd.read_csv(os.path.join(data_dir, 'context_subregion.csv'), header=None).values
wealth_reg = np.mean(context_sub, axis=1)
pd.DataFrame(wealth_reg).to_csv(os.path.join(data_dir, 'context_region.csv'), index=False, header=False)  # Temp save for viz
visualize_region_csv('context_region.csv', None, 'Wealth Region', is_wealth=True)

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os

# Define parameters
data_dir = 'data_exp2'
grid_size_reg = 10  # 10x10 for regions
grid_size_sub = 20  # 20x20 for subregions (2x2 per region)

# Color maps from seaborn deep palette
colors = sns.color_palette("deep")
red_cmap = sns.light_palette(colors[3], as_cmap=True)  # Red
blue_cmap = sns.light_palette(colors[0], as_cmap=True)  # Blue
green_cmap = sns.light_palette(colors[2], as_cmap=True)  # Green

# Parameter to control space between subplots (as fraction of average axis width)
space_between = 0.05  # Adjust this value as needed (e.g., 0.0 for no space, 0.5 for more space)

# Load data for combined
interv_data = pd.read_csv(os.path.join(data_dir, 'interventions_region.csv'), header=None).values.flatten().reshape(grid_size_reg, grid_size_reg).astype(float)  # Assuming 0/1, float for heatmap
outcome_data = pd.read_csv(os.path.join(data_dir, 'outcome_region.csv'), header=None).values.flatten().reshape(grid_size_reg, grid_size_reg)# / 100.0  # Assuming 0-100, normalize to 0-1
context_data = pd.read_csv(os.path.join(data_dir, 'context_subregion.csv'), header=None).values
context_grid = np.zeros((grid_size_sub, grid_size_sub))
for i in range(grid_size_reg):
    for j in range(grid_size_reg):
        region_idx = i * grid_size_reg + j
        sub_data = context_data[region_idx, :].reshape(2, 2)
        context_grid[2*i:2*(i+1), 2*j:2*(j+1)] = sub_data
context_norm = (context_grid - 1) / 2.0  # Normalize assuming values 1,2,3 to 0-1

# Create a single figure with three subplots side by side
fig, axs = plt.subplots(1, 3, figsize=(15, 5))

# Left: Intervention in red
sns.heatmap(interv_data, ax=axs[0], cmap=red_cmap, cbar=False, square=True, linewidths=2, vmin=0.0, vmax=1.0)
axs[0].set_xticks([])
axs[0].set_yticks([])

sns.heatmap(outcome_data, ax=axs[1], cmap=blue_cmap, cbar=False, square=True, linewidths=2, vmin=0.0, vmax=1.0)
axs[1].set_xticks([])
axs[1].set_yticks([])


# Right: Context in green
sns.heatmap(context_norm, ax=axs[2], cmap=green_cmap, cbar=False, square=True, linewidths=1)#, vmin=0.0)#, vmax=1.0)
axs[2].set_xticks([])
axs[2].set_yticks([])

# Adjust the space between subplots
fig.subplots_adjust(wspace=space_between)

# Save and show
fig.savefig(os.path.join(data_dir, 'combined_heatmaps.jpg'), dpi=300, bbox_inches='tight', pad_inches=0)
fig.savefig(os.path.join(data_dir, 'combined_heatmaps.pdf'), bbox_inches='tight', pad_inches=0)
plt.show()

In [None]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import seaborn as sns
import os

# Define parameters
data_dir = 'data_exp2'
num_regions = 100
num_subregions = 4
epochs = 10000
lr = 0.001
torch.manual_seed(42)
grid_size_sub = 20
sns.set_context('paper', font_scale=1.5)
reds = sns.color_palette("deep")[3]

# True parameters (from data generation)
true_shift_param = 1.23
true_scale_param = 4.2


# Read all CSVs
interventions_reg = torch.tensor(pd.read_csv(os.path.join(data_dir, 'interventions_region.csv'), header=None).values.flatten(), dtype=torch.float32)
outcome_reg = torch.tensor(pd.read_csv(os.path.join(data_dir, 'outcome_region.csv'), header=None).values.flatten(), dtype=torch.float32)
context_sub = torch.tensor(pd.read_csv(os.path.join(data_dir, 'context_subregion.csv'), header=None).values, dtype=torch.float32)
outcome_sub = torch.tensor(pd.read_csv(os.path.join(data_dir, 'outcome_subregion.csv'), header=None).values, dtype=torch.float32)


# Model
class CausalModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.inter_est = nn.Parameter(torch.randn(num_regions, num_subregions))
        self.shift = nn.Parameter(torch.tensor(1.0))
        self.scale = nn.Parameter(torch.tensor(1.0))
        self.temp = torch.tensor(10.0)
        self.noise_scale = 0.1

    def preprocess_inter(self):
        inter_pos = (self.inter_est + torch.randn_like(self.inter_est)*self.noise_scale) ** 2  # Ensure positivity
        inter_sparse = torch.softmax(inter_pos / self.temp, dim=1)  # Softmax for differentiable near-one-hot sparsity
        inter_constrained = inter_sparse * interventions_reg.unsqueeze(1)  # Multiply by known region sum
        return inter_constrained

    def forward(self):
        inter_sub = self.preprocess_inter()
        outcome_sub_pred = (self.shift - context_sub) * self.scale * inter_sub# + noise_sub
        outcome_reg_pred = outcome_sub_pred.mean(dim=1)
        return outcome_reg_pred, outcome_sub_pred

model = CausalModel()
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.MSELoss()

# Train
losses = []
param_losses = []
subregion_losses = []
parameter_logs = {'shift': [], 'scale': []}
start_temp = 10.0
end_temp = 2.0
for epoch in range(epochs):
    model.temp = torch.tensor(start_temp + (end_temp - start_temp) * (epoch / (epochs - 1)))
    optimizer.zero_grad()
    outcome_reg_pred, outcome_sub_pred = model.forward()
    loss = criterion(outcome_reg_pred, outcome_reg)

    # Calculate subregion loss
    subregion_loss = criterion(outcome_sub_pred, outcome_sub)
    subregion_losses.append(subregion_loss.item())

    loss.backward()
    optimizer.step()
    losses.append(loss.item())

    # Log parameter values
    parameter_logs['shift'].append(model.shift.item())
    parameter_logs['scale'].append(model.scale.item())

    # Calculate loss with respect to true parameters
    param_loss = criterion(torch.tensor([model.shift.item(), model.scale.item()]), torch.tensor([true_shift_param, true_scale_param]))
    param_losses.append(param_loss.item())


# Save losses to CSV
pd.DataFrame(losses).to_csv(os.path.join(data_dir, 'outcome_region_loss.csv'), index=False, header=['loss'])
pd.DataFrame(param_losses).to_csv(os.path.join(data_dir, 'parameter_loss.csv'), index=False, header=['param_loss'])
pd.DataFrame(subregion_losses).to_csv(os.path.join(data_dir, 'outcome_subregion_loss.csv'), index=False, header=['subregion_loss'])


# Save loss curve (Region)
plt.figure(figsize=(10, 5))
plt.plot(losses)
plt.title('Loss Curve (Outcome Region)')
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.savefig(os.path.join(data_dir, 'loss_curve_region.jpg'))
plt.savefig(os.path.join(data_dir, 'loss_curve_region.pdf'))
plt.show()
plt.close()

# Save loss curve (Subregion)
plt.figure(figsize=(10, 5))
plt.plot(subregion_losses)
plt.title('Loss Curve (Outcome Subregion)')
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.savefig(os.path.join(data_dir, 'loss_curve_subregion.jpg'))
plt.savefig(os.path.join(data_dir, 'loss_curve_subregion.pdf'))
plt.show()
plt.close()


# Plot parameter logs
plt.figure(figsize=(10, 5))
plt.plot(parameter_logs['shift'], label='Shift Parameter')
plt.plot(parameter_logs['scale'], label='Scale Parameter')
plt.title('Parameter Values During Training')
plt.xlabel('Epoch')
plt.ylabel('Parameter Value')
plt.legend()
plt.savefig(os.path.join(data_dir, 'parameter_values.jpg'))
plt.savefig(os.path.join(data_dir, 'parameter_values.pdf'))
plt.show()
plt.close()

# Plot parameter loss
plt.figure(figsize=(10, 5))
plt.plot(param_losses)
plt.title('Loss w.r.t True Parameters')
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.savefig(os.path.join(data_dir, 'parameter_loss_curve.jpg'))
plt.savefig(os.path.join(data_dir, 'parameter_loss_curve.pdf'))
plt.show()
plt.close()


# Print final estimates
print(f"Final shift_param: {model.shift.item()}")
print(f"Final scale_param: {model.scale.item()}")

# Save params to txt
with open(os.path.join(data_dir, 'parameter_estimate.txt'), 'w') as f:
    f.write(f"shift_param: {model.shift.item()}\n")
    f.write(f"scale_param: {model.scale.item()}\n")

# Save estimated interventions sub (processed)
inter_est_final = model.preprocess_inter().detach().numpy()
pd.DataFrame(inter_est_final).to_csv(os.path.join(data_dir, 'interventions_subregion_estimated.csv'), index=False, header=False)

# Visualize estimated subregion
def visualize_estimated_subregion():
    data = inter_est_final
    grid = np.zeros((grid_size_sub, grid_size_sub))
    for i in range(10):
        for j in range(10):
            region_idx = i * 10 + j
            sub_data = data[region_idx, :].reshape(2, 2)
            grid[2*i:2*(i+1), 2*j:2*(j+1)] = sub_data
    fig, ax = plt.subplots(figsize=(8, 8))
    heatmap_cmap = sns.light_palette(reds, as_cmap=True)
    sns.heatmap(grid, ax=ax, cmap=heatmap_cmap, square=True, cbar=True, cbar_kws={'label': 'Value', 'location': 'right', 'pad': 0.1})
    ax.set_title('Estimated Interventions Subregion', fontsize=18)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.grid(False)
    plt.legend([], [], frameon=False)
    plt.tight_layout()
    plt.savefig(os.path.join(data_dir, 'interventions_subregion_estimated.jpg'), dpi=300)
    plt.savefig(os.path.join(data_dir, 'interventions_subregion_estimated.pdf'))
    plt.show()
    plt.close()

visualize_estimated_subregion()

In [None]:
data_dir = 'data_exp2'
loglog_flag = False

# Load loss history from CSV
loss_hist = pd.read_csv(
    os.path.join(data_dir, 'outcome_region_loss.csv'),
    header=0  # use first row as header, then drop column names
).iloc[:, 0].values

ce_loss_hist = pd.read_csv(
    os.path.join(data_dir, 'outcome_subregion_loss.csv'),
    header=0
).iloc[:, 0].values
epochs = np.arange(1, len(loss_hist) + 1)

# Plotting
fig, ax1 = plt.subplots(figsize=(8, 6))
colors = sns.color_palette("deep")

ax1.set_xlabel('Epoch', fontsize=40)
ax1.set_ylabel('Training Loss', fontsize=30, color=colors[0])
if loglog_flag:
    ax1.loglog(epochs, loss_hist, color=colors[0], alpha=0.9, lw=3)
else:
    ax1.plot(epochs, loss_hist, color=colors[0], alpha=0.7, lw=2)
ax1.tick_params(axis='both', which='both', length=0, labelsize=18, colors=colors[0])
# ax1.set_title('MSE Curves', fontsize=32)
ax1.spines['top'].set_visible(False)

ax2 = ax1.twinx()
ax2.set_ylabel('MSE Local Effect', fontsize=30, color=colors[3])
if loglog_flag:
    ax2.loglog(epochs, ce_loss_hist, color=colors[3], alpha=0.9, ls='--', lw=5)
else:
    ax2.plot(epochs, ce_loss_hist, color=colors[3], alpha=0.7, ls='--', lw=2)
ax2.tick_params(axis='y', which='both', length=0, labelsize=18, colors=colors[3])
ax2.spines['top'].set_visible(False)

plt.tight_layout()
fig.savefig(os.path.join(data_dir, 'estimation_loss.jpg'), dpi=300)
fig.savefig(os.path.join(data_dir, 'exp2_estimation_loss.pdf'))
plt.show()


In [None]:
context_sub_np = context_sub.numpy() if isinstance(context_sub, torch.Tensor) else context_sub
interventions_sub_np = interventions_sub.numpy() if isinstance(interventions_sub, torch.Tensor) else interventions_sub
true_outcome_sub = (shift_param - context_sub_np) * scale_param * interventions_sub_np


# Load estimated interventions
inter_est_final = pd.read_csv(os.path.join(data_dir, 'interventions_subregion_estimated.csv'), header=None).values

# Check the shape of the loaded estimated interventions
if inter_est_final.shape != (num_regions, num_subregions):
    print(f"Warning: Loaded estimated interventions have shape {inter_est_final.shape}, expected ({num_regions}, {num_subregions})")
    pass


# Calculate predicted outcome at subregion level using estimated interventions and learned parameters
with open(os.path.join(data_dir, 'parameter_estimate.txt'), 'r') as f:
    lines = f.readlines()
    estimated_shift = float(lines[0].split(': ')[1])
    estimated_scale = float(lines[1].split(': ')[1])

# Ensure context_sub is a numpy array for the calculation
context_sub_np = context_sub.numpy() if isinstance(context_sub, torch.Tensor) else context_sub

predicted_outcome_sub = (estimated_shift - context_sub_np) * estimated_scale * inter_est_final


# Visualize true vs predicted outcome at subregion level
def visualize_true_vs_predicted_subregion(true_data, predicted_data, num_regions=100, num_subregions=4):
    grid_size_sub_calc = int(np.sqrt(num_regions) * np.sqrt(num_subregions))
    subregions_per_dim = int(np.sqrt(num_subregions))

    true_grid = np.zeros((grid_size_sub_calc, grid_size_sub_calc))
    predicted_grid = np.zeros((grid_size_sub_calc, grid_size_sub_calc))
    difference_grid = np.zeros((grid_size_sub_calc, grid_size_sub_calc))

    for i in range(int(np.sqrt(num_regions))):
        for j in range(int(np.sqrt(num_regions))):
            region_idx = i * int(np.sqrt(num_regions)) + j
            # Ensure sub_data is a numpy array before reshaping
            true_sub_data = np.asarray(true_data[region_idx, :]).reshape(subregions_per_dim, subregions_per_dim)
            predicted_sub_data = np.asarray(predicted_data[region_idx, :]).reshape(subregions_per_dim, subregions_per_dim)
            true_grid[i*subregions_per_dim:(i+1)*subregions_per_dim, j*subregions_per_dim:(j+1)*subregions_per_dim] = true_sub_data
            predicted_grid[i*subregions_per_dim:(i+1)*subregions_per_dim, j*subregions_per_dim:(j+1)*subregions_per_dim] = predicted_sub_data
            difference_grid[i*subregions_per_dim:(i+1)*subregions_per_dim, j*subregions_per_dim:(j+1)*subregions_per_dim] = true_sub_data - predicted_sub_data

    # Color maps from seaborn deep palette (assuming these are available in the notebook's global scope)
    colors = sns.color_palette("deep")
    blue_cmap = sns.light_palette(colors[0], as_cmap=True)  # Blue for Outcome
    coolwarm_cmap = 'coolwarm' # Using coolwarm for difference as not specified


    fig, axs = plt.subplots(1, 3, figsize=(15, 5)) # Adjusted figure size


    # True Outcome (Blue)
    sns.heatmap(true_grid, ax=axs[0], cmap=blue_cmap, square=True, cbar=False, linewidths=1) # Added linewidths
    axs[0].set_xticks([])
    axs[0].set_yticks([])

    # Predicted Outcome (Blue)
    sns.heatmap(predicted_grid, ax=axs[1], cmap=blue_cmap, square=True, cbar=False, linewidths=1) # Added linewidths
    axs[1].set_xticks([])
    axs[1].set_yticks([])

    # Difference (Coolwarm)
    sns.heatmap(difference_grid, ax=axs[2], cmap=coolwarm_cmap, square=True, cbar=False, linewidths=1) # Added linewidths
    axs[2].set_xticks([])
    axs[2].set_yticks([])

    plt.tight_layout()
    plt.savefig(os.path.join(data_dir, 'exp2_true_vs_predicted_outcome_subregion_styled.pdf'), dpi=300)
    plt.savefig(os.path.join(data_dir, 'exp2_true_vs_predicted_outcome_subregion_styled.pdf'))
    plt.show()
    plt.close()

visualize_true_vs_predicted_subregion(true_outcome_sub, predicted_outcome_sub, num_regions, num_subregions)

# Calculate MSE for true vs predicted subregion outcome
mse_subregion = np.mean((np.asarray(true_outcome_sub) - np.asarray(predicted_outcome_sub)) ** 2)
print(f"MSE between true and predicted subregion outcome: {mse_subregion}")

# Exp. 3
Hidden Spatial Confounding: Estimate the effect of heatwaves (regional interventions over time) on school outcomes when both parental education level (observed covariate) and vegetation coverage (unobserved confounder) modulate the effect. The goal is to recover local treatment effects and reconstruct the hidden confounder field from aggregated outcomes.

In [None]:
import os
import numpy as np
import pandas as pd

np.random.seed(42)

num_months = 48
num_regions = 100
num_subregions_per_region = 9
data_folder = 'data_exp3'

# Create directories
os.makedirs(f'{data_folder}/interventions_subregion', exist_ok=True)
os.makedirs(f'{data_folder}/interventions_region', exist_ok=True)
os.makedirs(f'{data_folder}/context_subregion', exist_ok=True)
os.makedirs(f'{data_folder}/noise_subregion', exist_ok=True)
os.makedirs(f'{data_folder}/outcome_subregion', exist_ok=True)
os.makedirs(f'{data_folder}/outcome_region', exist_ok=True)

# 1) Interventions
for month in range(1, num_months + 1):
    interventions_sub = np.zeros((num_regions, num_subregions_per_region))
    for r in range(num_regions):
        val = 1 if np.random.rand() < 0.4 else 0
        interventions_sub[r, :] = val
    pd.DataFrame(interventions_sub).to_csv(
        f'{data_folder}/interventions_subregion/interventions_subregion_{month:02d}.csv', index=False, header=False
    )
    interventions_reg = interventions_sub.mean(axis=1).reshape(-1, 1)
    pd.DataFrame(interventions_reg).to_csv(
        f'{data_folder}/interventions_region/interventions_region_{month:02d}.csv', index=False, header=False
    )

# 2) Context (parents education)
prev_context = None
for month in range(1, num_months + 1):
    if month == 1:
        context = np.random.randint(1, 4, size=(num_regions, num_subregions_per_region))
    else:
        context = prev_context.copy()
        flips = np.random.rand(num_regions, num_subregions_per_region) < 0.05
        context[flips] = np.random.randint(1, 4, size=flips.sum())
    pd.DataFrame(context).to_csv(
        f'{data_folder}/context_subregion/context_subregion_{month:02d}.csv', index=False, header=False
    )
    prev_context = context

# 3) Unobserved confounder (vegetation)
vegetation = np.random.uniform(0, 1, size=(num_regions, num_subregions_per_region))
pd.DataFrame(vegetation).to_csv(
    f'{data_folder}/unobserved_confounder_context_subregion.csv', index=False, header=False
)

# 4) Noise
for month in range(1, num_months + 1):
    noise = np.random.normal(0, np.sqrt(0.02), size=(num_regions, num_subregions_per_region))
    pd.DataFrame(noise).to_csv(
        f'{data_folder}/noise_subregion/noise_subregion_{month:02d}.csv', index=False, header=False
    )

# 5) Outcomes
for month in range(1, num_months + 1):
    interv_sub = pd.read_csv(
        f'{data_folder}/interventions_subregion/interventions_subregion_{month:02d}.csv', header=None
    ).values
    context = pd.read_csv(
        f'{data_folder}/context_subregion/context_subregion_{month:02d}.csv', header=None
    ).values
    poor_indicator = (context == 1).astype(float)
    intermediate_indicator = (context == 2).astype(float)
    rich_indicator = (context == 3).astype(float)
    outcome_sub = ((10 * poor_indicator * interv_sub) + (5 * intermediate_indicator * interv_sub) + (rich_indicator * interv_sub)) * (1 - vegetation)
    pd.DataFrame(outcome_sub).to_csv(
        f'{data_folder}/outcome_subregion/outcome_subregion_{month:02d}.csv', index=False, header=False
    )
    outcome_reg = outcome_sub.mean(axis=1).reshape(-1, 1)
    pd.DataFrame(outcome_reg).to_csv(
        f'{data_folder}/outcome_region/outcome_region_{month:02d}.csv', index=False, header=False
    )

In [None]:
import os
import glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import ListedColormap

# Parameters
data_folder = 'data_exp3'
num_regions = 100
region_grid_size = 10  # 10x10
sub_grid_size = 30  # 30x30
sub_per_region_side = 3  # 3x3 subregions per region

# Color definitions
colors = sns.color_palette("deep")
red = colors[3]
blue = colors[0]
green = colors[2]

# Function to reshape subregion data (100x9) to 30x30 grid
def reshape_to_subgrid(data):
    grid = np.zeros((sub_grid_size, sub_grid_size))
    for ri in range(region_grid_size):
        for rj in range(region_grid_size):
            reg_idx = ri * region_grid_size + rj
            sub_data = data[reg_idx].reshape(sub_per_region_side, sub_per_region_side)
            grid[ri*sub_per_region_side:(ri+1)*sub_per_region_side, rj*sub_per_region_side:(rj+1)*sub_per_region_side] = sub_data
    return grid

# Function to reshape region data (100x1) to 10x10 grid
def reshape_to_reggrid(data):
    return data.reshape(region_grid_size, region_grid_size)

# Base plot function
def base_heatmap(grid, cmap, title, vmin=None, vmax=None, cbar_ticks=None, cbar_labels=None, figsize=(10,10)):
    fig, ax = plt.subplots(figsize=figsize)
    sns.heatmap(grid, ax=ax, cmap=cmap, vmin=vmin, vmax=vmax, square=True, linewidths=0, cbar_kws={'shrink': 0.5})
    ax.set_aspect('equal')
    ax.set_title(title, fontsize=20)
    ax.tick_params(labelsize=14)
    if cbar_ticks is not None and cbar_labels is not None:
        cbar = ax.collections[0].colorbar
        cbar.set_ticks(cbar_ticks)
        cbar.set_ticklabels(cbar_labels)
        cbar.ax.tick_params(labelsize=14)
    plt.tight_layout()
    return fig

# Specific plot functions (one per data type)
def plot_interventions_sub(file_path):
    data = pd.read_csv(file_path, header=None).values  # 100x9
    grid = reshape_to_subgrid(data)
    cmap = sns.light_palette(red, as_cmap=True)
    fig = base_heatmap(grid, cmap, os.path.basename(file_path), vmin=0, vmax=1)
    fig.savefig(f'{file_path[:-4]}.jpg')
    fig.savefig(f'{file_path[:-4]}.pdf')
    plt.show()
    plt.close(fig)

def plot_interventions_reg(file_path):
    data = pd.read_csv(file_path, header=None).values.squeeze()  # 100
    grid = reshape_to_reggrid(data)
    cmap = sns.light_palette(red, as_cmap=True)
    fig = base_heatmap(grid, cmap, os.path.basename(file_path), vmin=0, vmax=1)
    fig.savefig(f'{file_path[:-4]}.jpg')
    fig.savefig(f'{file_path[:-4]}.pdf')
    plt.show()
    plt.close(fig)

def plot_context_sub(file_path):
    data = pd.read_csv(file_path, header=None).values  # 100x9
    grid = reshape_to_subgrid(data)
    cmap = ListedColormap(sns.color_palette("Spectral", 3))
    fig = base_heatmap(grid, cmap, os.path.basename(file_path), vmin=1, vmax=3,
                       cbar_ticks=[1,2,3], cbar_labels=['poor', 'intermediate', 'rich'])
    fig.savefig(f'{file_path[:-4]}.jpg')
    fig.savefig(f'{file_path[:-4]}.pdf')
    plt.show()
    plt.close(fig)

def plot_confounder(file_path):
    data = pd.read_csv(file_path, header=None).values  # 100x9
    grid = reshape_to_subgrid(data)
    cmap = sns.light_palette(green, as_cmap=True)
    fig = base_heatmap(grid, cmap, os.path.basename(file_path), vmin=0, vmax=1)
    fig.savefig(f'{file_path[:-4]}.jpg')
    fig.savefig(f'{file_path[:-4]}.pdf')
    plt.show()
    plt.close(fig)

def plot_outcome_sub(file_path):
    data = pd.read_csv(file_path, header=None).values  # 100x9
    grid = reshape_to_subgrid(data)
    cmap = sns.light_palette(blue, as_cmap=True)
    fig = base_heatmap(grid, cmap, os.path.basename(file_path), vmin=0, vmax=10)
    fig.savefig(f'{file_path[:-4]}.jpg')
    fig.savefig(f'{file_path[:-4]}.pdf')
    plt.show()
    plt.close(fig)

def plot_outcome_reg(file_path):
    data = pd.read_csv(file_path, header=None).values.squeeze()  # 100
    grid = reshape_to_reggrid(data)
    cmap = sns.light_palette(blue, as_cmap=True)
    fig = base_heatmap(grid, cmap, os.path.basename(file_path), vmin=0, vmax=10)
    fig.savefig(f'{file_path[:-4]}.jpg')
    fig.savefig(f'{file_path[:-4]}.pdf')
    plt.show()
    plt.close(fig)

def plot_noise_sub(file_path):
    data = pd.read_csv(file_path, header=None).values  # 100x9
    grid = reshape_to_subgrid(data)
    cmap = sns.light_palette(blue, as_cmap=True)
    fig = base_heatmap(grid, cmap, os.path.basename(file_path))
    fig.savefig(f'{file_path[:-4]}.jpg')
    fig.savefig(f'{file_path[:-4]}.pdf')
    plt.show()
    plt.close(fig)

# Unobserved confounder (single file)
confounder_file = f'{data_folder}/unobserved_confounder_context_subregion.csv'
if os.path.exists(confounder_file):
    plot_confounder(confounder_file)

# Independent plotting blocks (one per data type, no shared loop or conditions)
# Interventions subregion
for file in sorted(glob.glob(f'{data_folder}/interventions_subregion/*.csv')):
    plot_interventions_sub(file)

# Interventions region
for file in sorted(glob.glob(f'{data_folder}/interventions_region/*.csv')):
    plot_interventions_reg(file)

# Context subregion
for file in sorted(glob.glob(f'{data_folder}/context_subregion/*.csv')):
    plot_context_sub(file)

# Noise subregion
for file in sorted(glob.glob(f'{data_folder}/noise_subregion/*.csv')):
    plot_noise_sub(file)

# Outcome subregion
for file in sorted(glob.glob(f'{data_folder}/outcome_subregion/*.csv')):
    plot_outcome_sub(file)

# Outcome region
for file in sorted(glob.glob(f'{data_folder}/outcome_region/*.csv')):
    plot_outcome_reg(file)

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Parameters
data_folder = 'data_exp3'
region_grid_size = 10  # 10x10
sub_grid_size = 30  # 30x30
sub_per_region_side = 3  # 3x3 subregions per region

# Color definitions
colors = sns.color_palette("deep")
red_cmap = sns.light_palette(colors[3], as_cmap=True)  # Red
blue_cmap = sns.light_palette(colors[0], as_cmap=True)  # Blue
green_cmap = sns.light_palette(colors[2], as_cmap=True)  # Green

# Function to reshape subregion data (100x9) to 30x30 grid
def reshape_to_subgrid(data):
    grid = np.zeros((sub_grid_size, sub_grid_size))
    for ri in range(region_grid_size):
        for rj in range(region_grid_size):
            reg_idx = ri * region_grid_size + rj
            sub_data = data[reg_idx].reshape(sub_per_region_side, sub_per_region_side)
            grid[ri*sub_per_region_side:(ri+1)*sub_per_region_side, rj*sub_per_region_side:(rj+1)*sub_per_region_side] = sub_data
    return grid

# Function to reshape region data (100x1) to 10x10 grid
def reshape_to_reggrid(data):
    return data.reshape(region_grid_size, region_grid_size)

# Parameter to control space between subplots (as fraction of average axis width)
space_between = 0.05
# Day parameter
day = 1

# Load data for combined
interv_file = os.path.join(data_folder, 'interventions_region', f'interventions_region_{day:02d}.csv')
interv_data = pd.read_csv(interv_file, header=None).values.squeeze()
interv_grid = reshape_to_reggrid(interv_data.astype(float))  # Assuming 0/1, float for heatmap

outcome_file = os.path.join(data_folder, 'outcome_region', f'outcome_region_{day:02d}.csv')
outcome_data = pd.read_csv(outcome_file, header=None).values.squeeze()
outcome_grid = reshape_to_reggrid(outcome_data) #/ 10.0  # Assuming 0-10, normalize to 0-1

context_file = os.path.join(data_folder, 'context_subregion', f'context_subregion_{day:02d}.csv')
context_data = pd.read_csv(context_file, header=None).values
context_grid = reshape_to_subgrid(context_data)
context_norm = (context_grid - 1) / 2.0  # Normalize assuming values 1,2,3 to 0-1
context_norm = (context_norm + 0.2) / 1.2

# Create a single figure with three subplots side by side
fig, axs = plt.subplots(1, 3, figsize=(15, 5))

# Left: Interventions in red
sns.heatmap(interv_grid, ax=axs[0], cmap=red_cmap, cbar=False, square=True, linewidths=2, vmin=0.0, vmax=1.0)
axs[0].set_xticks([])
axs[0].set_yticks([])


# Center: Outcomes in blue
sns.heatmap(outcome_grid, ax=axs[1], cmap=blue_cmap, cbar=False, square=True, linewidths=2)#, vmin=0.0, vmax=1.0)
axs[1].set_xticks([])
axs[1].set_yticks([])


# Right: Context in green
sns.heatmap(context_norm, ax=axs[2], cmap=green_cmap, cbar=False, square=True, linewidths=1.0, vmin=0.0, vmax=1.0)
axs[2].set_xticks([])
axs[2].set_yticks([])


# Adjust the space between subplots
fig.subplots_adjust(wspace=space_between)

# Save and show
fig.savefig(os.path.join(data_folder, f'combined_heatmaps_day{day}.jpg'), dpi=300, bbox_inches='tight', pad_inches=0)
fig.savefig(os.path.join(data_folder, f'combined_heatmaps_day{day}.pdf'), bbox_inches='tight', pad_inches=0)
plt.show()

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from matplotlib.colors import ListedColormap

np.random.seed(42)
torch.manual_seed(42)

# Parameters
data_folder = 'data_exp3'
num_months = 48
num_regions = 100
num_subregions_per_region = 9
sub_grid_size = 30
region_grid_size = 10
sub_per_region_side = 3
learning_rate = 0.0001
epochs = 10000

# Load data
interv_reg_list = [
    pd.read_csv(f'{data_folder}/interventions_region/interventions_region_{month:02d}.csv', header=None).values.squeeze()
    for month in range(1, num_months + 1)
]
context_list = [
    pd.read_csv(f'{data_folder}/context_subregion/context_subregion_{month:02d}.csv', header=None).values
    for month in range(1, num_months + 1)
]
outcome_reg_list = [
    pd.read_csv(f'{data_folder}/outcome_region/outcome_region_{month:02d}.csv', header=None).values.squeeze()
    for month in range(1, num_months + 1)
]

gt_veg = pd.read_csv(f'{data_folder}/unobserved_confounder_context_subregion.csv', header=None).values

# Model
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(5, 32),
            nn.ReLU(),
            nn.Linear(32, 32),
            nn.Dropout(p=0.1),
            nn.ReLU(),
            nn.Linear(32, 8),
            nn.ReLU(),
            nn.Linear(8, 1)
        )
        self.veg = nn.Parameter(torch.randn(num_regions, num_subregions_per_region))

# Training
model = Model()
optimizer = Adam(model.parameters(), lr=learning_rate)
losses = []
veg_mses = []

for epoch in range(epochs):
    perm = np.random.permutation(num_months)
    epoch_loss = 0.0
    for i in perm:
        parent_edu = torch.tensor(context_list[i], dtype=torch.long)
        interv = torch.tensor(interv_reg_list[i], dtype=torch.float32)
        gt = torch.tensor(outcome_reg_list[i], dtype=torch.float32)

        onehot = F.one_hot(parent_edu - 1, num_classes=3).float()  # 100x9x3
        interv_sub = interv.unsqueeze(1).repeat(1, num_subregions_per_region)  # 100x9
        inputs = torch.cat([onehot, interv_sub.unsqueeze(-1), torch.sigmoid(model.veg).unsqueeze(-1)], dim=-1)  # 100x9x5
        flat = inputs.view(-1, 5)  # 900x5
        pred_sub = model.mlp(flat).view(num_regions, num_subregions_per_region)  # 100x9
        pred_reg = pred_sub.mean(dim=1)  # 100

        loss = F.mse_loss(pred_reg, gt)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    losses.append(epoch_loss / num_months)
    veg_est = torch.sigmoid(model.veg).detach().numpy()
    veg_mse = np.mean((gt_veg - veg_est)**2)
    veg_mses.append(veg_mse)
    if (epoch + 1) % (epochs // 100) == 0:
        print(f"Epoch {epoch + 1}/{epochs}: Avg Loss = {losses[-1]:.6f}, Veg MSE = {veg_mse:.6f}")
    row = pd.DataFrame({'epoch': [epoch+1], 'loss': [losses[-1]], 'veg_mse': [veg_mses[-1]]})
    if epoch == 0:
        row.to_csv(f'{data_folder}/training_curves.csv', index=False)
    else:
        row.to_csv(f'{data_folder}/training_curves.csv', mode='a', header=False, index=False)

# Save loss curve
plt.figure(figsize=(10, 5))
plt.plot(losses)
plt.title('Loss Curve')
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.savefig(f'{data_folder}/loss_curve.jpg')
plt.show()

plt.figure(figsize=(10, 5))
plt.plot(veg_mses)
plt.title('Vegetation MSE Curve')
plt.xlabel('Epoch')
plt.ylabel('MSE')
plt.savefig(f'{data_folder}/veg_mse_curve.jpg')
plt.show()

# Save model weights
torch.save(model.state_dict(), f'{data_folder}/model.pth')

# Compute and save estimated subregion outcomes (no noise)
os.makedirs(f'{data_folder}/outcome_subregion_estimated', exist_ok=True)
estimated_sub_outcomes = []
with torch.no_grad():
    for month in range(num_months):
        parent_edu = torch.tensor(context_list[month], dtype=torch.long)
        interv = torch.tensor(interv_reg_list[month], dtype=torch.float32)
        onehot = F.one_hot(parent_edu - 1, num_classes=3).float()
        interv_sub = interv.unsqueeze(1).repeat(1, num_subregions_per_region)
        inputs = torch.cat([onehot, interv_sub.unsqueeze(-1), torch.sigmoid(model.veg).unsqueeze(-1)], dim=-1)
        flat = inputs.view(-1, 5)
        pred_sub = model.mlp(flat).view(num_regions, num_subregions_per_region).numpy()
        estimated_sub_outcomes.append(pred_sub)
        pd.DataFrame(pred_sub).to_csv(
            f'{data_folder}/outcome_subregion_estimated/outcome_subregion_estimated_{month+1:02d}.csv',
            index=False, header=False
        )

# Save estimated vegetation
estimated_veg = torch.sigmoid(model.veg).detach().numpy()
pd.DataFrame(estimated_veg).to_csv(
    f'{data_folder}/unobserved_confounder_context_subregion_estimated.csv', index=False, header=False
)

# Visualization functions (repeated for independence)
colors = sns.color_palette("deep")
blue = colors[0]
green = colors[2]

def reshape_to_subgrid(data):
    grid = np.zeros((sub_grid_size, sub_grid_size))
    for ri in range(region_grid_size):
        for rj in range(region_grid_size):
            reg_idx = ri * region_grid_size + rj
            sub_data = data[reg_idx].reshape(sub_per_region_side, sub_per_region_side)
            grid[ri*sub_per_region_side:(ri+1)*sub_per_region_side, rj*sub_per_region_side:(rj+1)*sub_per_region_side] = sub_data
    return grid

def base_heatmap(grid, cmap, title, vmin=None, vmax=None, figsize=(10,10)):
    fig, ax = plt.subplots(figsize=figsize)
    sns.heatmap(grid, ax=ax, cmap=cmap, vmin=vmin, vmax=vmax, square=True, linewidths=0, cbar_kws={'shrink': 0.5})
    ax.set_aspect('equal')
    ax.set_title(title, fontsize=20)
    ax.tick_params(labelsize=14)
    plt.tight_layout()
    return fig

def plot_outcome_sub_est(file_path):
    data = pd.read_csv(file_path, header=None).values  # 100x9
    grid = reshape_to_subgrid(data)
    cmap = sns.light_palette(blue, as_cmap=True)
    fig = base_heatmap(grid, cmap, os.path.basename(file_path), vmin=0, vmax=10)
    fig.savefig(f'{file_path[:-4]}.jpg')
    fig.savefig(f'{file_path[:-4]}.pdf')
    plt.show()
    plt.close(fig)

def plot_confounder_est(file_path):
    data = pd.read_csv(file_path, header=None).values  # 100x9
    grid = reshape_to_subgrid(data)
    cmap = sns.light_palette(green, as_cmap=True)
    fig = base_heatmap(grid, cmap, os.path.basename(file_path), vmin=0, vmax=1)
    fig.savefig(f'{file_path[:-4]}.jpg')
    fig.savefig(f'{file_path[:-4]}.pdf')
    plt.show()
    plt.close(fig)

# Visualize estimated files
num_months_new = 3
est_sub_files = [f'{data_folder}/outcome_subregion_estimated/outcome_subregion_estimated_{month:02d}.csv' for month in range(1, num_months_new + 1)]
for file in est_sub_files:
    plot_outcome_sub_est(file)
plot_confounder_est(f'{data_folder}/unobserved_confounder_context_subregion_estimated.csv')

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from matplotlib.colors import ListedColormap

np.random.seed(42)
torch.manual_seed(42)

# Parameters
data_folder = 'data_exp3'
num_months = 48
num_regions = 100
num_subregions_per_region = 9
sub_grid_size = 30
region_grid_size = 10
sub_per_region_side = 3
learning_rate = 0.0001
epochs = 10000

# Load data
interv_reg_list = [
    pd.read_csv(f'{data_folder}/interventions_region/interventions_region_{month:02d}.csv', header=None).values.squeeze()
    for month in range(1, num_months + 1)
]
context_list = [
    pd.read_csv(f'{data_folder}/context_subregion/context_subregion_{month:02d}.csv', header=None).values
    for month in range(1, num_months + 1)
]
outcome_reg_list = [
    pd.read_csv(f'{data_folder}/outcome_region/outcome_region_{month:02d}.csv', header=None).values.squeeze()
    for month in range(1, num_months + 1)
]

gt_veg = pd.read_csv(f'{data_folder}/unobserved_confounder_context_subregion.csv', header=None).values

# Model
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(4, 32),
            nn.ReLU(),
            nn.Linear(32, 32),
            nn.Dropout(p=0.1),
            nn.ReLU(),
            nn.Linear(32, 8),
            nn.ReLU(),
            nn.Linear(8, 1)
        )
        self.veg = nn.Parameter(torch.randn(num_regions, num_subregions_per_region))

# Training
model = Model()
optimizer = Adam(model.parameters(), lr=learning_rate)
losses = []
veg_mses = []

for epoch in range(epochs):
    perm = np.random.permutation(num_months)
    epoch_loss = 0.0
    for i in perm:
        parent_edu = torch.tensor(context_list[i], dtype=torch.long)
        interv = torch.tensor(interv_reg_list[i], dtype=torch.float32)
        gt = torch.tensor(outcome_reg_list[i], dtype=torch.float32)

        onehot = F.one_hot(parent_edu - 1, num_classes=3).float()  # 100x9x3
        interv_sub = interv.unsqueeze(1).repeat(1, num_subregions_per_region)  # 100x9
        inputs = torch.cat([onehot, interv_sub.unsqueeze(-1)], dim=-1)  # 100x9x4
        flat = inputs.view(-1, 4)  # 900x4
        pred_sub_base = model.mlp(flat).view(num_regions, num_subregions_per_region)  # 100x9
        pred_sub = pred_sub_base * (1 - torch.sigmoid(model.veg))
        pred_reg = pred_sub.mean(dim=1)  # 100

        loss = F.mse_loss(pred_reg, gt)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    losses.append(epoch_loss / num_months)
    veg_est = torch.sigmoid(model.veg).detach().numpy()
    veg_mse = np.mean((gt_veg - veg_est)**2)
    veg_mses.append(veg_mse)
    if (epoch + 1) % (epochs // 100) == 0:
        print(f"Epoch {epoch + 1}/{epochs}: Avg Loss = {losses[-1]:.6f}, Veg MSE = {veg_mse:.6f}")
    row = pd.DataFrame({'epoch': [epoch+1], 'loss': [losses[-1]], 'veg_mse': [veg_mses[-1]]})
    if epoch == 0:
        row.to_csv(f'{data_folder}/training_curves_alt.csv', index=False)
    else:
        row.to_csv(f'{data_folder}/training_curves_alt.csv', mode='a', header=False, index=False)

print("last epoch loss:", epoch_loss)
# Save loss curve
plt.figure(figsize=(10, 5))
plt.plot(losses)
plt.title('Loss Curve')
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.savefig(f'{data_folder}/loss_curve_alt.jpg')
plt.show()

plt.figure(figsize=(10, 5))
plt.plot(veg_mses)
plt.title('Vegetation MSE Curve')
plt.xlabel('Epoch')
plt.ylabel('MSE')
plt.savefig(f'{data_folder}/veg_mse_curve_alt.jpg')
plt.show()

# Save model weights
torch.save(model.state_dict(), f'{data_folder}/model_alt.pth')

# Compute and save estimated subregion outcomes (no noise)
os.makedirs(f'{data_folder}/outcome_subregion_estimated_alt', exist_ok=True)
estimated_sub_outcomes = []
with torch.no_grad():
    for month in range(num_months):
        parent_edu = torch.tensor(context_list[month], dtype=torch.long)
        interv = torch.tensor(interv_reg_list[month], dtype=torch.float32)
        onehot = F.one_hot(parent_edu - 1, num_classes=3).float()
        interv_sub = interv.unsqueeze(1).repeat(1, num_subregions_per_region)
        inputs = torch.cat([onehot, interv_sub.unsqueeze(-1)], dim=-1)
        flat = inputs.view(-1, 4)
        pred_sub_base = model.mlp(flat).view(num_regions, num_subregions_per_region)
        pred_sub = pred_sub_base * (1 - torch.sigmoid(model.veg))
        pred_sub = pred_sub.numpy()
        estimated_sub_outcomes.append(pred_sub)
        pd.DataFrame(pred_sub).to_csv(
            f'{data_folder}/outcome_subregion_estimated_alt/outcome_subregion_estimated_{month+1:02d}.csv',
            index=False, header=False
        )

# Save estimated vegetation
estimated_veg = torch.sigmoid(model.veg).detach().numpy()
pd.DataFrame(estimated_veg).to_csv(
    f'{data_folder}/unobserved_confounder_context_subregion_estimated_alt.csv', index=False, header=False
)

# Visualization functions (repeated for independence)
colors = sns.color_palette("deep")
blue = colors[0]
green = colors[2]

def reshape_to_subgrid(data):
    grid = np.zeros((sub_grid_size, sub_grid_size))
    for ri in range(region_grid_size):
        for rj in range(region_grid_size):
            reg_idx = ri * region_grid_size + rj
            sub_data = data[reg_idx].reshape(sub_per_region_side, sub_per_region_side)
            grid[ri*sub_per_region_side:(ri+1)*sub_per_region_side, rj*sub_per_region_side:(rj+1)*sub_per_region_side] = sub_data
    return grid

def base_heatmap(grid, cmap, title, vmin=None, vmax=None, figsize=(10,10)):
    fig, ax = plt.subplots(figsize=figsize)
    sns.heatmap(grid, ax=ax, cmap=cmap, vmin=vmin, vmax=vmax, square=True, linewidths=0, cbar_kws={'shrink': 0.5})
    ax.set_aspect('equal')
    ax.set_title(title, fontsize=20)
    ax.tick_params(labelsize=14)
    plt.tight_layout()
    return fig

def plot_outcome_sub_est(file_path):
    data = pd.read_csv(file_path, header=None).values  # 100x9
    grid = reshape_to_subgrid(data)
    cmap = sns.light_palette(blue, as_cmap=True)
    fig = base_heatmap(grid, cmap, os.path.basename(file_path), vmin=0, vmax=10)
    fig.savefig(f'{file_path[:-4]}.jpg')
    fig.savefig(f'{file_path[:-4]}.pdf')
    plt.show()
    plt.close(fig)

def plot_confounder_est(file_path):
    data = pd.read_csv(file_path, header=None).values  # 100x9
    grid = reshape_to_subgrid(data)
    cmap = sns.light_palette(green, as_cmap=True)
    fig = base_heatmap(grid, cmap, os.path.basename(file_path), vmin=0, vmax=1)
    fig.savefig(f'{file_path[:-4]}.jpg')
    fig.savefig(f'{file_path[:-4]}.pdf')
    plt.show()
    plt.close(fig)

# Visualize estimated files
num_months_new = 3
est_sub_files = [f'{data_folder}/outcome_subregion_estimated_alt/outcome_subregion_estimated_{month:02d}.csv' for week in range(1, num_months_new + 1)]
for file in est_sub_files:
    plot_outcome_sub_est(file)
plot_confounder_est(f'{data_folder}/unobserved_confounder_context_subregion_estimated_alt.csv')

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Toggles for log scales
log_x = True
log_y = True

xlim = None
ylim = None

smoothing = 0

plt.rcParams.update({
    'font.size': 24,  # Very large text
    'axes.titlesize': 28,
    'axes.labelsize': 26,
    'legend.fontsize': 24,
    'xtick.labelsize': 22,
    'ytick.labelsize': 22,
    'lines.linewidth': 3,
    'figure.figsize': (12, 8),
    'axes.grid': False,  # No grid clutter
    'axes.spines.top': False,
    'axes.spines.right': False,
})

# Load data
df_regular = pd.read_csv('data_exp3/training_curves.csv')

# Apply smoothing to losses if applicable
if smoothing > 0:
    regular_loss = df_regular['loss'].rolling(smoothing, min_periods=1).mean()
else:
    regular_loss = df_regular['loss']

# Colors from sns deep palette
colors = sns.color_palette("deep")
blue = colors[0]
red = colors[3]


# Linear case plot with twin axes
fig, ax1 = plt.subplots()
ax1.plot(df_regular['epoch'], regular_loss, color=blue, linestyle='-', label='Loss', alpha=0.9, lw=3)
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Training Loss', color=blue)
ax1.tick_params(axis='both', which='both', length=0, labelsize=22, colors=blue)
ax1.spines['top'].set_visible(False)
ax1.spines['right'].set_visible(False)

ax2 = ax1.twinx()
ax2.plot(df_regular['epoch'], df_regular['veg_mse'], color=red, linestyle='--', label='Veg MSE', lw=5, alpha=0.9)
ax2.set_ylabel('MSE Hidden Confounder', color=red)
ax2.tick_params(axis='y', which='both', length=0, labelsize=22, colors=red)
ax2.spines['top'].set_visible(False)
ax2.spines['right'].set_visible(True)
ax2.yaxis.set_major_locator(plt.MaxNLocator(nbins=5))

if log_x:
    ax1.set_xscale('log')
if log_y:
    ax1.set_yscale('log')
    ax2.set_yscale('log')
if xlim is not None:
    ax1.set_xlim(xlim)
if ylim is not None:
    ax1.set_ylim(ylim)
    ax2.set_ylim(ylim)

plt.tight_layout()
fig.savefig('data_exp3/linear_plot.jpg', dpi=300)
fig.savefig('data_exp3/linear_plot.pdf')
plt.show()

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Toggles for log scales
log_x = True
log_y = True

# Custom ranges (set to None for automatic)
xlim = None  # Example: (1, 10000)
ylim = None  # Example: (1e-5, 1)

# Smoothing parameter for loss (window size, 0 = no smoothing)
smoothing = 0

# Set publication-ready style
plt.rcParams.update({
    'font.size': 24,  # Very large text
    'axes.titlesize': 28,
    'axes.labelsize': 26,
    'legend.fontsize': 24,
    'xtick.labelsize': 22,
    'ytick.labelsize': 22,
    'lines.linewidth': 3,
    'figure.figsize': (12, 8),
    'axes.grid': False,  # No grid clutter
    'axes.spines.top': False,
    'axes.spines.right': False,
})

# Load data
df_regular = pd.read_csv('data_exp3/training_curves_alt.csv')

# Apply smoothing to losses if applicable
if smoothing > 0:
    regular_loss = df_regular['loss'].rolling(smoothing, min_periods=1).mean()
else:
    regular_loss = df_regular['loss']

# Colors from sns deep palette
colors = sns.color_palette("deep")
blue = colors[0]
red = colors[3]

#  #alpha=0.9, ls='--', lw=5), alpha=0.7, ls='-', lw=2)

# Linear case plot with twin axes
fig, ax1 = plt.subplots()
ax1.plot(df_regular['epoch'], regular_loss, color=blue, linestyle='-', label='Loss', alpha=0.9, lw=3)
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Training Loss', color=blue)
ax1.tick_params(axis='both', which='both', length=0, labelsize=22, colors=blue)
ax1.spines['top'].set_visible(False)
ax1.spines['right'].set_visible(False)

ax2 = ax1.twinx()
ax2.plot(df_regular['epoch'], df_regular['veg_mse'], color=red, linestyle='--', label='Veg MSE', lw=5, alpha=0.9)
ax2.set_ylabel('MSE Hidden Confounder', color=red)
ax2.tick_params(axis='y', which='both', length=0, labelsize=22, colors=red)
ax2.spines['top'].set_visible(False)
ax2.spines['right'].set_visible(True)
ax2.yaxis.set_major_locator(plt.MaxNLocator(nbins=5))

if log_x:
    ax1.set_xscale('log')
if log_y:
    ax1.set_yscale('log')
    ax2.set_yscale('log')
if xlim is not None:
    ax1.set_xlim(xlim)
if ylim is not None:
    ax1.set_ylim(ylim)
    ax2.set_ylim(ylim)

plt.tight_layout()
fig.savefig('data_exp3/alt_plot.jpg', dpi=300)
fig.savefig('data_exp3/alt_plot.pdf')
plt.show()

In [None]:
import matplotlib as mpl

# Load true and estimated confounder data
gt_veg = pd.read_csv(f'{data_folder}/unobserved_confounder_context_subregion.csv', header=None).values
estimated_veg = pd.read_csv(f'{data_folder}/unobserved_confounder_context_subregion_estimated_alt.csv', header=None).values

# Reshape data for plotting
gt_veg_grid = reshape_to_subgrid(gt_veg)
estimated_veg_grid = reshape_to_subgrid(estimated_veg)

# Define shared colormap and range for confounder plots
green_cmap = sns.light_palette(colors[2], as_cmap=True)
vmin_veg = np.min([gt_veg_grid, estimated_veg_grid])
vmax_veg = np.max([gt_veg_grid, estimated_veg_grid])

# Create figure and axes with space for the colorbar for confounder plots
fig_veg, axs_veg = plt.subplots(1, 2, figsize=(12, 6), gridspec_kw={'width_ratios': [1, 1], 'wspace': 0.05})

# Plot true confounder
sns.heatmap(gt_veg_grid, ax=axs_veg[0], cmap=green_cmap, cbar=False, square=True,
            vmin=vmin_veg, vmax=vmax_veg, linewidths=0)
# axs_veg[0].set_title("True Hidden Confounder (Vegetation)")
axs_veg[0].set_xticks([])
axs_veg[0].set_yticks([])

# Plot estimated confounder
sns.heatmap(estimated_veg_grid, ax=axs_veg[1], cmap=green_cmap, cbar=False, square=True,
            vmin=vmin_veg, vmax=vmax_veg, linewidths=0)
# axs_veg[1].set_title("Estimated Hidden Confounder (Vegetation)")
axs_veg[1].set_xticks([])
axs_veg[1].set_yticks([])

# Add a single shared vertical colorbar to the right for confounder plots
norm_veg = mpl.colors.Normalize(vmin=vmin_veg, vmax=vmax_veg)
sm_veg = mpl.cm.ScalarMappable(cmap=green_cmap, norm=norm_veg)
sm_veg.set_array([])

cbar_veg = fig_veg.colorbar(sm_veg, ax=axs_veg, orientation='vertical', fraction=0.046, pad=0.04)
cbar_veg.set_label("Vegetation Value")

# Save and show confounder plots
fig_veg.savefig(f'{data_folder}/vegetation_heatmaps.jpg', dpi=300, bbox_inches='tight', pad_inches=0)
fig_veg.savefig(f'{data_folder}/vegetation_heatmaps_alt.pdf', bbox_inches='tight', pad_inches=0)
plt.show()

print("✓ Hidden confounder heatmaps generated.")

# Calculate and visualize the error between true and estimated confounder
vegetation_error = np.abs(gt_veg_grid - estimated_veg_grid)
fig_veg_error, ax_veg_error = plt.subplots(figsize=(6, 6))
error_cmap = sns.light_palette("red", as_cmap=True) # Use a different color for error
sns.heatmap(vegetation_error, ax=ax_veg_error, cmap=error_cmap, square=True, linewidths=0)
# ax_veg_error.set_title("Absolute Error in Vegetation Estimation")
ax_veg_error.set_xticks([])
ax_veg_error.set_yticks([])
plt.tight_layout()
fig_veg_error.savefig(f'{data_folder}/vegetation_error_heatmap.jpg', dpi=300, bbox_inches='tight', pad_inches=0)
fig_veg_error.savefig(f'{data_folder}/vegetation_error_heatmap_alt.pdf', bbox_inches='tight', pad_inches=0)
plt.show()

print("✓ Vegetation error heatmap generated.")


# Calculate true causal effect
context_week1 = pd.read_csv(f'{data_folder}/context_subregion/context_subregion_01.csv', header=None).values
poor_indicator_week1 = (context_week1 == 1).astype(float)
intermediate_indicator_week1 = (context_week1 == 2).astype(float)
rich_indicator_week1 = (context_week1 == 3).astype(float)
true_causal_effect = ((10 * poor_indicator_week1) + (5 * intermediate_indicator_week1) + (rich_indicator_week1)) * (1 - gt_veg)


# Calculate estimated causal effect using the trained model
# We need to load the trained model weights
model_alt = Model() # Assuming the alternative model is used for causal effect estimation based on previous plots
model_alt.load_state_dict(torch.load(f'{data_folder}/model_alt.pth'))
model_alt.eval() # Set model to evaluation mode

estimated_causal_effect = np.zeros((num_regions, num_subregions_per_region))
with torch.no_grad():
    # Assuming we use the context from the first week (or an average/representative week)
    wealth_week1 = torch.tensor(context_week1, dtype=torch.long)
    onehot_week1 = F.one_hot(wealth_week1 - 1, num_classes=3).float() # 100x9x3
    estimated_veg_tensor = torch.sigmoid(model_alt.veg) # 100x9

    # Predict outcome when intervention is 1
    interv_sub_1 = torch.ones(num_regions, num_subregions_per_region) # 100x9
    inputs_1 = torch.cat([onehot_week1, interv_sub_1.unsqueeze(-1)], dim=-1) # 100x9x4
    flat_1 = inputs_1.view(-1, 4) # 900x4
    pred_sub_base_1 = model_alt.mlp(flat_1).view(num_regions, num_subregions_per_region) # 100x9
    pred_sub_1 = pred_sub_base_1 * (1 - estimated_veg_tensor)

    # Predict outcome when intervention is 0
    interv_sub_0 = torch.zeros(num_regions, num_subregions_per_region) # 100x9
    inputs_0 = torch.cat([onehot_week1, interv_sub_0.unsqueeze(-1)], dim=-1) # 100x9x4
    flat_0 = inputs_0.view(-1, 4) # 900x4
    pred_sub_base_0 = model_alt.mlp(flat_0).view(num_regions, num_subregions_per_region) # 100x9
    pred_sub_0 = pred_sub_base_0 * (1 - estimated_veg_tensor)

    # Estimated causal effect = Predicted outcome with intervention 1 - Predicted outcome with intervention 0
    estimated_causal_effect = (pred_sub_1 - pred_sub_0).numpy()


# Reshape causal effect data for plotting
causal_effect_true_grid = reshape_to_subgrid(true_causal_effect)
estimated_causal_effect_grid = reshape_to_subgrid(estimated_causal_effect)


# Define shared colormap and range for causal effect
colors = sns.color_palette("deep")
pink_cmap = sns.light_palette(colors[4], as_cmap=True)
vmin = np.min([causal_effect_true_grid, estimated_causal_effect_grid])
vmax = np.max([causal_effect_true_grid, estimated_causal_effect_grid])

# Create figure and axes with space for the colorbar
fig, axs = plt.subplots(1, 2, figsize=(12, 6), gridspec_kw={'width_ratios': [1, 1], 'wspace': 0.05})

# Plot true causal effect
sns.heatmap(causal_effect_true_grid, ax=axs[0], cmap=pink_cmap, cbar=False, square=True,
            vmin=vmin, vmax=vmax, linewidths=0)
axs[0].set_title("True Causal Effect (Subregion)")
axs[0].set_xticks([])
axs[0].set_yticks([])

# Plot estimated causal effect
sns.heatmap(estimated_causal_effect_grid, ax=axs[1], cmap=pink_cmap, cbar=False, square=True,
            vmin=vmin, vmax=vmax, linewidths=0)
axs[1].set_title("Estimated Causal Effect (Subregion)")
axs[1].set_xticks([])
axs[1].set_yticks([])

# Add a single shared vertical colorbar to the right
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
sm = mpl.cm.ScalarMappable(cmap=pink_cmap, norm=norm)
sm.set_array([])

# Add colorbar to the right of both plots
cbar = fig.colorbar(sm, ax=axs, orientation='vertical', fraction=0.046, pad=0.04)
cbar.set_label("Causal Effect") # Removed (%) as it's not normalized to percentage

# Save and show
fig.savefig(f'{data_folder}/subregion_causal_effects_heatmaps.jpg', dpi=300, bbox_inches='tight', pad_inches=0)
fig.savefig(f'{data_folder}/subregion_causal_effects_heatmaps.pdf', bbox_inches='tight', pad_inches=0)
plt.show()

print("✓ Subregion causal effect plots generated.")

In [None]:
import matplotlib as mpl

# Load true and estimated confounder data
gt_veg = pd.read_csv(f'{data_folder}/unobserved_confounder_context_subregion.csv', header=None).values
estimated_veg = pd.read_csv(f'{data_folder}/unobserved_confounder_context_subregion_estimated.csv', header=None).values

# Reshape data for plotting
gt_veg_grid = reshape_to_subgrid(gt_veg)
estimated_veg_grid = reshape_to_subgrid(estimated_veg)

# Define shared colormap and range for confounder plots
green_cmap = sns.light_palette(colors[2], as_cmap=True)
vmin_veg = np.min([gt_veg_grid, estimated_veg_grid])
vmax_veg = np.max([gt_veg_grid, estimated_veg_grid])

# Create figure and axes with space for the colorbar for confounder plots
fig_veg, axs_veg = plt.subplots(1, 2, figsize=(12, 6), gridspec_kw={'width_ratios': [1, 1], 'wspace': 0.05})

# Plot true confounder
sns.heatmap(gt_veg_grid, ax=axs_veg[0], cmap=green_cmap, cbar=False, square=True,
            vmin=vmin_veg, vmax=vmax_veg, linewidths=0)
# axs_veg[0].set_title("True Hidden Confounder (Vegetation)")
axs_veg[0].set_xticks([])
axs_veg[0].set_yticks([])

# Plot estimated confounder
sns.heatmap(estimated_veg_grid, ax=axs_veg[1], cmap=green_cmap, cbar=False, square=True,
            vmin=vmin_veg, vmax=vmax_veg, linewidths=0)
# axs_veg[1].set_title("Estimated Hidden Confounder (Vegetation)")
axs_veg[1].set_xticks([])
axs_veg[1].set_yticks([])

# Add a single shared vertical colorbar to the right for confounder plots
norm_veg = mpl.colors.Normalize(vmin=vmin_veg, vmax=vmax_veg)
sm_veg = mpl.cm.ScalarMappable(cmap=green_cmap, norm=norm_veg)
sm_veg.set_array([])

cbar_veg = fig_veg.colorbar(sm_veg, ax=axs_veg, orientation='vertical', fraction=0.046, pad=0.04)
cbar_veg.set_label("Vegetation Value")

# Save and show confounder plots
fig_veg.savefig(f'{data_folder}/vegetation_heatmaps.jpg', dpi=300, bbox_inches='tight', pad_inches=0)
fig_veg.savefig(f'{data_folder}/vegetation_heatmaps.pdf', bbox_inches='tight', pad_inches=0)
plt.show()

print("✓ Hidden confounder heatmaps generated.")

# Calculate and visualize the error between true and estimated confounder
vegetation_error = np.abs(gt_veg_grid - estimated_veg_grid)
fig_veg_error, ax_veg_error = plt.subplots(figsize=(6, 6))
error_cmap = sns.light_palette("red", as_cmap=True) # Use a different color for error
sns.heatmap(vegetation_error, ax=ax_veg_error, cmap=error_cmap, square=True, linewidths=0)
# ax_veg_error.set_title("Absolute Error in Vegetation Estimation")
ax_veg_error.set_xticks([])
ax_veg_error.set_yticks([])
plt.tight_layout()
fig_veg_error.savefig(f'{data_folder}/vegetation_error_heatmap.jpg', dpi=300, bbox_inches='tight', pad_inches=0)
fig_veg_error.savefig(f'{data_folder}/vegetation_error_heatmap.pdf', bbox_inches='tight', pad_inches=0)
plt.show()

print("✓ Vegetation error heatmap generated.")


# Calculate true causal effect
context_week1 = pd.read_csv(f'{data_folder}/context_subregion/context_subregion_01.csv', header=None).values
poor_indicator_week1 = (context_week1 == 1).astype(float)
intermediate_indicator_week1 = (context_week1 == 2).astype(float)
rich_indicator_week1 = (context_week1 == 3).astype(float)
true_causal_effect = ((10 * poor_indicator_week1) + (5 * intermediate_indicator_week1) + (rich_indicator_week1)) * (1 - gt_veg)


# Calculate estimated causal effect using the trained model
model_alt = Model() # Assuming the alternative model is used for causal effect estimation based on previous plots
model_alt.load_state_dict(torch.load(f'{data_folder}/model_alt.pth'))
model_alt.eval() # Set model to evaluation mode

estimated_causal_effect = np.zeros((num_regions, num_subregions_per_region))
with torch.no_grad():
    # Assuming we use the context from the first week (or an average/representative week)
    wealth_week1 = torch.tensor(context_week1, dtype=torch.long)
    onehot_week1 = F.one_hot(wealth_week1 - 1, num_classes=3).float() # 100x9x3
    estimated_veg_tensor = torch.sigmoid(model_alt.veg) # 100x9

    # Predict outcome when intervention is 1
    interv_sub_1 = torch.ones(num_regions, num_subregions_per_region) # 100x9
    inputs_1 = torch.cat([onehot_week1, interv_sub_1.unsqueeze(-1)], dim=-1) # 100x9x4
    flat_1 = inputs_1.view(-1, 4) # 900x4
    pred_sub_base_1 = model_alt.mlp(flat_1).view(num_regions, num_subregions_per_region) # 100x9
    pred_sub_1 = pred_sub_base_1 * (1 - estimated_veg_tensor)

    # Predict outcome when intervention is 0
    interv_sub_0 = torch.zeros(num_regions, num_subregions_per_region) # 100x9
    inputs_0 = torch.cat([onehot_week1, interv_sub_0.unsqueeze(-1)], dim=-1) # 100x9x4
    flat_0 = inputs_0.view(-1, 4) # 900x4
    pred_sub_base_0 = model_alt.mlp(flat_0).view(num_regions, num_subregions_per_region) # 100x9
    pred_sub_0 = pred_sub_base_0 * (1 - estimated_veg_tensor)

    # Estimated causal effect = Predicted outcome with intervention 1 - Predicted outcome with intervention 0
    estimated_causal_effect = (pred_sub_1 - pred_sub_0).numpy()


# Reshape causal effect data for plotting
causal_effect_true_grid = reshape_to_subgrid(true_causal_effect)
estimated_causal_effect_grid = reshape_to_subgrid(estimated_causal_effect)


# Define shared colormap and range for causal effect
colors = sns.color_palette("deep")
pink_cmap = sns.light_palette(colors[4], as_cmap=True)
vmin = np.min([causal_effect_true_grid, estimated_causal_effect_grid])
vmax = np.max([causal_effect_true_grid, estimated_causal_effect_grid])

# Create figure and axes with space for the colorbar
fig, axs = plt.subplots(1, 2, figsize=(12, 6), gridspec_kw={'width_ratios': [1, 1], 'wspace': 0.05})

# Plot true causal effect
sns.heatmap(causal_effect_true_grid, ax=axs[0], cmap=pink_cmap, cbar=False, square=True,
            vmin=vmin, vmax=vmax, linewidths=0)
axs[0].set_title("True Causal Effect (Subregion)")
axs[0].set_xticks([])
axs[0].set_yticks([])

# Plot estimated causal effect
sns.heatmap(estimated_causal_effect_grid, ax=axs[1], cmap=pink_cmap, cbar=False, square=True,
            vmin=vmin, vmax=vmax, linewidths=0)
axs[1].set_title("Estimated Causal Effect (Subregion)")
axs[1].set_xticks([])
axs[1].set_yticks([])

# Add a single shared vertical colorbar to the right
# Use a ScalarMappable for correct scale
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
sm = mpl.cm.ScalarMappable(cmap=pink_cmap, norm=norm)
sm.set_array([])

# Add colorbar to the right of both plots
cbar = fig.colorbar(sm, ax=axs, orientation='vertical', fraction=0.046, pad=0.04)
cbar.set_label("Causal Effect") # Removed (%) as it's not normalized to percentage

# Save and show
fig.savefig(f'{data_folder}/subregion_causal_effects_heatmaps.jpg', dpi=300, bbox_inches='tight', pad_inches=0)
fig.savefig(f'{data_folder}/subregion_causal_effects_heatmaps.pdf', bbox_inches='tight', pad_inches=0)
plt.show()

print("✓ Subregion causal effect plots generated.")

# Exp. 4
Confounded Treatment Allocation: Estimating treatment locations and causal effect when treatment assignment depends on contextual covariates.

##Exp. 4 Low Confounding

In [None]:
#with wealth based intervention allocation
import numpy as np
import pandas as pd
import os

# Define parameters
num_regions = 100
num_subregions = 4
shift_param = 1.23
scale_param = -4.2
noise_var = 0.01
data_dir = 'data_exp4_lc'
os.makedirs(data_dir, exist_ok=True)
np.random.seed(42)

# 1) Context (wealth) subregion
context_sub = np.random.uniform(0, 1, (num_regions, num_subregions))
pd.DataFrame(context_sub).to_csv(os.path.join(data_dir, 'context_subregion.csv'), index=False, header=False)

# 2) Interventions subregion (only one per region, using softmax over context)
def softmax(x, tau=1):
    x = x / tau  # temperature scaling
    e_x = np.exp(x - np.max(x))  # numerical stability
    return e_x / e_x.sum()

interventions_sub = np.zeros((num_regions, num_subregions))
for i in range(num_regions):
    probs = softmax(context_sub[i], tau=1)  # smaller tau = sharper preferences for high context
    col_idx = np.random.choice(np.arange(num_subregions), p=probs)
    interventions_sub[i, col_idx] = np.random.uniform(0.2, 1)
pd.DataFrame(interventions_sub).to_csv(os.path.join(data_dir, 'interventions_subregion.csv'), index=False, header=False)


# Interventions region: sum of subregion interventions
interventions_reg = interventions_sub.sum(axis=1)
pd.DataFrame(interventions_reg).to_csv(os.path.join(data_dir, 'interventions_region.csv'), index=False, header=False)

# 3) Noise subregion
noise_sub = np.random.normal(0, np.sqrt(noise_var), size=(num_regions, num_subregions))
pd.DataFrame(noise_sub).to_csv(os.path.join(data_dir, 'noise_subregion.csv'), index=False, header=False)

# 4) Outcome subregion: context-dependent + noise
outcome_sub = (shift_param - context_sub) * scale_param * interventions_sub + noise_sub
pd.DataFrame(outcome_sub).to_csv(os.path.join(data_dir, 'outcome_subregion.csv'), index=False, header=False)

# Outcome region: average outcome across subregions
outcome_reg = outcome_sub.mean(axis=1)
pd.DataFrame(outcome_reg).to_csv(os.path.join(data_dir, 'outcome_region.csv'), index=False, header=False)

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os

# Define parameters
data_dir = 'data_exp4_lc'
num_regions = 100
num_subregions = 4
grid_size_reg = 10  # 10x10 for regions
grid_size_sub = 20  # 20x20 for subregions (2x2 per region)
sns.set_context('paper', font_scale=1.5)
colors = sns.color_palette("deep")
reds = colors[3]
blues = colors[0]

def visualize_subregion_csv(file_name, cmap, title, is_wealth=False, num_regions=100, num_subregions=4):
    data = pd.read_csv(os.path.join(data_dir, file_name), header=None).values
    # Reshape to subregion grid size: each region row becomes block of subregions
    grid_size_sub_calc = int(np.sqrt(num_regions) * np.sqrt(num_subregions))
    subregions_per_dim = int(np.sqrt(num_subregions))
    grid = np.zeros((grid_size_sub_calc, grid_size_sub_calc))
    for i in range(int(np.sqrt(num_regions))):
        for j in range(int(np.sqrt(num_regions))):
            region_idx = i * int(np.sqrt(num_regions)) + j
            sub_data = data[region_idx, :].reshape(subregions_per_dim, subregions_per_dim)
            grid[i*subregions_per_dim:(i+1)*subregions_per_dim, j*subregions_per_dim:(j+1)*subregions_per_dim] = sub_data
    fig, ax = plt.subplots(figsize=(8, 8))
    if is_wealth:
        heatmap_cmap = sns.color_palette("Spectral", as_cmap=True)
    else:
        if isinstance(cmap, tuple):
            heatmap_cmap = sns.light_palette(cmap, as_cmap=True)
        else:
            heatmap_cmap = cmap
    sns.heatmap(grid, ax=ax, cmap=heatmap_cmap, square=True, cbar=False, cbar_kws={'label': 'Value', 'location': 'right', 'pad': 0.1})
    ax.set_title(title, fontsize=18)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.grid(False)
    plt.legend([], [], frameon=False)  # No legend box
    plt.tight_layout()
    base_name = file_name.replace('.csv', '')
    plt.savefig(os.path.join(data_dir, f'{base_name}.jpg'), dpi=300)
    plt.savefig(os.path.join(data_dir, f'{base_name}.pdf'))
    plt.show()
    plt.close()

def visualize_region_csv(file_name, cmap, title, is_wealth=False, num_regions=100):
    data = pd.read_csv(os.path.join(data_dir, file_name), header=None).values.flatten()
    # Reshape to region grid size
    grid_size_reg_calc = int(np.sqrt(num_regions))
    grid = data.reshape(grid_size_reg_calc, grid_size_reg_calc)
    fig, ax = plt.subplots(figsize=(8, 8))
    if is_wealth:
        heatmap_cmap = sns.color_palette("Spectral", as_cmap=True)
    else:
        if isinstance(cmap, tuple):
            heatmap_cmap = sns.light_palette(cmap, as_cmap=True)
        else:
            heatmap_cmap = cmap
    sns.heatmap(grid, ax=ax, cmap=heatmap_cmap, square=True, cbar=False, cbar_kws={'label': 'Value', 'location': 'right', 'pad': 0.1})
    ax.set_title(title, fontsize=18)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.grid(False)
    plt.legend([], [], frameon=False)
    plt.tight_layout()
    base_name = file_name.replace('.csv', '')
    plt.savefig(os.path.join(data_dir, f'{base_name}.jpg'), dpi=300)
    plt.savefig(os.path.join(data_dir, f'{base_name}.pdf'))
    plt.show()
    plt.close()

# Visualize each
visualize_subregion_csv('interventions_subregion.csv', reds, 'Interventions Subregion', is_wealth=False, num_regions=num_regions, num_subregions=num_subregions)
visualize_region_csv('interventions_region.csv', reds, 'Interventions Region', is_wealth=False, num_regions=num_regions)
visualize_subregion_csv('context_subregion.csv', None, 'Wealth Subregion', is_wealth=True, num_regions=num_regions, num_subregions=num_subregions)
visualize_subregion_csv('noise_subregion.csv', 'coolwarm', 'Noise Subregion', is_wealth=False, num_regions=num_regions, num_subregions=num_subregions)  # Using coolwarm for noise as not specified
visualize_subregion_csv('outcome_subregion.csv', blues, 'Outcome Subregion', is_wealth=False, num_regions=num_regions, num_subregions=num_subregions)
visualize_region_csv('outcome_region.csv', blues, 'Outcome Region', is_wealth=False, num_regions=num_regions)

# Wealth low-res: average context_sub to get wealth_reg
context_sub = pd.read_csv(os.path.join(data_dir, 'context_subregion.csv'), header=None).values
wealth_reg = np.mean(context_sub, axis=1)
pd.DataFrame(wealth_reg).to_csv(os.path.join(data_dir, 'context_region.csv'), index=False, header=False)  # Temp save for viz
visualize_region_csv('context_region.csv', None, 'Wealth Region', is_wealth=True, num_regions=num_regions)

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os

# Define parameters
data_dir = 'data_exp4_lc'
num_regions = 100
num_subregions = 4
grid_size_reg = int(np.sqrt(num_regions))  # 10x10 for regions
grid_size_sub_calc = int(np.sqrt(num_regions) * np.sqrt(num_subregions))  # 20x20 for subregions (2x2 per region)
subregions_per_dim = int(np.sqrt(num_subregions))


# Color maps from seaborn deep palette
colors = sns.color_palette("deep")
red_cmap = sns.light_palette(colors[3], as_cmap=True)  # Red
blue_cmap = sns.light_palette(colors[0], as_cmap=True)  # Blue
green_cmap = sns.light_palette(colors[2], as_cmap=True)  # Green

# Parameter to control space between subplots (as fraction of average axis width)
space_between = 0.05  # Adjust this value as needed (e.g., 0.0 for no space, 0.5 for more space)

# Load data for combined
interv_data = pd.read_csv(os.path.join(data_dir, 'interventions_region.csv'), header=None).values.flatten().reshape(grid_size_reg, grid_size_reg).astype(float)  # Assuming 0/1, float for heatmap
outcome_data = pd.read_csv(os.path.join(data_dir, 'outcome_region.csv'), header=None).values.flatten().reshape(grid_size_reg, grid_size_reg)# / 100.0  # Assuming 0-100, normalize to 0-1
context_data = pd.read_csv(os.path.join(data_dir, 'context_subregion.csv'), header=None).values
context_grid = np.zeros((grid_size_sub_calc, grid_size_sub_calc))
for i in range(grid_size_reg):
    for j in range(grid_size_reg):
        region_idx = i * grid_size_reg + j
        sub_data = context_data[region_idx, :].reshape(subregions_per_dim, subregions_per_dim)
        context_grid[i*subregions_per_dim:(i+1)*subregions_per_dim, j*subregions_per_dim:(j+1)*subregions_per_dim] = sub_data
context_norm = (context_grid - 1) / 2.0  # Normalize assuming values 1,2,3 to 0-1

# Create a single figure with three subplots side by side
fig, axs = plt.subplots(1, 3, figsize=(15, 5))

# Left: Intervention in red
sns.heatmap(interv_data, ax=axs[0], cmap=red_cmap, cbar=False, square=True, linewidths=2)
axs[0].set_xticks([])
axs[0].set_yticks([])

# Center: Outcome in blue
sns.heatmap(outcome_data, ax=axs[1], cmap=blue_cmap, cbar=False, square=True, linewidths=2)
axs[1].set_xticks([])
axs[1].set_yticks([])

# Right: Context in green
sns.heatmap(context_norm, ax=axs[2], cmap=green_cmap, cbar=False, square=True, linewidths=1)#, vmin=0.0)#, vmax=1.0)
axs[2].set_xticks([])
axs[2].set_yticks([])#


# Adjust the space between subplots
fig.subplots_adjust(wspace=space_between)

# Save and show
fig.savefig(os.path.join(data_dir, 'combined_heatmaps.jpg'), dpi=300, bbox_inches='tight', pad_inches=0)
fig.savefig(os.path.join(data_dir, 'exp5_combined_heatmaps.pdf'), bbox_inches='tight', pad_inches=0)
plt.show()

In [None]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import seaborn as sns
import os

# Define parameters
data_dir = 'data_exp4_lc'
num_regions = 100
num_subregions = 4
epochs = 10000
lr = 0.001
torch.manual_seed(42)
grid_size_sub = 20
sns.set_context('paper', font_scale=1.5)
reds = sns.color_palette("deep")[3]

# Read all CSVs
interventions_reg = torch.tensor(pd.read_csv(os.path.join(data_dir, 'interventions_region.csv'), header=None).values.flatten(), dtype=torch.float32)
outcome_reg = torch.tensor(pd.read_csv(os.path.join(data_dir, 'outcome_region.csv'), header=None).values.flatten(), dtype=torch.float32)
context_sub = torch.tensor(pd.read_csv(os.path.join(data_dir, 'context_subregion.csv'), header=None).values, dtype=torch.float32)

# Model
class CausalModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.inter_est = nn.Parameter(torch.randn(num_regions, num_subregions))
        self.shift = nn.Parameter(torch.tensor(1.0))
        self.scale = nn.Parameter(torch.tensor(1.0))
        self.temp = torch.tensor(10.0)
        self.noise_scale = 0.1

    def preprocess_inter(self):
        inter_pos = (self.inter_est + torch.randn_like(self.inter_est)*self.noise_scale) ** 2  # Ensure positivity
        inter_sparse = torch.softmax(inter_pos / self.temp, dim=1)  # Softmax for differentiable near-one-hot sparsity
        inter_constrained = inter_sparse * interventions_reg.unsqueeze(1)  # Multiply by known region sum
        return inter_constrained

    def forward(self):
        inter_sub = self.preprocess_inter()
        outcome_sub_pred = (self.shift - context_sub) * self.scale * inter_sub# + noise_sub
        outcome_reg_pred = outcome_sub_pred.mean(dim=1)
        return outcome_reg_pred

model = CausalModel()
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.MSELoss()

# Train
losses = []
start_temp = 10.0
end_temp = 2.0
for epoch in range(epochs):
    model.temp = torch.tensor(start_temp + (end_temp - start_temp) * (epoch / (epochs - 1)))
    optimizer.zero_grad()
    pred = model.forward()
    loss = criterion(pred, outcome_reg)
    loss.backward()
    optimizer.step()
    losses.append(loss.item())

model.temp = 0.0001
model.noise_scale = 0.0

# Save loss curve
plt.figure(figsize=(10, 5))
plt.plot(losses)
plt.title('Loss Curve')
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.savefig(os.path.join(data_dir, 'loss_curve.jpg'))
plt.savefig(os.path.join(data_dir, 'loss_curve.pdf'))
plt.show()
plt.close()

# Print final estimates
print(f"Final shift_param: {model.shift.item()}")
print(f"Final scale_param: {model.scale.item()}")

# Save params to txt
with open(os.path.join(data_dir, 'parameter_estimate.txt'), 'w') as f:
    f.write(f"shift_param: {model.shift.item()}\n")
    f.write(f"scale_param: {model.scale.item()}\n")

# Save estimated interventions sub (processed)
inter_est_final = model.preprocess_inter().detach().numpy()
pd.DataFrame(inter_est_final).to_csv(os.path.join(data_dir, 'interventions_subregion_estimated.csv'), index=False, header=False)

# Visualize estimated subregion
def visualize_estimated_subregion():
    data = inter_est_final
    grid = np.zeros((grid_size_sub, grid_size_sub))
    for i in range(10):
        for j in range(10):
            region_idx = i * 10 + j
            sub_data = data[region_idx, :].reshape(2, 2)
            grid[2*i:2*(i+1), 2*j:2*(j+1)] = sub_data
    fig, ax = plt.subplots(figsize=(8, 8))
    heatmap_cmap = sns.light_palette(reds, as_cmap=True)
    sns.heatmap(grid, ax=ax, cmap=heatmap_cmap, square=True, cbar=False, cbar_kws={'label': 'Value', 'location': 'right', 'pad': 0.1})
    ax.set_title('Estimated Interventions Subregion', fontsize=18)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.grid(False)
    plt.legend([], [], frameon=False)
    plt.tight_layout()
    plt.savefig(os.path.join(data_dir, 'interventions_subregion_estimated.jpg'), dpi=300)
    plt.savefig(os.path.join(data_dir, 'interventions_subregion_estimated_standard.pdf'))
    plt.show()
    plt.close()

visualize_estimated_subregion()

In [None]:
# Calculate true causal effect at subregion level
true_outcome_sub = (shift_param - context_sub) * scale_param * interventions_sub

# Load estimated interventions
inter_est_final = pd.read_csv(os.path.join(data_dir, 'interventions_subregion_estimated.csv'), header=None).values

# Calculate predicted outcome at subregion level using estimated interventions and learned parameters
# Need to load learned parameters from the text file
with open(os.path.join(data_dir, 'parameter_estimate.txt'), 'r') as f:
    lines = f.readlines()
    estimated_shift = float(lines[0].split(': ')[1])
    estimated_scale = float(lines[1].split(': ')[1])

# Ensure context_sub is a numpy array for the calculation
if isinstance(context_sub, torch.Tensor):
    context_sub_np = context_sub.numpy()
else:
    context_sub_np = context_sub


predicted_outcome_sub = (estimated_shift - context_sub_np) * estimated_scale * inter_est_final

# Visualize true vs predicted outcome at subregion level
def visualize_true_vs_predicted_subregion(true_data, predicted_data, num_regions=100, num_subregions=4):
    grid_size_sub_calc = int(np.sqrt(num_regions) * np.sqrt(num_subregions))
    subregions_per_dim = int(np.sqrt(num_subregions))

    true_grid = np.zeros((grid_size_sub_calc, grid_size_sub_calc))
    predicted_grid = np.zeros((grid_size_sub_calc, grid_size_sub_calc))
    difference_grid = np.zeros((grid_size_sub_calc, grid_size_sub_calc))

    for i in range(int(np.sqrt(num_regions))):
        for j in range(int(np.sqrt(num_regions))):
            region_idx = i * int(np.sqrt(num_regions)) + j
            true_sub_data = true_data[region_idx, :].reshape(subregions_per_dim, subregions_per_dim)
            predicted_sub_data = predicted_data[region_idx, :].reshape(subregions_per_dim, subregions_per_dim)
            true_grid[i*subregions_per_dim:(i+1)*subregions_per_dim, j*subregions_per_dim:(j+1)*subregions_per_dim] = true_sub_data
            predicted_grid[i*subregions_per_dim:(i+1)*subregions_per_dim, j*subregions_per_dim:(j+1)*subregions_per_dim] = predicted_sub_data
            difference_grid[i*subregions_per_dim:(i+1)*subregions_per_dim, j*subregions_per_dim:(j+1)*subregions_per_dim] = true_sub_data - predicted_sub_data

    # Color maps from seaborn deep palette (assuming these are available in the notebook's global scope)
    colors = sns.color_palette("deep")
    blue_cmap = sns.light_palette(colors[0], as_cmap=True)  # Blue for Outcome
    coolwarm_cmap = 'coolwarm' # Using coolwarm for difference as it was used before


    fig, axs = plt.subplots(1, 3, figsize=(15, 5)) # Adjusted figure size


    # True Outcome (Blue)
    sns.heatmap(true_grid, ax=axs[0], cmap=blue_cmap, square=True, cbar=False, linewidths=1) # Added linewidths
    axs[0].set_title('True Outcome Subregion')
    axs[0].set_xticks([])
    axs[0].set_yticks([])

    # Predicted Outcome (Blue)
    sns.heatmap(predicted_grid, ax=axs[1], cmap=blue_cmap, square=True, cbar=False, linewidths=1) # Added linewidths
    axs[1].set_title('Predicted Outcome Subregion')
    axs[1].set_xticks([])
    axs[1].set_yticks([])

    # Difference (Coolwarm)
    sns.heatmap(difference_grid, ax=axs[2], cmap=coolwarm_cmap, square=True, cbar=False, linewidths=1) # Added linewidths
    axs[2].set_title('Difference (True - Predicted)')
    axs[2].set_xticks([])
    axs[2].set_yticks([])

    plt.tight_layout()
    plt.savefig(os.path.join(data_dir, 'true_vs_predicted_outcome_subregion_styled.jpg'), dpi=300)
    plt.savefig(os.path.join(data_dir, 'standard_true_vs_predicted_outcome_subregion_styled.pdf'))
    plt.show()
    plt.close()

visualize_true_vs_predicted_subregion(true_outcome_sub, predicted_outcome_sub, num_regions, num_subregions)

# Calculate MSE for true vs predicted subregion outcome
mse_subregion = np.mean((np.asarray(true_outcome_sub) - np.asarray(predicted_outcome_sub)) ** 2)
print(f"MSE between true and predicted subregion outcome: {mse_subregion}")

In [None]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import seaborn as sns
import os

# Define parameters
data_dir = 'data_exp4_lc'
num_regions = 100
num_subregions = 4
epochs = 10000
lr = 0.01
torch.manual_seed(42)
grid_size_sub = int(np.sqrt(num_regions) * np.sqrt(num_subregions))
sns.set_context('paper', font_scale=1.5)
reds = sns.color_palette("deep")[3]

# Read all CSVs
interventions_reg = torch.tensor(pd.read_csv(os.path.join(data_dir, 'interventions_region.csv'), header=None).values.flatten(), dtype=torch.float32)
outcome_reg = torch.tensor(pd.read_csv(os.path.join(data_dir, 'outcome_region.csv'), header=None).values.flatten(), dtype=torch.float32)
context_sub = torch.tensor(pd.read_csv(os.path.join(data_dir, 'context_subregion.csv'), header=None).values, dtype=torch.float32)
# context_reg is not directly used in the preprocess_inter in this version, but kept for consistency if needed elsewhere
context_reg = torch.mean(context_sub, dim=1)

# Model
class CausalModelCombined(nn.Module):
    def __init__(self, num_regions, num_subregions):
        super().__init__()
        self.shift = nn.Parameter(torch.tensor(1.0))
        self.scale = nn.Parameter(torch.tensor(1.0))
        # Make tau a learnable parameter
        self.log_tau = nn.Parameter(torch.tensor(np.log(1.0))) # Initialize log_tau to 0 (tau=1)
        self.num_regions = num_regions
        self.num_subregions = num_subregions


    def preprocess_inter(self, context_sub, interventions_reg):
        tau = torch.exp(self.log_tau)
        # Apply softmax over subregion contexts using the learnable tau
        inter_pred_raw = context_sub / tau
        inter_pred_softmax = torch.softmax(inter_pred_raw, dim=1)

        # Find the subregion with the highest predicted probability for each region
        # and set its value to 1, and others to 0.
        max_prob_indices = torch.argmax(inter_pred_softmax, dim=1)
        inter_discrete = torch.zeros_like(inter_pred_softmax)
        inter_discrete[torch.arange(self.num_regions), max_prob_indices] = 1.0

        # Scale the discrete intervention by the total regional intervention amount
        inter_constrained = inter_discrete * interventions_reg.unsqueeze(1)

        return inter_constrained, inter_pred_softmax # Also return softmax probabilities for loss

    def forward(self, context_sub, interventions_reg):
        inter_sub, inter_probs = self.preprocess_inter(context_sub, interventions_reg)
        outcome_sub_pred = (self.shift - context_sub) * self.scale * inter_sub# + noise_sub
        outcome_reg_pred = outcome_sub_pred.mean(dim=1)
        return outcome_reg_pred, inter_sub, inter_probs

model_combined = CausalModelCombined(num_regions, num_subregions)
optimizer_combined = optim.Adam(model_combined.parameters(), lr=lr)
criterion_outcome = nn.MSELoss()
criterion_inter = nn.NLLLoss()

# Train
losses_combined = []
outcome_losses_combined = []
inter_losses_combined = []
# Load true subregion interventions for loss calculation
true_interventions_sub = torch.tensor(pd.read_csv(os.path.join(data_dir, 'interventions_subregion.csv'), header=None).values, dtype=torch.float32)
# Create target indices for NLLLoss
true_intervention_indices = torch.argmax(true_interventions_sub, dim=1)


for epoch in range(epochs):
    optimizer_combined.zero_grad()
    outcome_reg_pred, inter_sub_pred, inter_probs = model_combined.forward(context_sub, interventions_reg)

    loss_outcome = criterion_outcome(outcome_reg_pred, outcome_reg)
    # Use log_softmax for NLLLoss
    loss_inter = criterion_inter(torch.log(inter_probs + 1e-9), true_intervention_indices)
    loss = loss_outcome + loss_inter # Combine losses

    loss.backward()
    optimizer_combined.step()
    losses_combined.append(loss.item())
    outcome_losses_combined.append(loss_outcome.item())
    inter_losses_combined.append(loss_inter.item())


# Save loss curves
plt.figure(figsize=(10, 5))
plt.plot(losses_combined, label='Total Loss')
plt.plot(outcome_losses_combined, label='Outcome Loss')
plt.plot(inter_losses_combined, label='Intervention Prediction Loss (NLL)')
plt.title('Combined Model Loss Curves')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.savefig(os.path.join(data_dir, 'loss_curves_combined_new.jpg'))
plt.savefig(os.path.join(data_dir, 'loss_curves_combined_new.pdf'))
plt.show()
plt.close()


# Print final estimates
print(f"Final shift_param (Combined): {model_combined.shift.item()}")
print(f"Final scale_param (Combined): {model_combined.scale.item()}")
print(f"Final estimated tau (Combined): {torch.exp(model_combined.log_tau).item()}")

# Save params to txt
with open(os.path.join(data_dir, 'parameter_estimate_combined_new.txt'), 'w') as f:
    f.write(f"shift_param: {model_combined.shift.item()}\n")
    f.write(f"scale_param: {model_combined.scale.item()}\n")
    f.write(f"estimated_tau: {torch.exp(model_combined.log_tau).item()}\n")

# Save estimated interventions sub (processed)
inter_est_final_combined, _ = model_combined.preprocess_inter(context_sub, interventions_reg)
inter_est_final_combined = inter_est_final_combined.detach().numpy()
pd.DataFrame(inter_est_final_combined).to_csv(os.path.join(data_dir, 'interventions_subregion_estimated_combined_new.csv'), index=False, header=False)

# Visualize estimated subregion
def visualize_estimated_subregion(data, title, file_suffix):
    grid_size_sub_calc = int(np.sqrt(num_regions) * np.sqrt(num_subregions))
    subregions_per_dim = int(np.sqrt(num_subregions))
    grid = np.zeros((grid_size_sub_calc, grid_size_sub_calc))
    for i in range(int(np.sqrt(num_regions))):
        for j in range(int(np.sqrt(num_regions))):
            region_idx = i * int(np.sqrt(num_regions)) + j
            sub_data = data[region_idx, :].reshape(subregions_per_dim, subregions_per_dim)
            grid[i*subregions_per_dim:(i+1)*subregions_per_dim, j*subregions_per_dim:(j+1)*subregions_per_dim] = sub_data
    fig, ax = plt.subplots(figsize=(8, 8))
    heatmap_cmap = sns.light_palette(reds, as_cmap=True)
    sns.heatmap(grid, ax=ax, cmap=heatmap_cmap, square=True, cbar=False, cbar_kws={'label': 'Value', 'location': 'right', 'pad': 0.1})
    ax.set_title(title, fontsize=18)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.grid(False)
    plt.legend([], [], frameon=False)
    plt.tight_layout()
    plt.savefig(os.path.join(data_dir, f'interventions_subregion_estimated_{file_suffix}.jpg'), dpi=300)
    plt.savefig(os.path.join(data_dir, f'interventions_subregion_estimated_{file_suffix}.pdf'))
    plt.show()
    plt.close()

visualize_estimated_subregion(inter_est_final_combined, 'Estimated Interventions Subregion (Combined Model)', 'combined_new')

In [None]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import seaborn as sns
import os

# --- Parameters ---
data_dir = 'data_exp4_lc'
num_regions = 100
num_subregions = 4
epochs = 10000
lr = 0.01
torch.manual_seed(42)
grid_size_sub = int(np.sqrt(num_regions) * np.sqrt(num_subregions))
sns.set_context('paper', font_scale=1.5)
reds = sns.color_palette("deep")[3]

# --- Load Data ---
interventions_reg = torch.tensor(
    pd.read_csv(os.path.join(data_dir, 'interventions_region.csv'), header=None).values.flatten(),
    dtype=torch.float32
)
outcome_reg = torch.tensor(
    pd.read_csv(os.path.join(data_dir, 'outcome_region.csv'), header=None).values.flatten(),
    dtype=torch.float32
)
context_sub = torch.tensor(
    pd.read_csv(os.path.join(data_dir, 'context_subregion.csv'), header=None).values,
    dtype=torch.float32
)
context_reg = torch.mean(context_sub, dim=1)

# --- Model ---
class CausalModelCombined(nn.Module):
    def __init__(self, num_regions, num_subregions):
        super().__init__()
        self.shift = nn.Parameter(torch.tensor(1.0))
        self.scale = nn.Parameter(torch.tensor(1.0))
        self.log_tau = nn.Parameter(torch.tensor(np.log(1.0)))  # learnable tau
        self.num_regions = num_regions
        self.num_subregions = num_subregions

    def preprocess_inter(self, context_sub, interventions_reg):
        tau = torch.exp(self.log_tau)
        inter_pred_raw = context_sub / tau
        inter_pred_softmax = torch.softmax(inter_pred_raw, dim=1)

        # Always sample from probabilities
        sampled_indices = torch.multinomial(inter_pred_softmax, num_samples=1).squeeze(1)

        inter_discrete = torch.zeros_like(inter_pred_softmax)
        inter_discrete[torch.arange(self.num_regions), sampled_indices] = 1.0
        inter_constrained = inter_discrete * interventions_reg.unsqueeze(1)

        return inter_constrained, inter_pred_softmax

    def forward(self, context_sub, interventions_reg):
        inter_sub, inter_probs = self.preprocess_inter(context_sub, interventions_reg)
        outcome_sub_pred = (self.shift - context_sub) * self.scale * inter_sub
        outcome_reg_pred = outcome_sub_pred.mean(dim=1)
        return outcome_reg_pred, inter_sub, inter_probs

# --- Setup ---
model_combined = CausalModelCombined(num_regions, num_subregions)
optimizer_combined = optim.Adam(model_combined.parameters(), lr=lr)
criterion_outcome = nn.MSELoss()
criterion_inter = nn.NLLLoss()

true_interventions_sub = torch.tensor(
    pd.read_csv(os.path.join(data_dir, 'interventions_subregion.csv'), header=None).values,
    dtype=torch.float32
)
true_intervention_indices = torch.argmax(true_interventions_sub, dim=1)

# --- Training ---
losses_combined = []
outcome_losses_combined = []
inter_losses_combined = []

for epoch in range(epochs):
    optimizer_combined.zero_grad()
    outcome_reg_pred, inter_sub_pred, inter_probs = model_combined.forward(context_sub, interventions_reg)

    loss_outcome = criterion_outcome(outcome_reg_pred, outcome_reg)
    loss_inter = criterion_inter(torch.log(inter_probs + 1e-9), true_intervention_indices)
    loss = loss_outcome + loss_inter

    loss.backward()
    optimizer_combined.step()

    losses_combined.append(loss.item())
    outcome_losses_combined.append(loss_outcome.item())
    inter_losses_combined.append(loss_inter.item())

# --- Save Loss Curves ---
plt.figure(figsize=(10, 5))
plt.plot(losses_combined, label='Total Loss')
plt.plot(outcome_losses_combined, label='Outcome Loss')
plt.plot(inter_losses_combined, label='Intervention Prediction Loss (NLL)')
plt.title('Combined Model Loss Curves (Sampling Always)')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.savefig(os.path.join(data_dir, 'loss_curves_combined_sampling_always.jpg'))
plt.savefig(os.path.join(data_dir, 'loss_curves_combined_sampling_always.pdf'))
plt.show()
plt.close()

# --- Save Parameters ---
print(f"Final shift_param (Combined): {model_combined.shift.item()}")
print(f"Final scale_param (Combined): {model_combined.scale.item()}")
print(f"Final estimated tau (Combined): {torch.exp(model_combined.log_tau).item()}")

with open(os.path.join(data_dir, 'parameter_estimate_combined_sampling_always.txt'), 'w') as f:
    f.write(f"shift_param: {model_combined.shift.item()}\n")
    f.write(f"scale_param: {model_combined.scale.item()}\n")
    f.write(f"estimated_tau: {torch.exp(model_combined.log_tau).item()}\n")

# --- Save Estimated Interventions (sampled) ---
inter_est_final_combined, _ = model_combined.preprocess_inter(context_sub, interventions_reg)
pd.DataFrame(inter_est_final_combined.detach().numpy()).to_csv(
    os.path.join(data_dir, 'interventions_subregion_estimated_combined_sampling_always.csv'),
    index=False, header=False
)

# --- Visualization ---
def visualize_estimated_subregion(data, title, file_suffix):
    grid_size_sub_calc = int(np.sqrt(num_regions) * np.sqrt(num_subregions))
    subregions_per_dim = int(np.sqrt(num_subregions))
    grid = np.zeros((grid_size_sub_calc, grid_size_sub_calc))
    for i in range(int(np.sqrt(num_regions))):
        for j in range(int(np.sqrt(num_regions))):
            region_idx = i * int(np.sqrt(num_regions)) + j
            sub_data = data[region_idx, :].reshape(subregions_per_dim, subregions_per_dim)
            grid[i*subregions_per_dim:(i+1)*subregions_per_dim,
                 j*subregions_per_dim:(j+1)*subregions_per_dim] = sub_data
    fig, ax = plt.subplots(figsize=(8, 8))
    heatmap_cmap = sns.light_palette(reds, as_cmap=True)
    sns.heatmap(grid, ax=ax, cmap=heatmap_cmap, square=True, cbar=False,
                cbar_kws={'label': 'Value', 'location': 'right', 'pad': 0.1})
    ax.set_title(title, fontsize=18)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.grid(False)
    plt.legend([], [], frameon=False)
    plt.tight_layout()
    plt.savefig(os.path.join(data_dir, f'interventions_subregion_estimated_{file_suffix}.jpg'), dpi=300)
    plt.savefig(os.path.join(data_dir, f'interventions_subregion_estimated_{file_suffix}.pdf'))
    plt.show()
    plt.close()

visualize_estimated_subregion(
    inter_est_final_combined.detach().numpy(),
    'Estimated Interventions Subregion (Combined Model)',
    'combined_sampling_always'
)


In [None]:
# Calculate true causal effect at subregion level
# This is based on the original data generation formula before adding noise
true_outcome_sub = (shift_param - context_sub) * scale_param * interventions_sub

# Load estimated interventions
inter_est_final = pd.read_csv(os.path.join(data_dir, 'interventions_subregion_estimated_combined_new.csv'), header=None).values

# Calculate predicted outcome at subregion level using estimated interventions and learned parameters
# Need to load learned parameters from the text file
with open(os.path.join(data_dir, 'parameter_estimate_combined_new.txt'), 'r') as f:
    lines = f.readlines()
    estimated_shift = float(lines[0].split(': ')[1])
    estimated_scale = float(lines[1].split(': ')[1])

# Ensure context_sub is a numpy array for the calculation
if isinstance(context_sub, torch.Tensor):
    context_sub_np = context_sub.numpy()
else:
    context_sub_np = context_sub


predicted_outcome_sub = (estimated_shift - context_sub_np) * estimated_scale * inter_est_final

# Visualize true vs predicted outcome at subregion level
def visualize_true_vs_predicted_subregion(true_data, predicted_data, num_regions=100, num_subregions=4):
    grid_size_sub_calc = int(np.sqrt(num_regions) * np.sqrt(num_subregions))
    subregions_per_dim = int(np.sqrt(num_subregions))

    true_grid = np.zeros((grid_size_sub_calc, grid_size_sub_calc))
    predicted_grid = np.zeros((grid_size_sub_calc, grid_size_sub_calc))
    difference_grid = np.zeros((grid_size_sub_calc, grid_size_sub_calc))

    for i in range(int(np.sqrt(num_regions))):
        for j in range(int(np.sqrt(num_regions))):
            region_idx = i * int(np.sqrt(num_regions)) + j
            true_sub_data = true_data[region_idx, :].reshape(subregions_per_dim, subregions_per_dim)
            predicted_sub_data = predicted_data[region_idx, :].reshape(subregions_per_dim, subregions_per_dim)
            true_grid[i*subregions_per_dim:(i+1)*subregions_per_dim, j*subregions_per_dim:(j+1)*subregions_per_dim] = true_sub_data
            predicted_grid[i*subregions_per_dim:(i+1)*subregions_per_dim, j*subregions_per_dim:(j+1)*subregions_per_dim] = predicted_sub_data
            difference_grid[i*subregions_per_dim:(i+1)*subregions_per_dim, j*subregions_per_dim:(j+1)*subregions_per_dim] = true_sub_data - predicted_sub_data

    # Color maps from seaborn deep palette (assuming these are available in the notebook's global scope)
    colors = sns.color_palette("deep")
    blue_cmap = sns.light_palette(colors[0], as_cmap=True)  # Blue for Outcome
    coolwarm_cmap = 'coolwarm' # Using coolwarm for difference as it was used before


    fig, axs = plt.subplots(1, 3, figsize=(15, 5)) # Adjusted figure size


    # True Outcome (Blue)
    sns.heatmap(true_grid, ax=axs[0], cmap=blue_cmap, square=True, cbar=False, linewidths=1) # Added linewidths
    axs[0].set_title('True Outcome Subregion')
    axs[0].set_xticks([])
    axs[0].set_yticks([])

    # Predicted Outcome (Blue)
    sns.heatmap(predicted_grid, ax=axs[1], cmap=blue_cmap, square=True, cbar=False, linewidths=1) # Added linewidths
    axs[1].set_title('Predicted Outcome Subregion')
    axs[1].set_xticks([])
    axs[1].set_yticks([])

    # Difference (Coolwarm)
    sns.heatmap(difference_grid, ax=axs[2], cmap=coolwarm_cmap, square=True, cbar=False, linewidths=1) # Added linewidths
    axs[2].set_title('Difference (True - Predicted)')
    axs[2].set_xticks([])
    axs[2].set_yticks([])

    plt.tight_layout()
    plt.savefig(os.path.join(data_dir, 'true_vs_predicted_outcome_subregion_styled.jpg'), dpi=300)
    plt.savefig(os.path.join(data_dir, 'exp5_tau_true_vs_predicted_outcome_subregion_styled.pdf'))
    plt.show()
    plt.close()

visualize_true_vs_predicted_subregion(true_outcome_sub, predicted_outcome_sub, num_regions, num_subregions)

# Calculate MSE for true vs predicted subregion outcome
mse_subregion = np.mean((np.asarray(true_outcome_sub) - np.asarray(predicted_outcome_sub)) ** 2)
print(f"MSE between true and predicted subregion outcome: {mse_subregion}")

##Exp 4. High Confounding

In [None]:
import numpy as np
import pandas as pd
import os

# Define parameters
num_regions = 100
num_subregions = 4
shift_param = 1.23
scale_param = -4.2
noise_var = 0.01
data_dir = 'data_exp4_hc'
os.makedirs(data_dir, exist_ok=True)
np.random.seed(42)

# 1) Context (wealth) subregion
context_sub = np.random.uniform(0, 1, (num_regions, num_subregions))
pd.DataFrame(context_sub).to_csv(os.path.join(data_dir, 'context_subregion.csv'), index=False, header=False)

# 2) Interventions subregion (only one per region, using softmax over context)
def softmax(x, tau=1):
    x = x / tau  # temperature scaling
    e_x = np.exp(x - np.max(x))  # numerical stability
    return e_x / e_x.sum()

interventions_sub = np.zeros((num_regions, num_subregions))
for i in range(num_regions):
    probs = softmax(context_sub[i], tau=0.1)  # smaller tau = sharper preferences for high context
    col_idx = np.random.choice(np.arange(num_subregions), p=probs)
    interventions_sub[i, col_idx] = np.random.uniform(0.2, 1)
pd.DataFrame(interventions_sub).to_csv(os.path.join(data_dir, 'interventions_subregion.csv'), index=False, header=False)


# Interventions region: sum of subregion interventions
interventions_reg = interventions_sub.sum(axis=1)
pd.DataFrame(interventions_reg).to_csv(os.path.join(data_dir, 'interventions_region.csv'), index=False, header=False)

# 3) Noise subregion
noise_sub = np.random.normal(0, np.sqrt(noise_var), size=(num_regions, num_subregions))
pd.DataFrame(noise_sub).to_csv(os.path.join(data_dir, 'noise_subregion.csv'), index=False, header=False)

# 4) Outcome subregion: context-dependent + noise
outcome_sub = (shift_param - context_sub) * scale_param * interventions_sub + noise_sub
pd.DataFrame(outcome_sub).to_csv(os.path.join(data_dir, 'outcome_subregion.csv'), index=False, header=False)

# Outcome region: average outcome across subregions
outcome_reg = outcome_sub.mean(axis=1)
pd.DataFrame(outcome_reg).to_csv(os.path.join(data_dir, 'outcome_region.csv'), index=False, header=False)

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os

# Define parameters
data_dir = 'data_exp4_hc'
num_regions = 100
num_subregions = 4
grid_size_reg = 10  # 10x10 for regions
grid_size_sub = 20  # 20x20 for subregions (2x2 per region)
sns.set_context('paper', font_scale=1.5)
colors = sns.color_palette("deep")
reds = colors[3]
blues = colors[0]

def visualize_subregion_csv(file_name, cmap, title, is_wealth=False, num_regions=100, num_subregions=4):
    data = pd.read_csv(os.path.join(data_dir, file_name), header=None).values
    # Reshape to subregion grid size: each region row becomes block of subregions
    grid_size_sub_calc = int(np.sqrt(num_regions) * np.sqrt(num_subregions))
    subregions_per_dim = int(np.sqrt(num_subregions))
    grid = np.zeros((grid_size_sub_calc, grid_size_sub_calc))
    for i in range(int(np.sqrt(num_regions))):
        for j in range(int(np.sqrt(num_regions))):
            region_idx = i * int(np.sqrt(num_regions)) + j
            sub_data = data[region_idx, :].reshape(subregions_per_dim, subregions_per_dim)
            grid[i*subregions_per_dim:(i+1)*subregions_per_dim, j*subregions_per_dim:(j+1)*subregions_per_dim] = sub_data
    fig, ax = plt.subplots(figsize=(8, 8))
    if is_wealth:
        heatmap_cmap = sns.color_palette("Spectral", as_cmap=True)
    else:
        if isinstance(cmap, tuple):
            heatmap_cmap = sns.light_palette(cmap, as_cmap=True)
        else:
            heatmap_cmap = cmap
    sns.heatmap(grid, ax=ax, cmap=heatmap_cmap, square=True, cbar=False, cbar_kws={'label': 'Value', 'location': 'right', 'pad': 0.1})
    ax.set_title(title, fontsize=18)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.grid(False)
    plt.legend([], [], frameon=False)  # No legend box
    plt.tight_layout()
    base_name = file_name.replace('.csv', '')
    plt.savefig(os.path.join(data_dir, f'{base_name}.jpg'), dpi=300)
    plt.savefig(os.path.join(data_dir, f'{base_name}.pdf'))
    plt.show()
    plt.close()

def visualize_region_csv(file_name, cmap, title, is_wealth=False, num_regions=100):
    data = pd.read_csv(os.path.join(data_dir, file_name), header=None).values.flatten()
    # Reshape to region grid size
    grid_size_reg_calc = int(np.sqrt(num_regions))
    grid = data.reshape(grid_size_reg_calc, grid_size_reg_calc)
    fig, ax = plt.subplots(figsize=(8, 8))
    if is_wealth:
        heatmap_cmap = sns.color_palette("Spectral", as_cmap=True)
    else:
        if isinstance(cmap, tuple):
            heatmap_cmap = sns.light_palette(cmap, as_cmap=True)
        else:
            heatmap_cmap = cmap
    sns.heatmap(grid, ax=ax, cmap=heatmap_cmap, square=True, cbar=False, cbar_kws={'label': 'Value', 'location': 'right', 'pad': 0.1})
    ax.set_title(title, fontsize=18)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.grid(False)
    plt.legend([], [], frameon=False)
    plt.tight_layout()
    base_name = file_name.replace('.csv', '')
    plt.savefig(os.path.join(data_dir, f'{base_name}.jpg'), dpi=300)
    plt.savefig(os.path.join(data_dir, f'{base_name}.pdf'))
    plt.show()
    plt.close()

# Visualize each
visualize_subregion_csv('interventions_subregion.csv', reds, 'Interventions Subregion', is_wealth=False, num_regions=num_regions, num_subregions=num_subregions)
visualize_region_csv('interventions_region.csv', reds, 'Interventions Region', is_wealth=False, num_regions=num_regions)
visualize_subregion_csv('context_subregion.csv', None, 'Wealth Subregion', is_wealth=True, num_regions=num_regions, num_subregions=num_subregions)
visualize_subregion_csv('noise_subregion.csv', 'coolwarm', 'Noise Subregion', is_wealth=False, num_regions=num_regions, num_subregions=num_subregions)  # Using coolwarm for noise as not specified
visualize_subregion_csv('outcome_subregion.csv', blues, 'Outcome Subregion', is_wealth=False, num_regions=num_regions, num_subregions=num_subregions)
visualize_region_csv('outcome_region.csv', blues, 'Outcome Region', is_wealth=False, num_regions=num_regions)

# Wealth low-res: average context_sub to get wealth_reg
context_sub = pd.read_csv(os.path.join(data_dir, 'context_subregion.csv'), header=None).values
wealth_reg = np.mean(context_sub, axis=1)
pd.DataFrame(wealth_reg).to_csv(os.path.join(data_dir, 'context_region.csv'), index=False, header=False)  # Temp save for viz
visualize_region_csv('context_region.csv', None, 'Wealth Region', is_wealth=True, num_regions=num_regions)

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os

# Define parameters
data_dir = 'data_exp4_hc'
num_regions = 100
num_subregions = 4
grid_size_reg = int(np.sqrt(num_regions))  # 10x10 for regions
grid_size_sub_calc = int(np.sqrt(num_regions) * np.sqrt(num_subregions))  # 20x20 for subregions (2x2 per region)
subregions_per_dim = int(np.sqrt(num_subregions))


# Color maps from seaborn deep palette
colors = sns.color_palette("deep")
red_cmap = sns.light_palette(colors[3], as_cmap=True)  # Red
blue_cmap = sns.light_palette(colors[0], as_cmap=True)  # Blue
green_cmap = sns.light_palette(colors[2], as_cmap=True)  # Green

# Parameter to control space between subplots (as fraction of average axis width)
space_between = 0.05  # Adjust this value as needed (e.g., 0.0 for no space, 0.5 for more space)

# Load data for combined
interv_data = pd.read_csv(os.path.join(data_dir, 'interventions_region.csv'), header=None).values.flatten().reshape(grid_size_reg, grid_size_reg).astype(float)  # Assuming 0/1, float for heatmap
outcome_data = pd.read_csv(os.path.join(data_dir, 'outcome_region.csv'), header=None).values.flatten().reshape(grid_size_reg, grid_size_reg)# / 100.0  # Assuming 0-100, normalize to 0-1
context_data = pd.read_csv(os.path.join(data_dir, 'context_subregion.csv'), header=None).values
context_grid = np.zeros((grid_size_sub_calc, grid_size_sub_calc))
for i in range(grid_size_reg):
    for j in range(grid_size_reg):
        region_idx = i * grid_size_reg + j
        sub_data = context_data[region_idx, :].reshape(subregions_per_dim, subregions_per_dim)
        context_grid[i*subregions_per_dim:(i+1)*subregions_per_dim, j*subregions_per_dim:(j+1)*subregions_per_dim] = sub_data
context_norm = (context_grid - 1) / 2.0  # Normalize assuming values 1,2,3 to 0-1

# Create a single figure with three subplots side by side
fig, axs = plt.subplots(1, 3, figsize=(15, 5))

# Left: Intervention in red
sns.heatmap(interv_data, ax=axs[0], cmap=red_cmap, cbar=False, square=True, linewidths=2)
axs[0].set_xticks([])
axs[0].set_yticks([])


# Center: Outcome in blue
sns.heatmap(outcome_data, ax=axs[1], cmap=blue_cmap, cbar=False, square=True, linewidths=2)
axs[1].set_xticks([])
axs[1].set_yticks([])

# Right: Context in green
sns.heatmap(context_norm, ax=axs[2], cmap=green_cmap, cbar=False, square=True, linewidths=1)#, vmin=0.0)#, vmax=1.0)
axs[2].set_xticks([])
axs[2].set_yticks([])

# Adjust the space between subplots
fig.subplots_adjust(wspace=space_between)

# Save and show
fig.savefig(os.path.join(data_dir, 'combined_heatmaps.jpg'), dpi=300, bbox_inches='tight', pad_inches=0)
fig.savefig(os.path.join(data_dir, 'exp5_combined_heatmaps.pdf'), bbox_inches='tight', pad_inches=0)
plt.show()

In [None]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import seaborn as sns
import os

# Define parameters
data_dir = 'data_exp4_hc'
num_regions = 100
num_subregions = 4
epochs = 10000
lr = 0.001
torch.manual_seed(42)
grid_size_sub = 20
sns.set_context('paper', font_scale=1.5)
reds = sns.color_palette("deep")[3]

# Read all CSVs
interventions_reg = torch.tensor(pd.read_csv(os.path.join(data_dir, 'interventions_region.csv'), header=None).values.flatten(), dtype=torch.float32)
outcome_reg = torch.tensor(pd.read_csv(os.path.join(data_dir, 'outcome_region.csv'), header=None).values.flatten(), dtype=torch.float32)
context_sub = torch.tensor(pd.read_csv(os.path.join(data_dir, 'context_subregion.csv'), header=None).values, dtype=torch.float32)
#noise_sub = torch.tensor(pd.read_csv(os.path.join(data_dir, 'noise_subregion.csv'), header=None).values, dtype=torch.float32)

# Model
class CausalModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.inter_est = nn.Parameter(torch.randn(num_regions, num_subregions))
        self.shift = nn.Parameter(torch.tensor(1.0))
        self.scale = nn.Parameter(torch.tensor(1.0))
        self.temp = torch.tensor(10.0)
        self.noise_scale = 0.1

    def preprocess_inter(self):
        inter_pos = (self.inter_est + torch.randn_like(self.inter_est)*self.noise_scale) ** 2  # Ensure positivity
        inter_sparse = torch.softmax(inter_pos / self.temp, dim=1)  # Softmax for differentiable near-one-hot sparsity
        inter_constrained = inter_sparse * interventions_reg.unsqueeze(1)  # Multiply by known region sum
        return inter_constrained

    def forward(self):
        inter_sub = self.preprocess_inter()
        outcome_sub_pred = (self.shift - context_sub) * self.scale * inter_sub# + noise_sub
        outcome_reg_pred = outcome_sub_pred.mean(dim=1)
        return outcome_reg_pred

model = CausalModel()
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.MSELoss()

# Train
losses = []
start_temp = 10.0
end_temp = 2.0
for epoch in range(epochs):
    model.temp = torch.tensor(start_temp + (end_temp - start_temp) * (epoch / (epochs - 1)))
    optimizer.zero_grad()
    pred = model.forward()
    loss = criterion(pred, outcome_reg)
    loss.backward()
    optimizer.step()
    losses.append(loss.item())

model.temp = 0.0001
model.noise_scale = 0.0

# Save loss curve
plt.figure(figsize=(10, 5))
plt.plot(losses)
plt.title('Loss Curve')
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.savefig(os.path.join(data_dir, 'loss_curve.jpg'))
plt.savefig(os.path.join(data_dir, 'loss_curve.pdf'))
plt.show()
plt.close()

# Print final estimates
print(f"Final shift_param: {model.shift.item()}")
print(f"Final scale_param: {model.scale.item()}")

# Save params to txt
with open(os.path.join(data_dir, 'parameter_estimate.txt'), 'w') as f:
    f.write(f"shift_param: {model.shift.item()}\n")
    f.write(f"scale_param: {model.scale.item()}\n")

# Save estimated interventions sub (processed)
inter_est_final = model.preprocess_inter().detach().numpy()
pd.DataFrame(inter_est_final).to_csv(os.path.join(data_dir, 'interventions_subregion_estimated.csv'), index=False, header=False)

# Visualize estimated subregion
def visualize_estimated_subregion():
    data = inter_est_final
    grid = np.zeros((grid_size_sub, grid_size_sub))
    for i in range(10):
        for j in range(10):
            region_idx = i * 10 + j
            sub_data = data[region_idx, :].reshape(2, 2)
            grid[2*i:2*(i+1), 2*j:2*(j+1)] = sub_data
    fig, ax = plt.subplots(figsize=(8, 8))
    heatmap_cmap = sns.light_palette(reds, as_cmap=True)
    sns.heatmap(grid, ax=ax, cmap=heatmap_cmap, square=True, cbar=False, cbar_kws={'label': 'Value', 'location': 'right', 'pad': 0.1})
    ax.set_title('Estimated Interventions Subregion', fontsize=18)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.grid(False)
    plt.legend([], [], frameon=False)
    plt.tight_layout()
    plt.savefig(os.path.join(data_dir, 'interventions_subregion_estimated.jpg'), dpi=300)
    plt.savefig(os.path.join(data_dir, 'exp5_interventions_subregion_estimated_standard.pdf'))
    plt.show()
    plt.close()

visualize_estimated_subregion()

In [None]:
# Calculate true causal effect at subregion level
true_outcome_sub = (shift_param - context_sub) * scale_param * interventions_sub

# Load estimated interventions
inter_est_final = pd.read_csv(os.path.join(data_dir, 'interventions_subregion_estimated.csv'), header=None).values

# Calculate predicted outcome at subregion level using estimated interventions and learned parameters
with open(os.path.join(data_dir, 'parameter_estimate.txt'), 'r') as f:
    lines = f.readlines()
    estimated_shift = float(lines[0].split(': ')[1])
    estimated_scale = float(lines[1].split(': ')[1])

# Ensure context_sub is a numpy array for the calculation
if isinstance(context_sub, torch.Tensor):
    context_sub_np = context_sub.numpy()
else:
    context_sub_np = context_sub


predicted_outcome_sub = (estimated_shift - context_sub_np) * estimated_scale * inter_est_final

# Visualize true vs predicted outcome at subregion level
def visualize_true_vs_predicted_subregion(true_data, predicted_data, num_regions=100, num_subregions=4):
    grid_size_sub_calc = int(np.sqrt(num_regions) * np.sqrt(num_subregions))
    subregions_per_dim = int(np.sqrt(num_subregions))

    true_grid = np.zeros((grid_size_sub_calc, grid_size_sub_calc))
    predicted_grid = np.zeros((grid_size_sub_calc, grid_size_sub_calc))
    difference_grid = np.zeros((grid_size_sub_calc, grid_size_sub_calc))

    for i in range(int(np.sqrt(num_regions))):
        for j in range(int(np.sqrt(num_regions))):
            region_idx = i * int(np.sqrt(num_regions)) + j
            true_sub_data = true_data[region_idx, :].reshape(subregions_per_dim, subregions_per_dim)
            predicted_sub_data = predicted_data[region_idx, :].reshape(subregions_per_dim, subregions_per_dim)
            true_grid[i*subregions_per_dim:(i+1)*subregions_per_dim, j*subregions_per_dim:(j+1)*subregions_per_dim] = true_sub_data
            predicted_grid[i*subregions_per_dim:(i+1)*subregions_per_dim, j*subregions_per_dim:(j+1)*subregions_per_dim] = predicted_sub_data
            difference_grid[i*subregions_per_dim:(i+1)*subregions_per_dim, j*subregions_per_dim:(j+1)*subregions_per_dim] = true_sub_data - predicted_sub_data

    # Color maps from seaborn deep palette (assuming these are available in the notebook's global scope)
    colors = sns.color_palette("deep")
    blue_cmap = sns.light_palette(colors[0], as_cmap=True)  # Blue for Outcome
    coolwarm_cmap = 'coolwarm' # Using coolwarm for difference as it was used before


    fig, axs = plt.subplots(1, 3, figsize=(15, 5)) # Adjusted figure size


    # True Outcome (Blue)
    sns.heatmap(true_grid, ax=axs[0], cmap=blue_cmap, square=True, cbar=False, linewidths=1) # Added linewidths
    axs[0].set_title('True Outcome Subregion')
    axs[0].set_xticks([])
    axs[0].set_yticks([])

    # Predicted Outcome (Blue)
    sns.heatmap(predicted_grid, ax=axs[1], cmap=blue_cmap, square=True, cbar=False, linewidths=1) # Added linewidths
    axs[1].set_title('Predicted Outcome Subregion')
    axs[1].set_xticks([])
    axs[1].set_yticks([])

    # Difference (Coolwarm)
    sns.heatmap(difference_grid, ax=axs[2], cmap=coolwarm_cmap, square=True, cbar=False, linewidths=1) # Added linewidths
    axs[2].set_title('Difference (True - Predicted)')
    axs[2].set_xticks([])
    axs[2].set_yticks([])

    plt.tight_layout()
    plt.savefig(os.path.join(data_dir, 'true_vs_predicted_outcome_subregion_styled.jpg'), dpi=300)
    plt.savefig(os.path.join(data_dir, 'exp5_standard_true_vs_predicted_outcome_subregion_styled.pdf'))
    plt.show()
    plt.close()

visualize_true_vs_predicted_subregion(true_outcome_sub, predicted_outcome_sub, num_regions, num_subregions)

# Calculate MSE for true vs predicted subregion outcome
mse_subregion = np.mean((np.asarray(true_outcome_sub) - np.asarray(predicted_outcome_sub)) ** 2)
print(f"MSE between true and predicted subregion outcome: {mse_subregion}")

In [None]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import seaborn as sns
import os

# Define parameters
data_dir = 'data_exp4_hc'
num_regions = 100
num_subregions = 4
epochs = 10000
lr = 0.001
torch.manual_seed(42)
grid_size_sub = int(np.sqrt(num_regions) * np.sqrt(num_subregions))
sns.set_context('paper', font_scale=1.5)
reds = sns.color_palette("deep")[3]

# Read all CSVs
interventions_reg = torch.tensor(pd.read_csv(os.path.join(data_dir, 'interventions_region.csv'), header=None).values.flatten(), dtype=torch.float32)
outcome_reg = torch.tensor(pd.read_csv(os.path.join(data_dir, 'outcome_region.csv'), header=None).values.flatten(), dtype=torch.float32)
context_sub = torch.tensor(pd.read_csv(os.path.join(data_dir, 'context_subregion.csv'), header=None).values, dtype=torch.float32)
# context_reg is not directly used in the preprocess_inter in this version, but kept for consistency if needed elsewhere
context_reg = torch.mean(context_sub, dim=1)

# Model
class CausalModelCombined(nn.Module):
    def __init__(self, num_regions, num_subregions):
        super().__init__()
        self.shift = nn.Parameter(torch.tensor(1.0))
        self.scale = nn.Parameter(torch.tensor(1.0))
        # Make tau a learnable parameter
        self.log_tau = nn.Parameter(torch.tensor(np.log(1.0))) # Initialize log_tau to 0 (tau=1)
        self.num_regions = num_regions
        self.num_subregions = num_subregions


    def preprocess_inter(self, context_sub, interventions_reg):
        tau = torch.exp(self.log_tau)
        # Apply softmax over subregion contexts using the learnable tau
        inter_pred_raw = context_sub / tau
        inter_pred_softmax = torch.softmax(inter_pred_raw, dim=1)

        # Find the subregion with the highest predicted probability for each region
        # and set its value to 1, and others to 0.
        max_prob_indices = torch.argmax(inter_pred_softmax, dim=1)
        inter_discrete = torch.zeros_like(inter_pred_softmax)
        inter_discrete[torch.arange(self.num_regions), max_prob_indices] = 1.0

        # Scale the discrete intervention by the total regional intervention amount
        inter_constrained = inter_discrete * interventions_reg.unsqueeze(1)

        return inter_constrained, inter_pred_softmax # Also return softmax probabilities for loss

    def forward(self, context_sub, interventions_reg):
        inter_sub, inter_probs = self.preprocess_inter(context_sub, interventions_reg)
        outcome_sub_pred = (self.shift - context_sub) * self.scale * inter_sub# + noise_sub
        outcome_reg_pred = outcome_sub_pred.mean(dim=1)
        return outcome_reg_pred, inter_sub, inter_probs

model_combined = CausalModelCombined(num_regions, num_subregions)
optimizer_combined = optim.Adam(model_combined.parameters(), lr=lr)
criterion_outcome = nn.MSELoss()
criterion_inter = nn.NLLLoss()

# Train
losses_combined = []
outcome_losses_combined = []
inter_losses_combined = []
# Load true subregion interventions for loss calculation
true_interventions_sub = torch.tensor(pd.read_csv(os.path.join(data_dir, 'interventions_subregion.csv'), header=None).values, dtype=torch.float32)
# Create target indices for NLLLoss
true_intervention_indices = torch.argmax(true_interventions_sub, dim=1)


for epoch in range(epochs):
    optimizer_combined.zero_grad()
    outcome_reg_pred, inter_sub_pred, inter_probs = model_combined.forward(context_sub, interventions_reg)

    loss_outcome = criterion_outcome(outcome_reg_pred, outcome_reg)
    # Use log_softmax for NLLLoss
    loss_inter = criterion_inter(torch.log(inter_probs + 1e-9), true_intervention_indices)
    loss = loss_outcome + loss_inter # Combine losses

    loss.backward()
    optimizer_combined.step()
    losses_combined.append(loss.item())
    outcome_losses_combined.append(loss_outcome.item())
    inter_losses_combined.append(loss_inter.item())


# Save loss curves
plt.figure(figsize=(10, 5))
plt.plot(losses_combined, label='Total Loss')
plt.plot(outcome_losses_combined, label='Outcome Loss')
plt.plot(inter_losses_combined, label='Intervention Prediction Loss (NLL)')
plt.title('Combined Model Loss Curves')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.savefig(os.path.join(data_dir, 'loss_curves_combined_new.jpg'))
plt.savefig(os.path.join(data_dir, 'loss_curves_combined_new.pdf'))
plt.show()
plt.close()


# Print final estimates
print(f"Final shift_param (Combined): {model_combined.shift.item()}")
print(f"Final scale_param (Combined): {model_combined.scale.item()}")
print(f"Final estimated tau (Combined): {torch.exp(model_combined.log_tau).item()}")

# Save params to txt
with open(os.path.join(data_dir, 'parameter_estimate_combined_new.txt'), 'w') as f:
    f.write(f"shift_param: {model_combined.shift.item()}\n")
    f.write(f"scale_param: {model_combined.scale.item()}\n")
    f.write(f"estimated_tau: {torch.exp(model_combined.log_tau).item()}\n")

# Save estimated interventions sub (processed)
inter_est_final_combined, _ = model_combined.preprocess_inter(context_sub, interventions_reg)
inter_est_final_combined = inter_est_final_combined.detach().numpy()
pd.DataFrame(inter_est_final_combined).to_csv(os.path.join(data_dir, 'interventions_subregion_estimated_combined_new.csv'), index=False, header=False)

# Visualize estimated subregion
def visualize_estimated_subregion(data, title, file_suffix):
    grid_size_sub_calc = int(np.sqrt(num_regions) * np.sqrt(num_subregions))
    subregions_per_dim = int(np.sqrt(num_subregions))
    grid = np.zeros((grid_size_sub_calc, grid_size_sub_calc))
    for i in range(int(np.sqrt(num_regions))):
        for j in range(int(np.sqrt(num_regions))):
            region_idx = i * int(np.sqrt(num_regions)) + j
            sub_data = data[region_idx, :].reshape(subregions_per_dim, subregions_per_dim)
            grid[i*subregions_per_dim:(i+1)*subregions_per_dim, j*subregions_per_dim:(j+1)*subregions_per_dim] = sub_data
    fig, ax = plt.subplots(figsize=(8, 8))
    heatmap_cmap = sns.light_palette(reds, as_cmap=True)
    sns.heatmap(grid, ax=ax, cmap=heatmap_cmap, square=True, cbar=True, cbar_kws={'label': 'Value', 'location': 'right', 'pad': 0.1})
    ax.set_title(title, fontsize=18)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.grid(False)
    plt.legend([], [], frameon=False)
    plt.tight_layout()
    plt.savefig(os.path.join(data_dir, f'interventions_subregion_estimated_{file_suffix}.jpg'), dpi=300)
    plt.savefig(os.path.join(data_dir, f'exp_5_interventions_subregion_estimated_{file_suffix}.pdf'))
    plt.show()
    plt.close()

visualize_estimated_subregion(inter_est_final_combined, 'Estimated Interventions Subregion (Combined Model)', 'combined_new')

In [None]:
# Calculate true causal effect at subregion level
true_outcome_sub = (shift_param - context_sub) * scale_param * interventions_sub

# Load estimated interventions
inter_est_final = pd.read_csv(os.path.join(data_dir, 'interventions_subregion_estimated_combined_new.csv'), header=None).values

# Calculate predicted outcome at subregion level using estimated interventions and learned parameters
# Need to load learned parameters from the text file
with open(os.path.join(data_dir, 'parameter_estimate_combined_new.txt'), 'r') as f:
    lines = f.readlines()
    estimated_shift = float(lines[0].split(': ')[1])
    estimated_scale = float(lines[1].split(': ')[1])

# Ensure context_sub is a numpy array for the calculation
if isinstance(context_sub, torch.Tensor):
    context_sub_np = context_sub.numpy()
else:
    context_sub_np = context_sub


predicted_outcome_sub = (estimated_shift - context_sub_np) * estimated_scale * inter_est_final

# Visualize true vs predicted outcome at subregion level
def visualize_true_vs_predicted_subregion(true_data, predicted_data, num_regions=100, num_subregions=4):
    grid_size_sub_calc = int(np.sqrt(num_regions) * np.sqrt(num_subregions))
    subregions_per_dim = int(np.sqrt(num_subregions))

    true_grid = np.zeros((grid_size_sub_calc, grid_size_sub_calc))
    predicted_grid = np.zeros((grid_size_sub_calc, grid_size_sub_calc))
    difference_grid = np.zeros((grid_size_sub_calc, grid_size_sub_calc))

    for i in range(int(np.sqrt(num_regions))):
        for j in range(int(np.sqrt(num_regions))):
            region_idx = i * int(np.sqrt(num_regions)) + j
            true_sub_data = true_data[region_idx, :].reshape(subregions_per_dim, subregions_per_dim)
            predicted_sub_data = predicted_data[region_idx, :].reshape(subregions_per_dim, subregions_per_dim)
            true_grid[i*subregions_per_dim:(i+1)*subregions_per_dim, j*subregions_per_dim:(j+1)*subregions_per_dim] = true_sub_data
            predicted_grid[i*subregions_per_dim:(i+1)*subregions_per_dim, j*subregions_per_dim:(j+1)*subregions_per_dim] = predicted_sub_data
            difference_grid[i*subregions_per_dim:(i+1)*subregions_per_dim, j*subregions_per_dim:(j+1)*subregions_per_dim] = true_sub_data - predicted_sub_data

    # Color maps from seaborn deep palette (assuming these are available in the notebook's global scope)
    colors = sns.color_palette("deep")
    blue_cmap = sns.light_palette(colors[0], as_cmap=True)  # Blue for Outcome
    coolwarm_cmap = 'coolwarm' # Using coolwarm for difference as it was used before


    fig, axs = plt.subplots(1, 3, figsize=(15, 5)) # Adjusted figure size


    # True Outcome (Blue)
    sns.heatmap(true_grid, ax=axs[0], cmap=blue_cmap, square=True, cbar=False, linewidths=1) # Added linewidths
    axs[0].set_title('True Outcome Subregion')
    axs[0].set_xticks([])
    axs[0].set_yticks([])

    # Predicted Outcome (Blue)
    sns.heatmap(predicted_grid, ax=axs[1], cmap=blue_cmap, square=True, cbar=False, linewidths=1) # Added linewidths
    axs[1].set_title('Predicted Outcome Subregion')
    axs[1].set_xticks([])
    axs[1].set_yticks([])

    # Difference (Coolwarm)
    sns.heatmap(difference_grid, ax=axs[2], cmap=coolwarm_cmap, square=True, cbar=False, linewidths=1) # Added linewidths
    axs[2].set_title('Difference (True - Predicted)')
    axs[2].set_xticks([])
    axs[2].set_yticks([])

    plt.tight_layout()
    plt.savefig(os.path.join(data_dir, 'true_vs_predicted_outcome_subregion_styled.jpg'), dpi=300)
    plt.savefig(os.path.join(data_dir, 'exp5_tau_true_vs_predicted_outcome_subregion_styled.pdf'))
    plt.show()
    plt.close()

visualize_true_vs_predicted_subregion(true_outcome_sub, predicted_outcome_sub, num_regions, num_subregions)

# Calculate MSE for true vs predicted subregion outcome
mse_subregion = np.mean((np.asarray(true_outcome_sub) - np.asarray(predicted_outcome_sub)) ** 2)
print(f"MSE between true and predicted subregion outcome: {mse_subregion}")

#Exp. 5
Unknown Aggregation Functions: Learn both the causal effect of driving bans on air quality and the unknown aggregation rule that maps subregional outcomes to regional reports. The goal is to jointly recover the aggregation mechanism and the fine-scale causal effects.

## Exp. 5 mean aggregation

In [None]:
import numpy as np

def soft_aggregate_np(values, temperature=1.0):
    if temperature <= 0:
        raise ValueError("Temperature must be positive")

    logits = values / temperature
    logits = logits - np.max(logits)  # For numerical stability
    weights = np.exp(logits)
    weights /= weights.sum()

    return np.sum(values * weights)


In [None]:
import os
import numpy as np
import pandas as pd

# ---------------- Hyper‑parameters ----------------
OUTFOLDER        = "data_exp5_mean"
GRID_SIZE        = 10          # number of regions per side
SUBREGION_SIZE   = 10          # sub‑regions per region side  (=> 100×100 grid)
SPATIAL_VARIANCE = 3           # wiggle room for rich/poor counts
OUTCOME_NOISE    = 0.02        # std‑dev of additive noise
RICH_VOTE        = 0.80
POOR_VOTE        = 0.40
INTER_VOTE       = 0.50
BASELINE_VOTE    = 0.50
N_SIM_EXTRA      = 10
RANDOM_SEED      = 42
agg_temp         = 6
rng = np.random.default_rng(RANDOM_SEED)

os.makedirs(OUTFOLDER, exist_ok=True)

# ---------------- 1) 10×10 intervention grid ----------------
interventions = rng.integers(0, 2, size=(GRID_SIZE, GRID_SIZE))
interventions[0, 0] = 1  # force interventions in the two top‑left cells
interventions[0, 1] = 1
pd.DataFrame(interventions).to_csv(f"{OUTFOLDER}/interventions.csv",
                                   header=False, index=False)

# ---------------- 2) 100×100 wealth grid ----------------
wealth_hi = np.zeros((GRID_SIZE*SUBREGION_SIZE,
                      GRID_SIZE*SUBREGION_SIZE), dtype=int)

for i in range(GRID_SIZE):
    for j in range(GRID_SIZE):
        # pick counts for this 10×10 block
        n_inter = rng.integers(20, 101)        # 20…100 intermediate
        remaining = 100 - n_inter
        wiggle = min(SPATIAL_VARIANCE, remaining//2)
        n_rich = remaining//2 + rng.integers(-wiggle, wiggle+1)
        n_rich = np.clip(n_rich, 0, remaining)
        n_poor = remaining - n_rich

        # special overrides for the two upper‑left blocks
        if (i, j) == (0, 0):
            block_vals = np.full(100, 2)                      # all intermediate
        elif (i, j) == (0, 1):
            half = 50
            block_vals = np.concatenate([np.ones(half,  int),   # 50 poor
                                          np.full(half, 3, int)])  # 50 rich
            rng.shuffle(block_vals)
        else:
            block_vals = np.concatenate([np.ones(n_poor, int),
                                          np.full(n_inter, 2, int),
                                          np.full(n_rich, 3, int)])
            rng.shuffle(block_vals)

        # place block into wealth_hi
        r0, c0 = i*SUBREGION_SIZE, j*SUBREGION_SIZE
        wealth_hi[r0:r0+SUBREGION_SIZE, c0:c0+SUBREGION_SIZE] = \
            block_vals.reshape(SUBREGION_SIZE, SUBREGION_SIZE)

pd.DataFrame(wealth_hi).to_csv(f"{OUTFOLDER}/wealth_high_res.csv",
                               header=False, index=False)

# ---------------- 3) 10×10 coarse (mean) wealth grid ---------
wealth_lo = wealth_hi.reshape(GRID_SIZE, SUBREGION_SIZE,
                              GRID_SIZE, SUBREGION_SIZE).mean(axis=(1, 3))
pd.DataFrame(np.round(wealth_lo, 2)).to_csv(f"{OUTFOLDER}/wealth_low_res.csv",
                                            header=False, index=False)

# ---------------- 4) voting outcome ----------------
subregion_noise = np.zeros_like(wealth_hi, dtype=float)  # NEW: store 100x100 noise

def vote(sub_wealth, treated, i, j):
    """Return a single sub‑region vote in ‑‑ fraction 0…1."""
    noise = rng.normal(0, OUTCOME_NOISE)
    subregion_noise[i, j] = noise  # NEW: store sampled noise
    if not treated:
        return np.clip(BASELINE_VOTE + noise, 0, 1)
    # treated case
    if sub_wealth == 3:
        base = RICH_VOTE
    elif sub_wealth == 1:
        base = POOR_VOTE
    else:
        base = INTER_VOTE
    return np.clip(base + noise, 0, 1)

outcome = np.zeros_like(interventions, dtype=float)

for i in range(GRID_SIZE):
    for j in range(GRID_SIZE):
        r0, c0 = i*SUBREGION_SIZE, j*SUBREGION_SIZE
        block_wealth = wealth_hi[r0:r0+SUBREGION_SIZE,
                                 c0:c0+SUBREGION_SIZE]
        block_treated = bool(interventions[i, j])
        votes = [[vote(block_wealth[r, c], block_treated, r0 + r, c0 + c)
                  for c in range(SUBREGION_SIZE)]
                 for r in range(SUBREGION_SIZE)]
        # outcome[i, j] = np.mean(votes) * 100.0        # percentage 0‑100
        outcome[i, j] = soft_aggregate_np(np.array(votes), temperature=agg_temp) * 100.0

# Save subregion noise
pd.DataFrame(np.round(subregion_noise, 4)).to_csv(f"{OUTFOLDER}/subregion_noise_gt.csv",
                                                  header=False, index=False)

# Save aggregated region-level noise
region_noise = subregion_noise.reshape(GRID_SIZE, SUBREGION_SIZE,
                                       GRID_SIZE, SUBREGION_SIZE).sum(axis=(1, 3))
pd.DataFrame(np.round(region_noise, 4)).to_csv(f"{OUTFOLDER}/region_noise_gt.csv",
                                               header=False, index=False)

# Save outcome as before
pd.DataFrame(np.round(outcome, 2)).to_csv(f"{OUTFOLDER}/outcome.csv",
                                          header=False, index=False)


# ---------------- 5) counterfactual outcome (GT) ----------------
# Invert top half (rows 0-4) of intervention matrix
interventions_cf = interventions.copy()
interventions_cf[:GRID_SIZE//2, :] = 1 - interventions_cf[:GRID_SIZE//2, :]

pd.DataFrame(interventions_cf).to_csv(f"{OUTFOLDER}/interventions_cf_gt.csv",
                                      header=False, index=False)

# Compute counterfactual using same noise
outcome_cf = np.zeros_like(interventions_cf, dtype=float)

for i in range(GRID_SIZE):
    for j in range(GRID_SIZE):
        r0, c0 = i*SUBREGION_SIZE, j*SUBREGION_SIZE
        block_wealth = wealth_hi[r0:r0+SUBREGION_SIZE, c0:c0+SUBREGION_SIZE]
        block_treated = bool(interventions_cf[i, j])
        votes = [[
            # reuse stored noise instead of generating new one
            np.clip(
                (RICH_VOTE if w == 3 else POOR_VOTE if w == 1 else INTER_VOTE) + subregion_noise[r0 + r, c0 + c]
                if block_treated else BASELINE_VOTE + subregion_noise[r0 + r, c0 + c],
                0, 1
            )
            for c, w in enumerate(block_wealth[r])]
            for r in range(SUBREGION_SIZE)]
        # outcome_cf[i, j] = np.mean(votes) * 100.0
        outcome_cf[i, j] = soft_aggregate_np(np.array(votes), temperature=agg_temp) * 100.0


pd.DataFrame(np.round(outcome_cf, 2)).to_csv(f"{OUTFOLDER}/outcome_cf_gt.csv",
                                             header=False, index=False)

print("✓ Synthetic data and counterfactuals written to:", OUTFOLDER)

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Set output folder
OUTFOLDER = "data_exp5_mean"

# Load wealth and noise
wealth_hi = pd.read_csv(f"{OUTFOLDER}/wealth_high_res.csv", header=None).values
subregion_noise = pd.read_csv(f"{OUTFOLDER}/subregion_noise_gt.csv", header=None).values

# Define constants from data generation
BASELINE_VOTE = 0.50
RICH_VOTE = 0.80
POOR_VOTE = 0.40
INTER_VOTE = 0.50

# Compute control outcome (no intervention) for each sub-cell
control_outcome = np.clip(BASELINE_VOTE + subregion_noise, 0, 1) * 100.0

# Compute treated outcome (with intervention) for each sub-cell
treated_outcome = np.zeros_like(subregion_noise)
for i in range(wealth_hi.shape[0]):
    for j in range(wealth_hi.shape[1]):
        w = wealth_hi[i, j]
        base = RICH_VOTE if w == 3 else POOR_VOTE if w == 1 else INTER_VOTE
        treated_outcome[i, j] = np.clip(base + subregion_noise[i, j], 0, 1) * 100.0

# Compute causal effect: treated - control
causal_effect = treated_outcome - control_outcome

# Save to CSVs
pd.DataFrame(np.round(control_outcome, 2)).to_csv(f"{OUTFOLDER}/control_sub.csv", header=False, index=False)
pd.DataFrame(np.round(treated_outcome, 2)).to_csv(f"{OUTFOLDER}/treated_sub.csv", header=False, index=False)
pd.DataFrame(np.round(causal_effect, 2)).to_csv(f"{OUTFOLDER}/effect_sub.csv", header=False, index=False)

# Normalize for plotting (0-100 to 0-1 for outcomes, effect min-max to 0-1? but for consistency vmin vmax)
control_norm = control_outcome #/ 100.0
treated_norm = treated_outcome #/ 100.0
effect_norm = causal_effect #/ 100.0  # since difference is -0.1 to 0.3 roughly

# Color maps from seaborn deep palette
colors = sns.color_palette("deep")
blue_cmap = sns.light_palette(colors[0], as_cmap=True)  # Blue for outcomes
pink_cmap = sns.light_palette(colors[4], as_cmap=True)  # Pink for effect

# Parameter to control space between subplots
space_between = 0.2

# Create a single figure with three subplots side by side
fig, axs = plt.subplots(1, 3, figsize=(15, 5))

# Left: Control outcome in blue
sns.heatmap(control_norm, ax=axs[0], cmap=blue_cmap, cbar=False, square=True, linewidths=0.0)#, vmin=0.0, vmax=1.0)
axs[0].set_xticks([])
axs[0].set_yticks([])


# Center: Treated outcome in blue
sns.heatmap(treated_norm, ax=axs[1], cmap=blue_cmap, cbar=False, square=True, linewidths=0.0)#, vmin=0.0, vmax=1.0)
axs[1].set_xticks([])
axs[1].set_yticks([])


# Right: Causal effect in pink
sns.heatmap(effect_norm, ax=axs[2], cmap=pink_cmap, cbar=False, square=True, linewidths=0.0)#, vmin=-0.2, vmax=0.4)  # Adjusted vmin/vmax for effect range
axs[2].set_xticks([])
axs[2].set_yticks([])


# Adjust the space between subplots
fig.subplots_adjust(wspace=space_between)

# Save and show
fig.savefig(f'{OUTFOLDER}/sub_heatmaps.jpg', dpi=300, bbox_inches='tight', pad_inches=0)
fig.savefig(f'{OUTFOLDER}/sub_heatmaps.pdf', bbox_inches='tight', pad_inches=0)
plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

# Set output folder
OUTFOLDER = "data_exp5_mean"

# Load the data
interv_data = pd.read_csv(f"{OUTFOLDER}/interventions.csv", header=None).values.astype(int)
outcome_data = pd.read_csv(f"{OUTFOLDER}/outcome.csv", header=None).values / 100.0
context_data = pd.read_csv(f"{OUTFOLDER}/wealth_high_res.csv", header=None).values
context_data_norm = (context_data - 1) / 2.0  # Normalize assuming values 1,2,3 to 0-1

# Color maps from seaborn deep palette
colors = sns.color_palette("deep")
red_cmap = sns.light_palette(colors[3], as_cmap=True)  # Red
blue_cmap = sns.light_palette(colors[0], as_cmap=True)  # Blue
green_cmap = sns.light_palette(colors[2], as_cmap=True)  # Green

# Parameter to control space between subplots (as fraction of average axis width)
space_between = 0.05  # Adjust this value as needed (e.g., 0.0 for no space, 0.5 for more space)

# Create a single figure with three subplots side by side
fig, axs = plt.subplots(1, 3, figsize=(15, 5))

# Left: Intervention in red
sns.heatmap(interv_data, ax=axs[0], cmap=red_cmap, cbar=False, square=True, linewidths=2, vmin=0.0, vmax=1.0)
axs[0].set_xticks([])
axs[0].set_yticks([])

# Center: Outcome in blue
sns.heatmap(outcome_data, ax=axs[1], cmap=blue_cmap, cbar=False, square=True, linewidths=2, vmin=0.4, vmax=0.6)
axs[1].set_xticks([])
axs[1].set_yticks([])

# Right: Context in green
sns.heatmap(context_data_norm, ax=axs[2], cmap=green_cmap, cbar=False, square=True, linewidths=0, vmin=0.0, vmax=1.0)
axs[2].set_xticks([])
axs[2].set_yticks([])


# Adjust the space between subplots
fig.subplots_adjust(wspace=space_between)

# Save and show
fig.savefig(f'{OUTFOLDER}/combined_heatmaps.jpg', dpi=300, bbox_inches='tight', pad_inches=0)
fig.savefig(f'{OUTFOLDER}/combined_heatmaps.pdf', bbox_inches='tight', pad_inches=0)
plt.show()

In [None]:
import os
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

sns.set_theme(style="white")

# -------------- parameters -------------
OUTFOLDER = "data_exp5_mean"
N_SIM_PER_ARM = 100 # 10 untreated + 10 treated (=20 total)
RANDOM_SEED = 100
rng = np.random.default_rng(RANDOM_SEED)

# ---------- load wealth & helper vote() ----------
wealth_hi = pd.read_csv(f"{OUTFOLDER}/wealth_high_res.csv",
                        header=None).values
GRID_SIZE = 10
SUBREGION_SIZE = 10

def vote(sub_wealth, treated):
    noise = rng.normal(0, OUTCOME_NOISE)
    if not treated:
        return np.clip(BASELINE_VOTE + noise, 0, 1)
    base = RICH_VOTE if sub_wealth==3 else POOR_VOTE if sub_wealth==1 else INTER_VOTE
    return np.clip(base + noise, 0, 1)

def simulate_outcome(intervention_grid):
    """Return a 10×10 matrix of region votes (0‑100%)."""
    out = np.zeros((GRID_SIZE, GRID_SIZE))
    for i in range(GRID_SIZE):
        for j in range(GRID_SIZE):
            r0, c0 = i*SUBREGION_SIZE, j*SUBREGION_SIZE
            block_wealth = wealth_hi[r0:r0+SUBREGION_SIZE,
                                     c0:c0+SUBREGION_SIZE]
            treated = bool(intervention_grid[i, j])
            votes = [[vote(block_wealth[r, c], treated)
                     for c in range(SUBREGION_SIZE)]
                     for r in range(SUBREGION_SIZE)]
            out[i, j] = np.mean(votes)*100
    return out

# ---------- run simulations ----------
base_interv = pd.read_csv(f"{OUTFOLDER}/interventions.csv",
                          header=None).values
global_diffs = []
all_treated_runs = []
all_control_runs = []

for idx in range(N_SIM_PER_ARM):
    # ---- control (all zeros) ----
    interv_off = np.zeros_like(base_interv)
    out_off = simulate_outcome(interv_off)
    pd.DataFrame(np.round(out_off, 2)).to_csv(
        f"{OUTFOLDER}/outcome_treatment_off_{idx:02d}.csv",
        header=False, index=False)
    all_control_runs.append(out_off)

    # ---- treated (all ones) ----
    interv_on = np.ones_like(base_interv)
    out_on = simulate_outcome(interv_on)
    pd.DataFrame(np.round(out_on, 2)).to_csv(
        f"{OUTFOLDER}/outcome_treatment_on_{idx:02d}.csv",
        header=False, index=False)
    all_treated_runs.append(out_on)

    # global ATE for this pair
    global_diffs.append(out_on.mean() - out_off.mean())

all_control_runs = np.stack(all_control_runs) # shape (10,10,10)
all_treated_runs = np.stack(all_treated_runs)

# ---------- compute regional ATE ----------
ate_grid = all_treated_runs.mean(axis=0) - all_control_runs.mean(axis=0)
pd.DataFrame(np.round(ate_grid, 2)).to_csv(f"{OUTFOLDER}/ate.csv",
                                           header=False, index=False)

# ---------- PLOT 1: heat‑map ----------
plt.figure(figsize=(6,6))
ax = sns.heatmap(ate_grid, annot=np.round(ate_grid,2), fmt=".2f",
                 cmap="RdBu_r", center=0, cbar_kws={'label': 'ATE (%)'})
ax.set_title("Average Treatment Effect per region")
ax.set_xticks([]); ax.set_yticks([])
plt.tight_layout()
plt.savefig(f"{OUTFOLDER}/ate_heatmap.jpg", dpi=300)
plt.savefig(f"{OUTFOLDER}/ate_heatmap.pdf")
plt.show()

# ---------- PLOT 1b: minimal heat‑map ----------
colors = sns.color_palette("deep")
cmap = sns.light_palette(colors[4], as_cmap=True)
fig, ax = plt.subplots(figsize=(5, 5))
sns.heatmap(ate_grid, ax=ax, cmap=cmap, cbar=False, square=True, linewidths=2, vmin=0.0, vmax=9.0)
ax.axis('off')
plt.tight_layout(pad=0)
fig.savefig(f"{OUTFOLDER}/ate_heatmap_minimal.jpg", dpi=300, bbox_inches='tight', pad_inches=0)
fig.savefig(f"{OUTFOLDER}/ate_heatmap_minimal.pdf", bbox_inches='tight', pad_inches=0)
plt.show()

# ---------- PLOT 2: violin of global ATE ----------
plt.figure(figsize=(4,6))
sns.violinplot(y=global_diffs, color=sns.color_palette("deep")[3],
               inner=None, cut=0)
sns.stripplot(y=global_diffs, color=sns.color_palette("deep")[0],
              size=8, alpha=0.7)
plt.axhline(np.mean(global_diffs), ls="--", lw=1,
            label=f"Mean ATE = {np.mean(global_diffs):.2f}%")
plt.ylabel("Global ATE (%)")
plt.title("Overall Average Treatment Effect")
plt.legend(frameon=False)
plt.tight_layout()
plt.savefig(f"{OUTFOLDER}/ate_global_violin.jpg", dpi=300)
plt.savefig(f"{OUTFOLDER}/ate_global_violin.pdf")
plt.show()

print("✓ ATE analysis complete – results stored in", OUTFOLDER)

In [None]:
def soft_aggregate(values, temperature=1.0):
    if isinstance(temperature, torch.Tensor):
        if (temperature <= 0).any():
            raise ValueError("Temperature must be positive")
    elif temperature <= 0:
        raise ValueError("Temperature must be positive")

    logits = values / temperature
    logits = logits - torch.amax(logits, dim=-1, keepdim=True)  # numerical stability
    weights = torch.softmax(logits, dim=-1)  # (..., K)
    return (values * weights).sum(dim=-1)     # (...,)

In [None]:
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
import random

# ------------------ Configuration ------------------
OUTFOLDER = "data_exp5_mean"
N_EPOCHS = 1000
LR = 0.01
SEED = 42
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
LOGLOG_FLAG = True  # Whether to plot losses in log-log scale

# ------------------ Set seeds ------------------
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

# ------------------ Load data ------------------
interv_np  = pd.read_csv(f"{OUTFOLDER}/interventions.csv",   header=None).values  # 0/1
wealth_hi  = pd.read_csv(f"{OUTFOLDER}/wealth_high_res.csv", header=None).values  # 100×100 wealth classes
outcome_np = pd.read_csv(f"{OUTFOLDER}/outcome.csv",         header=None).values  # region outcomes %

# Convert to tensor
outcome_t = torch.tensor(outcome_np, dtype=torch.float32, device=DEVICE)

# ------------------ Known Constants ------------------
BASELINE_VOTE = 0.50
RICH_VOTE     = 0.80
INTER_VOTE    = 0.50
POOR_VOTE     = 0.40

# ------------------ Initialize Parameters ------------------
params = torch.nn.Parameter(torch.full((6,), 0.5, dtype=torch.float32, device=DEVICE))  # 6 μs
log_tau = torch.nn.Parameter(torch.tensor(0.0, dtype=torch.float32, device=DEVICE))     # log temperature
opt = torch.optim.Adam([params, log_tau], lr=LR)

loss_hist = []
ce_loss_hist = []
temp_hist = []

# ------------------ Training Loop ------------------
for epoch in range(N_EPOCHS):
    opt.zero_grad()

    # Compute predicted outcomes at subregion level
    wealth_hi_t = torch.tensor(wealth_hi, dtype=torch.float32, device=DEVICE)
    wealth_hi_reshaped = wealth_hi_t.reshape(10, 10, 10, 10).permute(0, 2, 1, 3).reshape(10, 10, 100)
    treated_mask_t = torch.tensor(interv_np, dtype=torch.float32, device=DEVICE) == 1
    treated_mask_expanded = treated_mask_t.unsqueeze(-1).expand(-1, -1, 100)

    params_expanded = params.unsqueeze(0).unsqueeze(0).expand(10, 10, -1)

    predicted_sub_outcomes = torch.where(
        treated_mask_expanded,
        torch.where(
            wealth_hi_reshaped == 3, params_expanded[:, :, 5].unsqueeze(-1),
            torch.where(wealth_hi_reshaped == 1, params_expanded[:, :, 3].unsqueeze(-1), params_expanded[:, :, 4].unsqueeze(-1))
        ),
        torch.where(
            wealth_hi_reshaped == 3, params_expanded[:, :, 2].unsqueeze(-1),
            torch.where(wealth_hi_reshaped == 1, params_expanded[:, :, 0].unsqueeze(-1), params_expanded[:, :, 1].unsqueeze(-1))
        )
    )

    # Soft aggregation
    temperature = torch.exp(log_tau)
    mean_pred = soft_aggregate(predicted_sub_outcomes, temperature=temperature) * 100  # %

    # Compute region-level MSE loss
    loss = torch.mean((mean_pred - outcome_t) ** 2)
    loss.backward()
    opt.step()

    # Log losses and temperature
    loss_hist.append(loss.item())
    temp_hist.append(temperature.item())

    # Compute causal effect MSE loss (against true values)
    with torch.no_grad():
        est_causal_effect = (params[3:] - params[:3])  # estimated μ1 - μ0
        true_causal_effect = torch.tensor([
            (POOR_VOTE - BASELINE_VOTE),
            (INTER_VOTE - BASELINE_VOTE),
            (RICH_VOTE  - BASELINE_VOTE)
        ], device=DEVICE)

        ce_mse = torch.mean((est_causal_effect - true_causal_effect) ** 2).item()
        ce_loss_hist.append(ce_mse)

# ------------------ Save Loss Logs ------------------
loss_df = pd.DataFrame({
    'region_loss': loss_hist,
    'causal_effect_loss': ce_loss_hist,
    'temperature': temp_hist
})
loss_df.to_csv(f"{OUTFOLDER}/estimation_loss.csv", index_label="epoch")

# ------------------ Plot Loss Curves ------------------
fig, ax1 = plt.subplots(figsize=(10, 6))
colors = sns.color_palette("deep")
epochs = np.arange(1, N_EPOCHS + 1)

# Region Loss
ax1.set_xlabel('Epoch', fontsize=16)
ax1.set_ylabel('Training Loss', fontsize=16, color=colors[0])
if LOGLOG_FLAG:
    ax1.loglog(epochs, loss_hist, color=colors[0], lw=2.5)
else:
    ax1.plot(epochs, loss_hist, color=colors[0], lw=2)
ax1.tick_params(axis='y', labelcolor=colors[0])

# Causal Effect Loss
ax2 = ax1.twinx()
ax2.set_ylabel('Causal Effect MSE', fontsize=16, color=colors[3])
if LOGLOG_FLAG:
    ax2.loglog(epochs, ce_loss_hist, color=colors[3], ls='--', lw=2.5)
else:
    ax2.plot(epochs, ce_loss_hist, color=colors[3], ls='--', lw=2)
ax2.tick_params(axis='y', labelcolor=colors[3])

plt.tight_layout()
fig.savefig(f"{OUTFOLDER}/estimation_loss.jpg", dpi=300)
fig.savefig(f"{OUTFOLDER}/estimation_loss.pdf")
plt.show()

# ------------------ Save Parameters ------------------
param_names = ["mu_poor_0", "mu_inter_0", "mu_rich_0",
               "mu_poor_1", "mu_inter_1", "mu_rich_1", "temperature"]

values = np.append(params.detach().cpu().numpy(), temperature.item())
results_df = pd.DataFrame({'parameter': param_names, 'value': values})
results_df.to_csv(f"{OUTFOLDER}/results.csv", index=False)


In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

data_dir = 'data_exp5_mean'
loglog_flag = True  # Set to True for log-log plot, False for linear plot

# Load loss history from CSV
loss_data = pd.read_csv(os.path.join(data_dir, 'estimation_loss.csv'))
loss_hist = loss_data['region_loss'].values
ce_loss_hist = loss_data['causal_effect_loss'].values
epochs = np.arange(1, len(loss_hist) + 1)

# Plotting
fig, ax1 = plt.subplots(figsize=(8, 6))
colors = sns.color_palette("deep")

ax1.set_xlabel('Epoch', fontsize=20)
ax1.set_ylabel('Training Loss', fontsize=20, color=colors[0])
if loglog_flag:
    ax1.loglog(epochs, loss_hist, color=colors[0], alpha=0.9, lw=3)
else:
    ax1.plot(epochs, loss_hist, color=colors[0], alpha=0.7, lw=2)
ax1.tick_params(axis='both', which='both', length=0, labelsize=18, colors=colors[0])
ax1.set_title('MSE Curves', fontsize=22)
ax1.spines['top'].set_visible(False)

ax2 = ax1.twinx()
ax2.set_ylabel('MSE Causal Effect', fontsize=20, color=colors[3])
if loglog_flag:
    ax2.loglog(epochs, ce_loss_hist, color=colors[3], alpha=0.9, ls='--', lw=5)
else:
    ax2.plot(epochs, ce_loss_hist, color=colors[3], alpha=0.7, ls='--', lw=2)
ax2.tick_params(axis='y', which='both', length=0, labelsize=18, colors=colors[3])
ax2.spines['top'].set_visible(False)

plt.tight_layout()
fig.savefig(os.path.join(data_dir, 'estimation_loss.jpg'), dpi=300)
fig.savefig(os.path.join(data_dir, 'estimation_loss.pdf'))
plt.show()

In [None]:
import matplotlib as mpl
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

sns.set_theme(style="whitegrid")

# Set output folder
OUTFOLDER = "data_exp5_mean"

# Load data generated in previous cells
wealth_hi = pd.read_csv(f"{OUTFOLDER}/wealth_high_res.csv", header=None).values
subregion_noise = pd.read_csv(f"{OUTFOLDER}/subregion_noise_gt.csv", header=None).values # True noise

params = results_df["value"].values[:6] # Exclude temperature
learned_temp = results_df["value"].values[6] # Learned temperature


# Define constants from data generation
BASELINE_VOTE = 0.50
RICH_VOTE = 0.80
POOR_VOTE = 0.40
INTER_VOTE = 0.50

# ---------------- Compute True Causal Effect (Subregion Level) ----------------
# Compute control outcome (no intervention) for each sub-cell using true noise
control_outcome_true = np.clip(BASELINE_VOTE + subregion_noise, 0, 1) * 100.0

# Compute treated outcome (with intervention) for each sub-cell using true noise
treated_outcome_true = np.zeros_like(subregion_noise)
for i in range(wealth_hi.shape[0]):
    for j in range(wealth_hi.shape[1]):
        w = wealth_hi[i, j]
        base = RICH_VOTE if w == 3 else POOR_VOTE if w == 1 else INTER_VOTE
        treated_outcome_true[i, j] = np.clip(base + subregion_noise[i, j], 0, 1) * 100.0

# Compute true causal effect: treated - control
causal_effect_true = treated_outcome_true - control_outcome_true

# ---------------- Compute Estimated Causal Effect (Subregion Level) ----------------
estimated_control_outcome_base = np.zeros_like(wealth_hi, dtype=float)
estimated_treated_outcome_base = np.zeros_like(wealth_hi, dtype=float)

for i in range(wealth_hi.shape[0]):
    for j in range(wealth_hi.shape[1]):
        w = wealth_hi[i, j]
        # Untreated (control)
        if w == 1: # poor
            estimated_control_outcome_base[i, j] = params[0]
        elif w == 2: # intermediate
            estimated_control_outcome_base[i, j] = params[1]
        elif w == 3: # rich
            estimated_control_outcome_base[i, j] = params[2]

        # Treated
        if w == 1: # poor
             estimated_treated_outcome_base[i, j] = params[3]
        elif w == 2: # intermediate
             estimated_treated_outcome_base[i, j] = params[4]
        elif w == 3: # rich
             estimated_treated_outcome_base[i, j] = params[5]


estimated_causal_effect = (estimated_treated_outcome_base - estimated_control_outcome_base) * 100.0 # Convert to percentage

# Define shared colormap and range
colors = sns.color_palette("deep")
pink_cmap = sns.light_palette(colors[4], as_cmap=True)
vmin = np.min([causal_effect_true, estimated_causal_effect])
vmax = np.max([causal_effect_true, estimated_causal_effect])

# Create figure and axes with space for the colorbar
fig, axs = plt.subplots(1, 2, figsize=(12, 6), gridspec_kw={'width_ratios': [1, 1], 'wspace': 0.05})

# Plot true causal effect
sns.heatmap(causal_effect_true, ax=axs[0], cmap=pink_cmap, cbar=False, square=True,
            vmin=vmin, vmax=vmax, linewidths=0)
axs[0].set_title("True Causal Effect (Subregion)")
axs[0].set_xticks([])
axs[0].set_yticks([])

# Plot estimated causal effect
sns.heatmap(estimated_causal_effect, ax=axs[1], cmap=pink_cmap, cbar=False, square=True,
            vmin=vmin, vmax=vmax, linewidths=0)
axs[1].set_title("Estimated Causal Effect (Subregion)")
axs[1].set_xticks([])
axs[1].set_yticks([])

# Add a single shared vertical colorbar to the right
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
sm = mpl.cm.ScalarMappable(cmap=pink_cmap, norm=norm)
sm.set_array([])

# Add colorbar to the right of both plots
cbar = fig.colorbar(sm, ax=axs, orientation='vertical', fraction=0.046, pad=0.04)
cbar.set_label("Causal Effect (%)")

# Save and show
fig.savefig(f'{OUTFOLDER}/subregion_causal_effects_heatmaps.jpg', dpi=300, bbox_inches='tight', pad_inches=0)
fig.savefig(f'{OUTFOLDER}/subregion_causal_effects_heatmaps.pdf', bbox_inches='tight', pad_inches=0)
plt.show()

print("✓ Subregion causal effect plots generated.")


##Exp. 5 max aggregation

In [None]:
import os
import numpy as np
import pandas as pd

# ---------------- Hyper‑parameters ----------------
OUTFOLDER        = "data_exp5_max"
GRID_SIZE        = 10          # number of regions per side
SUBREGION_SIZE   = 10          # sub‑regions per region side
SPATIAL_VARIANCE = 3           # wiggle room for rich/poor counts
OUTCOME_NOISE    = 0.02        # std‑dev of additive noise
RICH_VOTE        = 0.80
POOR_VOTE        = 0.40
INTER_VOTE       = 0.50
BASELINE_VOTE    = 0.50
N_SIM_EXTRA      = 10
RANDOM_SEED      = 42
agg_temp         = 0.1
rng = np.random.default_rng(RANDOM_SEED)

os.makedirs(OUTFOLDER, exist_ok=True)

# ---------------- 1) 10×10 intervention grid ----------------
interventions = rng.integers(0, 2, size=(GRID_SIZE, GRID_SIZE))
interventions[0, 0] = 1  # force interventions in the two top‑left cells
interventions[0, 1] = 1
pd.DataFrame(interventions).to_csv(f"{OUTFOLDER}/interventions.csv",
                                   header=False, index=False)

# ---------------- 2) 100×100 wealth grid ----------------
wealth_hi = np.zeros((GRID_SIZE*SUBREGION_SIZE,
                      GRID_SIZE*SUBREGION_SIZE), dtype=int)

for i in range(GRID_SIZE):
    for j in range(GRID_SIZE):
        # pick counts for this 10×10 block
        n_inter = rng.integers(20, 101)        # 20…100 intermediate
        remaining = 100 - n_inter
        wiggle = min(SPATIAL_VARIANCE, remaining//2)
        n_rich = remaining//2 + rng.integers(-wiggle, wiggle+1)
        n_rich = np.clip(n_rich, 0, remaining)
        n_poor = remaining - n_rich

        # special overrides for the two upper‑left blocks
        if (i, j) == (0, 0):
            block_vals = np.full(100, 2)                      # all intermediate
        elif (i, j) == (0, 1):
            half = 50
            block_vals = np.concatenate([np.ones(half,  int),   # 50 poor
                                          np.full(half, 3, int)])  # 50 rich
            rng.shuffle(block_vals)
        else:
            block_vals = np.concatenate([np.ones(n_poor, int),
                                          np.full(n_inter, 2, int),
                                          np.full(n_rich, 3, int)])
            rng.shuffle(block_vals)

        # place block into wealth_hi
        r0, c0 = i*SUBREGION_SIZE, j*SUBREGION_SIZE
        wealth_hi[r0:r0+SUBREGION_SIZE, c0:c0+SUBREGION_SIZE] = \
            block_vals.reshape(SUBREGION_SIZE, SUBREGION_SIZE)

pd.DataFrame(wealth_hi).to_csv(f"{OUTFOLDER}/wealth_high_res.csv",
                               header=False, index=False)

# ---------------- 3) 10×10 coarse (mean) wealth grid ---------
wealth_lo = wealth_hi.reshape(GRID_SIZE, SUBREGION_SIZE,
                              GRID_SIZE, SUBREGION_SIZE).mean(axis=(1, 3))
pd.DataFrame(np.round(wealth_lo, 2)).to_csv(f"{OUTFOLDER}/wealth_low_res.csv",
                                            header=False, index=False)

# ---------------- 4) voting outcome ----------------
subregion_noise = np.zeros_like(wealth_hi, dtype=float)  # NEW: store 100x100 noise

def vote(sub_wealth, treated, i, j):
    """Return a single sub‑region vote in ‑‑ fraction 0…1."""
    noise = rng.normal(0, OUTCOME_NOISE)
    subregion_noise[i, j] = noise  # NEW: store sampled noise
    if not treated:
        return np.clip(BASELINE_VOTE + noise, 0, 1)
    # treated case
    if sub_wealth == 3:
        base = RICH_VOTE
    elif sub_wealth == 1:
        base = POOR_VOTE
    else:
        base = INTER_VOTE
    return np.clip(base + noise, 0, 1)

outcome = np.zeros_like(interventions, dtype=float)

for i in range(GRID_SIZE):
    for j in range(GRID_SIZE):
        r0, c0 = i*SUBREGION_SIZE, j*SUBREGION_SIZE
        block_wealth = wealth_hi[r0:r0+SUBREGION_SIZE,
                                 c0:c0+SUBREGION_SIZE]
        block_treated = bool(interventions[i, j])
        votes = [[vote(block_wealth[r, c], block_treated, r0 + r, c0 + c)
                  for c in range(SUBREGION_SIZE)]
                 for r in range(SUBREGION_SIZE)]
        # outcome[i, j] = np.mean(votes) * 100.0        # percentage 0‑100
        outcome[i, j] = soft_aggregate_np(np.array(votes), temperature=agg_temp) * 100.0

# Save subregion noise
pd.DataFrame(np.round(subregion_noise, 4)).to_csv(f"{OUTFOLDER}/subregion_noise_gt.csv",
                                                  header=False, index=False)

# Save aggregated region-level noise
region_noise = subregion_noise.reshape(GRID_SIZE, SUBREGION_SIZE,
                                       GRID_SIZE, SUBREGION_SIZE).sum(axis=(1, 3))
pd.DataFrame(np.round(region_noise, 4)).to_csv(f"{OUTFOLDER}/region_noise_gt.csv",
                                               header=False, index=False)

# Save outcome as before
pd.DataFrame(np.round(outcome, 2)).to_csv(f"{OUTFOLDER}/outcome.csv",
                                          header=False, index=False)


# ---------------- 5) counterfactual outcome (GT) ----------------
# Invert top half (rows 0-4) of intervention matrix
interventions_cf = interventions.copy()
interventions_cf[:GRID_SIZE//2, :] = 1 - interventions_cf[:GRID_SIZE//2, :]

pd.DataFrame(interventions_cf).to_csv(f"{OUTFOLDER}/interventions_cf_gt.csv",
                                      header=False, index=False)

# Compute counterfactual using same noise
outcome_cf = np.zeros_like(interventions_cf, dtype=float)

for i in range(GRID_SIZE):
    for j in range(GRID_SIZE):
        r0, c0 = i*SUBREGION_SIZE, j*SUBREGION_SIZE
        block_wealth = wealth_hi[r0:r0+SUBREGION_SIZE, c0:c0+SUBREGION_SIZE]
        block_treated = bool(interventions_cf[i, j])
        votes = [[
            # reuse stored noise instead of generating new one
            np.clip(
                (RICH_VOTE if w == 3 else POOR_VOTE if w == 1 else INTER_VOTE) + subregion_noise[r0 + r, c0 + c]
                if block_treated else BASELINE_VOTE + subregion_noise[r0 + r, c0 + c],
                0, 1
            )
            for c, w in enumerate(block_wealth[r])]
            for r in range(SUBREGION_SIZE)]
        outcome_cf[i, j] = soft_aggregate_np(np.array(votes), temperature=agg_temp) * 100.0


pd.DataFrame(np.round(outcome_cf, 2)).to_csv(f"{OUTFOLDER}/outcome_cf_gt.csv",
                                             header=False, index=False)

print("✓ Synthetic data and counterfactuals written to:", OUTFOLDER)

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Set output folder
OUTFOLDER = "data_exp5_max"

# Load wealth and noise
wealth_hi = pd.read_csv(f"{OUTFOLDER}/wealth_high_res.csv", header=None).values
subregion_noise = pd.read_csv(f"{OUTFOLDER}/subregion_noise_gt.csv", header=None).values

# Define constants from data generation
BASELINE_VOTE = 0.50
RICH_VOTE = 0.80
POOR_VOTE = 0.40
INTER_VOTE = 0.50

# Compute control outcome (no intervention) for each sub-cell
control_outcome = np.clip(BASELINE_VOTE + subregion_noise, 0, 1) * 100.0

# Compute treated outcome (with intervention) for each sub-cell
treated_outcome = np.zeros_like(subregion_noise)
for i in range(wealth_hi.shape[0]):
    for j in range(wealth_hi.shape[1]):
        w = wealth_hi[i, j]
        base = RICH_VOTE if w == 3 else POOR_VOTE if w == 1 else INTER_VOTE
        treated_outcome[i, j] = np.clip(base + subregion_noise[i, j], 0, 1) * 100.0

# Compute causal effect: treated - control
causal_effect = treated_outcome - control_outcome

# Save to CSVs
pd.DataFrame(np.round(control_outcome, 2)).to_csv(f"{OUTFOLDER}/control_sub.csv", header=False, index=False)
pd.DataFrame(np.round(treated_outcome, 2)).to_csv(f"{OUTFOLDER}/treated_sub.csv", header=False, index=False)
pd.DataFrame(np.round(causal_effect, 2)).to_csv(f"{OUTFOLDER}/effect_sub.csv", header=False, index=False)

# Normalize for plotting (0-100 to 0-1 for outcomes, effect min-max to 0-1? but for consistency vmin vmax)
control_norm = control_outcome #/ 100.0
treated_norm = treated_outcome #/ 100.0
effect_norm = causal_effect #/ 100.0  # since difference is -0.1 to 0.3 roughly

# Color maps from seaborn deep palette
colors = sns.color_palette("deep")
blue_cmap = sns.light_palette(colors[0], as_cmap=True)  # Blue for outcomes
pink_cmap = sns.light_palette(colors[4], as_cmap=True)  # Pink for effect

# Parameter to control space between subplots
space_between = 0.2

# Create a single figure with three subplots side by side
fig, axs = plt.subplots(1, 3, figsize=(15, 5))

# Left: Control outcome in blue
sns.heatmap(control_norm, ax=axs[0], cmap=blue_cmap, cbar=False, square=True, linewidths=0.0)#, vmin=0.0, vmax=1.0)
axs[0].set_xticks([])
axs[0].set_yticks([])


# Center: Treated outcome in blue
sns.heatmap(treated_norm, ax=axs[1], cmap=blue_cmap, cbar=False, square=True, linewidths=0.0)#, vmin=0.0, vmax=1.0)
axs[1].set_xticks([])
axs[1].set_yticks([])


# Right: Causal effect in pink
sns.heatmap(effect_norm, ax=axs[2], cmap=pink_cmap, cbar=False, square=True, linewidths=0.0)#, vmin=-0.2, vmax=0.4)  # Adjusted vmin/vmax for effect range
axs[2].set_xticks([])
axs[2].set_yticks([])


# Adjust the space between subplots
fig.subplots_adjust(wspace=space_between)

# Save and show
fig.savefig(f'{OUTFOLDER}/sub_heatmaps.jpg', dpi=300, bbox_inches='tight', pad_inches=0)
fig.savefig(f'{OUTFOLDER}/sub_heatmaps.pdf', bbox_inches='tight', pad_inches=0)
plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

# Set output folder
OUTFOLDER = "data_exp5_max"

# Load the data
interv_data = pd.read_csv(f"{OUTFOLDER}/interventions.csv", header=None).values.astype(int)
outcome_data = pd.read_csv(f"{OUTFOLDER}/outcome.csv", header=None).values / 100.0
context_data = pd.read_csv(f"{OUTFOLDER}/wealth_high_res.csv", header=None).values
context_data_norm = (context_data - 1) / 2.0  # Normalize assuming values 1,2,3 to 0-1

# Color maps from seaborn deep palette
colors = sns.color_palette("deep")
red_cmap = sns.light_palette(colors[3], as_cmap=True)  # Red
blue_cmap = sns.light_palette(colors[0], as_cmap=True)  # Blue
green_cmap = sns.light_palette(colors[2], as_cmap=True)  # Green

# Parameter to control space between subplots (as fraction of average axis width)
space_between = 0.05  # Adjust this value as needed (e.g., 0.0 for no space, 0.5 for more space)

# Create a single figure with three subplots side by side
fig, axs = plt.subplots(1, 3, figsize=(15, 5))

# Left: Intervention in red
sns.heatmap(interv_data, ax=axs[0], cmap=red_cmap, cbar=False, square=True, linewidths=2, vmin=0.0, vmax=1.0)
axs[0].set_xticks([])
axs[0].set_yticks([])

sns.heatmap(outcome_data, ax=axs[1], cmap=blue_cmap, cbar=False, square=True, linewidths=2, vmin=0.4, vmax=0.6)
axs[1].set_xticks([])
axs[1].set_yticks([])

sns.heatmap(context_data_norm, ax=axs[2], cmap=green_cmap, cbar=False, square=True, linewidths=0, vmin=0.0, vmax=1.0)
axs[2].set_xticks([])
axs[2].set_yticks([])

# Adjust the space between subplots
fig.subplots_adjust(wspace=space_between)

# Save and show
fig.savefig(f'{OUTFOLDER}/combined_heatmaps.jpg', dpi=300, bbox_inches='tight', pad_inches=0)
fig.savefig(f'{OUTFOLDER}/combined_heatmaps.pdf', bbox_inches='tight', pad_inches=0)
plt.show()

In [None]:
def soft_aggregate(values, temperature=1.0):
    if isinstance(temperature, torch.Tensor):
        if (temperature <= 0).any():
            raise ValueError("Temperature must be positive")
    elif temperature <= 0:
        raise ValueError("Temperature must be positive")

    logits = values / temperature
    logits = logits - torch.amax(logits, dim=-1, keepdim=True)  # numerical stability
    weights = torch.softmax(logits, dim=-1)  # (..., K)
    return (values * weights).sum(dim=-1)     # (...,)

In [None]:
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
import random

# ------------------ Configuration ------------------
OUTFOLDER = "data_exp5_max"
N_EPOCHS = 1000
LR = 0.01
SEED = 42
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
LOGLOG_FLAG = True  # Whether to plot losses in log-log scale

# ------------------ Set seeds ------------------
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

# ------------------ Load data ------------------
interv_np  = pd.read_csv(f"{OUTFOLDER}/interventions.csv",   header=None).values  # 0/1
wealth_hi  = pd.read_csv(f"{OUTFOLDER}/wealth_high_res.csv", header=None).values  # 100×100 wealth classes
outcome_np = pd.read_csv(f"{OUTFOLDER}/outcome.csv",         header=None).values  # region outcomes %

# Convert to tensor
outcome_t = torch.tensor(outcome_np, dtype=torch.float32, device=DEVICE)

# ------------------ Known Constants ------------------
BASELINE_VOTE = 0.50
RICH_VOTE     = 0.80
INTER_VOTE    = 0.50
POOR_VOTE     = 0.40

# ------------------ Initialize Parameters ------------------
params = torch.nn.Parameter(torch.full((6,), 0.5, dtype=torch.float32, device=DEVICE))  # 6 μs
log_tau = torch.nn.Parameter(torch.tensor(0.0, dtype=torch.float32, device=DEVICE))     # log temperature
opt = torch.optim.Adam([params, log_tau], lr=LR)

loss_hist = []
ce_loss_hist = []
temp_hist = []

# ------------------ Training Loop ------------------
for epoch in range(N_EPOCHS):
    opt.zero_grad()

    # Compute predicted outcomes at subregion level
    wealth_hi_t = torch.tensor(wealth_hi, dtype=torch.float32, device=DEVICE)
    wealth_hi_reshaped = wealth_hi_t.reshape(10, 10, 10, 10).permute(0, 2, 1, 3).reshape(10, 10, 100)
    treated_mask_t = torch.tensor(interv_np, dtype=torch.float32, device=DEVICE) == 1
    treated_mask_expanded = treated_mask_t.unsqueeze(-1).expand(-1, -1, 100)

    params_expanded = params.unsqueeze(0).unsqueeze(0).expand(10, 10, -1)

    predicted_sub_outcomes = torch.where(
        treated_mask_expanded,
        torch.where(
            wealth_hi_reshaped == 3, params_expanded[:, :, 5].unsqueeze(-1),
            torch.where(wealth_hi_reshaped == 1, params_expanded[:, :, 3].unsqueeze(-1), params_expanded[:, :, 4].unsqueeze(-1))
        ),
        torch.where(
            wealth_hi_reshaped == 3, params_expanded[:, :, 2].unsqueeze(-1),
            torch.where(wealth_hi_reshaped == 1, params_expanded[:, :, 0].unsqueeze(-1), params_expanded[:, :, 1].unsqueeze(-1))
        )
    )

    # Soft aggregation
    temperature = torch.exp(log_tau)
    mean_pred = soft_aggregate(predicted_sub_outcomes, temperature=temperature) * 100  # %

    # Compute region-level MSE loss
    loss = torch.mean((mean_pred - outcome_t) ** 2)
    loss.backward()
    opt.step()

    # Log losses and temperature
    loss_hist.append(loss.item())
    temp_hist.append(temperature.item())

    # Compute causal effect MSE loss (against true values)
    with torch.no_grad():
        est_causal_effect = (params[3:] - params[:3])  # estimated μ1 - μ0
        true_causal_effect = torch.tensor([
            (POOR_VOTE - BASELINE_VOTE),
            (INTER_VOTE - BASELINE_VOTE),
            (RICH_VOTE  - BASELINE_VOTE)
        ], device=DEVICE)

        ce_mse = torch.mean((est_causal_effect - true_causal_effect) ** 2).item()
        ce_loss_hist.append(ce_mse)

# ------------------ Save Loss Logs ------------------
loss_df = pd.DataFrame({
    'region_loss': loss_hist,
    'causal_effect_loss': ce_loss_hist,
    'temperature': temp_hist
})
loss_df.to_csv(f"{OUTFOLDER}/estimation_loss.csv", index_label="epoch")

# ------------------ Plot Loss Curves ------------------
fig, ax1 = plt.subplots(figsize=(10, 6))
colors = sns.color_palette("deep")
epochs = np.arange(1, N_EPOCHS + 1)

# Region Loss
ax1.set_xlabel('Epoch', fontsize=16)
ax1.set_ylabel('Training Loss', fontsize=16, color=colors[0])
if LOGLOG_FLAG:
    ax1.loglog(epochs, loss_hist, color=colors[0], lw=2.5)
else:
    ax1.plot(epochs, loss_hist, color=colors[0], lw=2)
ax1.tick_params(axis='y', labelcolor=colors[0])

# Causal Effect Loss
ax2 = ax1.twinx()
ax2.set_ylabel('Causal Effect MSE', fontsize=16, color=colors[3])
if LOGLOG_FLAG:
    ax2.loglog(epochs, ce_loss_hist, color=colors[3], ls='--', lw=2.5)
else:
    ax2.plot(epochs, ce_loss_hist, color=colors[3], ls='--', lw=2)
ax2.tick_params(axis='y', labelcolor=colors[3])


plt.tight_layout()
fig.savefig(f"{OUTFOLDER}/estimation_loss.jpg", dpi=300)
fig.savefig(f"{OUTFOLDER}/estimation_loss.pdf")
plt.show()

# ------------------ Save Parameters ------------------
param_names = ["mu_poor_0", "mu_inter_0", "mu_rich_0",
               "mu_poor_1", "mu_inter_1", "mu_rich_1", "temperature"]

values = np.append(params.detach().cpu().numpy(), temperature.item())
results_df = pd.DataFrame({'parameter': param_names, 'value': values})
results_df.to_csv(f"{OUTFOLDER}/results.csv", index=False)


In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

data_dir = 'data_exp5_max'
loglog_flag = True

# Load loss history from CSV
loss_data = pd.read_csv(os.path.join(data_dir, 'estimation_loss.csv'))
loss_hist = loss_data['region_loss'].values
ce_loss_hist = loss_data['causal_effect_loss'].values
epochs = np.arange(1, len(loss_hist) + 1)

# Plotting
fig, ax1 = plt.subplots(figsize=(8, 6))
colors = sns.color_palette("deep")

ax1.set_xlabel('Epoch', fontsize=20)
ax1.set_ylabel('Training Loss', fontsize=20, color=colors[0])
if loglog_flag:
    ax1.loglog(epochs, loss_hist, color=colors[0], alpha=0.9, lw=3)
else:
    ax1.plot(epochs, loss_hist, color=colors[0], alpha=0.7, lw=2)
ax1.tick_params(axis='both', which='both', length=0, labelsize=18, colors=colors[0])
ax1.set_title('MSE Curves', fontsize=22)
ax1.spines['top'].set_visible(False)

ax2 = ax1.twinx()
ax2.set_ylabel('MSE Causal Effect', fontsize=20, color=colors[3])
if loglog_flag:
    ax2.loglog(epochs, ce_loss_hist, color=colors[3], alpha=0.9, ls='--', lw=5)
else:
    ax2.plot(epochs, ce_loss_hist, color=colors[3], alpha=0.7, ls='--', lw=2)
ax2.tick_params(axis='y', which='both', length=0, labelsize=18, colors=colors[3])
ax2.spines['top'].set_visible(False)

plt.tight_layout()
fig.savefig(os.path.join(data_dir, 'estimation_loss.jpg'), dpi=300)
fig.savefig(os.path.join(data_dir, 'estimation_loss.pdf'))
plt.show()