In [109]:
import torch.nn as nn
import torch.optim as optim
import predictive_coding as pc
import collections
import random

class QNet(nn.Module):

    def __init__(self, input_size, hidden_size, output_size):

        super(QNet, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            pc.PCLayer(),
            nn.ReLU(),
            nn.Linear(hidden_size, output_size)
        )

        # options for the update of the latent state x
        optimizer_x_fn = optim.SGD          # optimizer for latent state x, SGD perform gradient descent. Other alternative are Adam, RMSprop, etc. 
        optimizer_x_kwargs = {'lr': 0.01}   # optimizer parameters for latent state x to pass to the optimizer. The best learning rate will depend on the task and the optimiser. 
                                            # Other parameters such as momentum, weight_decay could also be set here with additional elements, e.g., "momentum": 0.9, "weight_decay": 0.01

        # options for the update of the parameters p
        update_p_at = 'all'                 # update parameters p at the last iteration, can be set to 'all' to implement ipc (https://arxiv.org/abs/2212.00720)
        optimizer_p_fn = optim.Adam         # optimizer for parameters p
        optimizer_p_kwargs = {'lr': 0.001}  # optimizer parameters for parameters p, 0.001 is a good starting point for Adam, but it should be adjusted for the task

        T = 20

        self.trainer = pc.PCTrainer(self.model, 
            T = 20, 
            optimizer_x_fn = optimizer_x_fn,
            optimizer_x_kwargs = optimizer_x_kwargs,
            update_p_at = update_p_at,   
            optimizer_p_fn = optimizer_p_fn,
            optimizer_p_kwargs = optimizer_p_kwargs,
            plot_progress_at = [],
        )


    def forward(self, x):
        return self.model(x)
    
class ReplayBuffer:

    def __init__(self, batch_size=128, max_size=1e4):
        self.buffer = collections.deque(maxlen=int(max_size))
        self.batch_size = batch_size

    def add(self, observation, action, reward, done, next_observation):
        transition = (observation, action, reward, done, next_observation)
        self.buffer.append(transition)

    def can_sample(self):
        return len(self.buffer) >= self.batch_size

    def sample(self):
        transitions = random.sample(self.buffer, self.batch_size)
        batch = list(zip(*transitions))
        return batch

In [110]:
import gymnasium as gym
import random
import torch

class Agent():

    def __init__(
            self, 
            env: gym.Env, 
            initial_epsilon: float, 
            min_epsilon: float, 
            epsilon_decay: float, 
            buffer_size: int,
            batch_size: int,
            gamma: float
        ) -> None:

        self.env    = env
        self.gamma  = gamma

        # Q-network
        self.qnet           = QNet(env.observation_space.shape[0], 128, env.action_space.n)
        self.target_network = QNet(env.observation_space.shape[0], 128, env.action_space.n)
        self.target_network.load_state_dict(self.qnet.state_dict())

        # Epsilon-greedy
        self.epsilon        = initial_epsilon
        self.min_epsilon    = min_epsilon
        self.epsilon_decay  = epsilon_decay

        # Replay Memory
        self.buffer_size    = buffer_size
        self.batch_size     = batch_size
        self.replay_buffer  = ReplayBuffer(batch_size, buffer_size)

    def act(self, state):

        self.epsilon = max(self.min_epsilon, self.epsilon * self.epsilon_decay)

        self.qnet.eval()

        if random.random() < self.epsilon:
            return self.env.action_space.sample()
        
        else:
            return self.qnet(state).argmax().item()
        
    
    def learn(self):
            
        if not self.replay_buffer.can_sample():
            return

        self.qnet.train()

        batch = self.replay_buffer.sample()

        observations        = torch.Tensor(batch[0])
        actions             = torch.Tensor(batch[1])
        rewards             = torch.Tensor(batch[2])
        dones               = torch.Tensor(batch[3])
        next_observations   = torch.Tensor(batch[4])

        # Target:

        with torch.no_grad():
            target = rewards + self.gamma * self.target_network(next_observations).max(dim=1).values * (1 - dones)
        
        # Loss:

        def loss_fn(outputs, actions, target):
            predicted = outputs.gather(1, actions.long().unsqueeze(1)).squeeze(1)
            loss = (predicted - target).pow(2).sum() * 0.5
            return loss

        self.qnet.trainer.train_on_batch(
        inputs=observations,
        loss_fn=loss_fn,
        loss_fn_kwargs = {
            'actions': actions,
            'target': target                    
        }
    )
        
    def load_target_network(self):
        self.target_network.load_state_dict(self.qnet.state_dict(), strict=False)

In [111]:
import torch
env = gym.make('CartPole-v1')
agent = Agent(env, 0.9, 0.001, 0.999, 1000, 64, 0.99)

obs, _ = env.reset()
action = agent.act(torch.from_numpy(obs))

next_obs, reward, termination, truncation, info = env.step(action)

agent.replay_buffer.add(obs, action, reward, termination, next_obs)

next_obs = obs 

for i in range(100):

    action = agent.act(torch.from_numpy(obs))

    next_obs, reward, termination, truncation, info = env.step(action)

    agent.replay_buffer.add(obs, action, reward, termination, next_obs)

    if termination or truncation:
        obs, _ = env.reset()

agent.learn()

# Training Loop
for episode in range(100):
    obs, _ = env.reset()
    done = False
    episode_rewards = 0
    while not done:
        action = agent.act(torch.from_numpy(obs))
        next_obs, reward, done, _, _ = env.step(action)
        episode_rewards += reward
        agent.replay_buffer.add(obs, action, reward, done, next_obs)
        obs = next_obs
        agent.learn()
    agent.load_target_network()
    print(f'Episode {episode} completed - Reward {episode_rewards} - Epsilon {agent.epsilon}')

Episode 0 completed - Reward 18.0 - Epsilon 0.798979448275733
Episode 1 completed - Reward 18.0 - Epsilon 0.7847194125331725
Episode 2 completed - Reward 40.0 - Epsilon 0.7539350353479782
Episode 3 completed - Reward 10.0 - Epsilon 0.7464295217570214
Episode 4 completed - Reward 12.0 - Epsilon 0.7375214679987703
Episode 5 completed - Reward 30.0 - Epsilon 0.7157136715673376
Episode 6 completed - Reward 20.0 - Epsilon 0.7015345712765669
Episode 7 completed - Reward 15.0 - Epsilon 0.691084895594664
Episode 8 completed - Reward 14.0 - Epsilon 0.6814723449173306
Episode 9 completed - Reward 14.0 - Epsilon 0.6719934986967332
Episode 10 completed - Reward 78.0 - Epsilon 0.6215458242302211
Episode 11 completed - Reward 16.0 - Epsilon 0.6116753296042869
Episode 12 completed - Reward 32.0 - Epsilon 0.5924020979840192
Episode 13 completed - Reward 41.0 - Epsilon 0.5685931262319229
Episode 14 completed - Reward 45.0 - Epsilon 0.5435613584374495
Episode 15 completed - Reward 35.0 - Epsilon 0.52485