In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class PositionalEncoding(nn.Module):
    """
    MATHEMATICAL LOGIC:
    Since Transformers process sequences in parallel, they lack an inherent sense of time.
    We inject a sinusoidal signal PE(pos, 2i) = sin(pos / 10000^(2i/d_model)) to encode 
    temporal order. This allows the model to distinguish between an acceleration spike 
    at t=65 versus steady-state motion at t=80.
    """
    def __init__(self, d_model, max_len=500):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        
        # div_term defines the wavelength of the sinusoidal signals.
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        # Math: z_t = embedding_vector_t + PE_vector_t
        return x + self.pe[:, :x.size(1), :]

class PhysicsTransformerEstimator(nn.Module):
    def __init__(
        self,
        input_dim=2,         # Subscript x_t: [vel, acc]
        d_model=64,          # Latent dimension (d)
        nhead=4,
        # =========================================================================================
        # MULTI-HEAD ATTENTION (nhead=4)
        # =========================================================================================
        # 1. SPLITTING: The d_model (64) is split into 4 heads of 16 dimensions each.
        #    Math: head_dim = d_model // nhead = 16.
        #
        # 2. PARALLEL PHYSICS: Each head has its own W_Q, W_K, and W_V matrices.
        #    This allows Head 1 to spotlight the 'Spike' while Head 2 spotlights the 'Slide'.
        #
        # 3. CONCATENATION: The 4 outputs [Batch, 20, 16] are glued together to form [Batch, 20, 64].
        #    Final Context = Concat(head_1, head_2, head_3, head_4) @ W_O.
        #
        # 4. WHY 4?: It provides enough diversity to capture mass (transient) and 
        #    friction (steady-state) without making the feature space too small.
        # =========================================================================================
        num_encoder_layers=2, 
        dim_feedforward=128, 
        seq_len=20,          # T = 20
        dropout=0.1          
    ):
        super().__init__()
        self.d_model = d_model
        self.seq_len = seq_len

        # ============================================================
        # STEP 1: CONTINUOUS EMBEDDING & ENCODER
        # Math: z_t = W_in * x_t + b_in
        # W_in and b_in are TRAINABLE parameters. They learn how to 
        # map raw kinematics into a high-dimensional physical feature space.
        # ============================================================
        self.input_proj = nn.Linear(input_dim, d_model)
        self.pos_encoder = PositionalEncoding(d_model, max_len=seq_len + 10)
        
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=nhead, 
            dim_feedforward=dim_feedforward, 
            batch_first=True, dropout=dropout
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)

        # ============================================================
        # STEP 2: DECODER (GENERATING h_t_dec)
        # Math: h_t_dec = CrossAttention(q_t_query, H_enc, H_enc)
        # h_t_dec is the hidden state for time step 't' that integrates 
        # the entire motion context into a force-ready representation.
        # ============================================================
        
        # Parameter: Learned Time Queries
        # Action: A trainable matrix representing "what the model wants to know" about each time step.
        # MATH: In an LLM, Query = W_Q * Embedding. In your case, these ARE the Queries (Parameters).
        # Instead of calculating them from the input, the model optimizes these 64D vectors 
        # during training to become "templates" for each of the 20 time steps.
        self.time_queries = nn.Parameter(torch.randn(1, seq_len, d_model))
        
        # Layer: Cross-Attention
        # Logic: The model performs the standard dot-product attention: 
        # Score = (q_dec * W_Q) @ (h_enc * W_K).T
        # Even though q_dec starts as a fixed parameter, the 'MultiheadAttention' layer 
        # still applies a learned W_Q weight matrix to it during the forward pass.
        self.cross_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True, dropout=dropout)
        
        self.norm_dec = nn.LayerNorm(d_model)
        self.ffn_dec = nn.Sequential(
            nn.Linear(d_model, dim_feedforward), 
            nn.ReLU(), 
            nn.Linear(dim_feedforward, d_model)
        )
        self.norm_ffn = nn.LayerNorm(d_model)

        # ============================================================
        # STEP 3: INTERMEDIATE FORCE SEQUENCES
        # Math: F_net_t = MLP_net(h_t_dec) and F_fric_t = MLP_fric(h_t_dec)
        # These sequences decouple the transient (ma) from the steady (mu*m*g).
        # ============================================================
        self.net_force_mlp = nn.Sequential(
            nn.Linear(d_model, d_model), nn.ReLU(), nn.Linear(d_model, d_model)
        ) 
        self.fric_force_mlp = nn.Sequential(
            nn.Linear(d_model, d_model), nn.ReLU(), nn.Linear(d_model, d_model)
        ) 

        # ============================================================
        # STEP 4: GLOBAL READOUTS (MASS & MU)
        # We use Cross-Attention as a "Spotlight" to reduce sequences to scalars.
        # Alpha_t^m: Weight assigned to frame t for mass calculation.
        # ============================================================
        
        # --- A. MASS ESTIMATION ---
        # Logic: q_mass learns to spotlight the acceleration peak (e.g., t=65).
        
        # Parameter: Mass Query ("Spotlight")
        # MATH: Unlike an LLM where queries change per-word, self.q_mass is a GLOBAL Query.
        # Think of it as a specialized "sensor" that is permanently tuned to the 64D 
        # frequency of a mass-impact event. It is optimized through backpropagation 
        # to have a high dot-product similarity with the 'feat_net' vectors at t=65.
        self.q_mass = nn.Parameter(torch.randn(1, 1, d_model))
        self.mass_attn = nn.MultiheadAttention(d_model, 1, batch_first=True)
        self.mass_pred_mlp = nn.Sequential(
            nn.Linear(d_model, 64), nn.ReLU(), nn.Linear(64, 1)
        )

        # --- B. FRICTION ESTIMATION ---
        # Physics: F_fric = mu * m * g. 
        # Logic: We pass predicted Mass into the MLP so it can solve for mu = F / (m*g).

        # Parameter: Friction Query
        # MATH: Similarly, q_fric is a global learned parameter. It doesn't rely on 
        # input multiplication to exist; it is a dedicated query vector that has 
        # learned the "look" of steady-state latent friction features (t=75-83).
        self.q_fric = nn.Parameter(torch.randn(1, 1, d_model))
        self.fric_attn = nn.MultiheadAttention(d_model, 1, batch_first=True)
        self.mu_pred_mlp = nn.Sequential(
            nn.Linear(d_model + 1, 64), nn.ReLU(), nn.Linear(64, 1)
        )

    def forward(self, extracted_acc, extracted_vel):
        # x_t shape: [Batch, 20, 2]
        x = torch.cat([extracted_vel, extracted_acc], dim=-1) 
        B, T, _ = x.shape

        # ENCODER: Motion Context (h_enc)
        # Every frame looks at every other frame to identify "events" (spikes/sliding).
        # EXAMPLE: At t=65 (the deceleration spike), self-attention "notices" the sudden drop 
        # in velocity compared to t=50. It weights these frames highly to identify a 
        # "Transient Event," which is the critical window for mass estimation.
        # Conversely, at t=80, the model identifies "Steady Sliding" because acceleration 
        # remains near zero while velocity is constant.
        
        # Action: Project the 2D input [v, a] into a 64D Physical Feature Space.
        # Once trained, W_in acts as a physical feature extractor. 
        # EXAMPLE: Just as "King" and "Man" are close in LLM embeddings, W_in maps 
        # (v=0.04, a=-3.5) to a "Heavy Brake" vector and (v=0.04, a=0.0) to a 
        # "Steady Slide" vector. These distinct regions in 64D space allow the 
        # Transformer to decouple inertial effects from surface friction.
        z = self.input_proj(x)
        z = self.pos_encoder(z)
        h_enc = self.transformer_encoder(z)


        # =========================================================================================
        # 1. DECODER CROSS-ATTENTION: attn_output, _ = self.cross_attn(query=q_dec, key=h_enc, value=h_enc)
        # =========================================================================================
        # Even though we pass 'q_dec' and 'h_enc' directly, nn.MultiheadAttention 
        # INTERNALLY contains weight matrices W_Q, W_K, and W_V.
        #
        # Process per Attention Head (Head_i):
        # A. Projection: 
        #    - Query (Q_i) = q_dec * W_i_Q (Your 20 learned time templates)
        #    - Key (K_i) = h_enc * W_i_K (The 20 encoder context frames)
        #    - Value (V_i) = h_enc * W_i_V (The actual features/information at those frames)
        #
        # B. Calculation:
        #    - Score_Matrix = (Q_i @ K_i.T) / sqrt(d)
        #    - This matrix is [20 queries x 20 keys], representing how much each query looks at each key.
        #
        # C. Output Generation:
        #    - Head_Output = Softmax(Score_Matrix) @ V_i
        #
        # =========================================================================================
        # 2. SOFTMAX LOGIC: ONE QUERY vs. ALL KEYS
        # =========================================================================================
        # Softmax is applied ROW-WISE (one query vs. all keys).
        #
        # Math: For a specific Query 't', AttentionWeight(t, j) = exp(Score_t_j) / sum_across_all_j(exp(Score_t_j)).
        #
        # Purpose:
        # - It normalizes the dot-product multiplication (which can be any real number) into a 
        #   probability distribution.
        # - For 'q_mass', this ensures the "spotlight" focuses on the acceleration spike at t=65 
        #   by giving it a probability near 1.0, while quiet frames get near 0.0.
        #
        # =========================================================================================
        # 3. WHY NOT JUST SOFTMAX(QK.T)?
        # =========================================================================================
        # We use Softmax( (QK.T) / sqrt(d) ) to stabilize gradients.
        #
        # Reason:
        # - As the dimensionality 'd' increases, the magnitude of the dot product (Q @ K.T) grows large.
        # - Large values push the Softmax function into regions where the gradient is extremely small.
        # - Dividing by the scaling factor sqrt(d) keeps the values in a range where the model 
        #   can still learn effectively during training.
        #
        # =========================================================================================
        # 4. VALUE MATRIX IN YOUR CASE: WHAT IS ADDED?
        # =========================================================================================
        # In the LLM "fluffy creature" example, 'fluffy' adds a "texture" signal to 'creature'.
        #
        # In your robotics case (Motion Context -> Force Sequence):
        # - Query (q_dec_t): "I am time slot t. What is the physical state here?"
        # - Key (h_enc_j): "I am frame j, and I have a huge acceleration spike."
        # - Value (v_j): Multiplication (h_enc_j * W_V).
        #
        # Conceptual Meaning of Value Multiplication:
        # - If a frame is relevant (High Attention Score), the Value represents: 
        #   "What specific force features should be written into this time slot?"
        # - For Sample 441, the Value at t=65 adds an "11 Newton" signal to the decoder state.
        # - For Sample 2276, the Value at t=65 adds a "9 Newton" signal to the decoder state.
        # - The model doesn't just pass the input; the Value matrix learns how to transform 
        #   kinematic history into force-generating features.
        # =========================================================================================
        q_dec = self.time_queries.expand(B, -1, -1)

        # The layer internally applies W_Q to q_dec and W_K to h_enc here.
        # attn_output, _ = self.cross_attn(query=q_dec, key=h_enc, value=h_enc)
        attn_output, attn_weights = self.cross_attn(
            query=q_dec, 
            key=h_enc, 
            value=h_enc, 
            average_attn_weights=False  # To visualize the individual attention patterns of the 4 heads within the cross_attn layer
        )
        h_dec = self.norm_dec(q_dec + attn_output)
        h_dec = self.norm_ffn(h_dec + self.ffn_dec(h_dec))

        # FORCE SEQUENCES: Decoupling Dynamics
        # feat_net_t targets the Net Force Signal.
        # feat_fric_t targets the Friction Signal.
        feat_net = self.net_force_mlp(h_dec)
        feat_fric = self.fric_force_mlp(h_dec)

        # GLOBAL PROPERTY READOUT: Mass Prediction
        # -------------------------------------------------------------------------
        # MATH: Inside mass_attn, W_Q is applied to q_m_batch and W_K to feat_net.
        # Query: "Looking for inertial spikes."
        # Key:   "Answering with latent net-force signatures (e.g., the 11N peak)."
        # -------------------------------------------------------------------------
        q_m_batch = self.q_mass.expand(B, -1, -1)
        mass_ctx, _ = self.mass_attn(query=q_m_batch, key=feat_net, value=feat_net)
        mass_pred = self.mass_pred_mlp(mass_ctx.squeeze(1))

        # GLOBAL PROPERTY READOUT: Friction Prediction
        # -------------------------------------------------------------------------
        # MATH: Inside fric_attn, W_Q is applied to q_f_batch and W_K to feat_fric.
        # Query: "Looking for steady-state sliding."
        # Key:   "Answering with latent friction-force signatures."
        # -------------------------------------------------------------------------
        q_f_batch = self.q_fric.expand(B, -1, -1)
        fric_ctx, _ = self.fric_attn(query=q_f_batch, key=feat_fric, value=feat_fric)
        
        # Concatenate Mass for Physics Consistency
        # Math: mu = f(Mass_Context, Mass_Scalar)
        fric_input = torch.cat([fric_ctx.squeeze(1), mass_pred], dim=-1)
        mu_pred = self.mu_pred_mlp(fric_input)

        # Result: [Batch, 2] -> [Mass, Mu]
        return torch.cat([mass_pred, mu_pred], dim=-1)

