# Actor-Critic models

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# Actor network
class PolicyNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(PolicyNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, 128)
        self.fc2 = nn.Linear(128, action_dim)

    def forward(self, state):
        x = F.relu(self.fc1(state))
        probs = F.softmax(self.fc2(x), dim=-1)  # Action probabilities
        return probs

# Instantiate the policy network
state_dim = 4   # Example: state has 4 features
action_dim = 2  # Example: 2 possible actions
policy_net = PolicyNetwork(state_dim, action_dim)

In [2]:
# Critic network
class ValueNetwork(nn.Module):
    def __init__(self, state_dim):
        super(ValueNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, 128)
        self.fc2 = nn.Linear(128, 1)

    def forward(self, state):
        x = F.relu(self.fc1(state))
        value = self.fc2(x)  # Scalar value output
        return value

# Instantiate the value network
value_net = ValueNetwork(state_dim)


The advantage is derived from the TD error, which can be computed as:
$$ \delta_t = R_{t+1} + \gamma V(s_{t+1}) - V(s_t)  $$

In [3]:
# Compute TD error and advantage
def compute_td_error(state, reward, next_state, done, gamma=0.99):
    with torch.no_grad():
        td_target = reward + gamma * value_net(next_state) * (1 - done)  # Handle terminal states
        td_error = td_target - value_net(state) # target - critic prediction
    return td_error

def compute_td_error(state, reward, next_state, done, gamma=0.99):
    value = value_net(state).squeeze()           # V(s)
    next_value = value_net(next_state).squeeze() # V(s')
    next_value = next_value * (1 - done)         # Mask terminal states (done is 1 for terminal states)
    
    # TD Error (delta_t)
    td_error = reward + gamma * next_value - value
    return td_error

# Example usage
state = torch.tensor([1.0, 0.5, -0.2, 0.3])      # Example current state
next_state = torch.tensor([1.1, 0.4, -0.1, 0.3])  # Example next state
reward = torch.tensor(1.0)                        # Example reward
done = torch.tensor(0.0)                          # Example "not done"

td_error = compute_td_error(state, reward, next_state, done)
print("TD Error:", td_error.item())


TD Error: 1.0241752862930298


## Computing the loss

The actor uses the advantage $A(s_t,a_t)$ to adjust the policy, while the critic minimizes the TD error.

*Actor loss*:
$$ L_{\textrm{actor}} = -\log \pi_\theta(a_t|s_t)A(s_t, a_t) $$

*Critic loss*:
$$ L_{\textrm{critic}} = \delta_t^2 $$

In [4]:
# Optimizers
policy_optimizer = optim.Adam(policy_net.parameters(), lr=1e-3)
value_optimizer = optim.Adam(value_net.parameters(), lr=1e-3)

# Compute losses
def compute_losses(state, action, reward, next_state, done):
    
    # TD Error (delta_t)
    td_error = compute_td_error(state, reward, next_state, done)

    # Critic loss (mean squared error)
    critic_loss = td_error.pow(2).mean()

    # Compute log-probability of the action
    probs = policy_net(state)
    dist = torch.distributions.Categorical(probs)
    log_prob = dist.log_prob(action)

    # Actor loss (policy gradient with advantage)
    actor_loss = -(log_prob * td_error.detach())  # Detach TD error to avoid backprop through critic
    return actor_loss, critic_loss

# Example usage
action = torch.tensor(1)  # Example action taken (discrete action space)
actor_loss, critic_loss = compute_losses(state, action, reward, next_state, done)

# Optimize policy network
policy_optimizer.zero_grad()
actor_loss.backward()
policy_optimizer.step()

# Optimize value network
value_optimizer.zero_grad()
critic_loss.backward()
value_optimizer.step()


In [14]:
# Training loop
import gymnasium as gym
env = gym.make("CartPole-v1")

for episode in range(500):
    state, _ = env.reset()
    done = False

    print(f"Running episode {episode}...")
    total_reward = 0
    while not done:
        state_tensor = torch.tensor(state, dtype=torch.float32)

        # Sample action from policy
        probs = policy_net(state_tensor)
        dist = torch.distributions.Categorical(probs)
        action = dist.sample()

        # Step environment
        next_state, reward, done, trunc, _ = env.step(action.item())
        total_reward += reward

        # Convert to tensors
        reward_tensor = torch.tensor(reward, dtype=torch.float32)
        next_state_tensor = torch.tensor(next_state, dtype=torch.float32)
        done_tensor = torch.tensor(float(done), dtype=torch.float32)

        # Compute losses
        actor_loss, critic_loss = compute_losses(state_tensor, action, reward_tensor, next_state_tensor, done_tensor)

        # Optimize actor
        policy_optimizer.zero_grad()
        actor_loss.backward()
        policy_optimizer.step()

        # Optimize critic
        value_optimizer.zero_grad()
        critic_loss.backward()
        value_optimizer.step()

        # Update state
        state = next_state
    print(f"\tFinished with reward {total_reward}")


Running episode 0...
	Finished with reward 61.0
Running episode 1...
	Finished with reward 49.0
Running episode 2...
	Finished with reward 61.0
Running episode 3...
	Finished with reward 46.0
Running episode 4...
	Finished with reward 71.0
Running episode 5...
	Finished with reward 32.0
Running episode 6...
	Finished with reward 91.0
Running episode 7...
	Finished with reward 69.0
Running episode 8...
	Finished with reward 68.0
Running episode 9...
	Finished with reward 47.0
Running episode 10...
	Finished with reward 35.0
Running episode 11...
	Finished with reward 69.0
Running episode 12...
	Finished with reward 33.0
Running episode 13...
	Finished with reward 33.0
Running episode 14...
	Finished with reward 34.0
Running episode 15...
	Finished with reward 36.0
Running episode 16...
	Finished with reward 41.0
Running episode 17...
	Finished with reward 48.0
Running episode 18...
	Finished with reward 32.0
Running episode 19...
	Finished with reward 50.0
Running episode 20...
	Finishe

In [15]:
env = gym.make("CartPole-v1", render_mode="human")
state, _ = env.reset()
total_reward = 0
done = False
    
while not done:

    state_tensor = torch.tensor(state, dtype=torch.float32)
    
    env.render()  # Render the environment (optional)

    # Select an action using the policy network
    probs = policy_net(state_tensor)
    dist = torch.distributions.Categorical(probs)
    action = dist.sample().item()

    # Step the environment
    next_state, reward, done, trunc, _ = env.step(action)

    # Update the current state
    state = next_state
    total_reward += reward

print(f"Total Reward = {total_reward}")

env.close()

Total Reward = 286.0
