In [2]:
import os

import torch
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import pyro
import pyro.distributions as dist
import zuko
import numpy as np
from tqdm import tqdm
import IPython

sns.set(style="white", context="notebook", palette="deep", color_codes=True)

In [None]:
# TODO load data from various failures

In [88]:
def model(initial_state, commands, observed_states=None, observed_pqr=None, dt=0.25, observation_noise_scale=1e-1):
    """Define a simplified model for the UAV attitude dynamics.

    Args:
        initial_state: Initial state of the UAV.
        commands: Commands to the UAV.
        observed_states: Observed states of the UAV.
        observed_pqr: Observed angular velocities of the UAV.
        dt: Time step.
    """
    # Check consistency of batch and time dimensions
    N = commands.shape[0]
    T = commands.shape[1]
    assert initial_state.shape[0] == N
    assert observed_pqr is None or observed_pqr.shape[0] == N
    assert observed_pqr is None or observed_pqr.shape[1] == T
    assert observed_states is None or observed_states.shape[0] == N
    assert observed_states is None or observed_states.shape[1] == T

    # Use attitude dyamics with state x = [phi, theta, psi] (roll, pitch, yaw)
    # dx/dt = J^-1 * (Ax + Ke + d + eta)
    # where J^-1 is the kinematics, A is the state-to-state transfer matrix, K is the
    # error-to-state transfer matrix, e is the error, d is a constant bias, and eta is
    # Gaussian noise.

    # Sample the matrices from the prior
    A = pyro.sample("A", dist.Normal(torch.zeros(3, 3), torch.ones(3, 3)))
    K = pyro.sample("K", dist.Normal(torch.zeros(3, 3), torch.ones(3, 3)))
    d = pyro.sample("d", dist.Normal(torch.zeros(3), torch.ones(3))).reshape(3, 1)
    log_noise_strength = pyro.sample("log_noise_strength", dist.Normal(-2, 1.0))
    noise_strength = torch.exp(log_noise_strength)

    # Start the simulation with an initial state
    states = torch.zeros(N, T, 3)
    state_observation_noise = torch.zeros(N, T, 3)
    pqrs = torch.zeros(N, T, 3)
    action_noise_trajectory = torch.zeros(N, T, 3)
    
    for t in range(1, T + 1):
        state = states[:, t - 1].reshape(-1, 3, 1)
        command = commands[:, t - 1].reshape(-1, 3, 1)

        # Compute the error
        e = command - state

        # Get the mean velocity based on the system matrices
        pqr_mean = A @ state + K @ e + d
        pqrs[:, t - 1] = pqr_mean.reshape(-1, 3)
        # Add noise
        pqr = pyro.sample(
            f"pqr_{t}",
            dist.Normal(pqr_mean, noise_strength * torch.ones_like(pqr_mean)).to_event(2),
            obs=observed_pqr[:, t - 1].reshape(-1, 3, 1) if observed_pqr is not None else None,
        )
        action_noise_trajectory[:, t - 1] = (pqr - pqr_mean).reshape(-1, 3)

        # Only update the dynamics if we're not on the last step
        if t == T:
            continue

        # Construct the kinematic matrix 
        roll, pitch = state[:, 0], state[:, 1]
        Jinv = torch.zeros(N, 3, 3)
        Jinv[:, 0, 0] = 1.0
        Jinv[:, 0, 1] = torch.tan(pitch) * torch.sin(roll)
        Jinv[:, 0, 2] = torch.tan(roll) * torch.cos(pitch)
        Jinv[:, 1, 1] = torch.cos(roll)
        Jinv[:, 1, 2] = -torch.sin(roll)
        Jinv[:, 2, 1] = torch.sin(roll) / torch.cos(pitch)
        Jinv[:, 2, 2] = torch.cos(roll) / torch.cos(pitch)

        # Integrate the change in state
        next_state = state + dt * Jinv @ pqr
        next_state = next_state.reshape(-1, 3)
        states[:, t] = next_state
        observed_state = pyro.sample(
            f"state_{t}",
            dist.Normal(next_state, observation_noise_scale * torch.ones_like(next_state)).to_event(1),
            obs=observed_states[:, t] if observed_states is not None else None,
        )
        state_observation_noise[:, t] = (observed_state - next_state).reshape(-1, 3)

    return states, pqrs, state_observation_noise, action_noise_trajectory


model(torch.zeros(1, 3), torch.zeros(1, 2, 3), None, None, 0.25)

(tensor([[[ 0.0000,  0.0000,  0.0000],
          [-0.1232, -0.0953, -0.1550]]]),
 tensor([[[-0.5274, -0.5219, -0.7345],
          [-1.0774, -0.7063, -0.6193]]]),
 tensor([[[ 0.0000,  0.0000,  0.0000],
          [-0.0803,  0.0579,  0.2720]]]),
 tensor([[[ 0.0344,  0.1407,  0.1145],
          [-0.0763, -0.0358,  0.0572]]]))