# NSCA Physics Prior Evaluation: Predicting Stability from Initial State

## The Actual Physion Task

**Key insight**: The model must predict what WILL happen from the INITIAL configuration.
It does NOT see the outcome - that's what makes this a physics reasoning task.

**Previous bug**: Showing full video (including outcome) = trivially solvable.

**This version**: Model sees ONLY the initial frame. Must predict: "Will this fall?"

---

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from typing import Tuple, Dict, List

print("PyTorch version:", torch.__version__)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 1. Dataset: Initial State Only (No Outcome Visible)

The model sees:
- An object suspended in the air
- A support surface (table) somewhere below

The model must predict:
- Will the object land ON the table (stable) or MISS it (unstable)?

This requires understanding: "Objects fall straight down due to gravity"

In [None]:
class PhysionDataset:
    """
    Generate INITIAL FRAME for stability prediction.
    Model must predict outcome WITHOUT seeing it happen.
    
    Physics rule: Object falls straight down. 
    Stable = object's center is above the table.
    """
    
    def __init__(self, n_samples: int, img_size: int = 64, seed: int = 42, difficulty: str = 'hard'):
        np.random.seed(seed)
        torch.manual_seed(seed)
        
        self.n_samples = n_samples
        self.img_size = img_size
        self.difficulty = difficulty
        
        self.images, self.labels, self.metadata = self._generate()
        
        # Verify balance
        balance = self.labels.mean().item()
        print(f"Dataset: {n_samples} samples, {balance:.1%} stable (target: ~50%)")
    
    def _generate(self):
        images = []
        labels = []
        metadata = []
        
        # Force 50% balance
        n_stable = self.n_samples // 2
        n_unstable = self.n_samples - n_stable
        targets = [1.0] * n_stable + [0.0] * n_unstable
        np.random.shuffle(targets)
        
        for target in targets:
            img, label, meta = self._generate_sample(force_stable=(target == 1.0))
            images.append(img)
            labels.append(label)
            metadata.append(meta)
        
        return torch.stack(images), torch.tensor(labels).float(), metadata
    
    def _generate_sample(self, force_stable: bool):
        """Generate one initial configuration."""
        S = self.img_size
        
        # Object properties
        obj_w = np.random.randint(8, 14)
        obj_h = np.random.randint(6, 12)
        obj_y = np.random.randint(5, 20)  # High up (will fall)
        obj_color = torch.rand(3) * 0.4 + 0.5
        
        # Table properties  
        table_w = np.random.randint(15, 30)
        table_x = np.random.randint(5, S - table_w - 5)
        table_y = S - np.random.randint(12, 20)  # Near bottom
        table_color = torch.tensor([0.55, 0.35, 0.2])
        
        # Object X position determines stability
        table_center = table_x + table_w // 2
        
        if force_stable:
            # Object center must be above table
            margin = max(2, table_w // 2 - obj_w // 2 - 2)
            obj_center = table_center + np.random.randint(-margin, margin + 1)
        else:
            # Object center must NOT be above table
            if np.random.random() < 0.5:
                # Left of table
                obj_center = np.random.randint(obj_w // 2 + 2, max(obj_w // 2 + 3, table_x - 2))
            else:
                # Right of table
                obj_center = np.random.randint(min(table_x + table_w + 2, S - obj_w // 2 - 3), S - obj_w // 2 - 2)
        
        obj_x = obj_center - obj_w // 2
        obj_x = np.clip(obj_x, 0, S - obj_w)
        
        # Ground truth
        actual_center = obj_x + obj_w // 2
        is_stable = (table_x <= actual_center <= table_x + table_w)
        
        # Draw frame
        frame = torch.zeros(3, S, S)
        
        # Sky gradient background
        for row in range(S):
            frame[2, row, :] = 0.3 + 0.2 * (1 - row / S)  # Blue gradient
            frame[0, row, :] = 0.1
            frame[1, row, :] = 0.1
        
        # Draw table
        ty1, ty2 = table_y, min(table_y + 5, S)
        tx1, tx2 = table_x, min(table_x + table_w, S)
        frame[:, ty1:ty2, tx1:tx2] = table_color.view(3, 1, 1)
        
        # Draw object
        oy1, oy2 = obj_y, min(obj_y + obj_h, S)
        ox1, ox2 = obj_x, min(obj_x + obj_w, S)
        frame[:, oy1:oy2, ox1:ox2] = obj_color.view(3, 1, 1)
        
        # Add noise (makes pure memorization harder)
        if self.difficulty == 'hard':
            frame = frame + torch.randn_like(frame) * 0.03
            frame = frame.clamp(0, 1)
        
        meta = {
            'obj_center': actual_center, 'obj_x': obj_x, 'obj_w': obj_w,
            'table_x': table_x, 'table_w': table_w,
            'is_stable': is_stable
        }
        
        return frame, 1.0 if is_stable else 0.0, meta
    
    def __len__(self):
        return self.n_samples
    
    def __getitem__(self, idx):
        return self.images[idx], self.labels[idx]

In [None]:
# Create datasets
print("Creating datasets...")
train_full = PhysionDataset(n_samples=1000, seed=42)
test_data = PhysionDataset(n_samples=300, seed=9999)

In [None]:
# Visualize samples
fig, axes = plt.subplots(2, 5, figsize=(15, 6))

# Show stable examples
stable_idx = [i for i, m in enumerate(train_full.metadata) if m['is_stable']][:5]
for i, idx in enumerate(stable_idx):
    axes[0, i].imshow(train_full.images[idx].permute(1, 2, 0).numpy())
    axes[0, i].set_title(f"STABLE")
    axes[0, i].axis('off')

# Show unstable examples  
unstable_idx = [i for i, m in enumerate(train_full.metadata) if not m['is_stable']][:5]
for i, idx in enumerate(unstable_idx):
    axes[1, i].imshow(train_full.images[idx].permute(1, 2, 0).numpy())
    axes[1, i].set_title(f"UNSTABLE")
    axes[1, i].axis('off')

plt.suptitle("Task: Predict if object will land on table (without seeing it fall)", fontsize=12)
plt.tight_layout()
plt.show()

print("\nThe model must learn: 'Objects fall straight down'")
print("If object center is above table → STABLE")
print("If object center is NOT above table → UNSTABLE")

## 2. Models

### Baseline: Pure CNN (learns everything from scratch)
### NSCA: CNN + Physics Prior ("objects fall straight down")

In [None]:
class BaselineCNN(nn.Module):
    """Pure neural network - must learn physics from data."""
    
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 32, 5, stride=2, padding=2), nn.ReLU(), nn.BatchNorm2d(32),
            nn.Conv2d(32, 64, 5, stride=2, padding=2), nn.ReLU(), nn.BatchNorm2d(64),
            nn.Conv2d(64, 128, 5, stride=2, padding=2), nn.ReLU(),
            nn.AdaptiveAvgPool2d((2, 2))
        )
        self.fc = nn.Sequential(
            nn.Linear(128 * 4, 64), nn.ReLU(),
            nn.Linear(64, 1)
        )
    
    def forward(self, x):
        feat = self.conv(x).flatten(1)
        return torch.sigmoid(self.fc(feat))

In [None]:
class GravityPrior(nn.Module):
    """
    Physics Prior: Objects fall straight down.
    
    Detects:
    1. Object position (bright region in upper half)
    2. Table position (horizontal structure in lower half)
    3. Predicts: Is object X-center above table?
    """
    
    def forward(self, x):
        B, C, H, W = x.shape
        
        # Convert to grayscale for detection
        gray = x.mean(dim=1)  # (B, H, W)
        
        # Detect object (bright region in upper half)
        upper = gray[:, :H//2, :]  # (B, H/2, W)
        obj_mask = (upper > upper.mean(dim=(1,2), keepdim=True) + 0.1)
        
        # Find object center X by weighted average
        x_coords = torch.arange(W, device=x.device).float().view(1, 1, W)
        obj_weights = obj_mask.float() * upper
        obj_sum = obj_weights.sum(dim=(1, 2)) + 1e-6
        obj_center_x = (obj_weights * x_coords).sum(dim=(1, 2)) / obj_sum
        
        # Detect table (bright horizontal region in lower portion)
        lower = gray[:, H*2//3:, :]  # (B, H/3, W)
        table_mask = (lower > 0.3)  # Table is brownish, darker than object
        
        # Find table X range
        table_presence = table_mask.any(dim=1).float()  # (B, W)
        
        # For each sample, find table left and right edges
        results = []
        for b in range(B):
            tp = table_presence[b]  # (W,)
            if tp.sum() < 3:
                # No table detected - predict unstable
                results.append(0.2)
                continue
            
            # Find table bounds
            table_x_coords = torch.where(tp > 0.5)[0]
            if len(table_x_coords) == 0:
                results.append(0.2)
                continue
                
            table_left = table_x_coords.min().item()
            table_right = table_x_coords.max().item()
            
            obj_x = obj_center_x[b].item()
            
            # Physics prediction: Is object center above table?
            if table_left <= obj_x <= table_right:
                # How centered? More centered = more confident
                table_center = (table_left + table_right) / 2
                table_half_width = (table_right - table_left) / 2 + 1e-6
                centrality = 1.0 - abs(obj_x - table_center) / table_half_width
                results.append(0.6 + 0.35 * centrality)  # 0.6 to 0.95
            else:
                # Object will miss table
                results.append(0.15)
        
        return torch.tensor(results, device=x.device).view(B, 1)

In [None]:
class NSCAModel(nn.Module):
    """
    NSCA: Neural network + Physics Prior with learnable blending.
    
    Key: Prior weight starts high (trust physics) but can be 
    reduced if the network finds better patterns.
    """
    
    def __init__(self, initial_prior_weight: float = 0.5):
        super().__init__()
        self.cnn = BaselineCNN()
        self.physics_prior = GravityPrior()
        
        # Learnable blend weight with minimum floor
        self.min_weight = 0.2
        # Initialize so effective weight ≈ initial_prior_weight
        init_val = initial_prior_weight - self.min_weight
        self._raw_weight = nn.Parameter(torch.tensor(np.log(np.exp(init_val) - 1 + 1e-6)))
    
    @property
    def prior_weight(self):
        # Soft lower bound at min_weight
        return self.min_weight + F.softplus(self._raw_weight)
    
    def forward(self, x):
        learned = self.cnn(x)
        
        with torch.no_grad():  # Prior is not trained
            prior = self.physics_prior(x)
        
        w = self.prior_weight.clamp(max=0.8)  # Cap at 80%
        blended = w * prior + (1 - w) * learned
        
        return blended, {'prior': prior.mean().item(), 'learned': learned.mean().item(), 'weight': w.item()}

## 3. Verify Prior Accuracy First

Before comparing, let's check if the physics prior actually works!

In [None]:
# Test the physics prior alone
prior = GravityPrior()

with torch.no_grad():
    prior_preds = prior(test_data.images)
    prior_binary = (prior_preds.squeeze() > 0.5).float()
    prior_acc = (prior_binary == test_data.labels).float().mean().item()

print(f"Physics Prior Accuracy (standalone): {prior_acc:.1%}")
print(f"")
if prior_acc > 0.7:
    print("Prior encodes useful physics knowledge!")
elif prior_acc > 0.55:
    print("Prior has some signal but needs improvement")
else:
    print("WARNING: Prior is not working correctly - check detection logic")

## 4. Training Functions

In [None]:
def train_model(model, train_images, train_labels, epochs=100, lr=0.001, batch_size=32, verbose=False):
    """Train a model and return test accuracies during training."""
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    n = len(train_labels)
    
    model.train()
    for epoch in range(epochs):
        perm = torch.randperm(n)
        total_loss = 0
        
        for i in range(0, n, batch_size):
            idx = perm[i:min(i+batch_size, n)]
            x = train_images[idx]
            y = train_labels[idx].unsqueeze(1)
            
            # Forward
            if isinstance(model, NSCAModel):
                pred, _ = model(x)
            else:
                pred = model(x)
            
            loss = F.binary_cross_entropy(pred, y)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        if verbose and epoch % 20 == 0:
            print(f"  Epoch {epoch}: loss={total_loss/(n//batch_size):.4f}")
    
    return model


def evaluate(model, images, labels):
    """Evaluate model accuracy."""
    model.eval()
    with torch.no_grad():
        if isinstance(model, NSCAModel):
            pred, info = model(images)
        else:
            pred = model(images)
            info = {}
        
        pred_binary = (pred.squeeze() > 0.5).float()
        acc = (pred_binary == labels).float().mean().item()
    
    return acc, info

## 5. Main Experiment: Sample Efficiency

In [None]:
print("="*70)
print("EXPERIMENT: Does Physics Prior Improve Sample Efficiency?")
print("="*70)
print("")
print("Task: Predict if object will land on table from initial frame")
print("Physics knowledge needed: Objects fall straight down")
print("")

train_sizes = [20, 50, 100, 200, 500]
n_seeds = 5
epochs = 80

results = {
    'baseline': {n: [] for n in train_sizes},
    'nsca': {n: [] for n in train_sizes}
}

for seed in range(n_seeds):
    print(f"\n--- Seed {seed+1}/{n_seeds} ---")
    
    for n_train in train_sizes:
        # Subset training data
        torch.manual_seed(seed * 1000 + n_train)
        perm = torch.randperm(len(train_full.labels))[:n_train]
        train_imgs = train_full.images[perm]
        train_lbls = train_full.labels[perm]
        
        # Train baseline
        torch.manual_seed(seed * 100 + n_train)
        baseline = BaselineCNN()
        baseline = train_model(baseline, train_imgs, train_lbls, epochs=epochs)
        base_acc, _ = evaluate(baseline, test_data.images, test_data.labels)
        results['baseline'][n_train].append(base_acc)
        
        # Train NSCA
        torch.manual_seed(seed * 100 + n_train)
        nsca = NSCAModel(initial_prior_weight=0.5)
        nsca = train_model(nsca, train_imgs, train_lbls, epochs=epochs)
        nsca_acc, info = evaluate(nsca, test_data.images, test_data.labels)
        results['nsca'][n_train].append(nsca_acc)
        
        print(f"N={n_train:3d}  Baseline: {base_acc:.1%}  NSCA: {nsca_acc:.1%}  (prior_w={info.get('weight', 0):.2f})")

## 6. Results Analysis

In [None]:
print("\n" + "="*70)
print("RESULTS: Sample Efficiency Comparison")
print("="*70)
print(f"\n{'N_train':<10} {'Baseline':<20} {'NSCA (w/ prior)':<20} {'Difference'}")
print("-"*70)

diffs = []
for n in train_sizes:
    base = np.array(results['baseline'][n])
    nsca = np.array(results['nsca'][n])
    diff = nsca.mean() - base.mean()
    diffs.append(diff)
    
    print(f"{n:<10} {base.mean():.1%} +/- {base.std():.1%}      {nsca.mean():.1%} +/- {nsca.std():.1%}      {diff:+.1%}")

print("-"*70)
print(f"\nAverage advantage in low-data (N<=100): {np.mean(diffs[:3]):+.1%}")
print(f"Average advantage in high-data (N>100): {np.mean(diffs[3:]):+.1%}")

print("\n" + "="*70)
if np.mean(diffs[:3]) > 0.03:
    print("HYPOTHESIS SUPPORTED: Physics priors improve sample efficiency")
elif np.mean(diffs[:3]) > 0:
    print("MARGINAL SUPPORT: Small advantage with priors")
else:
    print("HYPOTHESIS NOT SUPPORTED in this configuration")
print("="*70)

In [None]:
# Plot results
plt.figure(figsize=(10, 6))

x = np.array(train_sizes)

base_means = [np.mean(results['baseline'][n]) for n in train_sizes]
base_stds = [np.std(results['baseline'][n]) for n in train_sizes]
nsca_means = [np.mean(results['nsca'][n]) for n in train_sizes]
nsca_stds = [np.std(results['nsca'][n]) for n in train_sizes]

plt.errorbar(x, base_means, yerr=base_stds, label='Baseline (no prior)', 
             marker='s', capsize=5, linewidth=2, markersize=8)
plt.errorbar(x, nsca_means, yerr=nsca_stds, label='NSCA (with physics prior)', 
             marker='o', capsize=5, linewidth=2, markersize=8)

# Prior-only baseline
plt.axhline(y=prior_acc, color='gray', linestyle='--', label=f'Prior only ({prior_acc:.0%})')

plt.xlabel('Number of Training Samples', fontsize=12)
plt.ylabel('Test Accuracy', fontsize=12)
plt.title('Sample Efficiency: Physics Prior on Stability Prediction', fontsize=14)
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3)
plt.xscale('log')
plt.ylim(0.4, 1.0)

plt.tight_layout()
plt.savefig('sample_efficiency_results.png', dpi=150)
plt.show()

## 7. Analysis: Why Results Matter

### If NSCA wins in low-data regime:
- Physics prior provides useful inductive bias
- "Objects fall down" knowledge transfers without learning
- Validates the NSCA hypothesis

### If Baseline catches up with more data:
- Expected behavior: priors buy efficiency, not accuracy ceiling
- Neural networks can learn physics from scratch with enough data

### If Baseline wins everywhere:
- Prior might not be accurate enough
- Task might be too simple (visual shortcuts exist)
- Need harder physics scenarios

In [None]:
print("\nEXPERIMENT COMPLETE")
print("\nNext steps:")
print("1. Try real Physion benchmark data")
print("2. Test with more complex physics (stacking, collisions)")
print("3. Measure prior weight adaptation over training")