In [1]:
%load_ext autoreload
%autoreload 2

import jaxfg
import flax
import jax
from jax import numpy as jnp
import numpy as onp
import matplotlib.pyplot as plt

import data
import networks
from trainer import Trainer

In [2]:
uncertainty_model, uncertainty_optimizer = networks.make_uncertainty_mlp()
uncertainty_optimizer = Trainer(
    experiment_name="initial-uncertainty"
).load_checkpoint(uncertainty_optimizer)

[Trainer] Loaded checkpoint: was at step 0, now at 1900


In [3]:
position_model, position_optimizer = networks.make_position_cnn()
position_optimizer = Trainer(experiment_name="overnight").load_checkpoint(position_optimizer)

[Trainer] Loaded checkpoint: was at step 0, now at 71900


In [22]:
from typing import List

trajectories: List[data.ToyDatasetStruct] = data.load_trajectories(train=False)
display_trajectory = trajectories[13]

[TrajectoriesFile-...tracking_val.hdf5] Loading trajectory from file: <HDF5 file "1hPujtHgYWWHyMikzGTvv1UpL3QrZfN1i-toy_tracking_val.hdf5" (mode r)>
[TrajectoriesFile-...tracking_val.hdf5] Existing trajectory count: 50
[TrajectoriesFile-...tracking_val.hdf5] Opening file...


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=50.0), HTML(value='')))


[TrajectoriesFile-...tracking_val.hdf5] Closing file...


In [23]:
import celluloid
from IPython.display import HTML
from tqdm.auto import tqdm
from typing import Optional

def predict_positions(images: jnp.ndarray):
    """Predict positions from images.
    
    Input is normalized image, output is unnormalized position.
    """
    N = images.shape[0]
    assert images.shape == (N, 120, 120, 3)

    return data.ToyDatasetStruct(
        normalized=True,
        position=jax.jit(position_model.apply)(position_optimizer.target, images),
    ).unnormalize().position

def visualize_cnn_predictions(trajectory: data.ToyDatasetStruct) -> HTML:
    print("Predicting")
    positions_pred = predict_positions(trajectory.image)
    print("Visualizing")
    return visualize_trajectory(trajectory=trajectory, positions_pred=positions_pred)

def visualize_trajectory(
    trajectory: data.ToyDatasetStruct,
    positions_pred: Optional[jnp.ndarray] = None,
) -> HTML:
    
    fig = plt.figure()#(figsize=(12,12))
    camera = celluloid.Camera(fig)

    positions_label = trajectory.unnormalize().position
    for i, image in enumerate(tqdm(trajectory.image)):
        plt.imshow(data.ToyDatasetStruct(
            normalized=True,
            image=image
        ).unnormalize().image.astype(onp.uint8), zorder=-2)
            
        color_label = "#7f7"
        color_pred = "#77f"
        plt.plot(*(positions_label.T + 60.0), c=color_label, linewidth=3, label="Label", zorder=-1)
        if positions_pred is not None:
            plt.plot(*(positions_pred.T + 60.0), c=color_pred, linewidth=3, label="Prediction", zorder=-1)
                
        if i == 0:
            legend = plt.legend()

        plt.scatter(x=positions_label[i, 0] + 60.0, y=positions_label[i, 1] + 60.0, c=color_label, label="Label", s=64, edgecolors='#fff', linewidth=2)
        if positions_pred is not None:
            plt.scatter(x=positions_pred[i, 0] + 60.0, y=positions_pred[i, 1] + 60.0, c=color_pred, label="Prediction", s=64, edgecolor='#fff', linewidth=2)

        camera.snap()

    print("Animating!")
    animation = camera.animate()
    plt.close(fig)
    return HTML(animation.to_html5_video())

visualize_cnn_predictions(trajectory=display_trajectory)

Predicting
Visualizing


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=20.0), HTML(value='')))


Animating!


In [24]:
import toy_system

