# GCP-HOLO Gradio Interface (In-Progress)

## Imports

In [1]:
import gradio as gr
from PIL import Image, ImageFilter
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import collections  as mc
import pickle
import sys
from matplotlib.animation import FuncAnimation
import matplotlib.patheffects as pe
from matplotlib.offsetbox import AnchoredText
from IPython import display
plt.rcParams['animation.ffmpeg_path'] = '/usr/bin/ffmpeg'

# adding subfolder to the system path
sys.path.insert(0, '/home/mitch/home/gcp_holo/linkage_gym/envs')
 
sys.path.insert(0, '/home/mitch/home/gcp_holo/linkage_gym/utils')
from env_utils import normalize_curve, uniquify

sys.path.insert(0, '/home/mitch/home/gcp_holo/')
from train import main

sys.path.insert(0, '/home/mitch/home/gcp_holo/linkage_gym/envs')
from Mech import Mech

from os import walk
import os
from datetime import datetime



In [2]:
# %load_ext tensorboard
# %tensorboard --logdir logs/ --port 6008

## Global Variables

In [3]:
# A bunch of global variables, which is bad practice, but I'm lazy
curve_names = ["Jansen", "Klann", "Strider", "Trot", "Dualloop", "Figure Eight", "Hour Glass", "Letter B", "Moon", "Outer Trifoil", "Triangle", "Trifoil"]
curve_paths = ["../data/other_curves/jansen_traj.pkl", "../data/other_curves/klann_traj.pkl", "../data/other_curves/strider_traj.pkl", "../data/other_curves/trot_traj.pkl", "../data/test_curves/dualloop.pkl", "../data/test_curves/eight.pkl", "../data/test_curves/hourglass.pkl", "../data/test_curves/letterB.pkl", "../data/test_curves/moon.pkl", "../data/test_curves/outerTrifoil.pkl", "../data/test_curves/triangle.pkl", "../data/test_curves/trifoil.pkl"]
name_to_path_dict = dict(zip(curve_names, curve_paths))

curve_id = ["jansen_traj", "klann_traj", "strider_traj", "trot_traj", "dualloop", "eight", "hourglass", "letterB", "moon", "outerTrifoil", "triangle", "trifoil"]
name_to_id_dict = dict(zip(curve_names, curve_id))

curve_dir = ["../data/other_curves", "../data/other_curves", "../data/other_curves", "../data/other_curves", "../data/test_curves", "../data/test_curves", "../data/test_curves", "../data/test_curves", "../data/test_curves", "../data/test_curves", "../data/test_curves", "../data/test_curves"]
name_to_dir_dict = dict(zip(curve_names, curve_dir))

param_names = ["sample_points", "max_nodes", "resolution", "feature_points", "fixed_initial_state", "seed", "model", "steps", "n_eval_episodes", "n_envs", "update_freq", "opt_iter", "gamma", "lr", "batch_size", "eps_clip", "ent_coef", "buffer_size", "log_freq", "save_freq", "cuda", "no_train", "cmaes"]
param_default_values = [int(20), int(11), int(11), int(1), True, int(123), "PPO", int(50000), int(50000), int(1), int(1000), int(1), 0.99, 0.0001, int(1000), 0.2, 0.01, int(1000000), int(10000), int(100000), "cpu", False, False]
default_values_dict = dict(zip(param_names, param_default_values))


full_param_names = ['max_nodes', 'resolution', 'bound', 'sample_points', 'feature_points', 'goal_filename', 'goal_path', 'use_self_loops', 'normalize', 'use_node_type', 'fixed_initial_state', 'seed', 'ordered', 'body_constraints', 'coupler_constraints', 'use_gnn', 'batch_normalize', 'model', 'n_envs', 'checkpoint', 'update_freq', 'opt_iter', 'eps_clip', 'ent_coef', 'gamma', 'lr', 'batch_size', 'buffer_size', 'steps', 'num_trials', 'n_eval_episodes', 'm_evals', 'log_freq', 'save_freq', 'wandb_mode', 'wandb_project', 'verbose', 'cuda', 'no_train', 'cmaes']

