# Neural ODE

In [None]:
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torchdyn.core as tdcore
import torchdyn.nn as tdnn
import torchdyn.utils as tdutils
from torch.utils.tensorboard import SummaryWriter
from tqdm.notebook import tqdm

%load_ext tensorboard

In [None]:
seed = 42
torch.manual_seed(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dry_run = False
epochs_default = 200

### Dataset

In [None]:
def evolve_unicycle(x: torch.Tensor, dt: float) -> torch.Tensor:
    # x = [x y theta v omega]
    y = x[:, 0:3] + torch.column_stack((
        x[:, 3] * torch.cos(x[:, 2]) * dt,
        x[:, 3] * torch.sin(x[:, 2]) * dt,
        x[:, 4] * dt
    ))
    
    return y

In [None]:
n_dataset = 1000
dt = 0.1

# x = [x y theta v omega]
x = ( torch.rand(n_dataset, 5, dtype=torch.float32, device=device) + \
    + torch.tensor([-0.5, -0.5, 0, -0.5, -0.5], device=device)) \
    * torch.tensor([10.0, 10.0, 2*3.14, 1.0, 1.0], device=device)
y = evolve_unicycle(x, dt)

train = torch.utils.data.TensorDataset(x, y)
train_loader = torch.utils.data.DataLoader(train, batch_size=32, shuffle=True)

### Vanilla Neural ODE

In [None]:
# ODE state is s = [x, y, theta, v, omega] (5D).
# We learn the time-derivative ds/dt. We *keep controls constant* by zeroing dv/dt, domega/dt.
class UnicycleVF(nn.Module):
    def __init__(self):
        super().__init__()
        # small MLP that outputs [dx, dy, dtheta] from current state
        self.mlp = nn.Sequential(
            nn.Linear(5, 64), nn.Tanh(),
            nn.Linear(64, 32), nn.Tanh(),
            nn.Linear(32, 3)
        )
    
    def forward(self, t, s, *args, **kwargs):
        deriv_pose = self.mlp(s)
        zeros_ctrl = torch.zeros(s.size(0), 2, device=s.device, dtype=s.dtype)
        return torch.cat([deriv_pose, zeros_ctrl], dim=1)

In [None]:
def build_node():
    return tdcore.NeuralODE(
        UnicycleVF(), solver='rk4', sensitivity='autograd',
    ).to(device)

In [None]:
writer = SummaryWriter()

def angular_difference(pred, target):
    diff = pred - target
    # Wrap into [-pi, pi]
    return (diff + torch.pi) % (2 * torch.pi) - torch.pi

def train_model(model, t_span, epochs=epochs_default, lr=1e-3):
    model.to(device)
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    model.train()
    with tqdm(range(epochs if not dry_run else 1)) as pbar:
        running = 0.0
        for ep in range(1 if dry_run else epochs):
            pbar.update(1)
            running = 0.0
            for xb, yb in train_loader:
                # xb: (B,5) initial state; yb: (B,3) target next pose
                t_eval, traj = model(xb, t_span)       # traj: (T, B, 5)
                y_pred = traj[-1][:, :3]
                loss = torch.nn.functional.huber_loss(y_pred[:, 0:2], yb[:, 0:2]) + \
                    torch.nn.functional.huber_loss(
                        angular_difference(y_pred[:, 2], yb[:, 2]),
                        torch.zeros_like(yb[:, 2])
                    )
                opt.zero_grad()
                loss.backward()
                opt.step()
                running += loss.item()
                
            error_loss = running / len(train_loader)
            writer.add_scalar('train/mse', error_loss, ep)
            
    writer.flush()
    return model

In [None]:
@torch.no_grad()
def quick_eval_plot(model, t_span):
    model.eval()
    t_eval, traj = model(x, t_span)
    y_hat = traj[-1][:, :3].cpu() - x[:, :3].cpu()
    y_true = y.cpu() - x[:, :3].cpu()
    plt.figure(figsize=(5,4))
    plt.scatter(y_true[0:,0], y_true[0:,1], s=6, label='true', alpha=0.6)
    plt.scatter(y_hat[0:,0],  y_hat[0:,1],  s=6, label='pred', alpha=0.6)
    plt.legend()
    plt.xlabel(r"$\Delta x$ [m]")
    plt.ylabel(r"$\Delta y$ [m]")
    plt.title(r"$(\Delta x_{k+1}, \Delta y_{k+1})$ â€“ true vs pred")
    plt.tight_layout()
    plt.show()

In [None]:
from IPython.display import display, Markdown

port = 4242
display(Markdown(f"**TensorBoard running at:** [http://localhost:{port}](http://localhost:{port})"))

%tensorboard --reload_interval 10 --logdir runs --port {port}

In [None]:
t_span = torch.linspace(0.0, dt, 2, device=device)  # integrate over real dt
model = build_node()
train_model(model, t_span, epochs=100 if not dry_run else 1, lr=1e-2)

In [None]:
quick_eval_plot(model, t_span)

In [None]:
model.eval()
with torch.no_grad():
    t_span = torch.linspace(0.0, dt, 2, device=device)
    x_0 = torch.tensor([[1.0, 1.0, 3.14, 0.5, 0.0]], device=device)  # example input
    _, traj = model(x_0, t_span)   # x_0 is your (B,5) tensor
    y_pred = traj[-1][:, :3]     # final [x', y', theta']
print(y_pred)

y_true = evolve_unicycle(x_0, dt)[:, :3]
print(y_true)