# Make factor graph
def make_factor_graph(
    trajectory_length: int,
    include_dynamics: bool = True, # Set to False to disable dynamics for debugging
) -> jaxfg.core.PreparedFactorGraph:
    variables = []
    factors = []
    for t in range(trajectory_length):
        variables.append(toy_system.StateVariable())
        variables[-1]._timestep = t

        # Add perception constraint
        factors.append(
            toy_system.VisionFactor.make(
                state_variable=variables[-1],
                predicted_position=onp.zeros(2) + 0.8,  # To be populated by network
                scale_tril_inv=onp.identity(2),  # To be populated by network
            )
        )

        # Add dynamics constraint
        if not include_dynamics:
            factors.append(
                toy_system.DummyVelocityFactor.make(variables[-1])
            )
        elif t != 0:
            factors.append(
                toy_system.DynamicsFactor.make(
                    before_variable=variables[-2],
                    after_variable=variables[-1],
                )
            )

    return jaxfg.core.PreparedFactorGraph.from_factors(factors)

graph_template: jaxfg.core.PreparedFactorGraph = make_factor_graph(
    trajectory_length=len(trajectories[0].image)
)

In [25]:
import dataclasses
from typing import Tuple

def update_factor_graph(
    graph_template: jaxfg.core.PreparedFactorGraph,
    trajectory: data.ToyDatasetStruct,
    uncertainty_factor: jnp.ndarray,
) -> Tuple[jaxfg.core.PreparedFactorGraph, jaxfg.core.VariableAssignments]:
    """Update factor graph, and produce guess of initial assignments."""
    predicted_positions = predict_positions(trajectory.image)
    
    # Guess initial assignments
    assignments_dict = {}
    velocity_guesses = jnp.roll(predicted_positions, shift=-1, axis=0) - predicted_positions
    velocity_guesses = velocity_guesses.at[-1].set(velocity_guesses[-2])
    for i, variable in enumerate(graph_template.variables):
        assignments_dict[variable] = toy_system.State.make(
            position=predicted_positions[i],
            velocity=velocity_guesses[i],
        )
    initial_assignments = jaxfg.core.VariableAssignments.from_dict(assignments_dict)

    # Populate positions
    stacked_factors = list(graph_template.stacked_factors)
    stacked_vision_factor: toy_system.VisionFactor = stacked_factors[0]
    assert isinstance(stacked_vision_factor, toy_system.VisionFactor)
    assert predicted_positions.shape == stacked_vision_factor.predicted_position.shape

    uncertainty_factor = jnp.asarray(uncertainty_factor)
    stacked_factors[0] = dataclasses.replace(
        stacked_vision_factor,
        predicted_position=predicted_positions,
        scale_tril_inv=stacked_vision_factor.scale_tril_inv * jnp.reshape(uncertainty_factor, (-1, 1, 1)),
    )
    

    # Return new graph with new factors
    return dataclasses.replace(
        graph_template,
        stacked_factors=stacked_factors
    ), initial_assignments

In [26]:
graph, initial_assignments = update_factor_graph(
    graph_template,
    trajectory=display_trajectory,
    uncertainty_factor=1.0,
)
solved_assignments = graph.solve(initial_assignments)

# Verify that variables are in the order we expect
# This is just to make sure our storage reshape visualization method works
for i, variable in enumerate(graph_template.variables):
    assert variable._timestep == i

# Visualize smoothed trajectory
visualize_trajectory(
    trajectory=display_trajectory,
    positions_pred=solved_assignments.storage.reshape((-1, 4))[:, :2],
)

[GaussNewtonSolver] Starting solve with GaussNewtonSolver(cost_tolerance=1e-05, gradient_tolerance=1e-09, gradient_tolerance_start_step=10, parameter_tolerance=1e-07, inexact_step_eta=0.1, max_iterations=100, verbose=True), initial cost=2521.48583984375
[GaussNewtonSolver] Iteration #0: cost=1003.55505     
[GaussNewtonSolver] Iteration #1: cost=986.14343      
[GaussNewtonSolver] Iteration #2: cost=986.04724      
[GaussNewtonSolver] Iteration #3: cost=986.04645      
Terminating early!


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=20.0), HTML(value='')))


Animating!


In [27]:
graph, initial_assignments = update_factor_graph(
    graph_template,
    trajectory=display_trajectory,
    uncertainty_factor=0.09389641135931015,
)
solved_assignments = graph.solve(initial_assignments)

# Verify that variables are in the order we expect
# This is just to make sure our storage reshape visualization method works
for i, variable in enumerate(graph_template.variables):
    assert variable._timestep == i

# Visualize smoothed trajectory
visualize_trajectory(
    trajectory=display_trajectory,
    positions_pred=solved_assignments.storage.reshape((-1, 4))[:, :2],
)