full_param_values = [11, 11, 1.0, 20, 1, 'jansen_traj', 'data/other_curves', False, False, False, True, 123, True, None, None, True, True, 'PPO', 1, None, 1000, 1, 0.2, 0.01, 0.99, 0.0001, 1000, 1000000, 50000, 1, 100, 1, 1000, 10000, 'disabled', 'linkage_sb4', 0, 'cpu', False, False]

full_default_value_dict = dict(zip(full_param_names, full_param_values))

video_file_dict = dict(zip(["1 cycle", "2 cycle", "3 cycle", "4 cycle", "5 cycle"], [None, None, None, None, None]))
video_file_dict_cma = dict(zip(["1 cycle", "2 cycle", "3 cycle", "4 cycle", "5 cycle"], [None, None, None, None, None]))

## Helper Functions

In [4]:
def animateWalker(paths, edges, goal, frames, filename="linkage_video.mp4"):
    """
    Creates a video animation of the linkage mechanism
    
        Args:
            paths: a 3D numpy array of shape (n_links, 2, n_frames) containing the x and y coordinates of the nodes of the mechanism
            edges: a 2D numpy array of shape (n_links, 2) containing the indices of the nodes that are connected by each link
            goal: a 2D numpy array of shape (2, n_frames) containing the x and y coordinates of the goal
            frames: the number of frames in the animation
            filename: the name of the file to save the animation to
        
        Returns:
            filename: the name of the file to save the animation to
    """
    
    # Set up the figure
    xlow=np.min(paths[:,0,:])-0.1
    xhigh=np.max(paths[:,0,:])+0.1
    ylow=np.min(paths[:,1,:])-0.1
    yhigh=np.max(paths[:,1,:])+0.1
    fig, ax1 = plt.subplots(figsize=(4, 4))
    fig.set_facecolor('white')
    # ax = fig.add_subplot(121)
    # ax1.set_xlim(xlow, xhigh)
    # ax1.set_ylim(ylow, yhigh)
    ax1.set_xlim(xlow, xhigh)
    ax1.set_ylim(ylow, yhigh)
    ax1.set_aspect('equal')
    ax1.set_axis_off()
    
    coupler_idx = paths.shape[0]-1

    ## Plot Shifted Goal
    mu = paths[coupler_idx, :, :].mean(1).reshape(-1, 1)
    std = max(paths[coupler_idx, :, :].std(1))
    
    goal = (normalize_curve(goal)*std+mu)
    ax1.plot(goal[0,:], goal[1,:], 'y-o', linewidth=4)
    
    lines = []
    for _ in edges:
        line, = ax1.plot([], [], '-', color='0.7', linewidth=3, path_effects=[pe.Stroke(linewidth=5, foreground='k'), pe.Normal()]) #linkage line
        lines.append(line)
        
    # ax1.set_title(title, fontsize=18)
    
    fixed_joint_plt, = ax1.plot(paths[2, 0, 0], paths[2,1, 0], marker='^', color='gray', label="fixed joint", ms=15, path_effects=[pe.Stroke(linewidth=3, foreground='k'), pe.Normal()])
    
    motor_joint_plt, = ax1.plot(paths[0, 0, 0], paths[0,1, 0], marker='^', color='magenta', label="motor joint", ms=15, path_effects=[pe.Stroke(linewidth=3, foreground='k'), pe.Normal()])
    
    
    crank_joint_plt, = ax1.plot(paths[1, 0, 0], paths[1,1, 0], marker='o', color='lime', label="crank joint", ms=15)
    
    fixed_ids = [0, 2]
    non_fixed_ids = list(set(fixed_ids) ^ set(range(coupler_idx+1)))
    # for n in non_fixed_ids[1:]:
    #     ax1.plot(self.paths[n, 0, :], self.paths[n,1, :], 'b-', label="pin path", ms=4)
    revolute_joint_plt, = ax1.plot(paths[non_fixed_ids[1:], 0, 0], paths[non_fixed_ids[1:],1, 0], 'ro', label="pin joints", ms=15)
    coupler_joint_plt, = ax1.plot(paths[coupler_idx, 0, 0],paths[coupler_idx, 1, 0], 'yo', label="coupler joint", markersize=15)
    
    
    # # lineRightLeg, = ax.plot([], [], '-o', ms=7, lw=3, mfc='blue', color='black') #RightLeg 2 linkage line
    footpath, = ax1.plot(paths[coupler_idx, 0, 0],paths[coupler_idx, 1, 0], 'r-o', lw=2, ms=3, alpha=1, mfc='red',mec='red') #foot-path line
    
    # # lineJoint12, = ax.plot([], [], '-o', lw=0, ms=3, alpha=1, mfc='green',mec='green') #foot-path line


    # #This is the animation function, which updates the drawing based off of where it is in the rotation.
    def animate(i):
        # print(i)
        
        for j, edge in enumerate(edges):
            
            lines[j].set_data(paths[edge, 0, i], paths[edge, 1, i])
            
        motor_joint_plt.set_data(paths[0,0,i], paths[0, 1, i])
        crank_joint_plt.set_data(paths[1, 0, i], paths[1, 1, i])
        fixed_joint_plt.set_data(paths[2, 0, i], paths[2, 1, i])
        revolute_joint_plt.set_data(paths[non_fixed_ids[1:], 0, i], paths[non_fixed_ids[1:], 1, i])
        coupler_joint_plt.set_data(paths[coupler_idx, 0, i], paths[coupler_idx, 1, i])
        
        # lineRightLeg.set_data(xRightLegVals[str(i)], yRightLegVals[str(i)])
        
        footpath.set_data(paths[coupler_idx,0,:i], paths[coupler_idx,1, :i])
        # lineJoint12.set_data(xValsLeg12[str(int(i / footSweepIncrements))], yValsLeg12[str(int(i / footSweepIncrements))])

        return #line, footpath #, lineRightLeg, lineJoint7, lineJoint12,

    # #This animation variable stores the animation, and notice that having blit=True makes it run a lot faster, and you can control the speed
    # #by changing the interval, although I reccommend changing the rotationIncrements instead. Right now it is running as fast as possible,
    # #with an interval of 0.    
    ani = FuncAnimation(fig, animate, frames=frames, interval=200, blit=False,) #init_func=init)
    filename = uniquify(filename) 
    ani.save(filename=filename)
    # video = ani.to_html5_video()
    # html = display.HTML(video)
    
    return filename

