# Actor Critic Algorithm test

In [505]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [506]:
import logging
import os
import random
import sys
import warnings
from itertools import accumulate

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm

ROOT_FOLDER = os.path.join(".", "..")
if ROOT_FOLDER not in sys.path:
    sys.path.insert(0, ROOT_FOLDER)


from dataset import RegexDataset

# from environment import Environment, EnvSettings
from environment_metrics import Environment, EnvSettings

warnings.filterwarnings("ignore")
logging.disable(logging.WARNING)
torch.backends.cudnn.deterministic = True
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE

device(type='cpu')

## Utils

In [507]:
def set_seed(seed: int = 420):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

## Dataset

In [508]:
dataset = RegexDataset(["a2d", "2bb"], r"\d+")
data_iter = dataset.create_iterator()

for i in range(10):
    print(next(data_iter))

('a2d', [0, 1, 0], 1)
('2bb', [1, 0, 0], 1)
('a2d', [0, 1, 0], 1)
('2bb', [1, 0, 0], 1)
('2bb', [1, 0, 0], 1)
('a2d', [0, 1, 0], 1)
('2bb', [1, 0, 0], 1)
('a2d', [0, 1, 0], 1)
('a2d', [0, 1, 0], 1)
('2bb', [1, 0, 0], 1)


## Environment

In [509]:
env = Environment(dataset, settings=EnvSettings(max_steps=2))

env.action_space

101

In [510]:
state = env.reset()
for _ in range(6):
    action = np.random.randint(env.action_space)
    print(f"{action=}")
    print(env.step(action))

action=72
(array([0.28712871, 0.        ]), 0, False)
action=6
(array([0., 0.]), -134.63333333333335, True)
action=53
(array([0.47524752, 0.        ]), 0, False)
action=19
(array([0., 0.]), -100, True)
action=81
(array([0.1980198, 0.       ]), 0, False)
action=79
(array([0., 0.]), -134.63333333333335, True)


## Advantage Actor-Critic

In [511]:
def calculate_qvals(
    rewards: list[float] | np.ndarray, gamma: float = 1.0, reward_steps: int = 0
) -> np.ndarray:
    rw_steps = reward_steps if reward_steps != 0 else len(rewards)

    return np.array(
        [
            list(
                accumulate(
                    reversed(rewards[i : i + rw_steps]), lambda x, y: gamma * x + y
                )
            )[-1]
            for i in range(len(rewards))
        ]
    )

In [512]:
class A2CNet(nn.Module):
    def __init__(
        self,
        input_dim: int = env.state_space,
        output_dim: int = env.action_space,
        hidden_dim: int = 128,
    ) -> None:
        super(A2CNet, self).__init__()

        self.body = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
        )

        self.policy = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim),
        )

        self.value = nn.Sequential(
            nn.Linear(hidden_dim, 1),
        )

        nn.init.xavier_uniform_(self.policy[-1].weight, gain=0.01)

    def forward(self, x):
        body_out = self.body(x)
        return self.policy(body_out), self.value(body_out)

## Agent

In [513]:
class Agent:
    def __init__(self, temperature_coefficient: float = 10.0):
        self.temperature_coefficient = temperature_coefficient

    def choose_action(self, action_logits: torch.Tensor, epoch: int):
        temperature = (
            1 / epoch * torch.max(torch.abs(action_logits)) * self.temperature_coefficient
            if self.temperature_coefficient > 0
            else 1
        )

        return np.random.choice(
            range(len(action_logits)),
            size=1,
            p=F.softmax(action_logits / temperature, dim=0).numpy(),
        )[0]

    def choose_optimal_action(self, action_logits: torch.Tensor) -> int:
        return int(np.argmax(F.softmax(action_logits, dim=0).cpu()).item())

## Trajectory Buffer

