# Dodge Timing Experiment

Can a pure CNN distinguish frames where dodge is about to happen vs not?

- **Positive samples**: 16 frames leading up to a dodge
- **Negative samples**: 16 frames from random points (no dodge within ±X frames)

In [3]:
import zarr
import numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
from flax.training import train_state
from pathlib import Path
from collections import defaultdict
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns

# Config
DATASET_PATH = '../dataset/margit_100_256x144.zarr'
NUM_FRAMES = 16  # Frames to stack
DODGE_IDX = 4  # dodge_roll/dash action index
NEGATIVE_BUFFER = 20  # Negative samples must be at least this far from any dodge
TRAIN_RATIO = 0.8
BATCH_SIZE = 32
NUM_EPOCHS = 20
LEARNING_RATE = 0.001

In [4]:
# Load dataset
zarr_root = zarr.open(DATASET_PATH, mode='r')
episodes = sorted([k for k in zarr_root.keys() if k.startswith('episode_')])
print(f"Loaded {len(episodes)} episodes")

Loaded 100 episodes


In [5]:
# Extract positive and negative samples
positive_samples = []  # (episode, end_frame) - frames [end_frame-15:end_frame+1] lead to dodge
negative_candidates = []  # (episode, end_frame) - no dodge nearby

for ep_name in tqdm(episodes, desc="Scanning episodes"):
    ep = zarr_root[ep_name]
    actions = ep['actions'][:]
    num_frames = actions.shape[0]
    
    # Find all dodge frames
    dodge_frames = set(np.where(actions[:, DODGE_IDX])[0])
    
    # Find dodge onsets (first frame of dodge press)
    dodge_onset_frames = []
    prev_dodge = False
    for i in range(num_frames):
        curr_dodge = actions[i, DODGE_IDX]
        if curr_dodge and not prev_dodge:
            dodge_onset_frames.append(i)
        prev_dodge = curr_dodge
    
    # Positive samples: 16 frames ending at dodge onset
    for onset in dodge_onset_frames:
        if onset >= NUM_FRAMES - 1:  # Need 16 frames before
            positive_samples.append((ep_name, onset))
    
    # Negative candidates: frames far from any dodge
    for i in range(NUM_FRAMES - 1, num_frames):
        # Check if any dodge within buffer
        nearby_dodge = False
        for d in range(i - NEGATIVE_BUFFER, i + NEGATIVE_BUFFER + 1):
            if d in dodge_frames:
                nearby_dodge = True
                break
        if not nearby_dodge:
            negative_candidates.append((ep_name, i))

print(f"Positive samples (dodge onsets): {len(positive_samples)}")
print(f"Negative candidates: {len(negative_candidates)}")

Scanning episodes:  13%|█▎        | 13/100 [00:00<00:00, 318.38it/s]


IndexError: index 4 is out of bounds for axis 1 with size 4

In [None]:
# Balance dataset - sample negatives to match positives
np.random.seed(42)
num_negatives = len(positive_samples)  # 1:1 ratio
negative_samples = [negative_candidates[i] for i in 
                    np.random.choice(len(negative_candidates), size=num_negatives, replace=False)]

print(f"Using {len(positive_samples)} positive, {len(negative_samples)} negative samples")

# Create labels
all_samples = positive_samples + negative_samples
all_labels = [1] * len(positive_samples) + [0] * len(negative_samples)

# Shuffle
indices = np.random.permutation(len(all_samples))
all_samples = [all_samples[i] for i in indices]
all_labels = [all_labels[i] for i in indices]

# Train/val split
split_idx = int(len(all_samples) * TRAIN_RATIO)
train_samples, val_samples = all_samples[:split_idx], all_samples[split_idx:]
train_labels, val_labels = all_labels[:split_idx], all_labels[split_idx:]

print(f"Train: {len(train_samples)}, Val: {len(val_samples)}")

In [None]:
def load_frame_stack(zarr_root, ep_name, end_frame, num_frames=16):
    """Load stack of frames ending at end_frame."""
    ep = zarr_root[ep_name]
    start_frame = end_frame - num_frames + 1
    frames = ep['frames'][start_frame:end_frame + 1]  # [T, C, H, W]
    # Normalize to [0, 1]
    frames = frames.astype(np.float32) / 255.0
    return frames

def create_batch(zarr_root, samples, labels, batch_indices):
    """Create a batch of frame stacks and labels."""
    batch_frames = []
    batch_labels = []
    for i in batch_indices:
        ep_name, end_frame = samples[i]
        frames = load_frame_stack(zarr_root, ep_name, end_frame, NUM_FRAMES)
        batch_frames.append(frames)
        batch_labels.append(labels[i])
    return np.stack(batch_frames), np.array(batch_labels, dtype=np.float32)

