# CausalTorch: Video Generation with Temporal Causality

This notebook demonstrates how to use CausalTorch to generate video sequences with causal temporal constraints. We'll implement a battle scene example where temporal causal rules like "arrow hit → soldier fall" are enforced across frames.

## 1. Setup and Installation

In [None]:
# Install CausalTorch if not already installed
%pip install -e ..

# Import necessary libraries
import torch
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import animation
import imageio
from IPython.display import HTML

# Import CausalTorch components
from causaltorch.models import CNSG_VideoGenerator
from causaltorch.rules import CausalRuleSet
from causaltorch.metrics import temporal_consistency

## 2. Define Temporal Causal Rules

Temporal causal rules define how certain events trigger effects over time. For example, an arrow hit at frame t should cause a soldier to fall at frame t+3.

In [None]:
# Define temporal causal rules for battle scene
battle_rules = {
    "hoof_contact": {
        "effect": "dust",
        "strength": 0.8,
        "temporal_offset": 0,  # Immediate effect
        "duration": 3  # Dust lasts 3 frames
    },
    "arrow_hit": {
        "effect": "soldier_fall",
        "strength": 0.9,
        "temporal_offset": 3,  # Effect happens 3 frames later
        "duration": 10  # Fall animation lasts 10 frames
    },
    "explosion": {
        "effect": "smoke_cloud",
        "strength": 0.95,
        "temporal_offset": 1,
        "duration": 15  # Smoke lasts 15 frames
    }
}

# Create rule set for visualization
rule_set = CausalRuleSet(battle_rules)

# Visualize the causal graph
rule_set.visualize(figsize=(8, 6))

## 3. Create a Video Generator Model

We'll create a CNSG_VideoGenerator model that enforces temporal causal constraints during video generation.

In [None]:
# Create the video generator model
frame_size = (64, 64)  # Height, width
latent_dim = 16  # Latent space dimension
model = CNSG_VideoGenerator(frame_size=frame_size, latent_dim=latent_dim, causal_rules=battle_rules)

# Set seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

## 4. Define Metadata for Causal Events

We need to specify when causal events occur in our video sequence. For example, when hooves hit the ground or when arrows hit.

In [None]:
def create_battle_metadata(num_frames=48):
    """Create metadata for a battle scene with causal events"""
    metadata = {}
    
    # Horse hoof contacts ground every 6 frames (gallop rhythm)
    metadata["hoof_contact"] = [1.0 if i % 6 == 0 else 0.0 for i in range(num_frames)]
    
    # Arrow hits at specific frames
    arrow_hits = [12, 28]  # Frames where arrows hit
    metadata["arrow_hit"] = [1.0 if i in arrow_hits else 0.0 for i in range(num_frames)]
    
    # Explosion at frame 20
    explosion_frame = 20
    metadata["explosion"] = [1.0 if i == explosion_frame else 0.0 for i in range(num_frames)]
    
    return metadata

# Create metadata for a 48-frame sequence
battle_metadata = create_battle_metadata(num_frames=48)

# Visualize the metadata events
plt.figure(figsize=(12, 4))
events = list(battle_metadata.keys())
for i, event in enumerate(events):
    plt.subplot(len(events), 1, i+1)
    plt.plot(battle_metadata[event], 'o-')
    plt.ylabel(event)
    plt.ylim(-0.1, 1.1)
plt.xlabel('Frame')
plt.tight_layout()
plt.show()

## 5. Generate a Battle Scene Video

Now we'll generate a video sequence with our causal constraints.

In [None]:
def generate_battle_video(model, metadata, num_frames=48):
    """Generate a battle video with causal constraints"""
    # Create initial inputs
    batch_size = 1
    # Start with a simple scene (could be a more realistic frame in a real implementation)
    initial_frame = torch.rand(batch_size, 3, model.frame_size[0], model.frame_size[1]) * 0.5 + 0.25
    initial_latent = torch.zeros(batch_size, model.latent_dim)
    
    # Add some randomness to the latent vector
    initial_latent[:, 0:5] = torch.randn(batch_size, 5) * 0.5
    
    # Generate video
    with torch.no_grad():
        video = model(initial_frame, initial_latent, seq_length=num_frames, metadata=metadata)
    
    return video

# Generate battle video
num_frames = 48
battle_video = generate_battle_video(model, battle_metadata, num_frames=num_frames)

# Get video shape
print(f"Generated video shape: {battle_video.shape}")

## 6. Visualize the Generated Video

