# Linear Pedagogical Example (Real Data)

Uses `bab_datasets` to load multisine data and fit the linear physical model.


In [None]:
# !pip install torchdiffeq git+https://github.com/helonayala/bab_datasets.git
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torchdiffeq import odeint
import bab_datasets as nod

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


In [None]:
# ==========================================
# 1. LOAD DATA (multisine_05)
# ==========================================
velMethod = "central"

data = nod.load_experiment(
    "multisine_05",
    preprocess=True,
    plot=True,
    end_idx=None,
    resample_factor=50,
    zoom_last_n=200,
    y_dot_method=velMethod,
)

u, y, y_ref, y_dot = data

# Time vector from sampling_time
Ts = data.sampling_time
t = np.arange(len(u)) * Ts

# Stack states: [position, velocity]
y_sim = np.column_stack([y, y_dot])


In [None]:
# ==========================================
# 2. PREPARE TENSORS
# ==========================================
t_tensor = torch.tensor(t, dtype=torch.float32).to(device)
u_tensor = torch.tensor(u, dtype=torch.float32).reshape(-1, 1).to(device)
y_tensor = torch.tensor(y_sim, dtype=torch.float32).to(device)


In [None]:
# ==========================================
# 3. MODEL DEFINITION
# ==========================================
class LinearPhysODE(nn.Module):
    # J * thdd + R * thd + K * (th + delta) = Tau * V
    # States: [th, thd]
    def __init__(self):
        super().__init__()
        self.log_J = nn.Parameter(torch.tensor(np.log(0.1), dtype=torch.float32))
        self.log_R = nn.Parameter(torch.tensor(np.log(0.1), dtype=torch.float32))
        self.log_K = nn.Parameter(torch.tensor(np.log(1.0), dtype=torch.float32))
        self.delta = nn.Parameter(torch.tensor(0.0, dtype=torch.float32))
        self.log_Tau = nn.Parameter(torch.tensor(np.log(1.0), dtype=torch.float32))

        self.u_series = None
        self.t_series = None
        self.batch_start_times = None

    def get_params(self):
        J = torch.exp(self.log_J)
        R = torch.exp(self.log_R)
        K = torch.exp(self.log_K)
        Tau = torch.exp(self.log_Tau)
        return J, R, K, self.delta, Tau

    def forward(self, t, x):
        J, R, K, delta, Tau = self.get_params()

        if self.batch_start_times is not None:
            t_abs = self.batch_start_times + t
        else:
            t_abs = t * torch.ones_like(x[:, 0:1])

        k_idx = torch.searchsorted(self.t_series, t_abs.reshape(-1), right=True)
        k_idx = torch.clamp(k_idx, 1, len(self.t_series) - 1)

        t1, t2 = self.t_series[k_idx - 1].unsqueeze(1), self.t_series[k_idx].unsqueeze(1)
        u1, u2 = self.u_series[k_idx - 1], self.u_series[k_idx]

        denom = (t2 - t1)
        denom[denom < 1e-6] = 1.0
        alpha = (t_abs - t1) / denom
        u_t = u1 + alpha * (u2 - u1)

        th, thd = x[:, 0:1], x[:, 1:2]
        thdd = (Tau * u_t - R * thd - K * (th + delta)) / J
        return torch.cat([thd, thdd], dim=1)


In [None]:
# ==========================================
# 4. TRAINING FUNCTION
# ==========================================

def train_model(model, name, epochs=500, lr=0.02):
    print(f"--- Training {name} ---")
    model.to(device)
    model.u_series = u_tensor
    model.t_series = t_tensor

    optimizer = optim.Adam(model.parameters(), lr=lr)

    K_STEPS = 20
    BATCH_SIZE = 128
    dt_local = (t_tensor[1] - t_tensor[0]).item()
    t_eval = torch.arange(0, K_STEPS * dt_local, dt_local, device=device)

    for epoch in range(epochs + 1):
        optimizer.zero_grad()

        start_idx = np.random.randint(0, len(t_tensor) - K_STEPS, size=BATCH_SIZE)
        x0 = y_tensor[start_idx]
        model.batch_start_times = t_tensor[start_idx].reshape(-1, 1)

        pred_state = odeint(model, x0, t_eval, method='rk4')

        batch_targets = []
        for i in start_idx:
            batch_targets.append(y_tensor[i:i + K_STEPS])
        y_target = torch.stack(batch_targets, dim=1)

        loss = torch.mean((pred_state - y_target) ** 2)
        loss.backward()
        optimizer.step()

        if epoch % 100 == 0:
            print(f"Epoch {epoch} | Loss: {loss.item():.6f}")

    return model


In [None]:
# ==========================================
# 5. TRAIN PHYSICAL MODEL
# ==========================================
phys_model = LinearPhysODE()
phys_model = train_model(phys_model, "Linear Physical Model", epochs=500, lr=0.02)

J, R, K, delta, Tau = phys_model.get_params()
print(f"identified: J={J.item():.4f}, R={R.item():.4f}, K={K.item():.4f}, delta={delta.item():.4f}, Tau={Tau.item():.4f}")


In [None]:
# ==========================================
# 6. FULL SIMULATION & COMPARISON
# ==========================================
print("--- Running Full Data Simulation ---")

with torch.no_grad():
    phys_model.batch_start_times = torch.zeros(1, 1).to(device)
    x0 = y_tensor[0].unsqueeze(0)

    full_pred = odeint(phys_model, x0, t_tensor, method='rk4')
    full_pred = full_pred.squeeze(1).cpu().numpy()

plt.figure(figsize=(12, 6))

plt.subplot(1, 2, 1)
plt.plot(t, y_sim[:, 0], 'k-', alpha=0.4, linewidth=3, label='Measured')
plt.plot(t, full_pred[:, 0], 'r--', linewidth=1.5, label='Simulated')
plt.title("Position Comparison")
plt.xlabel("Time (s)")
plt.ylabel("Position")
plt.legend()
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(t, y_sim[:, 1], 'k-', alpha=0.4, linewidth=3, label='Measured')
plt.plot(t, full_pred[:, 1], 'r--', linewidth=1.5, label='Simulated')
plt.title("Velocity Comparison")
plt.xlabel("Time (s)")
plt.ylabel("Velocity")
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()


In [None]:
# ==========================================
# 7. RESIDUALS (Measured - Simulated)
# ==========================================
res_pos = y_sim[:, 0] - full_pred[:, 0]
res_vel = y_sim[:, 1] - full_pred[:, 1]

plt.figure(figsize=(12, 6))

plt.subplot(2, 1, 1)
plt.plot(t, res_pos, 'k-', linewidth=1.5, label='Position residual')
plt.axhline(0, color='gray', linewidth=1)
plt.title("Residuals")
plt.ylabel("Position residual")
plt.grid(True)
plt.legend()

plt.subplot(2, 1, 2)
plt.plot(t, res_vel, 'k-', linewidth=1.5, label='Velocity residual')
plt.axhline(0, color='gray', linewidth=1)
plt.ylabel("Velocity residual")
plt.xlabel("Time (s)")
plt.grid(True)
plt.legend()

plt.tight_layout()
plt.show()
