In [2]:
import collections
import random
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

import predictive_coding as pc

# import analysis_utils as au


# def plot(df, plot="Reward"):

#     df = au.nature_pre(df)

#     groups = ['Env', 'Rule', 'pc_learning_rate']

#     df = au.add_metric_per_group(
#         df, groups,
#         lambda df: (
#             'mean per group', df['Mean of episode reward'].mean()
#         ),
#     )

#     groups.pop(-1)

#     df = au.select_rows_per_group(
#         df, groups,
#         lambda df: df['mean per group'] == df['mean per group'].max()
#     )

#     df = au.drop_cols(df, ['mean per group'])

#     df = au.extract_plot(df, f'Episode {plot}', 'training_iteration')

#     df = df[df['training_iteration'].isin(list(range(0, 10000, 100)))]

#     g = au.nature_relplot_curve(
#         data=df,
#         x='training_iteration',
#         y=f'Episode {plot}',
#         hue='Rule', style='Rule',
#         hue_order=['PC', 'BP'],
#         style_order=['PC', 'BP'],
#         col='Env',
#         aspect=0.8,
#         sharey=False
#     )

#     au.nature_post(g, is_grid=True)

#     return df


class RunningStats(object):
    """Computes running mean and standard deviation
    Url: https://gist.github.com/wassname/a9502f562d4d3e73729dc5b184db2501
    Adapted from:
        *
        <http://stackoverflow.com/questions/1174984/how-to-efficiently-\
calculate-a-running-standard-deviation>
        * <http://mathcentral.uregina.ca/QQ/database/QQ.09.02/carlos1.html>
        * <https://gist.github.com/fvisin/5a10066258e43cf6acfa0a474fcdb59f>

    Usage:
        rs = RunningStats()
        for i in range(10):
            rs += np.random.randn()
            print(rs)
        print(rs.mean, rs.std)
    """

    def __init__(self, n=0., m=None, s=None, per_dim=True):
        self.n = n
        self.m = m
        self.s = s
        self.per_dim = per_dim

    def clear(self):
        self.n = 0.

    def push(self, x):
        # process input
        if self.per_dim:
            self.update_params(x)
        else:
            for el in x.flatten():
                self.update_params(el)

    def update_params(self, x):
        self.n += 1
        if self.n == 1:
            self.m = x
            self.s = 0.
        else:
            prev_m = self.m.copy()
            self.m += (x - self.m) / self.n
            self.s += (x - prev_m) * (x - self.m)

    def __add__(self, other):
        if isinstance(other, RunningStats):
            sum_ns = self.n + other.n
            prod_ns = self.n * other.n
            delta2 = (other.m - self.m) ** 2.
            return RunningStats(sum_ns,
                                (self.m * self.n + other.m * other.n) / sum_ns,
                                self.s + other.s + delta2 * prod_ns / sum_ns)
        else:
            self.push(other)
            return self

    @property
    def mean(self):
        return self.m if self.n else 0.0

    def variance(self):
        return self.s / (self.n) if self.n else 0.0

    @property
    def std(self):
        return np.sqrt(self.variance())

    def __repr__(self):
        return '<RunningMean(mean={: 2.4f}, std={: 2.4f}, n={: 2f}, m={: 2.4f}, s={: 2.4f})>'.format(self.mean, self.std, self.n, self.m, self.s)

    def __str__(self):
        return 'mean={: 2.4f}, std={: 2.4f}'.format(self.mean, self.std)

    def normalize(self, x):
        return (
            x - self.mean
        ) / (
            self.std if np.all(self.std) else 1.0
        )


class ReplayBuffer():
    def __init__(self, buffer_limit, sample_to_device):

        self.buffer = collections.deque(maxlen=buffer_limit)
        self.sample_to_device = sample_to_device

    def put(self, transition):

        self.buffer.append(transition)

    def sample(self, batch_size):
        mini_batch = random.sample(self.buffer, batch_size)
        s_lst, a_lst, r_lst, s_prime_lst, done_mask_lst = [], [], [], [], []

        for transition in mini_batch:
            s, a, r, s_prime, done_mask = transition
            s_lst.append(s)
            a_lst.append([a])
            r_lst.append([r])
            s_prime_lst.append(s_prime)
            done_mask_lst.append([done_mask])

        return torch.tensor(s_lst, dtype=torch.float).to(self.sample_to_device), torch.tensor(a_lst).to(self.sample_to_device), \
            torch.Tensor(r_lst).to(self.sample_to_device), torch.tensor(s_prime_lst, dtype=torch.float).to(self.sample_to_device), \
            torch.Tensor(done_mask_lst).to(self.sample_to_device)

    def size(self):
        return len(self.buffer)


