In [19]:
import torch.nn as nn
import torch.optim as optim
import predictive_coding as pc
import collections
import random

class QNet(nn.Module):

    def __init__(self, input_size, hidden_size, output_size):

        super(QNet, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            pc.PCLayer(),
            nn.ReLU(),
            nn.Linear(hidden_size, output_size)
        )

    def forward(self, x):
        return self.model(x)
    
class ReplayBuffer:

    def __init__(self, batch_size=128, max_size=1e4):
        self.buffer = collections.deque(maxlen=int(max_size))
        self.batch_size = batch_size

    def add(self, observation, action, reward, done, next_observation):
        transition = (observation, action, reward, done, next_observation)
        self.buffer.append(transition)

    def can_sample(self):
        return len(self.buffer) >= self.batch_size

    def sample(self):
        transitions = random.sample(self.buffer, self.batch_size)
        batch = list(zip(*transitions))
        return batch

In [20]:
import gymnasium as gym
import random
import torch

class Agent():

    def __init__(
            self, 
            env: gym.Env, 
            hidden_layer_size: int, 
            x_lr: float,
            p_lr: float,
            initial_epsilon: float, 
            min_epsilon: float, 
            epsilon_decay: float,
            buffer_size: int,
            batch_size: int,
            gamma: float
        ) -> None:

        self.env    = env
        self.gamma  = gamma

        # Q-network
        self.qnet           = QNet(env.observation_space.shape[0], hidden_layer_size, env.action_space.n)
        self.target_network = QNet(env.observation_space.shape[0], hidden_layer_size, env.action_space.n)
        self.target_network.load_state_dict(self.qnet.state_dict())

        # PC Trainer

        # options for the update of the latent state x
        optimizer_x_fn = optim.SGD          # optimizer for latent state x, SGD perform gradient descent. Other alternative are Adam, RMSprop, etc. 
        optimizer_x_kwargs = {'lr': x_lr}   # optimizer parameters for latent state x to pass to the optimizer. The best learning rate will depend on the task and the optimiser. 
                                            # Other parameters such as momentum, weight_decay could also be set here with additional elements, e.g., "momentum": 0.9, "weight_decay": 0.01

        # options for the update of the parameters p
        # Randomly updating=
        update_p_at = 'all'                 # update parameters p at the last iteration, can be set to 'all' to implement ipc (https://arxiv.org/abs/2212.00720)
        optimizer_p_fn = optim.Adam         # optimizer for parameters p
        optimizer_p_kwargs = {'lr': p_lr}   # optimizer parameters for parameters p, 0.001 is a good starting point for Adam, but it should be adjusted for the task

        T = 20

        self.trainer = pc.PCTrainer(self.qnet, 
            T = 20, 
            optimizer_x_fn = optimizer_x_fn,
            optimizer_x_kwargs = optimizer_x_kwargs,
            update_p_at = update_p_at,   
            optimizer_p_fn = optimizer_p_fn,
            optimizer_p_kwargs = optimizer_p_kwargs,
            plot_progress_at = [],
        )

        # Epsilon-greedy
        self.epsilon        = initial_epsilon
        self.min_epsilon    = min_epsilon
        self.epsilon_decay  = epsilon_decay

        # Replay Memory
        self.buffer_size    = buffer_size
        self.batch_size     = batch_size
        self.replay_buffer  = ReplayBuffer(batch_size, buffer_size)

    def act(self, state):

        self.epsilon = max(self.min_epsilon, self.epsilon * self.epsilon_decay)

        self.qnet.eval()

        if random.random() < self.epsilon:
            return self.env.action_space.sample()
        
        else:
            return self.qnet(state).argmax().item()
        
    
    def learn(self):
            
        if not self.replay_buffer.can_sample():
            return

        batch = self.replay_buffer.sample()

        observations        = torch.Tensor(batch[0])
        actions             = torch.Tensor(batch[1])
        rewards             = torch.Tensor(batch[2])
        dones               = torch.Tensor(batch[3])
        next_observations   = torch.Tensor(batch[4])

        # Target:

        with torch.no_grad():
            target = rewards + self.gamma * self.target_network(next_observations).max(dim=1).values * (1 - dones)
        
        # Loss:

        def loss_fn(outputs, actions, target):
            predicted = outputs.gather(1, actions.long().unsqueeze(1)).squeeze(1)
            loss = (predicted - target).pow(2).sum() * 0.5
            return loss
        
        self.qnet.train()

        self.trainer.train_on_batch(
        inputs=observations,
        loss_fn=loss_fn,
        loss_fn_kwargs = {
            'actions': actions,
            'target': target                    
        }
    )
        
    def load_target_network(self):
        self.target_network.load_state_dict(self.qnet.state_dict(), strict=False)

