In [2]:
import os

import torch
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import pyro
import pyro.distributions as dist
import zuko
import numpy as np
from tqdm import tqdm
import IPython

sns.set(style="white", context="notebook", palette="deep", color_codes=True)

In [22]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

In [34]:
base_path = "data/uav_data"
elevator_failure_experiments = [
    "carbonZ_2018-09-11-14-41-51_elevator_failure",
    "carbonZ_2018-09-11-15-05-11_1_elevator_failure",
]
rudder_failure_experiments = [
    "carbonZ_2018-09-11-15-06-34_1_rudder_right_failure",
    "carbonZ_2018-09-11-15-06-34_2_rudder_right_failure",
    "carbonZ_2018-09-11-15-06-34_3_rudder_left_failure",
]
nominal_experiments = [
    "carbonZ_2018-07-18-16-37-39_1_no_failure",
    "carbonZ_2018-07-30-16-39-00_3_no_failure",
    "carbonZ_2018-09-11-14-16-55_no_failure",
    "carbonZ_2018-09-11-14-41-38_no_failure",
    "carbonZ_2018-09-11-15-05-11_2_no_failure",
    "carbonZ_2018-10-05-14-34-20_1_no_failure",
    "carbonZ_2018-10-05-14-37-22_1_no_failure",
    "carbonZ_2018-10-05-15-52-12_1_no_failure",
    "carbonZ_2018-10-05-15-52-12_2_no_failure",
    "carbonZ_2018-10-18-11-08-24_no_failure",
]
fields = {
    "mavros-local_position-velocity": {
        "field.twist.linear.x": "twist_vx",
        "field.twist.linear.y": "twist_vy",
        "field.twist.linear.z": "twist_vz",
        "field.twist.angular.x": "twist_wx",
        "field.twist.angular.y": "twist_wy",
        "field.twist.angular.z": "twist_wz",
    },
    "mavros-nav_info-pitch": {
        "field.commanded": "pitch_commanded",
        "field.measured": "pitch_measured",
    },
    "mavros-nav_info-roll": {
        "field.commanded": "roll_commanded",
        "field.measured": "roll_measured",
    },
    "mavros-nav_info-yaw": {
        "field.commanded": "yaw_commanded",
        "field.measured": "yaw_measured",
    },
}


def load_data(base_path, experiment_path, fields, dt=0.25):
    # Load all the dfs into a list, remapping columns to the names in field_names
    dfs = []
    for field_name, field_map in fields.items():
        path = os.path.join(
            base_path, experiment_path, experiment_path + "-" + field_name + ".csv"
        )
        df = pd.read_csv(path)
        df.rename(columns=field_map, inplace=True)
        dfs.append(df)

    # Get the min and max times
    min_time = min([df["%time"].min() for df in dfs]) * 1e-9
    max_time = max([df["%time"].max() for df in dfs]) * 1e-9

    # Normalize and resample time
    dt = 0.25
    t = np.arange(0, max_time - min_time, dt)
    normalized_dfs = []
    for df, field_map in zip(dfs, fields.values()):
        sampled_times = df["%time"] * 1e-9 - min_time
        normalized_df = pd.DataFrame(index=t, columns=field_map.values())
        normalized_df.index.name = "Time (s)"

        for field in field_map.values():
            # We have to treat the error status specially, since it's only reported
            # when a failure is occuring (and is implicitly zero otherwise)
            if "status" in field:
                normalized_df[field] = np.interp(t, sampled_times, df[field], left=0)
            else:
                normalized_df[field] = np.interp(t, sampled_times, df[field])

            # Handle angles to unwrap them
            if "roll" in field or "pitch" in field or "yaw" in field:
                normalized_df[field] = np.unwrap(normalized_df[field], period=360)

        normalized_dfs.append(normalized_df)

    # Merge all the dataframes into one
    df = pd.concat(normalized_dfs, axis=1, join="inner")

    return df


# Load the data
nominal_dfs = [
    load_data(os.path.join(base_path, "nominal"), experiment_path, fields, dt=0.25)
    for experiment_path in nominal_experiments
]
elevator_failure_dfs = [
    load_data(
        os.path.join(base_path, "failure"),
        experiment_path,
        fields | {"failure_status-elevator": {"field.data": "elevator_status"}},
        dt=0.25,
    )
    for experiment_path in elevator_failure_experiments
]
rudder_failure_dfs = [
    load_data(
        os.path.join(base_path, "failure"),
        experiment_path,
        fields | {"failure_status-rudder": {"field.data": "rudder_status"}},
        dt=0.25,
    )
    for experiment_path in rudder_failure_experiments
]