class Qnet(nn.Module):

    def __init__(self, predictive_coding, num_obs, num_act, bias=True, pc_layer_at='before_acf', hidden_size=128, num_hidden=1, acf='Sigmoid'):

        super(Qnet, self).__init__()

        self.predictive_coding = predictive_coding
        self.num_act = num_act

        model = []

        # input layer
        model.append(nn.Linear(num_obs, hidden_size, bias=bias))
        if self.predictive_coding and pc_layer_at == 'before_acf':
            model.append(pc.PCLayer())
        model.append(eval('nn.{}()'.format(acf)))
        if self.predictive_coding and pc_layer_at == 'after_acf':
            model.append(pc.PCLayer())

        for i in range(num_hidden):

            # hidden layer
            model.append(nn.Linear(hidden_size, hidden_size, bias=bias))
            if self.predictive_coding and pc_layer_at == 'before_acf':
                model.append(pc.PCLayer())
            model.append(eval('nn.{}()'.format(acf)))
            if self.predictive_coding and pc_layer_at == 'after_acf':
                model.append(pc.PCLayer())

        # output layer
        model.append(nn.Linear(hidden_size, num_act, bias=bias))

        self.model = nn.Sequential(*model)

    def forward(self, x):

        return self.model(x)

    def sample_action(self, obs, epsilon):

        if random.random() < epsilon:
            return random.randint(0, self.num_act - 1)

        else:
            return self.forward(obs).argmax().item()


In [7]:
import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np

import predictive_coding as pc

class BaseTrainable:
    def __init__(self, config):
        self.config = config
        self.device = torch.device(config['device'])
        self.seed = config['seed']
        self.env = None
        self.q = None
        self.q_target = None
        self.memory = None
        self.pc_trainer = None
        self.episode_rewards = []
        self.rs_s = None
        self.rs_r = None
        self._iteration = 0
        self.setup()

    def setup(self):
        # Setup code
        self.env = gym.make(self.config['env'])


        # create q
        self.q = Qnet(predictive_coding=True, num_obs=self.env.observation_space.shape[0], num_act=self.env.action_space.n).to(self.device)

        # create q_target
        if self.config['is_q_target']:
            self.q_target = Qnet(predictive_coding=True, num_obs=self.env.observation_space.shape[0], num_act=self.env.action_space.n).to(self.device)
            self.q_target.load_state_dict(self.q.state_dict(), strict=False)
        else:
            self.q_target = None

        # create memory
        self.memory = ReplayBuffer(
            buffer_limit=self.config['buffer_limit'],
            sample_to_device=self.device,
        )

        # create pc_trainer
        self.pc_trainer = pc.PCTrainer(model=self.q, plot_progress_at=[])

        if self.config['is_norm_obs']:
            self.rs_s = RunningStats()
        if self.config['is_norm_rew']:
            self.rs_r = RunningStats()

    def train(self, num_episodes):
        # Train code
        epsilon = max(
            self.config['bottom_epsilon'], 
            self.config['top_epsilon'] - 0.01 * (self._iteration / self.config['anneal_epsilon_scaler'])
        )

        for e in range(num_episodes):

            s, _ = self.env.reset(seed=self.seed)
            episode_reward = 0.0
            done = False

            self.q.eval()
            while not done:
                if self.rs_s is not None:
                    self.rs_s += s
                a = self.q.sample_action(
                    obs=torch.from_numpy(self.rs_s.normalize(s) if self.rs_s is not None else s).float().to(self.device),
                    epsilon=epsilon,
                )
                s_prime, r, done, trunc, info = self.env.step(a)
                if self.rs_r is not None:
                    self.rs_r += np.asarray([r])
                done_mask = 0.0 if done else 1.0
                self.memory.put(
                    (
                        self.rs_s.normalize(s) if self.rs_s is not None else s,
                        a,
                        self.rs_r.normalize(np.asarray([r])).item() if self.rs_r is not None else r,
                        self.rs_s.normalize(s_prime) if self.rs_s is not None else s_prime,
                        done_mask,
                    )
                )
                s = s_prime
                episode_reward += r
                if done:
                    if self.rs_s is not None:
                        self.rs_s += s_prime
                    break

            if self.memory.size() > self.config['start_learn_at_memory_size']:
                for i in range(self.config['num_learn_epochs_per_eposide']):
                    s, a, r, s_prime, done_mask = self.memory.sample(
                        batch_size=self.config['batch_size'],
                    )
                    max_q_prime_estimator = self.q_target if self.config['is_q_target'] else self.q
                    max_q_prime_estimator.eval()
                    max_q_prime = max_q_prime_estimator(s_prime).max(1)[0].unsqueeze(1)
                    target = r + self.config['gamma'] * max_q_prime * done_mask

                    self.q.train()
                    def loss_fn(outputs, a, target):
                        q_a = outputs.gather(1, a)
                        loss = (q_a - target).pow(2).sum() * 0.5
                        return loss

                    self.pc_trainer.train_on_batch(
                        s, loss_fn,
                        loss_fn_kwargs={
                            'a': a,
                            'target': target.detach() if self.config['is_detach_target'] else target,
                        },
                        **self.config['train_on_batch_kwargs'],
                    )

            if self.config['is_q_target']:
                if self._iteration % self.config['interval_update_target_q'] == 0 and self._iteration != 0:
                    self.q_target.load_state_dict(self.q.state_dict(), strict=False)

            self.episode_rewards.append(episode_reward)
            if len(self.episode_rewards) > self.config['interval_compute_episode_reward']:
                self.episode_rewards.pop(0)

            print(f'Episode {e}: {episode_reward}')
            result_dict = {e: np.mean(self.episode_rewards)}
            self._iteration += 1
        return result_dict

    def stop(self):
        self.env.close()


