In [111]:
from distributedDDPG import DDPGAgent, Episodes
import ray
import time
import numpy as np
import tqdm
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch.optim import Adam
from models import Actor, Critic
import copy
from env import DistributedTSCSEnv

In [79]:
config = {
    'env_config': {
        'nCyl': 4,
        'k0amax': 0.45,
        'k0amin': 0.35,
        'nFreq': 11,
        'actionRange': 0.2,
        'episodeLength': 100},
    'model': {
        'actor_nHidden': 2,
        'actor_hSize': 128,
        'critic_nHidden': 8,
        'critic_hSize': 128},
    'num_workers': 5,
}

In [80]:
ray.init()

2020-11-01 21:06:30,442	INFO services.py:1164 -- View the Ray dashboard at [1m[32mhttp://127.0.0.1:8265[39m[22m


{'node_ip_address': '10.0.0.12',
 'raylet_ip_address': '10.0.0.12',
 'redis_address': '10.0.0.12:6379',
 'object_store_address': '/tmp/ray/session_2020-11-01_21-06-29_895930_2773572/sockets/plasma_store',
 'raylet_socket_name': '/tmp/ray/session_2020-11-01_21-06-29_895930_2773572/sockets/raylet',
 'webui_url': '127.0.0.1:8265',
 'session_dir': '/tmp/ray/session_2020-11-01_21-06-29_895930_2773572',
 'metrics_export_port': 48653}

In [106]:
agents = [DDPGAgent.remote(config) for _ in range(config['num_workers'])]
data = []
for _ in range(5):
    futures = [agent.rollout_episode.remote(1.0) for agent in agents]
    start = time.time()
    data += ray.get(futures)
    print(f'Parallel env data generation: {time.time() - start}')

Parallel env data generation: 18.601786375045776
Parallel env data generation: 5.025837421417236
Parallel env data generation: 4.2851402759552
Parallel env data generation: 4.287029504776001
Parallel env data generation: 4.35964822769165


In [107]:
episode_data = {
    'states': [data[i]['states'] for i in range(len(data))],
    'actions': [data[i]['actions'] for i in range(len(data))],
    'rewards': [data[i]['rewards'] for i in range(len(data))],
    'next_states': [data[i]['next_states'] for i in range(len(data))],
    'dones': [data[i]['dones'] for i in range(len(data))]}

for key in episode_data.keys():
    episode_data[key] = np.concatenate(episode_data[key])
    
data = Episodes(episode_data)
trainLoader = DataLoader(data, batch_size=32, shuffle=True, num_workers=2)

In [108]:
observation_space = 21
action_space = 8

actor = Actor(
    observation_space,
    2,
    128,
    action_space,
    config['env_config']['actionRange']).cuda()

critic = Critic(
    observation_space,
    4,
    128,
    action_space).cuda()

targetActor = copy.deepcopy(actor).cuda()
targetCritic = copy.deepcopy(critic).cuda()


actorOpt = Adam(actor.parameters(), lr=1e-4)
criticOpt = Adam(critic.parameters(), lr=1e-3, weight_decay=1e-2)

In [109]:
def soft_update(target, source):
    for target_param, param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_(target_param.data * (1.0 - 0.001) + param.data * 0.001)

In [110]:
for batch in tqdm.tqdm(range(10)):
    for s, a, r, s_, done in trainLoader:
        ## Get data from memory
        s, a, r, s_, done = s.cuda(), a.cuda(), r.cuda(), s_.cuda(), done.cuda()

        ## Compute target
        maxQ = targetCritic(s_.float(), targetActor(s_.float()).detach().float()).float()
        target_q = r.float() + (1.0 - done.float()) * 0.90 * maxQ

        ## Update the critic network
        criticOpt.zero_grad()
        current_q = critic(s.float(), a.float()).float()
        criticLoss = F.smooth_l1_loss(current_q, target_q.detach().float()).float()
        criticLoss.backward()
        criticOpt.step()

        ## Update the actor network
        actorOpt.zero_grad()
        actorLoss = -critic(s.float(), actor(s.float())).mean()
        actorLoss.backward()
        actorOpt.step()

        ## Copy policy weights over to target net
        soft_update(targetActor, actor)
        soft_update(targetCritic, critic)

100%|██████████| 10/10 [00:18<00:00,  1.84s/it]


In [77]:
ray.shutdown()

In [112]:
env = DistributedTSCSEnv(config['env_config'])

In [117]:
start = time.time()
state = env.reset()
done = False
while not done:
    action = env.action_space.sample()
    state, reward, done, info = env.step(action)
print(time.time()-start)

3.639162540435791