In [21]:
import torch
import wandb
from gym.wrappers import TimeLimit

# WandB – Config

params = {
    'env_name': 'CartPole-v1',
    'buffer_size' : 1000,
    'batch_size' : 64,
    'gamma' : 0.99,
    'x_lr' : 0.01,
    'p_lr': 0.001
}

max_episodes        = 100
max_episode_steps   = 500

wandb.init(project="supervised-predictive-coding", config=params)

env = TimeLimit(gym.make(params['env_name']), max_episode_steps=max_episode_steps)

agent = Agent(
    env                 = env, 
    hidden_layer_size   = 128,
    x_lr                = params['x_lr'],
    p_lr                = params['p_lr'],
    initial_epsilon     = 0.9,
    min_epsilon         = 0.001,
    epsilon_decay       = 0.999,
    buffer_size         = params['buffer_size'],
    batch_size          = params['buffer_size'],
    gamma               = params['gamma']
)

obs, _ = env.reset()
action = agent.act(torch.from_numpy(obs))

next_obs, reward, termination, truncation, info = env.step(action)

agent.replay_buffer.add(obs, action, reward, termination, next_obs)

next_obs = obs 

try: 
    for i in range(250):

        action = agent.act(torch.from_numpy(obs))

        next_obs, reward, termination, truncation, info = env.step(action)

        agent.replay_buffer.add(obs, action, reward, termination, next_obs)

        if termination or truncation:
            obs, _ = env.reset()

    agent.learn()

    # Training Loop
    for episode in range(max_episodes):
        obs, _ = env.reset()
        done = False
        episode_rewards = 0
        while not done:
            action = agent.act(torch.from_numpy(obs))

            next_obs, reward, termination, truncation, infos = env.step(action)

            agent.replay_buffer.add(obs, action, reward, termination, next_obs)

            episode_rewards += reward
            obs             = next_obs
            done            = termination or truncation

            agent.learn()
        agent.load_target_network()

        print(f'Episode {episode} completed - Reward {episode_rewards} - Epsilon {agent.epsilon}')
        
        wandb.log({'reward': episode_rewards, 'eplison': agent.epsilon})

    env.close()
    wandb.finish()

except KeyboardInterrupt:
    env.close()
    wandb.finish()

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
eplison,██▇▇▇▇▇▆▆▆▆▅▅▅▅▅▅▅▅▅▄▄▄▄▄▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁
reward,▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▅▆▅▆▄▅▄▄▅▇▇▅▇█▆

0,1
eplison,0.001
reward,184.0


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011167390745443602, max=1.0…

Episode 0 completed - Reward 10.0 - Epsilon 0.6931623037120513
Episode 1 completed - Reward 13.0 - Epsilon 0.6842050626737857
Episode 2 completed - Reward 28.0 - Epsilon 0.6653037229189369
Episode 3 completed - Reward 29.0 - Epsilon 0.6462776129689248
Episode 4 completed - Reward 19.0 - Epsilon 0.6341082280488003
Episode 5 completed - Reward 14.0 - Epsilon 0.6252881865229492
Episode 6 completed - Reward 18.0 - Epsilon 0.6141281599309497
Episode 7 completed - Reward 27.0 - Epsilon 0.5977604730006341
Episode 8 completed - Reward 22.0 - Epsilon 0.5847469090696757
Episode 9 completed - Reward 23.0 - Epsilon 0.5714446407006086
Episode 10 completed - Reward 13.0 - Epsilon 0.5640602700281566
Episode 11 completed - Reward 24.0 - Epsilon 0.5506773684938275
Episode 12 completed - Reward 9.0 - Epsilon 0.545741050375066
Episode 13 completed - Reward 13.0 - Epsilon 0.5386888288296824
Episode 14 completed - Reward 20.0 - Epsilon 0.528016791626918
Episode 15 completed - Reward 10.0 - Epsilon 0.522760

wandb: ERROR Dropped streaming file chunk (see wandb/debug-internal.log)


