In [1]:
%load_ext autoreload
%autoreload 2

import jaxfg
import flax
import jax
from jax import numpy as jnp
import numpy as onp
import matplotlib.pyplot as plt

import data
import networks
from trainer import Trainer

In [2]:
def make_model_params(model: networks.SimpleCNN):
    prng_key = jax.random.PRNGKey(0)
    dummy_image = onp.zeros((1, 120, 120, 3))
    return model.init(prng_key, dummy_image)

model = networks.SimpleCNN()

trainer = Trainer(experiment_name="test")
optimizer = flax.optim.Adam(learning_rate=1e-3).create(
    target=make_model_params(model)
)
optimizer = trainer.load_checkpoint(optimizer, step=5000)
None

[Trainer] Loaded checkpoint: was at step 0, now at 5000


In [3]:
from typing import List

trajectories: List[data.ToyDatasetStruct] = data.load_trajectories("data/toy_0.hdf5")

[TrajectoriesFile-data/toy_0.hdf5] Loading trajectory from file: <HDF5 file "toy_0.hdf5" (mode r)>
[TrajectoriesFile-data/toy_0.hdf5] Existing trajectory count: 1
[TrajectoriesFile-data/toy_0.hdf5] Opening file...
[TrajectoriesFile-data/toy_0.hdf5] Closing file...
(image) Mean, std dev: [0.14879851 0.262405   0.33178994] [1.3279926 1.4022821 1.7159437]
(position) Mean, std dev: [0.25458547 4.1511474 ] [1.0018286 8.166728 ]
(velocity) Mean, std dev: [-0.3679276 -1.0096759] [0.8446168 2.978547 ]


In [4]:
print(jnp.mean((
    model.apply(optimizer.target, trajectories[0].image)
    - trajectories[0].position
)**2 ))

43.88358


In [5]:
import celluloid
from IPython.display import HTML
from tqdm.auto import tqdm

def visualize_cnn_predictions(trajectory: data.ToyDatasetStruct) -> HTML:
    def predict_positions(images: jnp.ndarray):
        N = images.shape[0]
        assert images.shape == (N, 120, 120, 3)

        return data.ToyDatasetStruct(
            normalized=True,
            position=jax.jit(model.apply)(optimizer.target, images),
        ).unnormalize().position

    print("Predicting")
    positions_pred = predict_positions(trajectory.image)
    print("Visualizing")
    return visualize_trajectory(positions_pred=positions_pred, trajectory=trajectory)

def visualize_trajectory(
    positions_pred: jnp.ndarray,
    trajectory: data.ToyDatasetStruct,
) -> HTML:
    
    fig = plt.figure(figsize=(12,12))
    camera = celluloid.Camera(fig)

    positions_label = trajectory.unnormalize().position
    for i, image in enumerate(tqdm(trajectory.image)):
        plt.imshow(data.ToyDatasetStruct(
            normalized=True,
            image=image
        ).unnormalize().image.astype(onp.uint8))

        plt.scatter(x=positions_label[i, 0] + 60.0, y=positions_label[i, 1] + 60.0, c="#7f7", label="Label")
        plt.scatter(x=positions_pred[i, 0] + 60.0, y=positions_pred[i, 1] + 60.0, c="#f77", label="Prediction")

        if i == 0:
            legend = plt.legend()
            
        plt.plot(*(positions_label.T + 60.0), c="#7f7")
        plt.plot(*(positions_pred.T + 60.0), c="#f77")

        camera.snap()

    animation = camera.animate()
    plt.close(fig)
    return HTML(animation.to_html5_video())

visualize_cnn_predictions(trajectory=trajectories[0])

Predicting
Visualizing


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=50.0), HTML(value='')))




{'loss': 1.2524703496552547e-13}