## Main Gradio Code

In [5]:
plt.close("all")

with gr.Blocks() as demo:
    gr.Markdown(
        """
    # GCP-HOLO Linkage Generator
    Once you select a target trajectory, algorithm parameters will become visible
    """
    )
    prev_runs = None
    if os.path.isdir(f"./parameters"):
        path = f"parameters/"
        old_param_files = next(walk(path), (None, None, []))[2]  # [] if no file
        prev_runs = gr.Dropdown(label="Previous Runs", choices=old_param_files)
        
    
    # sample_points = gr.Slider(label="Target Curve Points", minimum=6, maximum=200, value=20)
    
    with gr.Column(visible=True) as details_col:
        with gr.Row():
            with gr.Column():
                with gr.Tab("Load"):
                    example_target_curves = gr.Dropdown(label="Example Target Curves", choices=curve_names, value="", interactive=True)

                    target_curve_plot = gr.Plot(label="Target Curve Plot")

                with gr.Tab("Draw"):
                    draw_goal = gr.Paint()
                    set_draw_goal_btn = gr.Button("Set Goal")
                    discrete_goal_plot = gr.Plot(label="Drawn Goal Plot", visible=False)
                sample_points = gr.Slider(label="Target Curve Points", minimum=6, maximum=200, value=default_values_dict["sample_points"])

            with gr.Column():
                view_params = gr.Radio(label="Update Parameters", choices=["Environment Parameters", "Model Parameters", "Other Parameters"])
                # view_env_params = gr.Checkbox(label="Update Environment Parameters")
                # view_model_params = gr.Checkbox(label="Update Model Parameters")
                # view_other_params = gr.Checkbox(label="Update Other Parameters")
                reset_btn = gr.Button("Reset Params")
                with gr.Column(visible=False) as env_col:
                    max_nodes = gr.Slider(label="Max Allowable Nodes", minimum=5, maximum=20, step=1, value=default_values_dict["max_nodes"], precision=0)
                    resolution = gr.Slider(label="Scaffold Node Resolution", minimum=5, maximum=20, step=1, value=default_values_dict["resolution"], precision=0)
                    feature_points = gr.Slider(label="Number of Feature Points", minimum=1, maximum=20, step=1, value=default_values_dict["feature_points"], precision=0) #TODO updated to allow maximum = sample_points
                    fixed_initial_state = gr.Checkbox(label="Fixed Initial State", value=default_values_dict["fixed_initial_state"])
                    seed = gr.Number(label="Seed", value=default_values_dict["seed"], precision=0)
                    
                with gr.Column(visible=False) as model_col:
                    model = gr.Radio(label="Select Model", choices=["PPO", "DQN", "Random"], value=default_values_dict["model"])
                    steps = gr.Number(label="Number of Training Steps", value=default_values_dict["steps"], precision=0)
                    n_eval_episodes = gr.Number(label="Number of Evaluation Steps", value=default_values_dict["n_eval_episodes"], precision=0)
                    n_envs = gr.Number(label="Number of Parallel Environments to Run", value=default_values_dict["n_envs"], precision=0)
                    update_freq = gr.Number(label="Update Model Frequency", value=default_values_dict["update_freq"], precision=0)
                    opt_iter = gr.Slider(label="Gradient Steps per Update", minimum=1, maximum=10, step=1, value=default_values_dict["opt_iter"], precision=0)
                    gamma = gr.Slider(label="Discount Factor", minimum=0, maximum=1, step=0.01, value=default_values_dict["gamma"])
                    lr = gr.Slider(label="Learning Rate", minimum=0.00001, maximum=0.001, step=0.000005, value=default_values_dict["lr"])
                    batch_size = gr.Slider(label="Batch Size", minimum=1, maximum=5000, value=default_values_dict["batch_size"])
                    eps_clip = gr.Slider(label="Epsilon Clipping Value (PPO)", minimum=0, maximum=1, step=0.1, value=default_values_dict["eps_clip"])
                    ent_coef = gr.Slider(label="Entropy Coefficient Value (PPO)", minimum=0, maximum=1, step=0.01, value=default_values_dict["ent_coef"])
                    buffer_size = gr.Number(label="Buffer Size (DQN)", value=default_values_dict["buffer_size"], precision=0)
                    
                with gr.Column(visible=False) as other_col: 
                    log_freq = gr.Number(label="How Frequent to Log Training Data", value=default_values_dict["log_freq"], precision=0)
                    save_freq = gr.Number(label="How Save to Model Checkpoint", value=default_values_dict["save_freq"], precision=0)
                    cuda = gr.Radio(label="Device", choices=["cpu", "cuda"], value=default_values_dict["cuda"])
                    no_train = gr.Checkbox(label="Evaluate Only", value=default_values_dict["no_train"])
                    checkpoint_dropdown = gr.Dropdown(label="Checkpoints", choices=[], visible=False)
                    cmaes = gr.Checkbox(label="Enable CMA-ES Local Optimization", value=default_values_dict["cmaes"]) 
                
        train_btn = gr.Button("Run")
       
        
    solution_dropdown = gr.Dropdown(label="Solution Videos", choices=list(video_file_dict.keys()), visible=False)
    
    with gr.Row(visible=False) as video_row:
        video_cycle_1 = gr.PlayableVideo(label="1 Cycle Solution", visible=False)
        video_cycle_2 = gr.PlayableVideo(label="2 Cycle Solution", visible=False)
        video_cycle_3 = gr.PlayableVideo(label="3 Cycle Solution", visible=False)
        video_cycle_4 = gr.PlayableVideo(label="4 Cycle Solution", visible=False)
        video_cycle_5 = gr.PlayableVideo(label="5 Cycle Solution", visible=False)

    with gr.Row(visible=False) as video_row_cma:
        video_cycle_1_cma = gr.PlayableVideo(label="1 Cycle Solution cma", visible=False)
        video_cycle_2_cma = gr.PlayableVideo(label="2 Cycle Solution cma", visible=False)
        video_cycle_3_cma = gr.PlayableVideo(label="3 Cycle Solution cma", visible=False)
        video_cycle_4_cma = gr.PlayableVideo(label="4 Cycle Solution cma", visible=False)
        video_cycle_5_cma = gr.PlayableVideo(label="5 Cycle Solution cma", visible=False)
        
    training_message = gr.Markdown(visible=False)


    def update_curve(curve_name, slider_value):
        plt.close("all")
        path = name_to_path_dict[curve_name]

        goal_curve = pickle.load(open(path, 'rb')) # NOTE: sometimes ordering needs to be reversed add [:,::-1]
        
        
        idx = np.round(np.linspace(0, goal_curve.shape[1] - 1, slider_value)).astype(int)
        goal = normalize_curve(goal_curve[:,idx]) #R@normalize_curve(goal_curve[:,::-1][:,idx])
        goal[:, -1] = goal[:, 0]
        fig = plt.figure()
        plt.plot(goal[0,:], goal[1,:], '-o', c='red', mfc='blue', mec='k')
            
    
        return gr.Slider.update(minimum=6, maximum=goal_curve.shape[1], value=min(goal_curve.shape[1], slider_value)), fig, gr.update(visible=True)
    
    def update_slider(curve_name, slider_value):
        plt.close("all")
        path = name_to_path_dict[curve_name]

        goal_curve = pickle.load(open(path, 'rb')) 
        
        idx = np.round(np.linspace(0, goal_curve.shape[1] - 1, slider_value)).astype(int)
        goal = normalize_curve(goal_curve[:,idx]) #R@normalize_curve(goal_curve[:,::-1][:,idx])
        goal[:, -1] = goal[:, 0]
        fig = plt.figure()
        plt.plot(goal[0,:], goal[1,:], '-o', c='red', mfc='blue', mec='k')
        
        return fig
    
    def show_params(show):
        
        return gr.update(visible=(show=="Environment Parameters")), gr.update(visible=(show=="Model Parameters")), gr.update(visible=(show=="Other Parameters"))
    
    def show_env_params(show):
        return gr.update(visible=show)
    
    def show_model_params(show):
        return gr.update(visible=show)
    
    def show_other_params(show):
        return gr.update(visible=show)
    
    def run(example_target_curves,
            sample_points, 
            max_nodes, 
            resolution, 
            feature_points, 
            fixed_initial_state, 
            seed,
            model,
            steps,
            n_eval_episodes, 
            n_envs,
            update_freq,
            opt_iter,
            gamma,
            lr,
            batch_size,
            eps_clip,
            ent_coef,
            buffer_size,
            log_freq,
            save_freq,
            cuda,
            no_train,
            checkpoint,
            cmaes):
        
        parameters = full_default_value_dict
        
        parameters.update(dict(zip(param_names,[sample_points, max_nodes, resolution, feature_points, fixed_initial_state, seed,model,steps,n_eval_episodes, n_envs,update_freq,opt_iter,gamma,lr,batch_size,eps_clip,ent_coef,buffer_size,log_freq,save_freq,cuda,no_train, cmaes])))
        
        parameters["goal_filename"] = name_to_id_dict[example_target_curves]
        parameters["goal_path"] = name_to_dir_dict[example_target_curves]
        
        if checkpoint: 
            path = f"trained/{name_to_id_dict[example_target_curves]}/{model}/{checkpoint[:10]}/{checkpoint}"
            
            parameters["checkpoint"] = path
            
        now = datetime.now().strftime("%m_%d_%Y_%HD:%M:%S")
        day = datetime.now().strftime("%m_%d_%Y") 

        ## Log / eval / save location
        parameter_dir = f"./parameters"
        
        if not os.path.isdir(parameter_dir):
            os.makedirs(parameter_dir, exist_ok=True)
        
        pickle.dump([example_target_curves, sample_points, max_nodes, resolution, feature_points, fixed_initial_state, seed,model,steps,n_eval_episodes, n_envs,update_freq,opt_iter,gamma,lr,batch_size,eps_clip,ent_coef,buffer_size,log_freq,save_freq,cuda,no_train,checkpoint, cmaes], open(os.path.join(parameter_dir, f"{now}_{parameters['goal_filename']}_{parameters['model']}.pkl"), 'wb'))
        ## TODO add "Take a break message with Coffee" 
        ## TODO add progress bar 
        ## TODO add output figures and animations
        ## TODO add save button 
        
        env_kwargs, best_designs, best_cmaes = main(parameters)
        
        
        for design in best_designs.values():
            env_kwargs['node_positions'] = design[0]
            env_kwargs['edges'] = design[1]
            tmp_env = Mech(**env_kwargs)
            n = tmp_env.number_of_nodes()
            cycles = tmp_env.number_of_cycles()
            filename = animateWalker(tmp_env.paths[:n,:,:], tmp_env.get_edges(), tmp_env.goal, sample_points, filename=f"{cycles}_cycle_linkage.mp4")
            video_file_dict[f"{cycles} cycle"] = filename

        if cmaes:
            for design in best_cmaes.values():
                env_kwargs['node_positions'] = design[0]
                env_kwargs['edges'] = design[1]
                tmp_env = Mech(**env_kwargs)
                n = tmp_env.number_of_nodes()
                cycles = tmp_env.number_of_cycles()
                filename = animateWalker(tmp_env.paths[:n,:,:], tmp_env.get_edges(), tmp_env.goal, sample_points, filename=f"{cycles}_cycle_linkage_cmaes.mp4")
                video_file_dict_cma[f"{cycles} cycle"] = filename
        
        return gr.update(visible=True), gr.update(value=video_file_dict["1 cycle"], visible=True), gr.update(value=video_file_dict["2 cycle"], visible=True), gr.update(value=video_file_dict["3 cycle"], visible=True), gr.update(value=video_file_dict["4 cycle"], visible=True), gr.update(value=video_file_dict["5 cycle"], visible=True), gr.update(visible=False), gr.update(visible=True), gr.update(value=video_file_dict_cma["1 cycle"], visible=cmaes), gr.update(value=video_file_dict_cma["2 cycle"], visible=cmaes), gr.update(value=video_file_dict_cma["3 cycle"], visible=cmaes), gr.update(value=video_file_dict_cma["4 cycle"], visible=cmaes), gr.update(value=video_file_dict_cma["5 cycle"], visible=cmaes)#gr.Dropdown.update(choices=list(video_file_dict.keys()), visible=True)

    
    def reset():

        return list(default_values_dict.values())
    
    def show_solution(solution_dropdown):
        print(video_file_dict[solution_dropdown]) 
        return video_file_dict[solution_dropdown]


    example_target_curves.change(update_curve, [example_target_curves, sample_points], [sample_points, target_curve_plot, details_col])
    
    sample_points.change(update_slider, [example_target_curves, sample_points], target_curve_plot)
    
    # view_env_params.change(show_env_params, view_env_params, env_col)
    # view_model_params.change(show_model_params, view_model_params, model_col)
    # view_other_params.change(show_other_params, view_other_params, other_col)
    view_params.change(show_params, view_params, [env_col, model_col, other_col])
    
    reset_btn.click(reset, [], [sample_points, 
                                 max_nodes, 
                                 resolution, 
                                 feature_points, 
                                 fixed_initial_state, 
                                 seed,
                                 model,
                                 steps,
                                 n_eval_episodes, 
                                 n_envs,
                                 update_freq,
                                 opt_iter,
                                 gamma,
                                 lr,
                                 batch_size,
                                 eps_clip,
                                 ent_coef,
                                 buffer_size,
                                 log_freq,
                                 save_freq,
                                 cuda,
                                 no_train,
                                 cmaes])
    def show_message():
        return gr.update(value=" ## Started Training...go grab a coffee while running ", visible=True)

    
    train_btn.click(show_message, inputs=[], outputs=training_message) 
    train_btn.click(run, inputs=[example_target_curves,
                                 sample_points, 
                                 max_nodes, 
                                 resolution, 
                                 feature_points, 
                                 fixed_initial_state, 
                                 seed,
                                 model,
                                 steps,
                                 n_eval_episodes, 
                                 n_envs,
                                 update_freq,
                                 opt_iter,
                                 gamma,
                                 lr,
                                 batch_size,
                                 eps_clip,
                                 ent_coef,
                                 buffer_size,
                                 log_freq,
                                 save_freq,
                                 cuda,
                                 no_train,
                                 checkpoint_dropdown,
                                 cmaes] , outputs=[video_row, video_cycle_1, video_cycle_2, video_cycle_3, video_cycle_4, video_cycle_5, training_message, video_row_cma, video_cycle_1_cma, video_cycle_2_cma, video_cycle_3_cma, video_cycle_4_cma, video_cycle_5_cma])
    
    # solution_dropdown.change(show_solution, solution_dropdown, gr.Video())
    def load_checkpoints(show, goal_name, model):

        path = f"trained/{name_to_id_dict[goal_name]}/{model}/"
        dates = next(walk(path), (None, None, []))[1]  # [] if no file
        filenames = sum([next(walk(os.path.join(path, date)), (None, None, []))[2] for date in dates], [])
        
        return gr.update(choices=filenames, visible=show, interactive=show)
    
    no_train.change(load_checkpoints, inputs=[no_train, example_target_curves, model], outputs=checkpoint_dropdown)
    
    if prev_runs is not None: 
        def upload_params(filename):
            parameters = pickle.load(open(f"parameters/{filename}", "rb"))
            return parameters
        prev_runs.change(upload_params, prev_runs, outputs=[example_target_curves,
                                 sample_points, 
                                 max_nodes, 
                                 resolution, 
                                 feature_points, 
                                 fixed_initial_state, 
                                 seed,
                                 model,
                                 steps,
                                 n_eval_episodes, 
                                 n_envs,
                                 update_freq,
                                 opt_iter,
                                 gamma,
                                 lr,
                                 batch_size,
                                 eps_clip,
                                 ent_coef,
                                 buffer_size,
                                 log_freq,
                                 save_freq,
                                 cuda,
                                 no_train,
                                 checkpoint_dropdown,
                                 cmaes])
        
    def convert_draw_to_goal(image):
        #! TODO: implement this
        raise NotImplementedError
        
    set_draw_goal_btn.click(convert_draw_to_goal, inputs=draw_goal, outputs=discrete_goal_plot)
    