[GaussNewtonSolver] Starting solve with GaussNewtonSolver(cost_tolerance=1e-05, gradient_tolerance=1e-09, gradient_tolerance_start_step=10, parameter_tolerance=1e-07, inexact_step_eta=0.1, max_iterations=100, verbose=True), initial cost=2521.48583984375
[GaussNewtonSolver] Iteration #0: cost=218.05492      
[GaussNewtonSolver] Iteration #1: cost=21.568502      
[GaussNewtonSolver] Iteration #2: cost=20.062626      
[GaussNewtonSolver] Iteration #3: cost=20.059406      
[GaussNewtonSolver] Iteration #4: cost=20.059385      
Terminating early!


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=20.0), HTML(value='')))


Animating!


In [28]:
# To-do:
# [X] Evaluate MSE of positions only on validation set
# [X] Evaluate MSE of hand-tuned smoothing
# [ ] Evaluate MSE of E2E tuned smoothing
# [ ] Evaluate MSE of learned noise model smoothing

In [29]:
def compute_vision_only_mse(trajectory: data.ToyDatasetStruct) -> float:
    """Compute position MSE for a single trajectory."""
    
    positions_predicted = predict_positions(trajectory.image)
    positions_label = trajectory.unnormalize().position

    N = positions_label.shape[0]
    assert positions_label.shape == positions_predicted.shape == (N, 2)

    mse = jnp.mean((positions_predicted - positions_label) ** 2)
    mse /= N

    return mse


# Verify that variables are in the order we expect
# This is just to make sure our storage reshape below works
for i, variable in enumerate(graph_template.variables):
    assert variable._timestep == i
    
def compute_smoother_mse(
    trajectory: data.ToyDatasetStruct,
    uncertainty_factor: float,
):
    graph, initial_assignments = update_factor_graph(
        graph_template,
        trajectory=trajectory,
        uncertainty_factor=uncertainty_factor,
    )
    solved_assignments = graph.solve(
        initial_assignments,
        solver=jaxfg.solvers.FixedIterationGaussNewtonSolver(max_iterations=5, verbose=False)
    )
    positions_predicted = solved_assignments.storage.reshape((-1, 4))[:, :2]
    positions_label = trajectory.unnormalize().position

    N = positions_label.shape[0]
    assert positions_label.shape == positions_predicted.shape == (N, 2)

    mse = jnp.mean((positions_predicted - positions_label) ** 2)
    mse /= N

    return mse

print(
     onp.mean(
        jax.vmap(compute_vision_only_mse)(jaxfg.utils.pytree_stack(*trajectories))
     )
)
print(
     onp.mean(
        jax.vmap(
            jax.partial(compute_smoother_mse, uncertainty_factor=1.0)
        )(jaxfg.utils.pytree_stack(*trajectories))
     )
)
print(
     onp.mean(
        jax.vmap(
            jax.partial(compute_smoother_mse, uncertainty_factor=0.11)
        )(jaxfg.utils.pytree_stack(*trajectories))
     )
)

4.34756
3.7289422
2.999746


In [30]:
def compute_variable_uncertainty_mse(
    trajectory: data.ToyDatasetStruct,
):
    uncertainty_factor = uncertainty_model.apply(
        uncertainty_optimizer.target,
        trajectory.visible_pixels_count.reshape((-1, 1))
    )
    
    graph, initial_assignments = update_factor_graph(
        graph_template,
        trajectory=trajectory,
        uncertainty_factor=uncertainty_factor,
    )
    solved_assignments = graph.solve(
        initial_assignments,
        solver=jaxfg.solvers.FixedIterationGaussNewtonSolver(max_iterations=5, verbose=False)
    )
    positions_predicted = solved_assignments.storage.reshape((-1, 4))[:, :2]
    positions_label = trajectory.unnormalize().position

    N = positions_label.shape[0]
    assert positions_label.shape == positions_predicted.shape == (N, 2)

    mse = jnp.mean((positions_predicted - positions_label) ** 2)
    mse /= N

    return mse

print(
     onp.mean(
        jax.vmap(
            compute_variable_uncertainty_mse
        )(jaxfg.utils.pytree_stack(*trajectories))
     )
)


0.29607016