In [8]:
config = {
    "device": "cpu",
    "seed": 2024,
    "env": "CartPole-v1",
    "is_q_target": False,
    "buffer_limit": 1000,
    "is_norm_obs": True,
    "is_norm_rew": False,
    'interval_compute_episode_reward': 200,
    'top_epsilon': 0.08,
    'bottom_epsilon': 0.01,
    'anneal_epsilon_scaler': 200,
    'start_learn_at_memory_size': 100,
    'gamma': 0.99,
    'batch_size': 64,
    'num_learn_epochs_per_eposide': 1,
    'interval_update_target_q': 20,
    'is_detach_target': True,
    'interval_compute_episode_reward': 200,
    'top_epsilon': 0.08,
    'bottom_epsilon': 0.01,
    'anneal_epsilon_scaler': 200,
    'train_on_batch_kwargs': {
        'is_log_progress': False,
        'is_return_results_every_t': False,
        'is_checking_after_callback_after_t': False,
    },
}
x = BaseTrainable(config)
x.train(10000)

# 'PCTrainer_kwargs': {
#     'update_x_at': 'all',
#     'optimizer_x_fn': 'SGD',
#     'optimizer_x_kwargs': {
#         'lr': 0.05,
#     },
#     'x_lr_discount': 1.0,
#     'x_lr_amplifier': 1.0,
#     'update_p_at': 'last',
#     'optimizer_p_fn': 'SGD',
#     'optimizer_p_kwargs': {
#         'lr': 0.05,
#     },
#     'T': "self.config['T'] if self.config['predictive_coding'] else 1",
#     'plot_progress_at': "[]",
# },

Episode 0: 9.0
Episode 1: 9.0
Episode 2: 10.0
Episode 3: 9.0
Episode 4: 10.0
Episode 5: 9.0
Episode 6: 11.0
Episode 7: 9.0
Episode 8: 9.0
Episode 9: 9.0
Episode 10: 9.0
Episode 11: 9.0
Episode 12: 12.0
Episode 13: 22.0
Episode 14: 9.0
Episode 15: 13.0
Episode 16: 24.0
Episode 17: 10.0
Episode 18: 10.0
Episode 19: 10.0
Episode 20: 10.0
Episode 21: 26.0
Episode 22: 13.0
Episode 23: 10.0
Episode 24: 10.0
Episode 25: 10.0
Episode 26: 10.0
Episode 27: 10.0
Episode 28: 11.0
Episode 29: 12.0
Episode 30: 11.0
Episode 31: 15.0
Episode 32: 13.0
Episode 33: 10.0
Episode 34: 10.0
Episode 35: 13.0
Episode 36: 10.0
Episode 37: 10.0
Episode 38: 10.0
Episode 39: 10.0
Episode 40: 11.0
Episode 41: 12.0
Episode 42: 12.0
Episode 43: 11.0
Episode 44: 14.0
Episode 45: 13.0
Episode 46: 21.0
Episode 47: 14.0
Episode 48: 12.0
Episode 49: 12.0
Episode 50: 14.0
Episode 51: 14.0
Episode 52: 13.0
Episode 53: 13.0
Episode 54: 14.0
Episode 55: 21.0
Episode 56: 15.0
Episode 57: 14.0
Episode 58: 13.0
Episode 59: 17.0


KeyboardInterrupt: 