In [1]:
import jax
import jax.numpy as jnp
import numpy as np
from jax.experimental.ode import odeint
import matplotlib.pyplot as plt
from functools import partial # reduces arguments to function by making some subset implicit

from jax.experimental import stax
from jax.experimental import optimizers

# visualization
import matplotlib.pyplot as plt
from matplotlib.patches import Circle
from moviepy.editor import ImageSequenceClip
from functools import partial
import proglog
from PIL import Image

from LNN_functions import *
from get_data import *

In [3]:
lagrangian

<function LNN_functions.lagrangian(q, q_dot, m1, m2, l1, l2, g)>

In [2]:
parse_testing_annotations

<function get_data.parse_testing_annotations(csv_file_X, csv_file_Y)>

In [8]:
time_step = 0.01
N = 1500
analytical_step = jax.jit(jax.vmap(partial(rk4_step, f_analytical, t=0.0, h=time_step)))

# x0 = np.array([-0.3*np.pi, 0.2*np.pi, 0.35*np.pi, 0.5*np.pi], dtype=np.float32)
x0 = np.array([3*np.pi/7, 3*np.pi/4, 0, 0], dtype=np.float32)
t = np.arange(N, dtype=np.float32) # time steps 0 to N
%time x_train = jax.device_get(solve_analytical(x0, t)) # dynamics for first N time steps
%time xt_train = jax.device_get(jax.vmap(f_analytical)(x_train)) # time derivatives of each state
%time y_train = jax.device_get(analytical_step(x_train)) # analytical next step
print(x_train)

noise = np.random.RandomState(0).randn(x0.size)
t_test = np.arange(N, 2*N, dtype=np.float32) # time steps N to 2N
%time x_test = jax.device_get(solve_analytical(x0, t_test)) # dynamics for next N time steps
%time xt_test = jax.device_get(jax.vmap(f_analytical)(x_test)) # time derivatives of each state
%time y_test = jax.device_get(analytical_step(x_test)) # analytical next step

CPU times: user 7.71 s, sys: 878 ms, total: 8.59 s
Wall time: 8.14 s
CPU times: user 424 ms, sys: 394 ms, total: 817 ms
Wall time: 1.71 s
CPU times: user 836 ms, sys: 20.1 ms, total: 856 ms
Wall time: 1.16 s
[[ 1.3463968   2.3561945   0.          0.        ]
 [-0.42738873 -1.1811658  -4.6661725  -0.63316566]
 [-1.246472   -0.41255212  4.414217   -0.6399282 ]
 ...
 [-1.5063784  42.331657    2.8762856  -5.039038  ]
 [ 0.21640965 45.60818    -2.4176426   6.5494537 ]
 [-0.4651543  52.325962    4.988856    6.28838   ]]
CPU times: user 6.62 s, sys: 0 ns, total: 6.62 s
Wall time: 6.58 s
CPU times: user 24 ms, sys: 0 ns, total: 24 ms
Wall time: 18.8 ms
CPU times: user 938 µs, sys: 0 ns, total: 938 µs
Wall time: 494 µs


In [3]:
# training data
train_dir = '/raid/cs152/zxaa2018/penndulum/train_and_test_split/dpc_dataset_traintest_4_200_csv/train'
train_dir_video = '/raid/cs152/zxaa2018/penndulum//train_and_test_split/dpc_dataset_traintest_4_200_h264/train'

# test data
test_inputs_dir = '/raid/cs152/zxaa2018/penndulum/train_and_test_split/dpc_dataset_traintest_4_200_csv/test_inputs/'
test_targets_dir = '/raid/cs152/zxaa2018/penndulum/train_and_test_split/dpc_dataset_traintest_4_200_csv/test_targets/'
test_targets_video = '/raid/cs152/zxaa2018/penndulum/train_and_test_split/dpc_dataset_traintest_4_200_h264/test_targets/'

# validation data
validation_inputs_dir = '/raid/cs152/zxaa2018/penndulum/train_and_test_split/dpc_dataset_traintest_4_200_csv/validation_inputs/'
validation_targets_dir = '/raid/cs152/zxaa2018/penndulum/rain_and_test_split/dpc_dataset_traintest_4_200_csv/validation_targets/'
validation_targets_video = '/raid/cs152/zxaa2018/penndulum/train_and_test_split/dpc_dataset_traintest_4_200_h264/validation_targets/'

In [4]:
# some constants
DEFAULT_X_RED, DEFAULT_Y_RED = (240, 232)

PIXEL_DISTANCE_GREEN_TO_RED = 118 # approx. value | calculated with the Pythagorean theorem and averaged: np.sqrt((y_green-y_red)**2 + (x_green-x_red)**2)
PIXEL_DISTANCE_BLUE_TO_GREEN = 90 # approx. value | calculated with the Pythagorean theorem and averaged: np.sqrt((y_blue-y_green)**2 + (x_blue-x_green)**2)


In [5]:
BATCH_SIZE = 1000

# load in all separate files
X_train = []
y_train = []
for filename in tqdm([x for x in os.listdir(train_dir) if not x.startswith('.')]):
    # load in a file
    X_data, y_data = parse_training_annotations(os.path.join(train_dir, filename))

    X_train = X_train + X_data
    y_train = y_train + y_data

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 40/40 [00:35<00:00,  1.12it/s]


In [6]:
print(X_train[0:10])

[[[-0.5921029675006285, 0.8058623181890003, 0.28084609155505114, 0.9597528186248019], [-0.6234520062909286, 0.7818616219330733, 0.20891963585531031, 0.9779328124948485], [-0.6507955605498628, 0.7592530134076453, 0.13955877171118825, 0.9902137896628508], [-0.6784961933963939, 0.7346039174593362, 0.07093120900555809, 0.9974812096420713]], [[-0.6234520062909286, 0.7818616219330733, 0.20891963585531031, 0.9779328124948485], [-0.6507955605498628, 0.7592530134076453, 0.13955877171118825, 0.9902137896628508], [-0.6784961933963939, 0.7346039174593362, 0.07093120900555809, 0.9974812096420713], [-0.7062616489257983, 0.7079509045524359, 0.0, 1.0]], [[-0.6507955605498628, 0.7592530134076453, 0.13955877171118825, 0.9902137896628508], [-0.6784961933963939, 0.7346039174593362, 0.07093120900555809, 0.9974812096420713], [-0.7062616489257983, 0.7079509045524359, 0.0, 1.0], [-0.7346120694155456, 0.6784873672140179, -0.06651827470213303, 0.9977852069111125]], [[-0.6784961933963939, 0.7346039174593362, 0.0

In [9]:
%time xt_train = jax.device_get(jax.vmap(f_analytical)(X_train)) # time derivatives of each state
%time y_train = jax.device_get(analytical_step(y_train))

ValueError: vmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ())

CPU times: user 1.01 ms, sys: 134 µs, total: 1.15 ms
Wall time: 580 µs