In [None]:
# Simple CNN for binary classification
class DodgeTimingCNN(nn.Module):
    """CNN to predict if dodge is imminent from frame stack."""
    
    @nn.compact
    def __call__(self, x, training: bool = True):
        # x: [B, T, C, H, W]
        B, T, C, H, W = x.shape
        
        # Stack frames as channels: [B, T*C, H, W]
        x = x.reshape(B, T * C, H, W)
        
        # Transpose to [B, H, W, C] for Flax conv
        x = jnp.transpose(x, (0, 2, 3, 1))
        
        # Conv layers
        x = nn.Conv(features=32, kernel_size=(3, 3), strides=(2, 2))(x)
        x = nn.BatchNorm(use_running_average=not training)(x)
        x = nn.relu(x)
        
        x = nn.Conv(features=64, kernel_size=(3, 3), strides=(2, 2))(x)
        x = nn.BatchNorm(use_running_average=not training)(x)
        x = nn.relu(x)
        
        x = nn.Conv(features=128, kernel_size=(3, 3), strides=(2, 2))(x)
        x = nn.BatchNorm(use_running_average=not training)(x)
        x = nn.relu(x)
        
        x = nn.Conv(features=256, kernel_size=(3, 3), strides=(2, 2))(x)
        x = nn.BatchNorm(use_running_average=not training)(x)
        x = nn.relu(x)
        
        # Global average pooling
        x = jnp.mean(x, axis=(1, 2))  # [B, 256]
        
        # Dense layers
        x = nn.Dense(features=128)(x)
        x = nn.relu(x)
        x = nn.Dropout(rate=0.5, deterministic=not training)(x)
        
        # Output: single logit for binary classification
        x = nn.Dense(features=1)(x)
        return x.squeeze(-1)  # [B]

In [None]:
# Initialize model
model = DodgeTimingCNN()
rng = jax.random.PRNGKey(42)
init_rng, dropout_rng = jax.random.split(rng)

# Dummy input for init
dummy_input = jnp.ones((1, NUM_FRAMES, 3, 144, 256))
variables = model.init({'params': init_rng, 'dropout': dropout_rng}, dummy_input, training=False)

# Count parameters
param_count = sum(x.size for x in jax.tree_util.tree_leaves(variables['params']))
print(f"Model parameters: {param_count:,}")

In [None]:
# Training state
class TrainState(train_state.TrainState):
    batch_stats: dict

tx = optax.adam(LEARNING_RATE)
state = TrainState.create(
    apply_fn=model.apply,
    params=variables['params'],
    tx=tx,
    batch_stats=variables['batch_stats'],
)

In [None]:
# Loss and train step
@jax.jit
def train_step(state, batch, labels, rng):
    def loss_fn(params):
        logits, updates = state.apply_fn(
            {'params': params, 'batch_stats': state.batch_stats},
            batch, training=True,
            mutable=['batch_stats'],
            rngs={'dropout': rng}
        )
        loss = optax.sigmoid_binary_cross_entropy(logits, labels).mean()
        return loss, (logits, updates)
    
    (loss, (logits, updates)), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
    state = state.apply_gradients(grads=grads)
    state = state.replace(batch_stats=updates['batch_stats'])
    
    preds = jax.nn.sigmoid(logits) > 0.5
    acc = (preds == labels).mean()
    return state, loss, acc

@jax.jit
def eval_step(state, batch, labels):
    logits = state.apply_fn(
        {'params': state.params, 'batch_stats': state.batch_stats},
        batch, training=False
    )
    loss = optax.sigmoid_binary_cross_entropy(logits, labels).mean()
    preds = jax.nn.sigmoid(logits) > 0.5
    acc = (preds == labels).mean()
    probs = jax.nn.sigmoid(logits)
    return loss, acc, preds, probs

In [None]:
# Training loop
train_losses, val_losses = [], []
train_accs, val_accs = [], []

rng = jax.random.PRNGKey(0)

