In [1]:
import numpy as np
import matplotlib.pyplot as plt
import gymnasium as gym
import math
import mlflow
from mlflow.tracking import MlflowClient
import optuna
import os

from TMDP import TMDP
from algorithms import *
from model_functions import *
from policy_utils import *
from experiment_result_utils import *
from constants import *

from RiverSwim import *
from CurriculumMPI import CurriculumMPI
from scipy.stats import uniform

In [2]:
#River Swim Environment
nS = 50
uniform_restart = True
num_runs = 1

small = 5e-3
large = 1.
nA = 2
gamma = 0.99

original_mu = np.zeros(nS)
original_mu[1] = 1.
xi = np.ones(nS)/(nS-2)
xi[0] = 0
xi[-1] = 0

if uniform_restart:
    mu = xi
else:
    mu = original_mu


episodes = 10000000
checkpoint_step=10000
test_episodes = 10000
discount_tau = True
param_decay=True
debug = False

lam = 1
experiment_results = []
tests_returns = []
tests_lens = []
exp_taus = []

In [3]:
run_name = f"CurrMPI_{uniform_restart}"

mlflow.set_tracking_uri(MLFLOW_URI)
experiment_name = f"RiverSwim_{nS}_{uniform_restart}"
experiment_id = get_or_create_experiment(experiment_name)
mlflow.set_experiment(experiment_name)

save_path = f"results/{experiment_name}/run_{run_name}"
label = run_name.split("_")[0]

In [4]:
tau = 0.3
model_lr = 0.09
pol_lr = 0.05
temp = 3.0

epochs = 10
batch_size = 10
final_temp=1

check_convergence=False
biased=False

In [5]:
def run_experiment(index, seed, run_name):
    sub_run_name = f"{run_name}_{index}"
    
    with mlflow.start_run(nested=True, run_name=sub_run_name):
        
        # Environment specific configuration   
        set_policy_seed(seed)
        env = RiverSwim(nS, mu, small=small, large=large, seed=seed)
        
        # Environment independent configuration
        tmdp = TMDP(env, 
                    xi, 
                    tau=tau, 
                    gamma=gamma, 
                    discount_tau=discount_tau, 
                    seed=seed, 
                    xi_schedule=river_swim_uniform_curr_xi)
        tmdp.update_tau(tau)

        curr_MPI = CurriculumMPI(tmdp, checkpoint_step=checkpoint_step)
        curr_MPI.train(model_lr, pol_lr, batch_size=batch_size, 
                        lam=lam, temp=temp, final_temp=final_temp,
                        episodes=episodes, check_convergence=check_convergence,
                        param_decay=param_decay, biased=biased,
                        debug=debug, epochs=epochs, original_mu=original_mu)   
        
        avg_return = np.average(curr_MPI.reward_records[-10:])/batch_size
        
        mlflow.log_metric("Avg Return", avg_return)

        test_policies = test_policies_len(tmdp, curr_MPI.thetas, test_episodes, 1e-100, mu=original_mu)
        test_policies_return = test_policies[0]
        test_pol_len = test_policies[1]

        result_dict = {
            "thetas" : curr_MPI.thetas,
            "taus" : curr_MPI.taus,
            "reward_records" : curr_MPI.reward_records,
            "test_policies_return" : test_policies_return,
            "index" : index,
            "test_pol_len" : test_pol_len,
            "Qs": curr_MPI.Qs,
        }

        tests_returns.append(test_policies_return)
        experiment_results.append(result_dict)
        tests_lens.append(test_pol_len)
        exp_taus.append(curr_MPI.taus)


In [6]:
def run_experiments(num_runs=10):
    with mlflow.start_run(run_name=run_name):
        
        for i in range(num_runs):
            seed = constants.SEEDS[i]
            run_experiment(i, seed, run_name)
        
        pad_results = pad_to_same_length(tests_returns)
        pad_lens = pad_to_same_length(tests_lens)
        pad_taus = pad_to_same_length(exp_taus)
        experiment_dict = {
            "tests_returns": pad_results,
            "taus": pad_taus,
            "tests_lens": pad_lens,
            "num_runs": num_runs,
            "label": label,
            "checkpoint_step": checkpoint_step,
            "uniform_restart": uniform_restart,
        }

        mlflow.set_tags(tags={
            "seed": seed,
            "tau": tau,
            "gamma": gamma,
            "checkpoint_step": checkpoint_step,
            "test_episodes": test_episodes,
            "episodes": episodes,
            "model_lr": model_lr,
            "pol_lr": pol_lr,
            "temp": temp,
            "final_temp": final_temp,
            "batch_size": batch_size,
            "epochs": epochs,
            "lam": lam,
            "discount_tau": discount_tau,
            "param_decay": param_decay,
            "small": small,
            "large": large,
        })
        try:
            save_to_mlflow(experiment_dict)
        except Exception as e:
            print(e)
            print("Something went wrong saving the experiment results to MLFlow.")
            print("Saving locally instead.")
            time.sleep(5)
            save(save_path, experiment_dict)

        rewards_fig = plot_avg_test_return(tests_returns, f"{run_name[:-3]} Avg Return on {num_runs} runs")
        try:
            mlflow.log_figure(figure=rewards_fig, artifact_file="reward_image.png")
        except Exception as e:
            print(e)
            print("Something went wrong saving the figure to MLFlow.")
            print("Saving locally instead.")
            time.sleep(5)
            rewards_fig.savefig(save_path+"/reward_image.png")

