In [1]:
import gym
import numpy as np
import torch

In [9]:
env = gym.make('CartPole-v0')
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n

[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m


In [8]:
class SimpleNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out, num_hidden):
        super(SimpleNet, self).__init__()
        self.lin_in = torch.nn.Linear(D_in, H)
        self.lin_hd = torch.nn.Linear(H, H)
        self.lin_out = torch.nn.Linear(H, D_out)
        self.num_hidden = num_hidden
    
    def forward(self, x):
        x = self.lin_in(x).clamp(min=0)
        for _ in range(self.num_hidden):
            x = self.lin_hd(x).clamp(min=0)
        y = self.lin_out(x)
        return y

In [107]:
class rollout_replay():
    def __init__(self):
        self.states = []
        self.actions = []
        self.rewards = []
        self.next_states = []
        self.dones = []
        self.dtypes = [torch.float32, torch.int64, torch.float32, torch.float32, torch.float32]

    def add(self, state, action, reward, next_state, done):
        self.states.append(state)
        self.actions.append(action)
        self.rewards.append(reward) 
        self.next_states.append(next_state) 
        self.dones.append(done)    

    def get_batch(self, batch_size):
        indexes = np.arange(len(self.states))
        np.random.shuffle(indexes)
        batch_indexes = indexes[:batch_size]
        
        # batch and convert to tensors
        data = []
        for buf, dtype in zip([self.states,  self.actions, self.rewards, self.next_states, self.dones], self.dtypes):
            batch = torch.tensor(buf[batch_indexes], dtype=dtype)
            data.append(batch)
        return tuple(data)

    def tensors(self):
        # convert to tensors
        data = []
        for buf, dtype in zip([self.states,  self.actions, self.rewards, self.next_states, self.dones], self.dtypes):
            data.append(torch.tensor(buf, dtype=dtype))
        return tuple(data)

In [108]:
def rollout(N):
    replay = rollout_replay()
    
    for _ in range(N):
        state = env.reset()
        done = False
        
        while not done:
            # Stochastic Sampling
            act = env.action_space.sample()
            next_state, rew, done, _ = env.step(act)
            replay.add(state, act, rew, next_state, done)
            state = next_state
    
    return replay

In [125]:
def fit_VFunc(replay_buff):
    gamma = 0.99
    loss_fcn = torch.nn.MSELoss()
    optim = torch.optim.Adam(VFunc.parameters(), lr=1e-2)
    
    states, _, rewards, next_states, dones = replay_buff.tensors()

    state_val = rewards + (1 - dones) * gamma * torch.squeeze(VFunc(next_states))
    state_val = state_val.detach()
    
    state_val_pred = torch.squeeze(VFunc(states))
    
    loss = loss_fcn(state_val_pred, state_val)

    optim.zero_grad()
    loss.backward()
    optim.step()
    
    return replay_buff

In [110]:
def get_adv(rewards, states, next_states, dones):
    v_s = torch.squeeze(VFunc(states))
    v_ns = torch.squeeze(VFunc(next_states))
    return rewards + (1 - dones) * v_ns - v_s

In [120]:
def get_logprob(states, actions):
    act_distrib = QFunc(states)
    act_prob = torch.nn.LogSoftmax(dim=1)(act_distrib)
    return act_prob[np.arange(len(act_prob)), actions]

In [148]:
def train(num_rollouts):
    replay_buff = rollout(num_rollouts)
    buff = fit_VFunc(replay_buff)

    states, actions, rewards, next_states, dones = buff.tensors()
    adv = get_adv(rewards, states, next_states, dones)
    adv = adv.detach()
    logprob = get_logprob(states, actions)

    optim = torch.optim.Adam(QFunc.parameters(), lr=1e-2)
    loss = -1 * (logprob * adv).mean()
    optim.zero_grad()
    loss.backward()
    optim.step()

In [181]:
def view_greedy(render=False):
    done = False
    state = env.reset()
    total_reward = 0
    while not done:
        if render: 
            plt.imshow(env.render(mode='rgb_array'))
            display.display(plt.gcf())    
            display.clear_output(wait=True)
            
        act = torch.argmax(QFunc(torch.tensor(state, dtype=torch.float32)), dim=0)
        state, reward, done, _ = env.step(act.item())
        total_reward += reward
        
    print("TotalReward:", total_reward)

In [150]:
H, N = 64, 4 
# Two Network Design
VFunc = SimpleNet(D_in=state_dim, D_out=1, H=H, num_hidden=N)
QFunc = SimpleNet(D_in=state_dim, D_out=action_dim, H=H, num_hidden=N)

In [153]:
epochs = 1000
N = 500
for e in range(epochs):
    train(N)

In [183]:
import matplotlib.pyplot as plt
from IPython import display
%matplotlib inline

In [184]:
view_greedy(render=True)

TotalReward: 56.0
