In [1]:
is_on_colab = False

if is_on_colab:
    from google.colab import drive
    drive.mount('/content/drive')
    

In [2]:
#%cd /content/drive/MyDrive/git/optimal_control_jax
#%pip install jaxopt tensor-canvas

In [3]:
import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' # prevent JAX from allocating all GPU memory

In [4]:
from functools import partial

import numpy as np
import pandas as pd
import math
import plotly.express as px
import matplotlib.pyplot as plt 
from ipywidgets import interact, interactive

import tensorflow as tf
from tensorflow.keras.models import Model

if not is_on_colab:
    tf.config.set_visible_devices([], 'GPU')   # workaround on mac M1 

In [6]:
import jax
from jax import jit
from jax import lax
from jax import vmap
import jax.numpy as jnp
from jax.experimental import jax2tf

from jax_control_algorithms.plot_helpers import plot_states, plot_output_comparison, plot_state_comparison
from load_video import load_dataset_single_video, load_video
from uncontrolled_pendulum import *
from pendulum_nn_models import *
from pendulum_nn_helper import *

jax.config.update('jax_enable_x64', True)

# Automatic training

In [10]:
pdf_scenarios = pd.DataFrame([
        {'scenario' : 'redapple',    'vfile' : 'pendulum_videos/redapple_480p.mov'},
        {'scenario' : 'zucchini',    'vfile' : 'pendulum_videos/zucchini_480p.mov'},
        {'scenario' : 'star',        'vfile' : 'pendulum_videos/star_480p.mov'},
        {'scenario' : 'tree1',       'vfile' : 'pendulum_videos/tree1_480p.mov'},
        {'scenario' : 'tree2',       'vfile' : 'pendulum_videos/tree2_480p.mov'},
        {'scenario' : 'leaf',        'vfile' : 'pendulum_videos/leaf_480p.mov'},
])

pdf_scenarios

Unnamed: 0,scenario,vfile
0,redapple,pendulum_videos/redapple_480p.mov
1,zucchini,pendulum_videos/zucchini_480p.mov
2,star,pendulum_videos/star_480p.mov
3,tree1,pendulum_videos/tree1_480p.mov
4,tree2,pendulum_videos/tree2_480p.mov
5,leaf,pendulum_videos/leaf_480p.mov


In [11]:
eps = [200, 200, 200, 200,   500, 500, 500, 500,   500]
#eps = [200, 200, 200, 200,   500, 500, 500, 500,   500, 500, 500]

pdf_training_tasks = pd.DataFrame([

        {
            'scenario' : 'redapple', 
            'task_id' : 1,

            'learning_rate' : 0.001,
            'lambda_ml' : 1.0,
            'lambda_mv' : 1.0,
            'lambda_pendulum_fit' : 1.0,
            'lambda_stability' : 0.01,
            
            'wy1' : 1.0, 'wy2' : 1.0, 'wx1' : 100.0, 'wx2' : 100.0, 
            
            'lambda_exp' : 0.0,
            
            'n_epochs' : eps
        },
        {
            'scenario' : 'zucchini', 
            'task_id' : 2,

            'learning_rate' : 0.001,
            'lambda_ml' : 1.0,
            'lambda_mv' : 1.0,
            'lambda_pendulum_fit' : 1.0,
            'lambda_stability' : 0.01,
            
            'wy1' : 1.0, 'wy2' : 1.0, 'wx1' : 100.0, 'wx2' : 100.0, 
            
            'lambda_exp' : 0.0,
            
            'n_epochs' : eps
        },
        {
            'scenario' : 'star', 
            'task_id' : 3,

            'learning_rate' : 0.001,
            'lambda_ml' : 1.0,
            'lambda_mv' : 1.0,
            'lambda_pendulum_fit' : 1.0,
            'lambda_stability' : 0.01,
            
            'wy1' : 1.0, 'wy2' : 1.0, 'wx1' : 100.0, 'wx2' : 100.0, 
            
            'lambda_exp' : 0.0,
            
            'n_epochs' : eps
        },
        {
            'scenario' : 'tree1', 
            'task_id' : 4,

            'learning_rate' : 0.001,
            'lambda_ml' : 1.0,
            'lambda_mv' : 1.0,
            'lambda_pendulum_fit' : 1.0,
            'lambda_stability' : 0.01,
            
            'wy1' : 1.0, 'wy2' : 1.0, 'wx1' : 100.0, 'wx2' : 100.0, 
            
            'lambda_exp' : 0.0,
            
            'n_epochs' : eps
        },
        {
            'scenario' : 'tree2', 
            'task_id' : 5,

            'learning_rate' : 0.001,
            'lambda_ml' : 1.0,
            'lambda_mv' : 1.0,
            'lambda_pendulum_fit' : 1.0,
            'lambda_stability' : 0.01,
            
            'wy1' : 1.0, 'wy2' : 1.0, 'wx1' : 100.0, 'wx2' : 100.0, 
            
            'lambda_exp' : 0.0,
            
            'n_epochs' : eps
        },
    
    
])

pdf_training_tasks

Unnamed: 0,scenario,task_id,learning_rate,lambda_ml,lambda_mv,lambda_pendulum_fit,lambda_stability,wy1,wy2,wx1,wx2,lambda_exp,n_epochs
0,redapple,1,0.001,1.0,1.0,1.0,0.01,1.0,1.0,100.0,100.0,2.0,"[200, 200, 200, 200, 500, 500, 500, 500, 500]"
1,zucchini,2,0.001,1.0,1.0,1.0,0.01,1.0,1.0,100.0,100.0,2.0,"[200, 200, 200, 200, 500, 500, 500, 500, 500]"
2,star,3,0.001,1.0,1.0,1.0,0.01,1.0,1.0,100.0,100.0,2.0,"[200, 200, 200, 200, 500, 500, 500, 500, 500]"
3,tree1,4,0.001,1.0,1.0,1.0,0.01,1.0,1.0,100.0,100.0,2.0,"[200, 200, 200, 200, 500, 500, 500, 500, 500]"
4,tree2,5,0.001,1.0,1.0,1.0,0.01,1.0,1.0,100.0,100.0,2.0,"[200, 200, 200, 200, 500, 500, 500, 500, 500]"


In [None]:
run_traing_tasks(pdf_scenarios, pdf_training_tasks, logfolder='trained_models/autorun_xx', jit_compile=False)