# CausalTorch: Image Generation with Causal Constraints

This notebook demonstrates how to use CausalTorch to generate images that adhere to causal rules. We'll implement the classic example where "If it rains, the ground is wet" - our model will ensure that any generated image with rain also shows wet ground.

## 1. Setup and Installation

In [None]:
# Install CausalTorch if not already installed
%pip install -e ..

# Import necessary libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import pytorch_lightning as pl

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

## 2. Generate Synthetic Training Data

Let's create a simple dataset of synthetic images with rain and wet/dry ground.

In [None]:
def generate_synthetic_image(has_rain=True, img_size=28):
    """Generate a synthetic image with or without rain and appropriate ground wetness"""
    img = np.zeros((img_size, img_size))
    
    if has_rain:
        # Add rain (vertical lines)
        rain_intensity = 0.8
        for i in range(0, img_size, 3):
            # Random raindrop length
            drop_length = np.random.randint(5, 15)
            start_y = np.random.randint(0, img_size - drop_length)
            img[start_y:start_y+drop_length, i] = rain_intensity
        
        # Make ground wet (bottom 8 rows darker)
        img[-8:, :] = 0.6
    else:
        # Dry ground is lighter
        img[-8:, :] = 0.1
    
    # Add noise for realism
    img += np.random.normal(0, 0.05, (img_size, img_size))
    img = np.clip(img, 0, 1)
    
    return img

# Generate and display a few examples
plt.figure(figsize=(10, 5))
for i in range(6):
    plt.subplot(2, 3, i+1)
    has_rain = i < 3  # First row: rain, Second row: no rain
    img = generate_synthetic_image(has_rain=has_rain)
    plt.imshow(img, cmap='gray')
    plt.title(f"{'Rain' if has_rain else 'No Rain'}")
    plt.axis('off')
plt.tight_layout()
plt.show()

## 3. Define the Causal Layer

This is the core of our approach. The `CausalSymbolicLayer` enforces the relationship between rain and wet ground in the latent space.

In [None]:
class CausalSymbolicLayer(nn.Module):
    def __init__(self, causal_rules=None):
        super().__init__()
        # Default rule: rain → wet ground
        self.causal_rules = causal_rules or {
            "rain": {"effect": "ground_wetness", "strength": 0.9}
        }
    
    def forward(self, z):
        """Apply causal constraints to latent vector z
        z: [batch_size, latent_dim] where:
           z[:, 0] = rain_intensity (0-1)
           z[:, 1] = ground_wetness (0-1)
        """
        # Get rain intensity and enforce causal relationship with ground wetness
        rain_intensity = z[:, 0]
        
        # Simple causal rule: If rain_intensity > 0.5, ground must be wet
        strength = self.causal_rules["rain"]["strength"]
        # Apply sigmoid to make a smooth transition around threshold 0.5
        ground_wetness = torch.sigmoid((rain_intensity - 0.5) * 10) * strength 
        
        # Override the ground_wetness value in the latent vector
        z[:, 1] = ground_wetness
        
        return z

## 4. Define the CNSG-Net Model (Causal VAE)

We'll create a simple VAE with our causal layer integrated.

