In [8]:
# TODO.
'''
Code for Policy Gradient for a grid example for learning
'''

'\nCode for Policy Gradient for a grid example for learning\n'

In [9]:
import numpy as np
import torch
import torch.nn as nn

In [10]:
# GRID INTERFACE GAME
class Grid:
    def __init__(self, m, n, exit_pos, figure_pos):
        super().__init__()
        self.m = m
        self.n = n
        self.exit_pos = exit_pos
        self.figure_pos = figure_pos

    def move(self, direction):
        x, y = self.figure_pos
        if direction == 'up':
            if y < self.n-1:
                self.figure_pos = (x, y+1)
        elif direction == 'down':
            if y > 0:
                self.figure_pos = (x, y-1)
        elif direction == 'left':
            if x > 0:
                self.figure_pos = (x-1, y)
        elif direction == 'right':
            if x < self.m-1:
                self.figure_pos = (x+1, y)

    def is_at_exit(self):
        return self.figure_pos == self.exit_pos

    def get_state(self, device):
        return torch.FloatTensor(self.figure_pos).unsqueeze(0).to(device)

In [11]:
class Policy(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(2, 16)
        self.fc2 = nn.Linear(16, 4)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        x = nn.functional.softmax(x, dim=1)
        return x

In [12]:
actions = ['up', 'down', 'left', 'right']

In [5]:
def generate_episode(grid, policy_net, device="cpu", max_episode_len = 100):
    state = grid.get_state(device)
    ep_length = 0
    while not grid.is_at_exit():
        # Convert state to tensor and pass through policy network to get action probabilities
        ep_length+=1
        action_probs = policy_net(state).squeeze()
        log_probs = torch.log(action_probs)
        cpu_action_probs = action_probs.detach().cpu().numpy()
        action = np.random.choice(np.arange(4), p=cpu_action_probs)

        # Take the action and get the new state and reward
        grid.move(actions[action])
        next_state = grid.get_state(device)
        reward = -0.1 if not grid.is_at_exit() else 0

        # Add the state, action, and reward to the episode
        new_episode_sample = (state, action, reward)
        yield new_episode_sample, log_probs

        # We do not want to add the state, action, and reward for reaching the exit position
        if reward == 0:
            break

        # Update the current state
        state = next_state
        if ep_length > max_episode_len:
            return

    # Add the final state, action, and reward for reaching the exit position
    new_episode_sample = (grid.get_state(device), None, 0)
    yield new_episode_sample, log_probs

In [6]:
def gradients_wrt_params(
    net: torch.nn.Module, loss_tensor: torch.Tensor
):
    # Dictionary to store gradients for each parameter
    # Compute gradients with respect to each parameter
    for name, param in net.named_parameters():
        g = grad(loss_tensor, param, retain_graph=True)[0]
        param.grad = g

def update_params(net: torch.nn.Module, lr: float) -> None:
    # Update parameters for the network
    for name, param in net.named_parameters():
        param.data += lr * param.grad

In [14]:
device = 'mps' if torch.mps.is_available() else 'cpu'
device

'mps'

In [7]:
from tqdm import tqdm

policy_net = Policy()
policy_net.to(device)

lengths = []
rewards = []

gamma = 0.99
lr_policy_net = 2**-13
optimizer = torch.optim.Adam(policy_net.parameters(), lr=lr_policy_net)

prefix = "reinforce-per-step"

for episode_num in tqdm(range(2500)):
    all_iterations = []
    all_log_probs = []
    grid = get_good_starting_grid()
    episode = list(generate_episode(grid, policy_net=policy_net, device=device))
    lengths.append(len(episode))
    loss = 0
    for t, ((state, action, reward), log_probs) in enumerate(episode[:-1]):
        gammas_vec = gamma ** (torch.arange(t+1, len(episode))-t-1)
        # Since the reward is -1 for all steps except the last, we can just sum the gammas
        G = - torch.sum(gammas_vec)
        rewards.append(G.item())
        policy_loss = log_probs[action]
        optimizer.zero_grad()
        gradients_wrt_params(policy_net, policy_loss)
        update_params(policy_net, lr_policy_net  * G * gamma**t)

NameError: name 'PolicyNet' is not defined