# Convert the list of DFs into couple of lists of tensors
def df_to_tensors(df):
    roll = df["roll_measured"].to_numpy() * np.pi / 180
    pitch = df["pitch_measured"].to_numpy() * np.pi / 180
    yaw = df["yaw_measured"].to_numpy() * np.pi / 180

    roll_desired = df["roll_commanded"].to_numpy() * np.pi / 180
    pitch_desired = df["pitch_commanded"].to_numpy() * np.pi / 180
    yaw_desired = df["yaw_commanded"].to_numpy() * np.pi / 180

    p = df["twist_wx"].to_numpy()
    q = -df["twist_wy"].to_numpy()
    r = -df["twist_wz"].to_numpy()

    states = torch.tensor(
        np.stack([roll, pitch, yaw], axis=1),
        dtype=torch.float32,
        device=device,
    )
    initial_states = states[0]
    pqr = torch.tensor(
        np.stack([p, q, r], axis=1),
        dtype=torch.float32,
        device=device,
    )
    desired_states = torch.tensor(
        np.stack([roll_desired, pitch_desired, yaw_desired], axis=1),
        dtype=torch.float32,
        device=device,
    )

    return initial_states, states, pqr, desired_states


nominal_data = [df_to_tensors(df) for df in nominal_dfs]
nominal_data = tuple(map(list, zip(*nominal_data)))
(
    nominal_initial_states,
    nominal_observed_states,
    nominal_observed_pqr,
    nominal_commands,
) = nominal_data

elevator_failure_data = [df_to_tensors(df) for df in elevator_failure_dfs]
elevator_failure_data = tuple(map(list, zip(*elevator_failure_data)))
(
    elevator_failure_initial_states,
    elevator_failure_observed_states,
    elevator_failure_observed_pqr,
    elevator_failure_commands,
) = elevator_failure_data

rudder_failure_data = [df_to_tensors(df) for df in rudder_failure_dfs]
rudder_failure_data = tuple(map(list, zip(*rudder_failure_data)))
(
    rudder_failure_initial_states,
    rudder_failure_observed_states,
    rudder_failure_observed_pqr,
    rudder_failure_commands,
) = rudder_failure_data

In [75]:
def model(
    initial_state,
    commands,
    observed_states=None,
    observed_pqr=None,
    dt=0.25,
    observation_noise_scale=1e-1,
):
    """Define a simplified model for the UAV attitude dynamics.

    Args:
        initial_state: List of initial states of the UAV for each run in the batch.
        commands: List of commands to the UAV for each run in the batch.
        observed_states: List of observed states of the UAV for each run in the batch.
        observed_pqr: List of observed angular velocities  for each run in the batch.
        dt: Time step.
    """
    # Check consistency of batch and time dimensions
    N = len(initial_state)
    for i in range(N):
        assert initial_state[i].shape[0] == 3
        T = commands[i].shape[0]
        assert commands[i].shape == (T, 3)
        assert observed_pqr[i] is None or observed_pqr[i].shape == (T, 3)
        assert observed_states[i] is None or observed_states[i].shape == (T, 3)

    # Use attitude dyamics with state x = [phi, theta, psi] (roll, pitch, yaw)
    # dx/dt = J^-1 * (Ax + Ke + d + eta)
    # where J^-1 is the kinematics, A is the state-to-state transfer matrix, K is the
    # error-to-state transfer matrix, e is the error, d is a constant bias, and eta is
    # Gaussian noise.

    # Sample the matrices from the prior
    A = pyro.sample(
        "A",
        dist.Normal(torch.zeros(3, 3, device=device), torch.ones(3, 3, device=device)),
    )
    K = pyro.sample(
        "K",
        dist.Normal(torch.zeros(3, 3, device=device), torch.ones(3, 3, device=device)),
    )
    d = pyro.sample(
        "d", dist.Normal(torch.zeros(3, device=device), torch.ones(3, device=device))
    ).reshape(3, 1)
    log_noise_strength = pyro.sample(
        "log_noise_strength", dist.Normal(torch.tensor(-2.0, device=device), 1.0)
    )
    noise_strength = torch.exp(log_noise_strength)

    # Start the simulation with an initial state
    Ts = [x.shape[0] for x in commands]
    states = [torch.zeros(T, 3, device=device) for T in Ts]
    state_observation_noise = [torch.zeros(T, 3, device=device) for T in Ts]
    pqrs = [torch.zeros(T, 3, device=device) for T in Ts]
    action_noise_trajectory = [torch.zeros(T, 3, device=device) for T in Ts]

    for i in range(N):
        for t in range(1, T + 1):
            state = states[i][t - 1].reshape(3, 1)
            command = commands[i][t - 1].reshape(3, 1)

            # Compute the error
            e = command - state

            # Get the mean velocity based on the system matrices
            pqr_mean = A @ state + K @ e + d
            pqrs[i][t - 1] = pqr_mean.reshape(3)
            # Add noise
            pqr = pyro.sample(
                f"pqr_{i}_t{t}",
                dist.Normal(
                    pqr_mean, noise_strength * torch.ones_like(pqr_mean)
                ).to_event(2),
                obs=observed_pqr[i][t - 1].reshape(-1, 3, 1)
                if observed_pqr[i] is not None
                else None,
            )
            action_noise_trajectory[i][t - 1] = (pqr - pqr_mean).reshape(3)

            # Only update the dynamics if we're not on the last step
            if t == T:
                continue

            # Construct the kinematic matrix
            roll, pitch = state[0], state[1]
            Jinv = torch.zeros(3, 3, device=device)
            Jinv[0, 0] = 1.0
            Jinv[0, 1] = torch.tan(pitch) * torch.sin(roll)
            Jinv[0, 2] = torch.tan(roll) * torch.cos(pitch)
            Jinv[1, 1] = torch.cos(roll)
            Jinv[1, 2] = -torch.sin(roll)
            Jinv[2, 1] = torch.sin(roll) / torch.cos(pitch)
            Jinv[2, 2] = torch.cos(roll) / torch.cos(pitch)

            # Integrate the change in state
            next_state = state + dt * Jinv @ pqr
            states[i][t] = next_state.reshape(3)
            observed_state = pyro.sample(
                f"state_{i}_t{t}",
                dist.Normal(
                    next_state, observation_noise_scale * torch.ones_like(next_state)
                ).to_event(1),
                obs=observed_states[i][t].reshape(next_state.shape)
                if observed_states[i] is not None
                else None,
            )
            state_observation_noise[i][t] = (observed_state - next_state).reshape(3)

    return states, pqrs, state_observation_noise, action_noise_trajectory


