In [1]:
import collections
import random

import gym
import numpy as np
import torch
from torch import optim
from torch.nn import functional as F

from models.bvae import BetaVAE
from models.ddpg import OUNoise
from models.sac import GaussianPolicy, QNetwork
from utils.memory import ReplayBuffer
from utils.sync import soft_sync

device = torch.device('cuda')

In [2]:
env = gym.make('FetchReach-v1')

## Vae training

In [3]:
pre_training = list()
pre_training.append(env.reset()['observation'])
for _ in range(100):
    status = env.reset()
    state = status['observation']
    done = False
    while not done:
        action = env.action_space.sample()
        next_state, reward, done, _ = env.step(action)
        next_state = next_state['observation']
        pre_training.append(next_state)
        state = next_state

In [4]:
vae_model = BetaVAE(env.observation_space['observation'].shape[0],
                    env.observation_space['observation'].shape[0]//2)
vae_model = vae_model.to(torch.device('cuda'))
vae_optim = optim.Adam(vae_model.parameters(), lr=1e-3)
vae_model.train()

VAE(
  (fc1): Linear(in_features=10, out_features=400, bias=True)
  (fc21): Linear(in_features=400, out_features=5, bias=True)
  (fc22): Linear(in_features=400, out_features=5, bias=True)
  (fc3): Linear(in_features=5, out_features=400, bias=True)
  (fc4): Linear(in_features=400, out_features=10, bias=True)
)

In [5]:
def train_vae(train_data):
    train_loss = 0
    for batch_idx, data in enumerate(train_data):
        data = torch.FloatTensor(data).to(torch.device('cuda'))
        vae_optim.zero_grad()
        results = vae_model(data)
        loss = vae_model.loss_function(*results, M_N=1/len(train_data))
        loss['loss'].backward()
        train_loss += loss['loss'].item()
        vae_optim.step()
    print('Train Loss:', train_loss/len(train_data))

Train Loss: 0.23852225937522376
Train Loss: -0.07086417240162309
Train Loss: -0.31805492268923
Train Loss: -0.6278284910397652
Train Loss: -0.738910334232526
Train Loss: -0.777806621331435
Train Loss: -0.792729261593941
Train Loss: -0.8013213536678216
Train Loss: -0.8050079758350666
Train Loss: -0.8104759515860142
Train Loss: -0.8127880906447386
Train Loss: -0.8138412191317632
Train Loss: -0.8146612659478799
Train Loss: -0.8151785777165339
Train Loss: -0.8152213142468379
Train Loss: -0.8160267243018517
Train Loss: -0.8166583883456695
Train Loss: -0.8171907724478306
Train Loss: -0.8164930664576017
Train Loss: -0.8174881369639666
Train Loss: -0.8172295506183918
Train Loss: -0.8176563534981165
Train Loss: -0.8178965556315887
Train Loss: -0.81769859790802
Train Loss: -0.8179870721621391
Train Loss: -0.8180087040632199
Train Loss: -0.8175732906048114
Train Loss: -0.8175020447144141
Train Loss: -0.8178203808955657
Train Loss: -0.818039440191709
Train Loss: -0.8177252793923403
Train Loss: -0.

In [None]:
for epoch in range(50):
    batches = np.array_split(np.array(pre_training), len(pre_training)//128)
    train_vae(batches)

## Training policy

In [6]:
lr = 1e-3

target_entropy = - \
    torch.prod(torch.Tensor(env.action_space.shape).to(device)).item()
log_alpha = torch.zeros(1, requires_grad=True, device=device)
alpha_optim = optim.Adam([log_alpha], lr=lr)

policy = GaussianPolicy((env.observation_space['observation'].shape[0]//2)*2,
                        env.action_space.shape[0]).to(device)

crt = QNetwork((env.observation_space['observation'].shape[0]//2)
               * 2, env.action_space.shape[0]).to(device)
tgt_crt = QNetwork((env.observation_space['observation'].shape[0]//2)*2,
                   env.action_space.shape[0]).to(device)

tgt_crt.load_state_dict(crt.state_dict())

policy_optim = optim.Adam(policy.parameters(), lr=lr)
crt_optim = optim.Adam(crt.parameters(), lr=lr)

In [7]:
noise = OUNoise(env.action_space)
memory = ReplayBuffer(1000000)

In [8]:
def dist(x, y):
    x = x.cpu().numpy()
    y = y.cpu().numpy()
    res = np.linalg.norm(x-y, axis=1)
    return torch.tensor(res).unsqueeze(1).to(device)

In [9]:
def train_policy(act_net, crt_net, tgt_crt_net,
                 optimizer_act, optimizer_crt,
                 memory, vae_model, batch_size=128,
                 automatic_entropy_tuning=True):
    global alpha, log_alpha, alpha_optim
    gamma = 0.99
    state_batch, action_batch, reward_batch,\
        next_state_batch, mask_batch, goal_batch = memory.sample(batch_size)

    state_batch = torch.FloatTensor(state_batch).to(device)
    goal_batch = torch.FloatTensor(goal_batch).to(device)
    next_state_batch = torch.FloatTensor(
        next_state_batch).to(device)
    action_batch = torch.FloatTensor(action_batch).to(device)

    reward_batch = torch.FloatTensor(
        reward_batch).to(device).unsqueeze(1)
    reward_batch = - dist(next_state_batch, goal_batch)

    mask_batch = torch.BoolTensor(
        mask_batch).to(device).unsqueeze(1)

    # state_batch, logvar = vae_model.encode(state_batch)
    # state_batch, logvar = state_batch.detach(), logvar.detach()
    # next_state_batch = vae_model.encode(next_state_batch)[0].detach()

    # if np.random.rand() > 0.5:
    #     goal_batch = vae_model.reparameterize(state_batch, logvar)

    state_batch = torch.cat((state_batch, goal_batch), 1)
    next_state_batch = torch.cat((next_state_batch, goal_batch), 1)
    for i, (state, action, next_state, done) in enumerate(episode):
        for t in np.random.choice(len(episode), 20):
            s_hi = episode[t][-2]
            s_hi, _ = vae_model.encode(
                torch.FloatTensor(next_state).to(device))
            s_hi = s_hi.detach().cpu().numpy()
            mu, _ = vae_model.encode(torch.FloatTensor(state).to(device))
            mu = mu.detach().cpu().numpy()
            memory.push(mu, action, 0, s_hi, done, s_hi)

    with torch.no_grad():
        next_state_action, next_state_log_pi, _ = act_net.sample(
            next_state_batch)
        qf1_next_target, qf2_next_target = tgt_crt_net(
            next_state_batch, next_state_action)
        min_qf_next_target = torch.min(
            qf1_next_target,
            qf2_next_target) - alpha * next_state_log_pi
        min_qf_next_target[mask_batch] = 0.0
        next_q_value = reward_batch + gamma * (min_qf_next_target)
    # Two Q-functions to mitigate
    # positive bias in the policy improvement step
    qf1, qf2 = crt_net(state_batch, action_batch)
    # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
    qf1_loss = F.mse_loss(qf1, next_q_value.detach())
    # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
    qf2_loss = F.mse_loss(qf2, next_q_value.detach())

    pi, log_pi, _ = act_net.sample(state_batch)

    qf1_pi, qf2_pi = crt_net(state_batch, pi)
    min_qf_pi = torch.min(qf1_pi, qf2_pi)

    # Jπ = 𝔼st∼D,εt∼N[α * logπ(f(εt;st)|st) − Q(st,f(εt;st))]
    policy_loss = ((alpha * log_pi) - min_qf_pi).mean()

    optimizer_crt.zero_grad()
    qf1_loss.backward()
    optimizer_crt.step()

    optimizer_crt.zero_grad()
    qf2_loss.backward()
    optimizer_crt.step()

    optimizer_act.zero_grad()
    policy_loss.backward()
    optimizer_act.step()

    if automatic_entropy_tuning:
        alpha_loss = -(log_alpha * (log_pi +
                                    target_entropy
                                    ).detach()).mean()

        alpha_optim.zero_grad()
        alpha_loss.backward()
        alpha_optim.step()

        alpha = log_alpha.exp()
    else:
        alpha_loss = torch.tensor(0.).to(device)

    return policy_loss.item(), qf1_loss.item(), \
        qf2_loss.item(), alpha_loss.item()


def select_action(policy, state, evaluate=False):
    state = state.unsqueeze(0)
    if evaluate is False:
        action, _, _ = policy.sample(state)
    else:
        _, _, action = policy.sample(state)
    return action.detach().cpu().numpy()[0]

In [None]:
data_vae = collections.deque(maxlen=500000)
for data in pre_training:
    data_vae.append(data)

In [10]:
update_target = 0
for epi in range(1000):
    steps = 0
    state = env.reset()['observation']
    mu, logvar = vae_model.encode(torch.FloatTensor(state).to(device))
    mu, logvar = mu.detach(), logvar.detach()
    zg = vae_model.reparameterize(mu, logvar)
    done = False
    episode = list()
    epi_reward = 0
    while not done:
        to_fwd = torch.cat((mu, zg))
        action = select_action(policy, to_fwd)
        action = noise.get_action(action, steps)
        next_state, reward, done, _ = env.step(action)
        next_state = next_state['observation']
        if epi % 20 == 0:
            env.render()
        next_mu, _ = vae_model.encode(torch.FloatTensor(next_state).to(device))
        next_mu = next_mu.detach()
        memory.push(mu.cpu().numpy(), action, reward,
                    next_mu.cpu().numpy(), done, zg.detach().cpu().numpy())
        episode.append((state, action, next_state, done))

        state = next_state
        data_vae.append(state)
        mu = next_mu
        update_target += 1
        steps += 1
        epi_reward += reward
    print('Episode', epi, '-> Reward:', epi_reward)
    if len(memory) > 128:
        for epoch in range(10):
            train_policy(policy, crt, tgt_crt, policy_optim,
                         crt_optim, memory, vae_model)
            soft_sync(crt, tgt_crt)

    for i, (state, action, next_state, done) in enumerate(episode):
        for t in np.random.choice(len(episode), 5):
            s_hi = episode[t][-2]
            s_hi, _ = vae_model.encode(
                torch.FloatTensor(next_state).to(device))
            s_hi = s_hi.detach().cpu().numpy()
            mu, _ = vae_model.encode(torch.FloatTensor(state).to(device))
            mu = mu.detach().cpu().numpy()
            memory.push(mu, action, 0, s_hi, done, s_hi)

    if epi % 10 == 0 and epi > 0:
        batches = [random.sample(data_vae, 128) for _ in range(10)]
        for _ in range(10):
            train_vae(batches)
            

Episode 0 -> Reward: -50.0
Episode 1 -> Reward: -50.0
Episode 2 -> Reward: -50.0
Episode 3 -> Reward: -50.0
Episode 4 -> Reward: -49.0
Episode 5 -> Reward: -50.0
Episode 6 -> Reward: -50.0
Episode 7 -> Reward: -50.0
Episode 8 -> Reward: -50.0
Episode 9 -> Reward: -50.0
Episode 10 -> Reward: -50.0
Train Loss: -0.8048813879489899
Train Loss: -0.8050292730331421
Train Loss: -0.8052889287471772
Train Loss: -0.8050947844982147
Train Loss: -0.8056789755821228
Train Loss: -0.8053813457489014
Train Loss: -0.8051162898540497
Train Loss: -0.8053756833076477
Train Loss: -0.8051999628543853
Train Loss: -0.8054827213287353
Train Loss: -0.8053594946861267
Train Loss: -0.8054150402545929
Train Loss: -0.805502200126648
Train Loss: -0.8052301943302155
Train Loss: -0.8051408231258392
Train Loss: -0.805518639087677
Train Loss: -0.8054404020309448
Train Loss: -0.8055746197700501
Train Loss: -0.8055885672569275
Train Loss: -0.805409038066864
Train Loss: -0.8054190218448639
Train Loss: -0.8050252199172974
T