In [1]:
import collections
import gym
import numpy as np
import torch
import random


from torch import optim
from torch.nn import functional as F

from models.ddpg import Actor, Critic, OUNoise, train
from models.vae import VAE, loss_function
from utils.memory import ReplayBuffer
from utils.sync import soft_sync

device = torch.device('cuda')
torch.autograd.set_detect_anomaly(True)

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x7f90eeb58610>

In [2]:
env = gym.make('FetchReach-v1')

## Vae training

In [3]:
pre_training = list()
pre_training.append(env.reset()['observation'])
for _ in range(1000):
    status = env.reset()
    state = status['observation']
    done = False
    while not done:
        action = env.action_space.sample()
        next_state, reward, done, _ = env.step(action)
        next_state = next_state['observation']
        pre_training.append(next_state)
        state = next_state

In [4]:
vae_model = VAE(env.observation_space['observation'].shape[0], env.observation_space['observation'].shape[0]//3)
vae_model = vae_model.to(torch.device('cuda'))
vae_optim = optim.Adam(vae_model.parameters(), lr=1e-3)
vae_model.train()

VAE(
  (fc1): Linear(in_features=10, out_features=400, bias=True)
  (fc21): Linear(in_features=400, out_features=3, bias=True)
  (fc22): Linear(in_features=400, out_features=3, bias=True)
  (fc3): Linear(in_features=3, out_features=400, bias=True)
  (fc4): Linear(in_features=400, out_features=10, bias=True)
)

In [5]:
def train_vae(train_data):
    train_loss = 0
    for batch_idx, data in enumerate(train_data):
        data = torch.FloatTensor(data).to(torch.device('cuda'))
        vae_optim.zero_grad()
        recon_batch, mu, logvar = vae_model(data)
        loss = loss_function(recon_batch, data, mu, logvar, env.observation_space['observation'].shape[0])
        loss.backward()
        train_loss += loss.item()
        vae_optim.step()

for epoch in range(10):
    train_vae(np.array_split(np.array(pre_training), len(pre_training)/128))

## Training policy

In [6]:
policy = Actor((env.observation_space['observation'].shape[0]//3)*2, env.action_space.shape[0]).to(device)
tgt_policy = Actor((env.observation_space['observation'].shape[0]//3)*2, env.action_space.shape[0]).to(device)

crt = Critic((env.observation_space['observation'].shape[0]//3)*2, env.action_space.shape[0]).to(device)
tgt_crt = Critic((env.observation_space['observation'].shape[0]//3)*2, env.action_space.shape[0]).to(device)

tgt_policy.load_state_dict(policy.state_dict())
tgt_crt.load_state_dict(crt.state_dict())

policy_optim = optim.Adam(policy.parameters(), lr=1e-3)
crt_optim = optim.Adam(crt.parameters(), lr=1e-3)

In [7]:
noise = OUNoise(env.action_space)
memory = ReplayBuffer(1000000)

In [8]:
def dist(x, y):
    x = x.cpu().numpy()
    y = y.cpu().numpy()
    res = np.linalg.norm(x-y, axis=1)
    return torch.tensor(res).unsqueeze(1).to(device)

def train_policy(critic, critic_target, actor, actor_target,
          critic_optim, actor_optim, memory, vae_model, batch_size=128):
    gamma = 0.99
    state_batch, action_batch,\
        reward_batch, next_state_batch, done_batch, goal_batch = memory.sample(batch_size)
    state_batch, logvar = vae_model.encode(state_batch)
    state_batch, logvar = state_batch.detach(), logvar.detach()

    next_state_batch, _ = vae_model.encode(next_state_batch)
    next_state_batch = next_state_batch.detach()

    if np.random.rand() > 0.5:
        goal_batch = vae_model.reparameterize(state_batch, logvar)
    
    reward_batch = -dist(next_state_batch, goal_batch)

    state_batch = torch.cat((state_batch, goal_batch), 1)
    next_state_batch = torch.cat((next_state_batch, goal_batch), 1)

    actor_loss = critic(state_batch, actor(state_batch))
    actor_loss = -actor_loss.mean()

    actor_optim.zero_grad()
    actor_loss.backward()
    actor_optim.step()

    next_actions_target = actor_target(next_state_batch)
    q_targets = critic_target(next_state_batch, next_actions_target)
    targets = reward_batch + (1.0 - done_batch)*gamma*q_targets
    q_values = critic(state_batch, action_batch.squeeze())
    critic_loss = F.smooth_l1_loss(q_values, targets.detach())
    critic_optim.zero_grad()
    critic_loss.backward()
    critic_optim.step()


    return actor_loss.item(), critic_loss.item()

In [9]:
data_vae = collections.deque(maxlen=500000)
for data in pre_training:
    data_vae.append(data)

for epi in range(1000):
    update_target = 0
    steps = 0
    state = env.reset()['observation']
    mu, logvar = vae_model.encode(torch.FloatTensor(state).to(device))
    mu, logvar = mu.detach(), logvar.detach()
    zg = vae_model.reparameterize(mu, logvar)
    done = False
    episode = list()
    epi_reward = 0
    while not done:
        to_fwd = torch.cat((mu, zg))
        action = policy.get_action(to_fwd)
        action = noise.get_action(action, steps)[0]
        next_state, reward, done, _ = env.step(action)
        if epi%20 == 0 and epi > 0:
            env.render()
        next_state = next_state['observation']
        done = 1 if done else 0
        memory.put((state, action, reward, next_state, done, zg.detach().cpu().numpy()))
        episode.append((state, action, next_state, done))
        
        if memory.size() > 128:
            act_loss, crt_loss = train_policy(crt, tgt_crt, policy, tgt_policy, crt_optim, policy_optim, memory, vae_model)
            if update_target%2==0 and update_target > 0:
                soft_sync(policy, tgt_policy)
                soft_sync(crt, tgt_crt)
        
        state = next_state
        data_vae.append(state)
        mu, _ = vae_model.encode(torch.FloatTensor(next_state).to(device))
        mu = mu.detach()
        update_target += 1
        steps += 1
        epi_reward += reward
    print('Episode', epi, '-> Reward:', epi_reward)

    for i, (state, action, next_state, done) in enumerate(episode):
        for t in np.random.choice(len(episode), 5):
            s_hi = episode[t][-2]
            s_hi = vae_model.encode(torch.FloatTensor(next_state).to(device))[0].detach().cpu().numpy()
            memory.put((state, action, 0, next_state, done, s_hi))
   
    
    if epi%5 == 0 and epi > 0:
        batches = [random.sample(data_vae, 128) for _ in range(10)]
        train_vae(batches)
            

Episode 0 -> Reward: -50.0
Episode 1 -> Reward: -50.0
Episode 2 -> Reward: -50.0
Episode 3 -> Reward: -50.0
Episode 4 -> Reward: -50.0
Episode 5 -> Reward: -50.0
Episode 6 -> Reward: -50.0
Episode 7 -> Reward: -50.0
Episode 8 -> Reward: -50.0
Episode 9 -> Reward: -50.0
Episode 10 -> Reward: -50.0
Episode 11 -> Reward: -50.0
Episode 12 -> Reward: -50.0
Episode 13 -> Reward: -50.0
Episode 14 -> Reward: -50.0
Episode 15 -> Reward: -50.0
Episode 16 -> Reward: -50.0
Episode 17 -> Reward: -50.0
Episode 18 -> Reward: -50.0
Episode 19 -> Reward: -50.0
Creating window glfw
Episode 20 -> Reward: -50.0
Episode 21 -> Reward: -50.0
Episode 22 -> Reward: -50.0
Episode 23 -> Reward: -50.0
Episode 24 -> Reward: -50.0
Episode 25 -> Reward: -50.0
Episode 26 -> Reward: -50.0
Episode 27 -> Reward: -50.0
Episode 28 -> Reward: -50.0
Episode 29 -> Reward: -50.0
Episode 30 -> Reward: -50.0
Episode 31 -> Reward: -50.0
Episode 32 -> Reward: -50.0
Episode 33 -> Reward: -50.0
Episode 34 -> Reward: -50.0
Episode 3

KeyboardInterrupt: 