In [14]:
import torch
import matplotlib.pyplot as plt
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

"""
a toy example of a actor-critic network.

the actor has to pick the index with the highest value in a list of 10 random numbers.

"""


class Actor(nn.Module):
    def __init__(self, input_dim, hs):
        super(Actor, self).__init__()
        self.fc1 = nn.Linear(input_dim, hs)
        self.fc2 = nn.Linear(hs, input_dim)
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.softmax(self.fc2(x), dim=-1)
        return x

class Critic(nn.Module):
    def __init__(self, input_dim, hs):
        super(Critic, self).__init__()
        self.fc1 = nn.Linear(input_dim, hs)
        self.fc2 = nn.Linear(hs, 1)
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x
    

input_dim = 10 # number of values the agent has to choose from
hs = 5 # hidden layer size for both actor/critic
lr = .01

# Initialize the networks
actor = Actor(input_dim=input_dim, hs=hs)
critic = Critic(input_dim=input_dim, hs=hs)

# Initialize the optimizers
actor_optimizer = optim.Adam(actor.parameters(), lr=lr)
critic_optimizer = optim.Adam(critic.parameters(), lr=lr)

# Initialize the average_rewards_per_step list
average_rewards_per_step = []

# Start the training loop
for episode in range(1):  # Train for 1000 episodes
    average_rewards = 0

    inputs = torch.randn(input_dim)  # Generate a random list of scalars

    # Perform 30 updates for the given policy
    for _ in range(30):
        # Pass the list through the Actor network
        log_probs = actor(inputs.unsqueeze(0))

        action = torch.multinomial(log_probs, 1)

        # Calculate the reward
        reward = inputs[action]

        # Calculate the average reward per step * for monitoring only
        average_rewards += reward.item() / 100

        # Pass the list through the Critic network
        value = critic(inputs.unsqueeze(0))

        # Calculate the advantage
        advantage = value - reward

        # Calculate the losses
        actor_loss = -log_probs[0][action] * advantage.detach()  # Detach the advantage tensor to prevent gradients from flowing into the critic network
        critic_loss = advantage.pow(2)

        # Print the weights before update for debugging ; they are changing after updates
        print("First value in Actor weights:", actor.fc1.weight.data[0][0])
        print("First value in Critic weights:", critic.fc1.weight.data[0][0])

        # Backpropagate the losses and update the networks
        actor_optimizer.zero_grad()
        actor_loss.backward(retain_graph=True)
        actor_optimizer.step()

        critic_optimizer.zero_grad()
        critic_loss.backward()
        critic_optimizer.step()


    average_rewards_per_step.append(average_rewards)
    if len(average_rewards_per_step) > 200:
        average_rewards_per_step.pop(0)

    # * Check for early stopping every 100 episodes, starting from the 101st episode
    if episode >= 100 and episode % 50 == 0:
        print(f"Episode {episode} Average reward per step:{average_rewards}")
        if episode >= 200 and average_rewards <= average_rewards_per_step[0]:
            print("Early stopping triggered")
            break

First value in Actor weights: tensor(-0.1295)
First value in Critic weights: tensor(-0.2304)
First value in Actor weights: tensor(-0.1195)
First value in Critic weights: tensor(-0.2204)
First value in Actor weights: tensor(-0.1211)
First value in Critic weights: tensor(-0.2219)
First value in Actor weights: tensor(-0.1276)
First value in Critic weights: tensor(-0.2183)
First value in Actor weights: tensor(-0.1326)
First value in Critic weights: tensor(-0.2137)
First value in Actor weights: tensor(-0.1355)
First value in Critic weights: tensor(-0.2086)
First value in Actor weights: tensor(-0.1405)
First value in Critic weights: tensor(-0.2021)
First value in Actor weights: tensor(-0.1453)
First value in Critic weights: tensor(-0.1993)
First value in Actor weights: tensor(-0.1486)
First value in Critic weights: tensor(-0.1946)
First value in Actor weights: tensor(-0.1509)
First value in Critic weights: tensor(-0.1899)
First value in Actor weights: tensor(-0.1525)
First value in Critic we

In [2]:
action

tensor([[5]])

In [3]:
inputs

tensor([-0.9331,  0.2693,  0.1480,  0.2423, -0.2720, -0.5691, -1.4266,  1.7765,
         0.7039,  0.3265])