In [None]:
import torch
import matplotlib.pyplot as plt
import os

# ==========================================
# LOAD CHECKPOINT & VISUALIZE
# ==========================================

# --- CONFIGURATION: WHICH MODEL TO LOAD? ---
# Set this to the epoch number you want to load (e.g., 50000)
# Or set to None to use the model currently in memory (if you just finished training)
LOAD_EPOCH = 100
CRITERION_TYPE = "mse"
LOSS_TYPE = "data"
CHECKPOINT_DIR = f"checkpoints/{CRITERION_TYPE}_{LOSS_TYPE}"

# 1. Initialize Model Structure (Must match training config)
#    (We re-initialize to ensure we are testing a clean state, or if running this later)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = PhysicsTransformerEstimator(
    input_dim=2, 
    d_model=64, 
    nhead=4, 
    num_encoder_layers=2, 
    seq_len=20 # Make sure this matches your data prep
).to(device)

# 2. Load Weights
if LOAD_EPOCH is not None:
    load_path = f"{CHECKPOINT_DIR}/transformer_epoch{LOAD_EPOCH}.pth"
    
    if os.path.exists(load_path):
        print(f"üîÑ Loading model from: {load_path}")
        state_dict = torch.load(load_path, map_location=device)
        model.load_state_dict(state_dict)
        print("‚úÖ Model weights loaded successfully!")
    else:
        print(f"‚ùå Error: Checkpoint not found at {load_path}")
        print("   Using random initialization or current model state.")