In [31]:
uncertainty_factor = uncertainty_model.apply(
    uncertainty_optimizer.target,
    display_trajectory.visible_pixels_count.reshape((-1, 1))
)

graph, initial_assignments = update_factor_graph(
    graph_template,
    trajectory=display_trajectory,
    uncertainty_factor=uncertainty_factor,
)
solved_assignments = graph.solve(initial_assignments)

# Verify that variables are in the order we expect
# This is just to make sure our storage reshape visualization method works
for i, variable in enumerate(graph_template.variables):
    assert variable._timestep == i

# Visualize smoothed trajectory
visualize_trajectory(
    trajectory=display_trajectory,
    positions_pred=solved_assignments.storage.reshape((-1, 4))[:, :2],
)

[GaussNewtonSolver] Starting solve with GaussNewtonSolver(cost_tolerance=1e-05, gradient_tolerance=1e-09, gradient_tolerance_start_step=10, parameter_tolerance=1e-07, inexact_step_eta=0.1, max_iterations=100, verbose=True), initial cost=2521.48583984375
[GaussNewtonSolver] Iteration #0: cost=247.87997      
[GaussNewtonSolver] Iteration #1: cost=17.2975        
[GaussNewtonSolver] Iteration #2: cost=16.703444      
[GaussNewtonSolver] Iteration #3: cost=16.70243       
[GaussNewtonSolver] Iteration #4: cost=16.702425      
Terminating early!


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=20.0), HTML(value='')))


Animating!


In [32]:
uncertainty_factor = uncertainty_model.apply(
    uncertainty_optimizer.target,
    display_trajectory.visible_pixels_count.reshape((-1, 1))
)

graph, initial_assignments = update_factor_graph(
    graph_template,
    trajectory=display_trajectory,
    uncertainty_factor=0.11,
)
solved_assignments = graph.solve(initial_assignments)

# Verify that variables are in the order we expect
# This is just to make sure our storage reshape visualization method works
for i, variable in enumerate(graph_template.variables):
    assert variable._timestep == i

# Visualize smoothed trajectory
visualize_trajectory(
    trajectory=display_trajectory,
    positions_pred=solved_assignments.storage.reshape((-1, 4))[:, :2],
)

[GaussNewtonSolver] Starting solve with GaussNewtonSolver(cost_tolerance=1e-05, gradient_tolerance=1e-09, gradient_tolerance_start_step=10, parameter_tolerance=1e-07, inexact_step_eta=0.1, max_iterations=100, verbose=True), initial cost=2521.48583984375
[GaussNewtonSolver] Iteration #0: cost=220.9557       
[GaussNewtonSolver] Iteration #1: cost=27.713253      
[GaussNewtonSolver] Iteration #2: cost=26.442131      
[GaussNewtonSolver] Iteration #3: cost=26.440897      
[GaussNewtonSolver] Iteration #4: cost=26.440895      
Terminating early!


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=20.0), HTML(value='')))


Animating!


In [35]:
uncertainty_factor = uncertainty_model.apply(
    uncertainty_optimizer.target,
    display_trajectory.visible_pixels_count.reshape((-1, 1))
)

graph, initial_assignments = update_factor_graph(
    graph_template,
    trajectory=display_trajectory,
    uncertainty_factor=0.5,
)
solved_assignments = graph.solve(initial_assignments)

# Verify that variables are in the order we expect
# This is just to make sure our storage reshape visualization method works
for i, variable in enumerate(graph_template.variables):
    assert variable._timestep == i

# Visualize smoothed trajectory
visualize_trajectory(
    trajectory=display_trajectory,
    positions_pred=solved_assignments.storage.reshape((-1, 4))[:, :2],
)

[GaussNewtonSolver] Starting solve with GaussNewtonSolver(cost_tolerance=1e-05, gradient_tolerance=1e-09, gradient_tolerance_start_step=10, parameter_tolerance=1e-07, inexact_step_eta=0.1, max_iterations=100, verbose=True), initial cost=2521.48583984375
[GaussNewtonSolver] Iteration #0: cost=442.0009       
[GaussNewtonSolver] Iteration #1: cost=370.53314      
[GaussNewtonSolver] Iteration #2: cost=370.43616      
[GaussNewtonSolver] Iteration #3: cost=370.43607      
Terminating early!


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=20.0), HTML(value='')))


Animating!