In [7]:
run_experiments(num_runs=num_runs)

Current seed for result reproducibility: 2999
Episode: 100000 reward: 3.5 tau 0.3 batch_len 5 teleports 1307
Episode: 200000 reward: 3.5 tau 0.3 batch_len 5 teleports 379
Episode: 300000 reward: 0.0 tau 0.3 batch_len 0 teleports 27
Episode: 400000 reward: 0.7000500000000001 tau 0.29995 batch_len 1 teleports 224
Episode: 500000 reward: 6.3126 tau 0.2986 batch_len 9 teleports 557
Episode: 600000 reward: 6.339510000000001 tau 0.29561 batch_len 9 teleports 446
Episode: 700000 reward: 4.9710013 tau 0.29087 batch_len 9 teleports 455
Episode: 800000 reward: 0.71409 tau 0.28591 batch_len 1 teleports 158
Episode: 900000 reward: 1.4411939999999999 tau 0.2812 batch_len 3 teleports 106
Episode: 1000000 reward: 3.6224500000000006 tau 0.27551 batch_len 5 teleports 335
Episode: 1100000 reward: 0.0 tau 0.26971 batch_len 0 teleports 18
Episode: 1200000 reward: 4.41618 tau 0.26397 batch_len 6 teleports 330
Episode: 1300000 reward: 4.45212 tau 0.25798 batch_len 6 teleports 491
Episode: 1400000 reward: 4.

In [None]:
print(get_softmax_policy(experiment_results[-1]["thetas"][-1], 1e-1))

[[4.99999595e-01 5.00000405e-01]
 [5.00224535e-01 4.99775465e-01]
 [5.00074242e-01 4.99925758e-01]
 [5.00156586e-01 4.99843414e-01]
 [5.00066930e-01 4.99933070e-01]
 [5.00069192e-01 4.99930808e-01]
 [5.00071809e-01 4.99928191e-01]
 [5.00084453e-01 4.99915547e-01]
 [5.00064295e-01 4.99935705e-01]
 [5.00044531e-01 4.99955469e-01]
 [5.00065432e-01 4.99934568e-01]
 [5.00046898e-01 4.99953102e-01]
 [5.00051028e-01 4.99948972e-01]
 [5.00054619e-01 4.99945381e-01]
 [5.00060811e-01 4.99939189e-01]
 [5.00058650e-01 4.99941350e-01]
 [5.00056644e-01 4.99943356e-01]
 [5.00061310e-01 4.99938690e-01]
 [5.00061301e-01 4.99938699e-01]
 [5.00062129e-01 4.99937871e-01]
 [5.00063220e-01 4.99936780e-01]
 [5.00063904e-01 4.99936096e-01]
 [5.00057626e-01 4.99942374e-01]
 [5.00059929e-01 4.99940071e-01]
 [5.00061504e-01 4.99938496e-01]
 [5.00059689e-01 4.99940311e-01]
 [5.00061757e-01 4.99938243e-01]
 [5.00062610e-01 4.99937390e-01]
 [5.00059413e-01 4.99940587e-01]
 [5.00056565e-01 4.99943435e-01]
 [5.000527

In [None]:
print(experiment_results[-1]["Qs"][-1], 1e-1)

[[0.005      0.005     ]
 [0.005      0.005     ]
 [0.005      0.005     ]
 [0.005      0.005     ]
 [0.005      0.005     ]
 [0.005      0.005     ]
 [0.005      0.005     ]
 [0.005      0.005     ]
 [0.005      0.005     ]
 [0.005      0.005     ]
 [0.005      0.005     ]
 [0.005      0.005     ]
 [0.005      0.005     ]
 [0.005      0.005     ]
 [0.005      0.005     ]
 [0.005      0.005     ]
 [0.005      0.005     ]
 [0.005      0.005     ]
 [0.005      0.005     ]
 [0.005      0.005     ]
 [0.005      0.005     ]
 [0.005      0.005     ]
 [0.005      0.005     ]
 [0.005      0.005     ]
 [0.005      0.005     ]
 [0.005      0.005     ]
 [0.005      0.005     ]
 [0.005      0.005     ]
 [0.005      0.005     ]
 [0.005      0.005     ]
 [0.005      0.005     ]
 [0.005      0.00500001]
 [0.005      0.00500003]
 [0.00500002 0.00500011]
 [0.00500006 0.00500035]
 [0.00500021 0.00500138]
 [0.00500078 0.00500468]
 [0.00500268 0.00501728]
 [0.00501025 0.00506059]
 [0.00503569 0.00520426]