for epoch in range(NUM_EPOCHS):
    # Shuffle training data
    perm = np.random.permutation(len(train_samples))
    
    # Training
    epoch_losses, epoch_accs = [], []
    for i in tqdm(range(0, len(train_samples), BATCH_SIZE), desc=f"Epoch {epoch+1}/{NUM_EPOCHS}", leave=False):
        batch_idx = perm[i:i+BATCH_SIZE]
        if len(batch_idx) < BATCH_SIZE:
            continue
        
        batch, labels = create_batch(zarr_root, train_samples, train_labels, batch_idx)
        batch = jnp.array(batch)
        labels = jnp.array(labels)
        
        rng, step_rng = jax.random.split(rng)
        state, loss, acc = train_step(state, batch, labels, step_rng)
        epoch_losses.append(float(loss))
        epoch_accs.append(float(acc))
    
    train_loss = np.mean(epoch_losses)
    train_acc = np.mean(epoch_accs)
    train_losses.append(train_loss)
    train_accs.append(train_acc)
    
    # Validation
    val_epoch_losses, val_epoch_accs = [], []
    for i in range(0, len(val_samples), BATCH_SIZE):
        batch_idx = list(range(i, min(i+BATCH_SIZE, len(val_samples))))
        if len(batch_idx) < 2:
            continue
        batch, labels = create_batch(zarr_root, val_samples, val_labels, batch_idx)
        batch = jnp.array(batch)
        labels = jnp.array(labels)
        
        loss, acc, _, _ = eval_step(state, batch, labels)
        val_epoch_losses.append(float(loss))
        val_epoch_accs.append(float(acc))
    
    val_loss = np.mean(val_epoch_losses)
    val_acc = np.mean(val_epoch_accs)
    val_losses.append(val_loss)
    val_accs.append(val_acc)
    
    print(f"Epoch {epoch+1}: Train Loss={train_loss:.4f}, Acc={train_acc:.4f} | Val Loss={val_loss:.4f}, Acc={val_acc:.4f}")

In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

axes[0].plot(train_losses, label='Train')
axes[0].plot(val_losses, label='Val')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Loss Curve')
axes[0].legend()

axes[1].plot(train_accs, label='Train')
axes[1].plot(val_accs, label='Val')
axes[1].axhline(y=0.5, color='gray', linestyle='--', label='Random')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].set_title('Accuracy Curve')
axes[1].legend()

plt.tight_layout()
plt.show()

print(f"\nBest Val Accuracy: {max(val_accs):.4f}")

In [None]:
# Full validation evaluation
all_preds, all_probs, all_true = [], [], []

for i in range(0, len(val_samples), BATCH_SIZE):
    batch_idx = list(range(i, min(i+BATCH_SIZE, len(val_samples))))
    if len(batch_idx) < 1:
        continue
    batch, labels = create_batch(zarr_root, val_samples, val_labels, batch_idx)
    batch = jnp.array(batch)
    labels = jnp.array(labels)
    
    _, _, preds, probs = eval_step(state, batch, labels)
    all_preds.extend(preds.tolist())
    all_probs.extend(probs.tolist())
    all_true.extend(labels.tolist())

all_preds = np.array(all_preds)
all_probs = np.array(all_probs)
all_true = np.array(all_true)

print("Classification Report:")
print(classification_report(all_true, all_preds, target_names=['No Dodge', 'Dodge']))

In [None]:
# Confusion matrix
cm = confusion_matrix(all_true, all_preds)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=['No Dodge', 'Dodge'],
            yticklabels=['No Dodge', 'Dodge'])
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.show()

In [None]:
# Probability distribution
fig, ax = plt.subplots(figsize=(10, 6))

ax.hist(all_probs[all_true == 0], bins=50, alpha=0.5, label='No Dodge (True)', color='blue')
ax.hist(all_probs[all_true == 1], bins=50, alpha=0.5, label='Dodge (True)', color='red')
ax.axvline(x=0.5, color='black', linestyle='--', label='Threshold')
ax.set_xlabel('Predicted Probability of Dodge')
ax.set_ylabel('Count')
ax.set_title('Probability Distribution by True Label')
ax.legend()
plt.show()

print(f"\n=== CONCLUSION ===")
if max(val_accs) > 0.65:
    print(f"Model achieves {max(val_accs)*100:.1f}% accuracy - VISUAL SIGNAL EXISTS!")
    print("The CNN can distinguish pre-dodge frames from non-dodge frames.")
    print("Issue is likely in how temporal_cnn uses this signal, not the signal itself.")
elif max(val_accs) > 0.55:
    print(f"Model achieves {max(val_accs)*100:.1f}% accuracy - WEAK VISUAL SIGNAL")
    print("Some signal exists but it's noisy. May need more data or augmentation.")
else:
    print(f"Model achieves {max(val_accs)*100:.1f}% accuracy - NO CLEAR VISUAL SIGNAL")
    print("The visual frames alone don't predict dodge timing well.")
    print("Need to rely more on state features (NPC animation, positions).")