In [None]:
def display_video_frames(video, num_frames=8):
    """Display selected frames from the video"""
    # Convert to numpy for display
    video_np = video.squeeze().permute(0, 2, 3, 1).cpu().numpy()
    
    # Select evenly spaced frames
    indices = np.linspace(0, len(video_np)-1, num_frames, dtype=int)
    
    # Display frames
    plt.figure(figsize=(16, 4))
    for i, idx in enumerate(indices):
        plt.subplot(2, num_frames//2, i+1)
        plt.imshow(video_np[idx])
        plt.title(f"Frame {idx}")
        plt.axis('off')
    plt.tight_layout()
    plt.show()

# Display selected frames
display_video_frames(battle_video, num_frames=8)

In [None]:
def create_animation(video):
    """Create an animation from the video frames"""
    video_np = video.squeeze().permute(0, 2, 3, 1).cpu().numpy()
    
    # Create figure and axes
    fig, ax = plt.subplots(figsize=(5, 5))
    ax.set_axis_off()
    
    # Create initial image
    img = ax.imshow(video_np[0])
    
    # Animation function
    def animate(i):
        img.set_array(video_np[i])
        return [img]
    
    # Create animation
    anim = animation.FuncAnimation(fig, animate, frames=len(video_np), interval=100, blit=True)
    plt.close()
    return HTML(anim.to_jshtml())

# Create and display animation
create_animation(battle_video)

## 7. Save the Generated Video

In [None]:
def save_video(video, filename="battle_scene.mp4", fps=12):
    """Save the generated video to a file"""
    video_np = video.squeeze().permute(0, 2, 3, 1).cpu().numpy()
    
    # Convert to uint8 (0-255)
    video_np = (video_np * 255).astype(np.uint8)
    
    # Save as MP4
    imageio.mimsave(filename, video_np, fps=fps)
    print(f"Video saved to {filename}")

# Save the video
save_video(battle_video)

## 8. Measure Temporal Consistency and Causal Fidelity

In [None]:
def measure_causal_events(video, metadata):
    """Measure the visible effects of causal events in the video"""
    video_np = video.squeeze().permute(0, 2, 3, 1).cpu().numpy()
    num_frames = video_np.shape[0]
    
    # Get key frames for each causal event
    hoof_frames = [i for i in range(num_frames) if metadata["hoof_contact"][i] > 0.5]
    arrow_frames = [i for i in range(num_frames) if metadata["arrow_hit"][i] > 0.5]
    explosion_frame = [i for i in range(num_frames) if metadata["explosion"][i] > 0.5][0]
    
    # Analyze dust after hoof contact (look at lower part of frames)
    dust_intensity = []
    for frame in hoof_frames:
        if frame + 1 < num_frames:
            # Calculate motion in ground area (bottom 20% of frame)
            ground_area = video_np[frame + 1, int(0.8 * video_np.shape[1]):, :, :]
            dust = np.mean(ground_area)
            dust_intensity.append(dust)
    
    # Calculate temporal consistency
    tc_score = temporal_consistency(video)
    
    # Print results
    print(f"Temporal Consistency Score: {tc_score:.4f}")
    print(f"Average dust intensity after hoof contact: {np.mean(dust_intensity):.4f}")
    print(f"Causal events: Hoof contacts at frames {hoof_frames}")
    print(f"              Arrow hits at frames {arrow_frames}")
    print(f"              Explosion at frame {explosion_frame}")

# Measure causal effects
measure_causal_events(battle_video, battle_metadata)

## 9. Counterfactual Intervention: "What if there was no explosion?"

In [None]:
# Create alternative metadata with no explosion
alternative_metadata = create_battle_metadata(num_frames=48)
alternative_metadata["explosion"] = [0.0 for _ in range(48)]  # Remove explosion

# Generate counterfactual video
counterfactual_video = generate_battle_video(model, alternative_metadata, num_frames=48)

# Display frames around where the explosion would have been
explosion_frame = 20
window = 2  # Show frames before and after

plt.figure(figsize=(15, 6))
# Original video (with explosion)
video_np = battle_video.squeeze().permute(0, 2, 3, 1).cpu().numpy()
for i in range(explosion_frame - window, explosion_frame + window + 1):
    plt.subplot(2, 5, i - explosion_frame + window + 1)
    plt.imshow(video_np[i])
    plt.title(f"Original: Frame {i}")
    plt.axis('off')

# Counterfactual video (no explosion)
cf_video_np = counterfactual_video.squeeze().permute(0, 2, 3, 1).cpu().numpy()
for i in range(explosion_frame - window, explosion_frame + window + 1):
    plt.subplot(2, 5, i - explosion_frame + window + 6)
    plt.imshow(cf_video_np[i])
    plt.title(f"Counterfactual: Frame {i}")
    plt.axis('off')

plt.suptitle("Counterfactual: What if there was no explosion?")
plt.tight_layout()
plt.show()

## 10. Save and Load the Model

In [None]:
# Save the model
torch.save(model.state_dict(), "battle_video_generator.pt")
print("Model saved to battle_video_generator.pt")

# Load the model (for demonstration)
loaded_model = CNSG_VideoGenerator(frame_size=frame_size, latent_dim=latent_dim, causal_rules=battle_rules)
loaded_model.load_state_dict(torch.load("battle_video_generator.pt"))
print("Model loaded successfully")

## 11. Conclusion

In this notebook, we demonstrated how CausalTorch can be used to generate video sequences with temporal causal constraints. Key takeaways:

1. We defined temporal causal rules with appropriate offsets and durations.
2. We created a video generator model that enforces these rules during generation.
3. We specified when causal events occur using metadata.
4. We generated a battle scene video with causally consistent effects.
5. We measured the temporal consistency and verified causal effects.
6. We performed a counterfactual intervention to see how the video changes when an event is removed.

This approach enables more realistic and logically consistent video generation, especially for scenarios where temporal causality is important, like simulations, game development, and film production.