In [514]:
class TrajectoryBuffer:
    """
    Buffer class to store the experience from a unique policy
    """

    def _batch(self, iterable):
        ln = len(iterable)
        for ndx in range(0, ln, self.batch_size):
            yield iterable[ndx : min(ndx + self.batch_size, ln)]

    def __init__(self, batch_size: int = 64):
        self.batch_size = batch_size
        self.clean()

    def clean(self):
        self.states = []
        self.actions = []
        self.discounted_rewards = []

    def store(
        self,
        states_trajectory: np.ndarray,
        trajectory: np.ndarray,
    ):
        """
        Add trajectory values to the buffers and compute the advantage and reward to go

        Parameters:
        -----------
        states_trajectory:  list that contains states
        trajectory: list where each element is a list that contains: reward, action
        """
        assert len(states_trajectory) == len(trajectory)

        if len(states_trajectory) > 0:
            self.states.extend(states_trajectory)
            self.actions.extend(trajectory[:, 1])

            self.discounted_rewards.extend(calculate_qvals(trajectory[:, 0]))

    def get_batches(self, mean_baseline: bool):
        mean_rewards = np.mean(self.discounted_rewards) if mean_baseline else 0

        for states_batch, actions_batch, discounted_rewards_batch in zip(
            self._batch(self.states),
            self._batch(self.actions),
            self._batch(self.discounted_rewards),
        ):
            mean_batch_reward = np.mean(discounted_rewards_batch) if mean_baseline else 0
            yield (
                torch.tensor(states_batch, dtype=torch.float32, device=DEVICE),
                torch.tensor(actions_batch, dtype=torch.long, device=DEVICE),
                torch.tensor(
                    np.array(discounted_rewards_batch) - mean_rewards,
                    # np.array(discounted_rewards_batch) - mean_batch_reward,
                    dtype=torch.float,
                    device=DEVICE,
                ),
            )

    def __len__(self):
        return len(self.states)

## Training

In [515]:
def fill_buffer(
    a2c_net: nn.Module, agent: Agent, buffer: TrajectoryBuffer, episodes: int, epoch: int
):
    buffer.clean()
    state = env.reset()
    done_episodes = 0
    ep_states_buf, ep_rew_act_buf = [], []

    train_rewards = []

    epoch_loop = tqdm(total=episodes, desc=f"Epoch #{epoch}", position=0, disable=True)

    with torch.no_grad():
        while done_episodes < episodes:
            state_tensor = torch.tensor(state, dtype=torch.float32, device=DEVICE)

            action_logits, _ = a2c_net(state_tensor)

            action = agent.choose_action(action_logits, epoch=epoch)
            next_state, reward, done = env.step(action)

            ep_states_buf.append(state)
            ep_rew_act_buf.append([reward, int(action)])

            state = next_state

            if done:
                buffer.store(
                    np.array(ep_states_buf),
                    np.array(ep_rew_act_buf),
                )

                ep_states_buf, ep_rew_act_buf = [], []

                train_rewards.append(reward)

                done_episodes += 1
                epoch_loop.update(1)

    return train_rewards


def train(
    a2c_net: nn.Module,
    a2c_optimizer: optim.Optimizer,
    buffer: TrajectoryBuffer,
    mean_baseline: bool = True,
    entropy_beta: float = 1e-3,
    clip_grad: float = 10,
):
    a2c_net.train()
    losses = []
    entropies = []
    for batch in buffer.get_batches(mean_baseline):
        a2c_optimizer.zero_grad()
        (
            state_batch,
            action_batch,
            reward_batch,
        ) = batch

        logits_v, value_v = a2c_net(state_batch)

        # Value loss
        loss_value_v = F.mse_loss(value_v.squeeze(-1), reward_batch)

        # Policy loss
        log_prob_v = F.log_softmax(logits_v, dim=1)
        adv_v = reward_batch - value_v.detach()
        log_prob_actions_v = adv_v * log_prob_v[range(len(state_batch)), action_batch]
        loss_policy_v = -log_prob_actions_v.mean()

        # Entropy loss
        prob_v = F.softmax(logits_v, dim=1)
        entropy_v = -(prob_v * log_prob_v).sum(dim=1).mean()
        entropy_loss_v = entropy_beta * entropy_v
        loss_policy_v = loss_policy_v - entropy_loss_v

        # Policy backward
        loss_policy_v.backward(retain_graph=True)

        # Value backward
        loss_v = loss_value_v - entropy_loss_v
        loss_v.backward()

        if clip_grad > 0:
            nn.utils.clip_grad_norm_(a2c_net.parameters(), clip_grad)

        a2c_optimizer.step()

        losses.append(loss_v.item() + loss_policy_v.item())
        entropies.append(entropy_v.item())

    return losses, entropies


