In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
import math

# ==========================================
# 1. CONFIGURATION (MAP YOUR COLUMNS HERE)
# ==========================================
FEATURE_CONFIG = {
    # Continuous features to normalize and input
    'continuous_cols': ['x', 'y', 's', 'dir', 'o'], 
    
    # Categorical/Static features
    'role_col': 'role_label',       # e.g., 0=CB, 1=LB, 2=S...
    'coverage_col': 'coverage_type', # e.g., 0=Man, 1=Zone...
    
    # Dimensions
    'num_agents': 22,      # 11 Offense + 11 Defense (or 23 with ball)
    'num_roles': 10,       # Max number of unique position roles
    'num_coverages': 5,    # Max number of coverage types
    'input_dim': 5,        # Length of 'continuous_cols'
    'hidden_dim': 256,
    'latent_dim': 32,
    'nhead': 4,
    'num_layers': 4
}

class DefensiveGhostModel(pl.LightningModule):
    def __init__(self, cfg=FEATURE_CONFIG):
        super().__init__()
        self.save_hyperparameters()
        self.cfg = cfg

        # -------------------------------------------------------
        # A. Embeddings (Feature Extraction)
        # -------------------------------------------------------
        # Projects continuous features (x, y, s...) to hidden_dim
        self.feature_embedding = nn.Linear(cfg['input_dim'], cfg['hidden_dim'])
        
        # Learnable embeddings for discrete inputs
        self.role_embedding = nn.Embedding(cfg['num_roles'], cfg['hidden_dim'])
        self.coverage_embedding = nn.Embedding(cfg['num_coverages'], cfg['hidden_dim'])
        
        # Positional Encoding (Temporal)
        self.pos_encoder = PositionalEncoding(cfg['hidden_dim'])

        # -------------------------------------------------------
        # B. The Encoders (Transformer)
        # -------------------------------------------------------
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=cfg['hidden_dim'], 
            nhead=cfg['nhead'], 
            batch_first=True,
            norm_first=True
        )
        
        # 1. Past Encoder: Encodes history + context (coverage)
        self.past_encoder = nn.TransformerEncoder(encoder_layer, num_layers=cfg['num_layers'])
        
        # 2. Future Encoder (Training only): Encodes the target future to learn Z
        self.future_encoder = nn.TransformerEncoder(encoder_layer, num_layers=cfg['num_layers'] // 2)

        # -------------------------------------------------------
        # C. The Latent Space (CVAE)
        # -------------------------------------------------------
        # Maps "Future Info" -> Mean (mu) and Log-Variance (logvar)
        self.z_mean = nn.Linear(cfg['hidden_dim'], cfg['latent_dim'])
        self.z_logvar = nn.Linear(cfg['hidden_dim'], cfg['latent_dim'])
        
        # Projects Z back to hidden dimension to merge with Past
        self.z_projection = nn.Linear(cfg['latent_dim'], cfg['hidden_dim'])

        # -------------------------------------------------------
        # D. The Decoder (Trajectory Generator)
        # -------------------------------------------------------
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=cfg['hidden_dim'], 
            nhead=cfg['nhead'], 
            batch_first=True,
            norm_first=True
        )
        self.trajectory_decoder = nn.TransformerDecoder(decoder_layer, num_layers=cfg['num_layers'])

        # Output Head: Predicts (x, y) coordinates (or velocity)
        self.output_head = nn.Linear(cfg['hidden_dim'], 2) 

    def embed_agents(self, x, roles):
        """Helper to combine continuous features + role embeddings"""
        # x shape: (Batch, Time, Agents, Features)
        B, T, A, F_dim = x.shape
        
        # Flatten Batch/Time for Linear Layer
        emb = self.feature_embedding(x) # (B, T, A, H)
        
        # Add Role Embeddings (broadcast across time)
        # roles shape: (Batch, Agents) -> (Batch, 1, Agents, H)
        role_emb = self.role_embedding(roles).unsqueeze(1)
        
        return emb + role_emb

    def forward(self, batch):
        """
        Forward pass mainly used for INFERENCE.
        Returns predicted trajectory.
        """
        history = batch['history']      # (B, T_obs, A, F)
        roles = batch['roles']          # (B, A)
        coverage = batch['coverage']    # (B,)
        
        # 1. Encode History
        B, T_obs, A, _ = history.shape
        
        # Flatten Agents into the Batch dimension or treat as super-sequence
        # For simplicity: We pool agents to create a single "Play State" per timestep
        # (A more advanced version would keep agents separate and use masked attention)
        hist_emb = self.embed_agents(history, roles) # (B, T, A, H)
        hist_emb = hist_emb.mean(dim=2) # Mean pool over agents -> (B, T, H)
        
        # Add Coverage Context
        cov_emb = self.coverage_embedding(coverage).unsqueeze(1) # (B, 1, H)
        hist_emb = hist_emb + cov_emb # Broadcasting addition
        
        hist_emb = self.pos_encoder(hist_emb)
        memory = self.past_encoder(hist_emb)
        
        # 2. Sample Latent Z (From Prior N(0,1) because we are inferring)
        z = torch.randn(B, 1, self.cfg['latent_dim'], device=self.device)
        z_proj = self.z_projection(z) # (B, 1, H)
        
        # 3. Decode Future (Autoregressive loop is usually done here)
        # For this example, we generate 'pred_len' steps at once using the memory
        tgt_query = z_proj.repeat(1, 50, 1) # Predict 50 frames, repeating Z as query
        tgt_query = self.pos_encoder(tgt_query)
        
        output = self.trajectory_decoder(tgt=tgt_query, memory=memory)
        predictions = self.output_head(output) # (B, 50, 2) -> (dx, dy)
        
        return predictions

    def training_step(self, batch, batch_idx):
        """
        CVAE Training Logic
        """
        history = batch['history']   # Past (t=0 to 10)
        future = batch['future']     # Target (t=11 to 60)
        roles = batch['roles']
        coverage = batch['coverage']
        
        # --- A. ENCODE PAST ---
        hist_emb = self.embed_agents(history, roles).mean(dim=2) # (B, T_past, H)
        cov_emb = self.coverage_embedding(coverage).unsqueeze(1)
        hist_emb = self.pos_encoder(hist_emb + cov_emb)
        memory = self.past_encoder(hist_emb)
        
        # --- B. ENCODE FUTURE (To learn Z) ---
        fut_emb = self.embed_agents(future, roles).mean(dim=2)
        fut_emb = self.pos_encoder(fut_emb)
        # We take the final hidden state of the future as the "summary"
        future_summary = self.future_encoder(fut_emb)[:, -1, :] 
        
        # --- C. LATENT SPACE ---
        mu = self.z_mean(future_summary)
        logvar = self.z_logvar(future_summary)
        
        # Reparameterization Trick: z = mu + std * eps
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z = mu + eps * std
        
        # --- D. DECODE ---
        # Project Z to be the "seed" for the decoder
        z_proj = self.z_projection(z).unsqueeze(1) # (B, 1, H)
        
        # The decoder query needs to be the length of the future
        # We combine Z with the future embeddings (Teacher Forcing: feed correct inputs)
        # In practice, we usually add Z to the target embeddings
        tgt_emb = self.pos_encoder(fut_emb) + z_proj 
        
        # Causal Mask (prevent seeing future positions during decoding)
        T_fut = future.shape[1]
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(T_fut).to(self.device)
        
        pred_output = self.trajectory_decoder(tgt=tgt_emb, memory=memory, tgt_mask=tgt_mask)
        pred_coords = self.output_head(pred_output) # Predicted (dx, dy)
        
        # --- E. LOSS CALCULATION ---
        target_coords = future[:, :, :, :2].mean(dim=2) # Taking mean agent movement for simplicity in this snippet
        
        # 1. Reconstruction Loss (MSE)
        recon_loss = F.mse_loss(pred_coords, target_coords)
        
        # 2. KL Divergence (force Z to be Normal)
        kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        kl_loss /= B # Normalize by batch
        
        total_loss = recon_loss + (0.001 * kl_loss) # Weight KL term small
        
        self.log("train_loss", total_loss)
        return total_loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=1e-4)

# ==========================================
# Helper: Positional Encoding
# ==========================================
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=500):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        # x: (Batch, Seq_Len, Dim)
        return x + self.pe[:x.size(1), :]