In [1]:
import numpy as np
import torch
import torch.nn as nn
import gym

In [2]:
env = gym.make('CartPole-v1')

In [10]:
obs = env.reset()[0]

In [11]:
obs.shape

(4,)

In [15]:
actor = nn.Sequential(
    nn.Linear(4, 32), nn.ReLU(), nn.Linear(32, env.action_space.n), nn.Softmax()
)

critic = nn.Sequential(
    nn.Linear(4, 32), nn.ReLU(), nn.Linear(32, 1)
)

In [21]:
actor_optimizer = torch.optim.Adam(actor.parameters(), lr=0.001)
critic_optimizer = torch.optim.Adam(critic.parameters(), lr=0.001)

In [59]:
# Main training loop
num_episodes = 1
gamma = 0.99

for episode in range(num_episodes):
    state = env.reset()[0]
    print("State: ", state.shape)
    episode_reward = 0

    for t in range(1, 5):  # Limit the number of time steps
            # Choose an action using the actor
            action_probs = actor(torch.from_numpy(state).to(torch.float))
            print("Action probs: ", action_probs.shape)
            action = np.random.choice(env.action_space.n, p=action_probs.detach().numpy())
            print("Action: ", action)

            # Take the chosen action and observe the next state and reward
            next_state, reward, done, _, _ = env.step(action)

            # Compute the advantage
            print(critic(torch.from_numpy(state).to(torch.float)).shape)
            state_value = critic(torch.from_numpy(state).to(torch.float))
            print("State value: ", state_value)
            next_state_value = critic(torch.from_numpy(next_state).to(torch.float)) # NOTE: do not do .detach().item() here in order to keep grad_fn
            advantage = reward + gamma * next_state_value - state_value
            print("Advantage: ", advantage)

            # Compute actor and critic losses
            actor_loss = -torch.log(action_probs[action]) * advantage
            print(actor_loss)
            critic_loss = torch.square(advantage)
            print(critic_loss)

            actor_loss.backward(retain_graph=True) # needed: otherwise intermediary results get deleted before critic_loss can backprop
            actor_optimizer.step()

            critic_loss.backward()
            critic_optimizer.step()

            episode_reward += reward

            if done:
                break

    if episode % 10 == 0:
        print(f"Episode {episode}, Reward: {episode_reward}")

env.close()

State:  (4,)
Action probs:  torch.Size([2])
Action:  0
torch.Size([1])
State value:  tensor([24.5412], grad_fn=<ViewBackward0>)
Advantage:  tensor([0.4657], grad_fn=<SubBackward0>)
tensor([-0.], grad_fn=<MulBackward0>)
tensor([0.2169], grad_fn=<PowBackward0>)
Action probs:  torch.Size([2])
Action:  0
torch.Size([1])
State value:  tensor([24.5599], grad_fn=<ViewBackward0>)
Advantage:  tensor([0.1605], grad_fn=<SubBackward0>)
tensor([-0.], grad_fn=<MulBackward0>)
tensor([0.0258], grad_fn=<PowBackward0>)
Action probs:  torch.Size([2])
Action:  0
torch.Size([1])
State value:  tensor([24.5787], grad_fn=<ViewBackward0>)
Advantage:  tensor([0.1783], grad_fn=<SubBackward0>)
tensor([-0.], grad_fn=<MulBackward0>)
tensor([0.0318], grad_fn=<PowBackward0>)
Action probs:  torch.Size([2])
Action:  0
torch.Size([1])
State value:  tensor([24.5975], grad_fn=<ViewBackward0>)
Advantage:  tensor([0.8165], grad_fn=<SubBackward0>)
tensor([-0.], grad_fn=<MulBackward0>)
tensor([0.6668], grad_fn=<PowBackward0>)