def evaluate(
    a2c_net: nn.Module,
    env: Environment,
    agent: Agent,
) -> tuple[str, float]:
    a2c_net.eval()
    max_steps = env.settings.max_steps
    regex_actions = []
    total_reward = 0

    state = env.reset()
    with torch.no_grad():
        for i in range(len(env)):
            regex_actions = []
            for _ in range(max_steps):
                state_tensor = torch.tensor(state, dtype=torch.float32, device=DEVICE)
                action_logits, _ = a2c_net(state_tensor)

                action = agent.choose_optimal_action(action_logits)
                regex_actions.append(env.idx_to_action(action))

                next_state, reward, done = env.step(action)

                state = next_state
                if done:
                    total_reward += reward
                    break

    if regex_actions and regex_actions[-1] == env._finish_action:
        regex_actions = regex_actions[:-1]

    try:
        regex = env.rpn.to_infix(regex_actions)
    except BaseException:
        regex = f"Invalid: {regex_actions}"

    return regex, total_reward


def train_eval_loop(
    a2c_net: nn.Module,
    a2c_optimizer: optim.Optimizer,
    agent: Agent,
    buffer: TrajectoryBuffer,
    epochs: int,
    episodes: int,
    mean_baseline: bool = True,
    entropy_beta: float = 0.5,
    eval_period: int = 5,
    clip_grad: float = 10,
):
    set_seed()

    for i in range(1, epochs + 1):
        train_rewards = fill_buffer(a2c_net, agent, buffer, episodes, epoch=i)

        losses, entropies = train(
            a2c_net, a2c_optimizer, buffer, mean_baseline, entropy_beta, clip_grad
        )

        print(
            f"Epoch {i: >3}/{epochs}:"
            f"\tReward: {np.mean(train_rewards):.1f}"
            f"\tLoss: {np.mean(losses):.3f}"
            f"\tEntropy: {np.mean(entropies):.3f}"
        )

        if (i % eval_period == 0) or (eval_period == (epochs + 1)):
            built_regex, total_reward = evaluate(a2c_net, env, agent)

            print(f"\nEVALUATION\nRegex: {built_regex}\nTotal reward: {total_reward}\n")

In [516]:
def fill_buffer_pre_train(buffer: TrajectoryBuffer, action_str: str):
    buffer.clean()
    state = env.reset()
    ep_states_buf, ep_rew_act_buf = [], []
    fin_action = env.action_to_idx(env._finish_action)

    for _ in range(len(env)):
        action = env.action_to_idx(action_str)

        next_state, reward, done = env.step(action)
        ep_states_buf.append(state)
        ep_rew_act_buf.append([reward, int(action)])

        state = next_state

        next_state, reward, done = env.step(fin_action)
        ep_states_buf.append(state)
        ep_rew_act_buf.append([reward, int(action)])

        buffer.store(
            np.array(ep_states_buf),
            np.array(ep_rew_act_buf),
        )


def train_pre_train(
    a2c_net: nn.Module,
    a2c_optimizer: optim.Optimizer,
    buffer: TrajectoryBuffer,
    mean_baseline: bool = True,
    entropy_beta: float = 1e-3,
):
    a2c_net.train()
    losses = []
    entropies = []
    for batch in buffer.get_batches(mean_baseline):
        a2c_optimizer.zero_grad()
        (
            state_batch,
            action_batch,
            reward_batch,
        ) = batch

        logits_v, value_v = a2c_net(state_batch)

        # Value loss
        loss_value_v = F.mse_loss(value_v.squeeze(-1), reward_batch)

        # Policy loss
        log_prob_v = F.log_softmax(logits_v, dim=1)
        adv_v = reward_batch - value_v.detach()
        log_prob_actions_v = adv_v * log_prob_v[range(len(state_batch)), action_batch]
        loss_policy_v = -log_prob_actions_v.mean()

        # Entropy loss
        prob_v = F.softmax(logits_v, dim=1)
        entropy_v = -(prob_v * log_prob_v).sum(dim=1).mean()
        entropy_loss_v = entropy_beta * entropy_v
        loss_policy_v = loss_policy_v - entropy_loss_v

        # Policy backward
        loss_policy_v.backward(retain_graph=True)

        # Value backward
        loss_v = loss_value_v - entropy_loss_v
        loss_v.backward()

        a2c_optimizer.step()

        losses.append(loss_v.item() + loss_policy_v.item())
        entropies.append(entropy_v.item())

    return losses, entropies