else:
    print("‚ÑπÔ∏è Using current model state (no file loaded).")

# 3. VISUALIZATION
print("\n--- Running Evaluation ---")

# A. Loss Curve (Only available if you just trained, otherwise skip)
if 'train_losses' in locals() and len(train_losses) > 0:
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Train Loss', linewidth=2)
    plt.plot(val_losses, label='Validation Loss', linewidth=2)
    plt.title(f'Training Curve (Loaded Epoch: {LOAD_EPOCH})')
    plt.xlabel('Epoch')
    plt.ylabel('MSE Loss')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()
else:
    print("‚ÑπÔ∏è No training loss history found in memory (skipped plot).")

# B. Prediction Check (First 5 samples of Val Set)
model.eval()
with torch.no_grad():
    # Get a batch from validation loader
    try:
        sample_acc, sample_vel, sample_y = next(iter(val_loader))
    except NameError:
        print("‚ùå Error: 'val_loader' is not defined. Please run Data Preparation first.")
        sample_acc = None

    if sample_acc is not None:
        sample_acc, sample_vel = sample_acc.to(device), sample_vel.to(device)
        preds = model(sample_acc, sample_vel).cpu()
        
        print("\n--- Sample Predictions (Val Set) ---")
        print(f"{'GT Mass':<10} {'Pred Mass':<10} | {'GT Mu':<10} {'Pred Mu':<10}")
        print("-" * 45)
        for i in range(5):
            print(f"{sample_y[i,0]:.4f}     {preds[i,0]:.4f}     | {sample_y[i,1]:.4f}     {preds[i,1]:.4f}")