# Step 1: Set up Parameters

In [1]:
%matplotlib inline
from dotmap import DotMap
from envs import *
import torch as th
import numpy as np
import os
import datetime
import run_learning
import run_eval

from stable_baselines3 import SAC
from stable_baselines3.sac import MlpPolicy
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.logger import configure
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.callbacks import CheckpointCallback

In [4]:
params = DotMap()

dt = 0.01
total_time = 5

#General Params
params.runner = "Mohsin" #just your first name
params.device = "Hybrid Robotics Server"
params.eval_freq = 2500
params.save_freq = 50000
params.timesteps = 500000
params.gamma = 0.98
params.policy_kwargs = dict(activation_fn=th.nn.Tanh)
params.eps = 0.1
params.num_trials = 1

#Env Specific Params
params.envs.pendulum.env = Pendulum.Pendulum() #base env for simulation
params.envs.pendulum.eval_env = Pendulum.Pendulum() #extra env for eval callback
params.envs.pendulum.run = True #if you want run_learning to train on this env
params.envs.pendulum.m = 1 #mass of pendulum
params.envs.pendulum.l = 1 #half the length of pendulum (length to com)
params.envs.pendulum.g = 1 #gravity
params.envs.pendulum.lam = 0.005 #damping coefficient
params.envs.pendulum.eps = params.eps
params.envs.pendulum.max_input = 10
params.envs.pendulum.min_input = -10
params.envs.pendulum.init_low = [-1, -0.5]
params.envs.pendulum.init_high = [1, 0.5]
params.envs.pendulum.dt = dt
params.envs.pendulum.total_time = total_time

params.envs.quadrotor.env = Quadrotor.Quadrotor()
params.envs.quadrotor.eval_env = Quadrotor.Quadrotor()
params.envs.quadrotor.run = False
params.envs.quadrotor.dt = dt
params.envs.quadrotor.total_time = total_time
params.envs.quadrotor.i_xx = 1
params.envs.quadrotor.m = 1
params.envs.quadrotor.g = 1
params.envs.quadrotor.max_input = np.array([10, 3])
params.envs.quadrotor.min_input = np.array([5, -3])
params.envs.quadrotor.init_low = -1
params.envs.quadrotor.init_high = 1

params.envs.pvtol.env = Pvtol.Pvtol()
params.envs.pvtol.eval_env = Pvtol.Pvtol()
params.envs.pvtol.run = False
params.envs.pvtol.dt = dt
params.envs.pvtol.total_time = total_time
params.envs.pvtol.eps = 0.01
params.envs.pvtol.m = 1
params.envs.pvtol.g = 1
params.envs.pvtol.max_input = np.array([4, 1])
params.envs.pvtol.min_input = np.array([1, -1])
params.envs.pvtol.init_low = -1
params.envs.pvtol.init_high = 1

params.envs.manipulator.env = Manipulator.Manipulator()
params.envs.manipulator.eval_env = Manipulator.Manipulator()
params.envs.manipulator.run = False
params.envs.manipulator.dt = dt
params.envs.manipulator.total_time = total_time
params.envs.manipulator.eps = 0.01
params.envs.manipulator.k1 = 1
params.envs.manipulator.k2 = 1
params.envs.manipulator.k3 = 1
params.envs.manipulator.max_input = 4
params.envs.manipulator.min_input = -4
params.envs.manipulator.init_low = -1
params.envs.manipulator.init_high = 1

params.envs.cartpole.env = Cartpole.Cartpole()
params.envs.cartpole.eval_env = Cartpole.Cartpole()
params.envs.cartpole.run = False
params.envs.cartpole.dt = dt
params.envs.cartpole.total_time = total_time
params.envs.cartpole.eps = 0.01
params.envs.cartpole.g = 1
params.envs.cartpole.mp = 0.5
params.envs.cartpole.mc = 1
params.envs.cartpole.l = 0.5  # actually half the pole's length
params.envs.cartpole.f = 5 # force
params.envs.cartpole.lam = 0.01
params.envs.cartpole.max_input = 2
params.envs.cartpole.min_input = -2
params.envs.cartpole.init_low = -1
params.envs.cartpole.init_high = 1

# Step 2: Run Learning

In [None]:
run_learning(params)

# Step 3: Evaluate Models

In [None]:
def find_folders(params, check):
    folders=[]
    for folder in os.listdir("./Runs"):
        folder_param_path = os.path.join(os.path.join("./Runs", folder), "params.pkl")
        if os.path.exists(folder_param_path):
            with open(os.path.join(os.path.join("./Runs", folder), "params.pkl"), 'rb') as f:
                params = pickle.load(f)
            if check(params):
                folders.append(folder)
                
    if not folders:
        print("No folders matching")
    return folders

In [None]:
results = evaluate(path)

In [None]:
plt.plot(results["actions"])
plt.show()