def pre_train_eval_loop(
    a2c_net: nn.Module,
    a2c_optimizer: optim.Optimizer,
    mean_baseline: bool = True,
    epochs: int = 10,
):
    set_seed()
    for i in tqdm(range(1, epochs + 1)):
        buffer = TrajectoryBuffer()
        for action in env.actions:
            if action == env._finish_action:
                continue
            fill_buffer_pre_train(buffer, action)

            train_pre_train(
                a2c_net, a2c_optimizer, buffer, mean_baseline, entropy_beta=1.0
            )

In [None]:
set_seed()
env = Environment(
    # RegexDataset(["a2d", "2bb", "cc2", "d3d"], r"\d+"), settings=EnvSettings(max_steps=5)
    RegexDataset(["a2d", "2bb", "cc2"], r"\d+"),
    settings=EnvSettings(max_steps=5),
)

agent = Agent(temperature_coefficient=0)
buffer = TrajectoryBuffer(batch_size=64)

a2c_net = A2CNet(input_dim=env.state_space, output_dim=env.action_space).to(DEVICE)
a2c_optimizer = optim.Adam(a2c_net.parameters(), lr=1e-4)

In [518]:
pre_train_eval_loop(a2c_net, a2c_optimizer, mean_baseline=True, epochs=10)

  0%|          | 0/10 [00:00<?, ?it/s]

100%|██████████| 10/10 [00:08<00:00,  1.21it/s]


In [519]:
train_eval_loop(
    a2c_net,
    a2c_optimizer,
    agent,
    buffer,
    epochs=300,
    episodes=1000,
    eval_period=20,
    entropy_beta=0.01,
    # entropy_beta=100,
    clip_grad=10,
)

Epoch   1/300:	Reward: -111.1	Loss: 373.439	Entropy: 1.633
Epoch   2/300:	Reward: -109.9	Loss: 422.268	Entropy: 1.619
Epoch   3/300:	Reward: -107.5	Loss: 387.965	Entropy: 1.597
Epoch   4/300:	Reward: -106.6	Loss: 325.420	Entropy: 1.562
Epoch   5/300:	Reward: -105.4	Loss: 225.525	Entropy: 1.502
Epoch   6/300:	Reward: -103.0	Loss: 116.297	Entropy: 1.411
Epoch   7/300:	Reward: -102.6	Loss: 127.671	Entropy: 1.294
Epoch   8/300:	Reward: -101.5	Loss: 67.792	Entropy: 1.163
Epoch   9/300:	Reward: -100.8	Loss: 17.312	Entropy: 1.028
Epoch  10/300:	Reward: -100.0	Loss: 50.371	Entropy: 0.942
Epoch  11/300:	Reward: -100.4	Loss: 29.949	Entropy: 0.879
Epoch  12/300:	Reward: -100.4	Loss: 9.926	Entropy: 0.822
Epoch  13/300:	Reward: -100.5	Loss: 11.377	Entropy: 0.773
Epoch  14/300:	Reward: -100.1	Loss: 24.127	Entropy: 0.710
Epoch  15/300:	Reward: -100.3	Loss: 8.382	Entropy: 0.683
Epoch  16/300:	Reward: -100.1	Loss: 2.469	Entropy: 0.645
Epoch  17/300:	Reward: -100.0	Loss: 25.059	Entropy: 0.630
Epoch  18/