In [1]:
import pybullet_envs
from gym import make
import numpy as np
import torch
from torch import nn
from torch.distributions import Normal
from torch.nn import functional as F
from torch.optim import Adam
import random
from itertools import product
import joblib
from os import makedirs
import uuid
from train import *
import json
from joblib import Parallel, delayed
from matplotlib import pyplot as plt
from itertools import combinations, product, permutations

In [2]:
from train import *

In [3]:
def run(env_name="AntBulletEnv-v0", 
        transitions=1000000, 
        eps=0.2, 
        gamma=0.99, 
        tau=0.002, 
        actor_lr=2e-4, 
        critic_lr=5e-4, 
        sigma=2, 
        c=1, 
        updates_number=1, 
        policy_delay=1, 
        batch_size=128):
    log = {
        "cfg": {
            "transitions": transitions,
            "eps": eps,
            "gamma": gamma,
            "tau": tau,
            "actor_lr": actor_lr,
            "critic_lr": critic_lr,
            "sigma": sigma,
            "c": c,
            "updates_number": updates_number,
            "policy_delay": policy_delay,
            "batch_size": batch_size
        },
        "step": [],
        "rmean": [],
        "rstd": []
    }

    makedirs("experiments", exist_ok=True)
    saved_agent_dir = "experiments/" + str(uuid.uuid4())
    makedirs(saved_agent_dir)
    with open(f"{saved_agent_dir}/params.json", "w") as param:
        json.dump(log["cfg"], param, indent=4)
    log_file = open(f"{saved_agent_dir}/log", "a")

    env = make(env_name)
    test_env = make(env_name)
    
    td3 = TD3(state_dim=env.observation_space.shape[0], 
              action_dim=env.action_space.shape[0], 
              actor_lr=actor_lr, critic_lr=critic_lr)

    state = env.reset()
    episodes_sampled = 0
    steps_sampled = 0
    
    for i in range(transitions):
        action = td3.act(state)
        action = np.clip(action + eps * np.random.randn(*action.shape), -1, +1)

        next_state, reward, done, _ = env.step(action)

        td3.update((state, action, next_state, reward, done), 
                    sigma=sigma, c=c, updates_number=updates_number, policy_delay=policy_delay,
                    batch_size=batch_size, gamma=gamma, tau= tau)
        
        state = env.reset() if done else next_state
        
        if (i + 1) % (transitions // 100) == 0:
            rewards = evaluate_policy(test_env, td3, 50)
            rmean = np.mean(rewards)
            rstd = np.std(rewards)

            log["step"].append(i + 1)
            log["rmean"].append(rmean)
            log["rstd"].append(rstd)

            log_file.write(f"{i + 1},{rmean},{rstd}\n")
            log_file.flush()
            td3.save(name=f"{saved_agent_dir}/{i + 1}_{int(rmean)}_{int(rstd)}.pkl")
    return log

drun = delayed(run)

In [4]:
def plot(log):
    rmean = np.array(log["rmean"])
    rstd = np.array(log["rstd"])

    fig, ax = plt.subplots(figsize=(12, 8))
    cfg = log["cfg"]
    ax.set_title(f"{cfg}")
    ax.set_xlabel("№ of transitions")
    ax.set_ylabel("Mean reward")

    plt.hlines(2000, np.min(log["step"]), np.max(log["step"]),
               colors="r", label="Solved")

    plt.plot(log["step"],
             rmean,
             label="TD3")

    plt.fill_between(log["step"],
                     rmean - rstd,
                     rmean + rstd, alpha=0.5)
    plt.legend()
    plt.show()

In [5]:
base_config = {
    "transitions": 1_000_000,
    "eps": 0.2,
    "gamma": 0.99,
    "tau": 0.002,
    "actor_lr": 2e-4,
    "critic_lr": 5e-4,
    "sigma": 2,
    "c": 1,
    "updates_number": 1,
    "policy_delay": 1,
    "batch_size": 128
}

configs = [base_config, base_config.copy(), base_config.copy(), base_config.copy(), base_config.copy(), base_config.copy()]

In [6]:
# eps_lst = [0.01, 0.1, 0.3, 0.4, 0.5]
# for idx, eps in enumerate(eps_lst):
#     configs[idx + 1]["eps"] = eps

In [7]:
# logs = Parallel(n_jobs=6)(drun(**cfg) for cfg in configs)

In [8]:
# for log in logs:
#     plot(log)

In [9]:
configs = []

params = product(
    (3_000_000,),
    (0.1, 0.2, 0.3, 0.4, 0.5),
    (0.95, 0.99, 0.999),
    (1e-4, 1e-3, 1e-2),
    (1e-4, 2e-4, 1e-3),
    (2e-4, 4e-4, 2e-3),
    (0.5, 1, 2),
    (0.5, 1, 2),
    (1, 2, 4, 8),
    (1, 2, 4, 8),
    (64, 128),
)

for tramsitions, eps, gamma, tau, actor_lr, critic_lr, sigma, c, updates_number, policy_delay, batch_size in params:
    configs.append({
        "transitions": tramsitions,
        "eps": eps,
        "gamma": gamma,
        "tau": tau,
        "actor_lr": actor_lr,
        "critic_lr": critic_lr,
        "sigma": sigma,
        "c": c,
        "updates_number": updates_number,
        "policy_delay": policy_delay,
        "batch_size": batch_size
    })

print(len(configs))

116640


In [None]:
%%time

logs = Parallel(n_jobs=6)(drun(**cfg) for cfg in configs)

In [None]:
for log in logs:
    plot(log)