In [None]:
import pybullet_envs
from gym import make
import numpy as np
import torch
from torch import nn
from torch.distributions import Normal
from torch.nn import functional as F
from torch.optim import Adam
import random
from itertools import product
import joblib
from os import makedirs
import uuid
from train import *
import json
from joblib import Parallel, delayed
from matplotlib import pyplot as plt
from itertools import combinations, product, permutations
from tqdm.notebook import tqdm
import warnings
warnings.filterwarnings("ignore")

from train import *

In [None]:
def run(env_name="AntBulletEnv-v0", 
        transitions=1000000, 
        eps=0.2, 
        gamma=0.99, 
        tau=0.002, 
        actor_lr=2e-4, 
        critic_lr=5e-4, 
        sigma=2, 
        c=1, 
        updates_number=1, 
        policy_delay=1, 
        batch_size=128,
        buffer_size=200000,
        start_training=None,
        evaluate_every=None,
        seed=42):
    start_training = start_training if start_training else buffer_size // 10
    evaluate_every = evaluate_every if evaluate_every else transitions // 100
    log = {
        "cfg": {
            "transitions": transitions,
            "eps": eps,
            "gamma": gamma,
            "tau": tau,
            "actor_lr": actor_lr,
            "critic_lr": critic_lr,
            "sigma": sigma,
            "c": c,
            "updates_number": updates_number,
            "policy_delay": policy_delay,
            "batch_size": batch_size,
            "start_training": start_training
        },
        "step": [],
        "rmean": [],
        "rstd": []
    }

    makedirs("experiments", exist_ok=True)
    saved_agent_dir = "experiments/" + str(uuid.uuid4())
    makedirs(saved_agent_dir)
    with open(f"{saved_agent_dir}/params.json", "w") as param:
        json.dump(log["cfg"], param, indent=4)
    log_file = open(f"{saved_agent_dir}/log.csv", "a")

    env = make(env_name)
    test_env = make(env_name)
    
    td3 = TD3(state_dim=env.observation_space.shape[0], 
              action_dim=env.action_space.shape[0], 
              actor_lr=actor_lr, critic_lr=critic_lr,
              buffer_size=buffer_size)

    state = env.reset()
    episodes_sampled = 0
    steps_sampled = 0
    
    set_seed(env, seed=seed)

    t = tqdm(range(transitions))

    for i in t:
        if i > start_training:
            action = td3.act(state)
            action = np.clip(action + eps * np.random.randn(*action.shape), -1, +1)

            next_state, reward, done, _ = env.step(action)

            td3.update((state, action, next_state, reward, done), 
                        sigma=sigma, c=c, updates_number=updates_number, policy_delay=policy_delay,
                        batch_size=batch_size, gamma=gamma, tau= tau)
        else:
            action = np.random.uniform(-1, 1, size=env.action_space.shape)
            next_state, reward, done, _ = env.step(action)
            td3.replay_buffer.append((state, action, next_state, reward, done))

        state = env.reset() if done else next_state
        
        if (i + 1) % evaluate_every == 0:
            rewards = evaluate_policy(test_env, td3, 5, seed=seed)
            rmean = np.mean(rewards)
            rstd = np.std(rewards)

            log["step"].append(i + 1)
            log["rmean"].append(rmean)
            log["rstd"].append(rstd)

            log_file.write(f"{i + 1},{rmean},{rstd}\n")
            log_file.flush()
            
            if rmean > 2500:
                td3.save(name=f"{saved_agent_dir}/{i + 1}_{int(rmean)}_{int(rstd)}.pkl")

            t.set_description(f"{rmean:0.2f} | {rstd:0.2f}")

    return log

drun = delayed(run)

In [None]:
def plot(log):
    rmean = np.array(log["rmean"])
    rstd = np.array(log["rstd"])

    fig, ax = plt.subplots(figsize=(12, 8))
    cfg = log["cfg"]
    ax.set_title(f"{cfg}")
    ax.set_xlabel("№ of transitions")
    ax.set_ylabel("Mean reward")

    plt.hlines(2000, np.min(log["step"]), np.max(log["step"]),
               colors="r", label="Solved")

    plt.plot(log["step"],
             rmean,
             label="TD3")                


    plt.fill_between(log["step"],
                     rmean - rstd,
                     rmean + rstd, alpha=0.5)
    plt.legend()
    plt.show()

In [None]:
base_config = {
    "transitions": 10_000_000,
    "eps": 0.1,
    "gamma": 0.99,
    "tau": 0.002,
    "actor_lr": 2e-4,
    "critic_lr": 5e-4,
    "sigma": 2,
    "c": 0.5,
    "updates_number": 1,
    "policy_delay": 1,
    "batch_size": 128,
    "buffer_size": 200_000,
    "start_training": -1,
    "evaluate_every": 10_000,
}

In [None]:
plot(run(**base_config))

In [None]:
configs = []

params = product(
    (1_000_000,),    # transitions
    (0.2, 0.5),      # eps
    (0.999, 0.99),   # gamma
    (1e-3, 1e-2),    # tau
    (1e-4, 1e-3),    # actor lr
    (2e-4, 2e-3),    # critic lr
    (2, 1, 0.5),     # sigma
    (2, 1, 0.5),     # c
    (1, 2, 4, 8),    # number of updates
    (1, 2, 4, 8),    # policy delay
    (64, 128, 256),  # batch size
)

for tramsitions, eps, gamma, tau, actor_lr, critic_lr, sigma, c, updates_number, policy_delay, batch_size in params:
    configs.append({
        "transitions": tramsitions,
        "eps": eps,
        "gamma": gamma,
        "tau": tau,
        "actor_lr": actor_lr,
        "critic_lr": critic_lr,
        "sigma": sigma,
        "c": c,
        "updates_number": updates_number,
        "policy_delay": policy_delay,
        "batch_size": batch_size
    })

print(len(configs))

In [None]:
%%time

logs = Parallel(n_jobs=5)(drun(**cfg) for cfg in configs)

In [None]:
for log in logs:
    plot(log)