In [2]:
import gym
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim

In [3]:
# Define the Actor-Critic network
class ActorCritic(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(ActorCritic, self).__init__()
        self.fc1 = nn.Linear(input_dim, 128)
        self.fc2 = nn.Linear(128, 128)
        self.policy_head = nn.Linear(128, output_dim)
        self.value_head = nn.Linear(128, 1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        policy = torch.softmax(self.policy_head(x), dim=-1)
        value = self.value_head(x)
        return policy, value

# Function to select action based on policy
def select_action(state,model,output_dim):
    state = torch.FloatTensor(state).unsqueeze(0)  # Add batch dimension
    policy, _ = model(state)
    action = np.random.choice(output_dim, p=policy.detach().numpy()[0])  # Use the first (and only) batch
    return action

In [4]:
# Hyperparameters
learning_rate = 0.001
num_episodes = 1000
gamma = 0.99  # Discount factor

# Initialize environment and model
env = gym.make('CartPole-v1')
input_dim = env.observation_space.shape[0]
output_dim = env.action_space.n
model = ActorCritic(input_dim, output_dim)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [17]:
# Training loop
for episode in range(num_episodes):
    state = env.reset()[0]
    # Debugging: Print initial state
    print(f"Initial state: {state}")

    done = False
    episode_reward = 0

    while not done:
        action = select_action(state,model,output_dim)
        next_state, reward, done, _, _ = env.step(action)

        # Debugging: Print state transition
        print(f"State: {state}, Action: {action}, Reward: {reward}, Next State: {next_state}")

        # Convert to tensors
        state_tensor = torch.FloatTensor(state).unsqueeze(0)
        next_state_tensor = torch.FloatTensor(next_state).unsqueeze(0)
        reward_tensor = torch.FloatTensor([reward])

        # Get value estimates
        _, value = model(state_tensor)
        _, next_value = model(next_state_tensor)

        # Calculate advantage
        advantage = reward_tensor + (1 - done) * gamma * next_value - value

        # Update policy (actor)
        policy, _ = model(state_tensor)
        policy_loss = -torch.log(policy[0, action]) * advantage.detach()  # Use first batch item
        
        # Update value function (critic)
        value_loss = advantage.pow(2)

        # Total loss
        optimizer.zero_grad()
        (policy_loss + value_loss).backward()
        optimizer.step()

        state = next_state
        episode_reward += reward

    if episode % 100 == 0:
        print(f'Episode {episode}, Reward: {episode_reward}')

env.close()

Initial state: [0.03083471 0.02280321 0.00026013 0.0204192 ]
State: [0.03083471 0.02280321 0.00026013 0.0204192 ], Action: 0, Reward: 1.0, Next State: [ 0.03129078 -0.17232247  0.00066852  0.3131842 ]
State: [ 0.03129078 -0.17232247  0.00066852  0.3131842 ], Action: 0, Reward: 1.0, Next State: [ 0.02784433 -0.36745393  0.0069322   0.60607785]
State: [ 0.02784433 -0.36745393  0.0069322   0.60607785], Action: 1, Reward: 1.0, Next State: [ 0.02049525 -0.1724296   0.01905376  0.31558645]
State: [ 0.02049525 -0.1724296   0.01905376  0.31558645], Action: 1, Reward: 1.0, Next State: [0.01704666 0.02241583 0.02536548 0.02897281]
State: [0.01704666 0.02241583 0.02536548 0.02897281], Action: 1, Reward: 1.0, Next State: [ 0.01749497  0.21716502  0.02594494 -0.25560033]
State: [ 0.01749497  0.21716502  0.02594494 -0.25560033], Action: 0, Reward: 1.0, Next State: [0.02183827 0.02168242 0.02083294 0.04515182]
State: [0.02183827 0.02168242 0.02083294 0.04515182], Action: 1, Reward: 1.0, Next State: [

KeyboardInterrupt: 

In [10]:
state = env.reset()[0]
# Debugging: Print initial state
print(f"Initial state: {state}")

done = False
episode_reward = 0

action = select_action(state,model,output_dim)

Initial state: [-0.01542485 -0.0308018   0.00522729  0.04654868]


In [14]:
state

array([-0.01542485, -0.0308018 ,  0.00522729,  0.04654868], dtype=float32)

In [15]:
action

0

In [16]:
env.step(action)

(array([-0.01604089, -0.22599831,  0.00615826,  0.34087628], dtype=float32),
 1.0,
 False,
 False,
 {})