Episode 62 completed - Reward 500.0 - Epsilon 0.09167525557152809
Episode 63 completed - Reward 500.0 - Epsilon 0.0555899447433427
Episode 64 completed - Reward 205.0 - Epsilon 0.04528155461051793
Episode 65 completed - Reward 15.0 - Epsilon 0.04460706531316048
Episode 66 completed - Reward 80.0 - Epsilon 0.041175862988562933
Episode 67 completed - Reward 301.0 - Epsilon 0.03046875201945389
Episode 68 completed - Reward 56.0 - Epsilon 0.0288085902661336
Episode 69 completed - Reward 125.0 - Epsilon 0.025421901699075003
Episode 70 completed - Reward 174.0 - Epsilon 0.021360085482111043


wandb: ERROR Dropped streaming file chunk (see wandb/debug-internal.log)


Episode 71 completed - Reward 179.0 - Epsilon 0.01785769408166896
Episode 72 completed - Reward 125.0 - Epsilon 0.01575837412808154
Episode 73 completed - Reward 174.0 - Epsilon 0.01324056014452908
Episode 74 completed - Reward 241.0 - Epsilon 0.010403728756924474
Episode 75 completed - Reward 423.0 - Epsilon 0.006813818176725296
Episode 76 completed - Reward 360.0 - Epsilon 0.00475298344714137
Episode 77 completed - Reward 246.0 - Epsilon 0.003716005990196342
Episode 78 completed - Reward 243.0 - Epsilon 0.002914003333960734


wandb: ERROR Dropped streaming file chunk (see wandb/debug-internal.log)


Episode 79 completed - Reward 285.0 - Epsilon 0.002191059587697678
Episode 80 completed - Reward 146.0 - Epsilon 0.001893282713946969
Episode 81 completed - Reward 500.0 - Epsilon 0.0011480467744070838
Episode 82 completed - Reward 135.0 - Epsilon 0.0010029989842732629
Episode 83 completed - Reward 144.0 - Epsilon 0.001
Episode 84 completed - Reward 154.0 - Epsilon 0.001
Episode 85 completed - Reward 168.0 - Epsilon 0.001
Episode 86 completed - Reward 168.0 - Epsilon 0.001
Episode 87 completed - Reward 193.0 - Epsilon 0.001


wandb: ERROR Dropped streaming file chunk (see wandb/debug-internal.log)


Episode 88 completed - Reward 226.0 - Epsilon 0.001
Episode 89 completed - Reward 165.0 - Epsilon 0.001
Episode 90 completed - Reward 233.0 - Epsilon 0.001
Episode 91 completed - Reward 118.0 - Epsilon 0.001
Episode 92 completed - Reward 209.0 - Epsilon 0.001
Episode 93 completed - Reward 182.0 - Epsilon 0.001
Episode 94 completed - Reward 173.0 - Epsilon 0.001
Episode 95 completed - Reward 204.0 - Epsilon 0.001
Episode 96 completed - Reward 195.0 - Epsilon 0.001
Episode 97 completed - Reward 210.0 - Epsilon 0.001
Episode 98 completed - Reward 212.0 - Epsilon 0.001
Episode 99 completed - Reward 111.0 - Epsilon 0.001


wandb: ERROR Dropped streaming file chunk (see wandb/debug-internal.log)


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
eplison,██▇▇▇▇▆▆▆▆▆▅▅▅▅▅▅▅▄▄▄▄▄▃▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
reward,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂█▂▂▃▃▆▄█▃▃▄▃▃▄▂

0,1
eplison,0.001
reward,111.0


GPU stats error: Command '['/Users/jackmontgomery/anaconda3/lib/python3.11/site-packages/wandb/bin/apple_gpu_stats', '--json']' died with <Signals.SIGTRAP: 5>.
Traceback (most recent call last):
  File "/Users/jackmontgomery/anaconda3/lib/python3.11/site-packages/wandb/sdk/internal/system/assets/gpu_apple.py", line 64, in sample
    subprocess.check_output(command, universal_newlines=True)
  File "/Users/jackmontgomery/anaconda3/lib/python3.11/subprocess.py", line 466, in check_output
    return run(*popenargs, stdout=PIPE, timeout=timeout, check=True,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/jackmontgomery/anaconda3/lib/python3.11/subprocess.py", line 571, in run
    raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['/Users/jackmontgomery/anaconda3/lib/python3.11/site-packages/wandb/bin/apple_gpu_stats', '--json']' died with <Signals.SIGTRAP: 5>.
GPU stats error: Command '['/Users/jackmontgomery/anaconda