In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import gym
import numpy as np

In [7]:
class ActorCritic(nn.Module):
    def __init__(self, num_inputs, num_actions):
        super(ActorCritic, self).__init__()
        self.actor = nn.Sequential(
            nn.Linear(num_inputs, 64),
            nn.ReLU(),
            nn.Linear(64, num_actions)
        )
        self.critic = nn.Sequential(
            nn.Linear(num_inputs, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, x):
        policy = F.softmax(self.actor(x), dim=1)
        value = self.critic(x)
        return policy, value

In [8]:
# Función para calcular la ventaja
def compute_advantage(rewards, values, gamma=0.99):
    returns = []
    advs = []
    R = 0
    for r, v in zip(reversed(rewards), reversed(values)):
        R = r + gamma * R
        adv = R - v.item()
        returns.insert(0, R)
        advs.insert(0, adv)
    return torch.FloatTensor(returns), torch.FloatTensor(advs)

In [9]:
num_inputs = env.observation_space.shape[0]
num_actions = env.action_space.n
ac_model = ActorCritic(num_inputs, num_actions)
optimizer = optim.Adam(ac_model.parameters(), lr=3e-2)
env = gym.make('CartPole-v1')

# Entrenamiento
num_episodes = 1000
for episode in range(num_episodes):
    state = env.reset()
    log_probs = []
    values = []
    rewards = []

    while True:
        state_tensor = torch.FloatTensor(state).unsqueeze(0)
        policy, value = ac_model(state_tensor)
        action = torch.multinomial(policy, 1).item()
        next_state, reward, done, _ = env.step(action)

        log_prob = F.log_softmax(policy, dim=1)
        log_probs.append(log_prob[0, action])
        values.append(value)
        rewards.append(reward)

        state = next_state

        if done:
            returns, advantages = compute_advantage(rewards, values)
            log_probs = torch.stack(log_probs)
            advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

            actor_loss = -(log_probs * advantages.detach()).sum()
            critic_loss = F.smooth_l1_loss(torch.cat(values), returns.detach())

            total_loss = actor_loss + 0.5 * critic_loss

            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

            print(f'Episode: {episode}, Total Reward: {sum(rewards)}')
            break

env.close()

ValueError: expected sequence of length 4 at dim 1 (got 0)