# Toy REINFORCE playground

In [47]:
import random
from itertools import accumulate

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [48]:
torch.cuda.is_available()

True

## Utilities

In [49]:
GAMMA = 0.99

In [50]:
def calculate_qvals(rewards: list[float], gamma: float = GAMMA) -> list[float]:
    return list(reversed(list(accumulate(reversed(rewards), lambda x, y: gamma * x + y))))

## Environment

In [51]:
class BlackJack:
    def _get_reward(self) -> float:
        if self.score < self.win_score:
            return self.step_penalty
        if self.score == self.win_score:
            return self.win_reward
        return self.loose_reward

    def __init__(self) -> None:
        self.actions_dict = {0: 1, 1: 5, 2: 10}
        self.win_score = 101

        self.win_reward = 1000
        # self.win_reward = 0
        self.step_penalty = -1
        self.loose_reward = -100

        self.reset()

    def reset(self):
        self.score = random.randint(0, self.win_score - 1)
        # self.score = 85
        # self.score = 91
        self.steps = 0

    def get_state(self) -> list[int]:
        return [self.score]

    def is_terminal(self) -> bool:
        return self.score >= self.win_score

    def interact(self, action: int) -> tuple[list[int], float]:
        if self.is_terminal():
            return [self.score], 0
        self.score += self.actions_dict[action]
        self.steps += 1

        return [self.score], self._get_reward()

    def get_observation_shape(self) -> int:
        return 1

    def get_actions_shape(self) -> int:
        return len(self.actions_dict)

In [52]:
test_env = BlackJack()

In [53]:
print(test_env.interact(0))
print(test_env.is_terminal())

([42], -1)
False


## Policy Gradient Network

In [54]:
class PGN(nn.Module):
    def __init__(self, input_dim: int, output_dim: int) -> None:
        super(PGN, self).__init__()

        self.net = nn.Sequential(
            nn.Linear(input_dim, 16), nn.ReLU(), nn.Linear(16, output_dim)
        )

    def forward(self, x):
        return self.net(x)

## Agent

In [55]:
action_logits = torch.FloatTensor([1, 2, 3])
random.choices(range(len(action_logits)), F.softmax(action_logits, dim=0))[0]

1

In [56]:
class Agent:
    def choose_action(self, action_logits):
        return random.choices(range(len(action_logits)), F.softmax(action_logits, dim=0))[
            0
        ]

## Training

In [57]:
LEARNING_RATE = 0.01

env = BlackJack()
net = PGN(input_dim=env.get_observation_shape(), output_dim=env.get_actions_shape())
agent = Agent()
optimizer = optim.Adam(net.parameters(), lr=LEARNING_RATE)

In [58]:
total_rewards = []
total_steps = []
trajectories = []
losses = []

batch_episodes = 0
cur_rewards = []
batch_states, batch_actions, batch_qvals = [], [], []

In [59]:
EPISODES_TO_TRAIN = 100
EPOCHS = 30

In [67]:
epochs_iterator = 0

while True:
    if epochs_iterator >= EPOCHS:
        break

    state = env.get_state()

    with torch.no_grad():
        action_logits = net(torch.FloatTensor(state))

    # action_logits = net(torch.FloatTensor(state))

    action = agent.choose_action(action_logits)
    _, reward = env.interact(action)

    batch_states.append(state)
    batch_actions.append(int(action))
    cur_rewards.append(reward)

    if env.is_terminal():
        # print(batch_states[-len(cur_rewards):])
        # print(batch_actions[-len(cur_rewards):])
        # print(len(cur_rewards))
        # print(calculate_qvals(cur_rewards))
        # print()

        batch_qvals.extend(calculate_qvals(cur_rewards))
        batch_episodes += 1

        total_rewards.append(sum(cur_rewards))
        total_steps.append(env.steps)
        trajectories.append(batch_states[-len(cur_rewards) :])

        cur_rewards.clear()

        env.reset()

    if batch_episodes < EPISODES_TO_TRAIN:
        continue
    epochs_iterator += 1

    optimizer.zero_grad()
    states_v = torch.FloatTensor(batch_states)
    batch_actions_t = torch.LongTensor(batch_actions)
    batch_qvals_t = torch.FloatTensor(batch_qvals)

    logits_v = net(states_v)
    log_prob_v = F.log_softmax(logits_v, dim=1)
    # print(logits_v)
    # print(log_prob_v)
    # print(log_prob_v[range(len(batch_states)), batch_actions_t])
    # print(log_prob_v)
    log_prob_actions_v = (
        batch_qvals_t * log_prob_v[range(len(batch_states)), batch_actions_t]
    )
    loss_v = -log_prob_actions_v.mean()
    print(loss_v)

    loss_v.backward()
    optimizer.step()

    losses.append(loss_v.item())

    batch_episodes = 0
    batch_states.clear()
    batch_actions.clear()
    batch_qvals.clear()

tensor(-0.0278, grad_fn=<NegBackward0>)
tensor(3.3351, grad_fn=<NegBackward0>)
tensor(1.4330, grad_fn=<NegBackward0>)
tensor(0.2716, grad_fn=<NegBackward0>)
tensor(-0.0563, grad_fn=<NegBackward0>)
tensor(0.3396, grad_fn=<NegBackward0>)
tensor(1.3009, grad_fn=<NegBackward0>)
tensor(1.2361, grad_fn=<NegBackward0>)
tensor(-0.1844, grad_fn=<NegBackward0>)
tensor(0.5038, grad_fn=<NegBackward0>)
tensor(0.5296, grad_fn=<NegBackward0>)
tensor(0.2015, grad_fn=<NegBackward0>)
tensor(1.9507, grad_fn=<NegBackward0>)
tensor(0.0146, grad_fn=<NegBackward0>)
tensor(-0.4242, grad_fn=<NegBackward0>)
tensor(0.7634, grad_fn=<NegBackward0>)
tensor(1.4597, grad_fn=<NegBackward0>)
tensor(1.8858, grad_fn=<NegBackward0>)
tensor(1.4237, grad_fn=<NegBackward0>)
tensor(-0.3912, grad_fn=<NegBackward0>)
tensor(0.6483, grad_fn=<NegBackward0>)
tensor(4.3323, grad_fn=<NegBackward0>)
tensor(1.5298, grad_fn=<NegBackward0>)
tensor(-0.0066, grad_fn=<NegBackward0>)
tensor(0.2961, grad_fn=<NegBackward0>)
tensor(-0.0038, gra

In [68]:
print(total_rewards[-10:])
# print(losses[-10:])
print([x[0] for x in trajectories[-10:]])
print(total_steps[-10:])

[-107, -101, -116, -108, 980, -112, -106, -101, -108, 993]
[[62], [94], [19], [57], [0], [40], [69], [94], [57], [61]]
[8, 2, 17, 9, 21, 13, 7, 2, 9, 8]