demo.launch()



Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.




Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


main took 164.73ms


Traceback (most recent call last):
  File "/home/mitch/anaconda3/envs/py_walkers/lib/python3.9/site-packages/gradio/routes.py", line 292, in run_predict
    output = await app.blocks.process_api(
  File "/home/mitch/anaconda3/envs/py_walkers/lib/python3.9/site-packages/gradio/blocks.py", line 1007, in process_api
    result = await self.call_function(fn_index, inputs, iterator, request)
  File "/home/mitch/anaconda3/envs/py_walkers/lib/python3.9/site-packages/gradio/blocks.py", line 848, in call_function
    prediction = await anyio.to_thread.run_sync(
  File "/home/mitch/anaconda3/envs/py_walkers/lib/python3.9/site-packages/anyio/to_thread.py", line 31, in run_sync
    return await get_asynclib().run_sync_in_worker_thread(
  File "/home/mitch/anaconda3/envs/py_walkers/lib/python3.9/site-packages/anyio/_backends/_asyncio.py", line 937, in run_sync_in_worker_thread
    return await future
  File "/home/mitch/anaconda3/envs/py_walkers/lib/python3.9/site-packages/anyio/_backends/_asyncio.

Training set to: True
Starting Training...
Finished Training...
Saving Model...
Evaluating Model...
evaluate_policy took 2.510sec
Saving Evaluation Designs...
main took 33.276sec




Training set to: True
Starting Training...