# Try the model on the nominal data
sim_states, sim_pqrs, sim_obs_noise, sim_action_noise = model(
    nominal_initial_states,
    nominal_commands,
    [None] * len(nominal_commands),  # nominal_observed_states,
    [None] * len(nominal_commands),  # nominal_observed_pqr,
    dt=0.25,
    observation_noise_scale=1e-1,
)

In [76]:
# Define loss functions
def elbo_loss(model, guide, context, num_particles=10, *model_args, **model_kwargs):
    elbo = torch.tensor(0.0).to(context.device)
    guide_dist = guide(context)
    for _ in range(num_particles):
        posterior_sample, posterior_logprob = guide_dist.rsample_and_log_prob()

        # Parse the sample into the correct shapes
        A = posterior_sample[:9].reshape(3, 3)
        K = posterior_sample[9:18].reshape(3, 3)
        d = posterior_sample[18:21].reshape(3)
        log_noise_strength = posterior_sample[21]

        model_trace = pyro.poutine.trace(
            pyro.poutine.condition(
                model,
                data={
                    "A": A,
                    "K": K,
                    "d": d,
                    "log_noise_strength": log_noise_strength,
                },
            )
        ).get_trace(*model_args, **model_kwargs)
        model_logprob = model_trace.log_prob_sum()

        elbo += (model_logprob - posterior_logprob) / num_particles

    return -elbo  #  negative to make it a loss


def kl_divergence(p, q, p_contexts, q_contexts, num_particles=10):
    # Make sure contexts have a batch dimension
    if p_contexts.ndim == 1:
        p_contexts = p_contexts.unsqueeze(0)

    if q_contexts.ndim == 1:
        q_contexts = q_contexts.unsqueeze(0)

    # Make sure contexts have the same shape
    if p_contexts.shape != q_contexts.shape:
        raise ValueError("Contexts must have the same shape")

    p_dist = p(p_contexts)
    q_dist = q(q_contexts)

    p_samples, p_logprobs = p_dist.rsample_and_log_prob((num_particles,))
    q_logprobs = q_dist.log_prob(p_samples)

    kl_divergence = (p_logprobs - q_logprobs).mean(dim=0)

    return kl_divergence

In [80]:
# Test the ELBO function
with torch.no_grad():
    flow = zuko.flows.NSF(
        features=2 * 3 * 3 + 3 + 1,
        context=1,
        transforms=2,
        hidden_features=(16, 16),
    ).to(device)
    context = torch.tensor([0.0], device=device)
    print(
        elbo_loss(
            model,
            flow,
            context,
            10,
            initial_state=nominal_initial_states,
            commands=nominal_commands,
            observed_states=nominal_observed_states,
            observed_pqr=nominal_observed_pqr,
            dt=0.25,
            observation_noise_scale=1e-1,
        )
    )