In [1]:
import rlig
from   scipy.stats import multivariate_normal
import numpy as np
import torch

from   rlig.base import Base
from   rlig.agent import TD3
from   rlig.buffer import ReplayBuffer
from   rlig.pytorch.bvae import BetaVAE
import gym

env = gym.make('LunarLanderContinuous-v2')

In [51]:
class ImaginedGoalsAgent(Base):

    def __init__(self, Q, bvae, buffer, gamma = 0.99):
        super().__init__(locals())

    def store(self, *args):
        self.buffer.store(*args)

    def sample(self, n = 1):
        return self.buffer.sample(n)

    def get_action(self, state, latent_goal):
        return self.Q.select_action(torch.concat([torch.Tensor(state), torch.Tensor(latent_goal)]))

    def recall_states(self):
        return torch.Tensor(list(zip(*self.buffer.buffer))[0])

    def fit(self, enc_state, action, enc_next_state, latent_goal, reward, done):
        state      = torch.concat([enc_state     , latent_goal])
        next_state = torch.concat([enc_next_state, latent_goal])
        self.Q.train(state, action, reward, next_state, done)

    def prior(self, states):
        _, z_mu, _ = self.bvae(states)
        z_np       = z_mu.detach().numpy()
        mu, cov    = np.mean(z_np, axis = 0), np.cov(z_np, rowvar = False)
        return multivariate_normal(mu, cov, allow_singular = True)

In [52]:
def rig(env, agent, n_episodes, n_exploration, epochs = 10, K = 10):
    D   = exploration_policy(env, n_exploration)
    agent.bvae.fit(D, epochs = epochs)
    p_z = agent.prior(D)
    for i in range(n_episodes):
        z_g   = p_z.rvs()
        T     = 0
        done  = False
        state = env.reset()
        while not done:
            action = agent.get_action(state, z_g)
            next_state, reward, done, info = env.step(action)
            agent.store(state, action, reward, next_state, z_g, done)
            T += 1
            (state, action, reward, next_state_, z_g, done) = agent.sample()
            z  = agent.bvae.mean_encode(state)
            z_ = agent.bvae.mean_encode(next_state_)
            if np.random.binomial(1, 0.5):
                z_g = p_z.rvs()
            reward += torch.cdist(z, z_)
            agent.fit(z, action, z_, z_g, reward, done)
        for t in range(T):
            state, action, reward, next_state, _ = agent.buffer[-T + t]
            for _ in range(K):
                i = np.random.randint(-T + t, -1)
                future = agent.buffer[i][2]
                z_g    = agent.bvae.mean_encode(future)
                agent.store(state, action, reward, next_state, z_g, done)
        
        if i and i % K:
            R = agent.recall_states()
            M = torch.concat(D, R)
            agent.bvae.fit(M)
            p_z = agent.prior(M)

In [53]:
def exploration_policy(env, N):
    states = [env.reset()]
    n      = 1
    while n < N:
        state, reward, done, info = env.step(env.action_space.sample())
        states.append(state)
        n += 1 + done
        if done:
            states.append(env.reset())

    return torch.Tensor(states)
    


In [54]:

env_dim    = env.observation_space.shape[0]
latent_dim = 32

in_dim     = env_dim + latent_dim
Q = TD3(in_dim, [64, 64], env.action_space.shape[0], 1)
Q


TD3(state_dim=40, neurons=[64, 64], action_dim=2, max_action=1, discount=0.99, tau=0.005, policy_noise=0.2, noise_clip=0.5, policy_freq=2)

In [55]:

bvae = BetaVAE([env_dim, 64, 128], [64, 64, env_dim])

In [56]:
bvae.encoder

Sequential(
  (0): Linear(in_features=8, out_features=64, bias=True)
  (1): ReLU(inplace=True)
  (2): Linear(in_features=64, out_features=128, bias=True)
)

In [57]:
agent = ImaginedGoalsAgent(Q, bvae, ReplayBuffer())

In [58]:
D   = exploration_policy(env, 100)
agent.bvae.fit(D, epochs = 2)
p_z = agent.prior(D)
for i in range(5):
    z_g   = p_z.rvs()
    T     = 0
    done  = False
    state = env.reset()
    while not done:
        action = agent.get_action(state, z_g)
        next_state, reward, done, info = env.step(action)
        agent.store(state, action, reward, next_state, z_g, done)
        T += 1
        (state, action, reward, next_state_, z_g, done) = agent.sample()
        z  = agent.bvae.mean_encode(state)
        z_ = agent.bvae.mean_encode(next_state_)
        if np.random.binomial(1, 0.5):
            z_g = p_z.rvs()
        reward += torch.cdist(z, z_)
        agent.fit(z, action, z_, z_g, reward, done)
    for t in range(T):
        state, action, reward, next_state, _ = agent.buffer[-T + t]
        for _ in range(K):
            i = np.random.randint(-T + t, -1)
            future = agent.buffer[i][2]
            z_g    = agent.bvae.mean_encode(future)
            agent.store(state, action, reward, next_state, z_g, done)
    
    if i and i % 1:
        R = agent.recall_states()
        M = torch.concat(D, R)
        agent.bvae.fit(M)
        p_z = agent.prior(M)



RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x72 and 40x64)

In [60]:
Q

TD3(state_dim=40, neurons=[64, 64], action_dim=2, max_action=1, discount=0.99, tau=0.005, policy_noise=0.2, noise_clip=0.5, policy_freq=2)

In [11]:
dist = bvae._encode(D)
mu   = dist[:,:bvae.z_dim]
ls2  = dist[:,bvae.z_dim:]

mu.shape

torch.Size([100, 128])

In [18]:
D.shape

torch.Size([100, 8])

In [13]:
rig(env, agent, 100, 1000)

  return torch.Tensor(states)


RuntimeError: mat1 and mat2 shapes cannot be multiplied (1000x8 and 40x64)