In [None]:
class CausalVAE(pl.LightningModule):
    def __init__(self, latent_dim=3, img_size=28):
        super().__init__()
        self.latent_dim = latent_dim
        self.img_size = img_size
        
        # Encoder (Image → Latent)
        self.encoder = nn.Sequential(
            nn.Linear(img_size * img_size, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU()
        )
        
        # Mean and variance layers for VAE
        self.fc_mu = nn.Linear(128, latent_dim)
        self.fc_var = nn.Linear(128, latent_dim)
        
        # Causal layer to enforce rain → wet ground
        self.causal_layer = CausalSymbolicLayer()
        
        # Decoder (Latent → Image)
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, img_size * img_size),
            nn.Sigmoid()  # Output pixel values 0-1
        )
    
    def encode(self, x):
        x = x.view(x.size(0), -1)  # Flatten
        h = self.encoder(x)
        mu = self.fc_mu(h)
        log_var = self.fc_var(h)
        return mu, log_var
    
    def reparameterize(self, mu, log_var):
        std = torch.exp(log_var / 2)
        eps = torch.randn_like(std)
        z = mu + eps * std
        return z
    
    def decode(self, z):
        # Apply causal constraints
        z = self.causal_layer(z)
        h = self.decoder(z)
        return h.view(h.size(0), 1, self.img_size, self.img_size)
    
    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        x_reconstructed = self.decode(z)
        return x_reconstructed, mu, log_var
    
    def training_step(self, batch, batch_idx):
        x, has_rain = batch
        x_reconstructed, mu, log_var = self(x)
        
        # Reconstruction loss
        recon_loss = F.mse_loss(x_reconstructed, x)
        
        # KL Divergence
        kl_loss = -0.5 * torch.mean(1 + log_var - mu.pow(2) - log_var.exp())
        
        # Causal loss - analyze generated images to check ground wetness
        ground_region = x_reconstructed[:, :, -8:, :].mean(dim=(1, 2, 3))
        
        # For images with rain, ground should be wet (darker)
        # For images without rain, no constraint
        rain_indices = torch.where(has_rain == 1)[0]
        causal_loss = 0
        if len(rain_indices) > 0:
            # Higher ground_region value means lower wetness (brighter pixels)
            # We want wetness for rainy images, so ground_region should be lower
            causal_loss = F.relu(ground_region[rain_indices] - 0.3).mean()
        
        # Total loss
        loss = recon_loss + 0.1 * kl_loss + 5.0 * causal_loss
        
        self.log_dict({
            'train_loss': loss,
            'recon_loss': recon_loss,
            'kl_loss': kl_loss,
            'causal_loss': causal_loss
        })
        
        return loss
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)
    
    def generate(self, rain_intensity=1.0, num_samples=1):
        """Generate images with specified rain intensity"""
        with torch.no_grad():
            # Create a latent vector with the desired rain intensity
            z = torch.randn(num_samples, self.latent_dim)  # Random latent vector
            z[:, 0] = rain_intensity  # Set first dimension to rain intensity
            
            # Apply causal constraints and decode
            images = self.decode(z)
            return images

# Create the model
model = CausalVAE(latent_dim=3, img_size=28)

## 5. Create Dataset and DataLoader

In [None]:
class SyntheticRainDataset(torch.utils.data.Dataset):
    def __init__(self, num_samples=100, img_size=28):
        self.num_samples = num_samples
        self.img_size = img_size
        self.data = []
        self.labels = []  # 1 for rain, 0 for no rain
        
        # Generate data
        for i in range(num_samples):
            has_rain = (i < num_samples / 2)  # Half rain, half no rain
            img = generate_synthetic_image(has_rain=has_rain, img_size=img_size)
            self.data.append(torch.tensor(img, dtype=torch.float32).unsqueeze(0))
            self.labels.append(1 if has_rain else 0)
    
    def __len__(self):
        return self.num_samples
    
    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

# Create dataset and dataloader
dataset = SyntheticRainDataset(num_samples=100)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

## 6. Train the Model

In [None]:
# Create trainer
trainer = pl.Trainer(max_epochs=30)

# For demonstration purposes, let's only run a few steps
# Remove this limit for full training
print("Training for a few iterations...")
trainer.fit(model, dataloader)

## 7. Generate Images with Causal Constraints

In [None]:
def display_generations(rain_intensities):
    plt.figure(figsize=(12, 4))
    for i, rain_intensity in enumerate(rain_intensities):
        # Generate image with specified rain intensity
        image = model.generate(rain_intensity=rain_intensity)
        
        # Get ground wetness from the image (average of bottom rows)
        ground_wetness = 1.0 - image[0, 0, -8:, :].mean().item()  # Invert: darker = wetter
        
        # Display
        plt.subplot(1, len(rain_intensities), i+1)
        plt.imshow(image[0, 0].cpu().numpy(), cmap='gray')
        plt.title(f"Rain: {rain_intensity:.1f}\nGround Wetness: {ground_wetness:.2f}")
        plt.axis('off')
    plt.tight_layout()
    plt.show()

# Generate images with varying rain intensities
rain_intensities = [0.0, 0.3, 0.5, 0.7, 1.0]
display_generations(rain_intensities)

## 8. Evaluate Causal Fidelity Score (CFS)

Let's compute how well our model adheres to the causal rule "rain → wet ground"

In [None]:
def calculate_cfs():
    """Calculate Causal Fidelity Score"""
    # Generate 50 images across a range of rain intensities
    rain_intensities = np.linspace(0, 1, 50)
    correct = 0
    
    for rain in rain_intensities:
        image = model.generate(rain_intensity=rain)
        
        # Calculate ground wetness
        ground_wetness = 1.0 - image[0, 0, -8:, :].mean().item()  # Invert: darker = wetter
        
        # Check if the causal rule is satisfied
        # Rule: If rain_intensity > 0.5, ground_wetness should be > 0.5
        if (rain > 0.5 and ground_wetness > 0.5) or (rain <= 0.5 and ground_wetness <= 0.7):
            correct += 1
    
    return correct / len(rain_intensities)

# Calculate CFS
cfs = calculate_cfs()
print(f"Causal Fidelity Score (CFS): {cfs:.2f} (higher is better)")

# Visualize
plt.figure(figsize=(6, 3))
plt.bar(['CFS'], [cfs], color='blue')
plt.ylim(0, 1)
plt.ylabel('Score')
plt.title('Causal Fidelity Score')
plt.axhline(y=0.5, color='r', linestyle='--', label='Minimum Acceptable')
plt.legend()
plt.tight_layout()
plt.show()

## 9. Latent Space Traversal

Let's explore the latent space to see how our causal constraints affect generation.

In [None]:
def latent_traversal():
    # Create a grid of images by varying two dimensions of the latent space
    dim1_values = np.linspace(0, 1, 5)  # rain intensity
    dim2_values = np.linspace(-1, 1, 4)  # time of day (arbitrary)
    
    plt.figure(figsize=(12, 10))
    
    for i, dim2 in enumerate(dim2_values):
        for j, dim1 in enumerate(dim1_values):
            # Create latent vector
            z = torch.zeros(1, 3)  # 3-dimensional latent space
            z[0, 0] = dim1  # rain intensity
            z[0, 2] = dim2  # time of day
            
            # Ground wetness (dimension 1) will be enforced by the causal layer
            
            # Generate image
            with torch.no_grad():
                # Apply causal constraints
                z_causal = model.causal_layer(z)
                image = model.decode(z)
            
            # Get ground wetness from the causal z
            ground_wetness = z_causal[0, 1].item()
            
            # Display
            ax = plt.subplot(len(dim2_values), len(dim1_values), i*len(dim1_values) + j + 1)
            plt.imshow(image[0, 0].cpu().numpy(), cmap='gray')
            plt.title(f"Rain: {dim1:.1f}, GW: {ground_wetness:.2f}")
            plt.axis('off')
    
    plt.suptitle("Latent Space Traversal\nRows: Time of Day, Columns: Rain Intensity", fontsize=16)
    plt.tight_layout(pad=1.5)
    plt.show()

# Show latent space traversal
latent_traversal()

## 10. Counterfactual Analysis

Let's perform a simple counterfactual intervention in the latent space.

In [None]:
def counterfactual_analysis():
    # Take a random data point
    img, has_rain = dataset[10]
    
    # Encode it
    with torch.no_grad():
        mu, log_var = model.encode(img.unsqueeze(0))
        z = model.reparameterize(mu, log_var)
        
        # Original image
        original_img = model.decode(z)
        
        # Counterfactual: What if it was raining/not raining?
        z_cf = z.clone()
        z_cf[0, 0] = 1.0 if z[0, 0] < 0.5 else 0.0  # Flip rain intensity
        
        # The causal layer will automatically update ground wetness
        cf_img = model.decode(z_cf)
    
    # Display
    plt.figure(figsize=(10, 5))
    
    plt.subplot(1, 2, 1)
    plt.imshow(original_img[0, 0].cpu().numpy(), cmap='gray')
    plt.title(f"Original (Rain: {z[0, 0].item():.2f})")
    plt.axis('off')
    
    plt.subplot(1, 2, 2)
    plt.imshow(cf_img[0, 0].cpu().numpy(), cmap='gray')
    plt.title(f"Counterfactual (Rain: {z_cf[0, 0].item():.2f})")
    plt.axis('off')
    
    plt.suptitle("Counterfactual Analysis: What if it was (not) raining?", fontsize=14)
    plt.tight_layout()
    plt.show()

# Perform counterfactual analysis
counterfactual_analysis()

## 11. Conclusion

In this notebook, we demonstrated how CausalTorch can be used to generate images that adhere to causal constraints. Key takeaways:

1. We enforced the causal rule "rain → wet ground" using a specialized layer in the latent space
2. The model was trained with minimal data (just 100 synthetic examples)
3. We validated the model's adherence to causal rules using the Causal Fidelity Score (CFS)
4. We demonstrated counterfactual reasoning by manipulating the causal variables

This approach ensures that generated images respect real-world physics and logical consistency. The same principles can be applied to more complex domains like medical imaging or scientific visualizations.