# Assignment 2

Instructions: Implement both PG and an evolutionary algorithm to solve the Open AI Gym Lunar Lander problem, and then apply it to my area of choice, which is chess.

First, we need to do some setup

In [None]:
import torch
import numpy as np
import gym

# Set the device
if torch.cuda.is_available():
    device = "cuda" # 🧮
# elif torch.backends.mps.is_available():
#     device = "mps" # 🧠
else:
    device = "cpu" # 🥺
    
print(f"Using device: {device}")

First, we need to write the code for our Policy Gradient function with a baseline (taken from REINFORCE). I'm going to use PyTorch as my neural network library (I want to try JAX, but this is the more practical choice for me at the moment. Exploration-Exploitation tradeoff 🤷‍♂️).

I'm going to start with a basic feed forward net for both the network that chooses the policy and the network that learns states' values.

First, the policy network for choosing actions

In [None]:
from torch import nn

class PolicyChoice(nn.Module):
    def __init__(self):
        super(PolicyChoice, self).__init__()
        self.layer1 = torch.nn.Linear(8, 8)
        self.relu1 = torch.nn.ReLU()
        self.layer2 = torch.nn.Linear(8, 4)
        self.relu2 = torch.nn.ReLU()
        self.layer3 = torch.nn.Linear(4, 4)

    def forward(self, x):
        x_weighted1 = self.layer1(x)
        h1 = self.relu1(x_weighted1)
        x_weighted2 = self.layer2(h1)
        h2 = self.relu2(x_weighted2)
        logits = self.layer3(h2)
        return logits

policy_model = PolicyChoice().to(device)
policy_adam = torch.optim.Adam(policy_model.parameters(), 1e-3)

For our loss function for the policy network, we want to adjust just the parameters with the primary aim of affecting the probability of taking the action that we took on that time step. If the return of the resulting state is better than expected, we want to increase it proportionally. If it is less than expected, we want to decrease it proportionally. Thus, we multiply the gradient of the parameter weights w.r.t. the taken action's probability by the difference of the return for that state-action pair.

Importantly, there is an extra factor however that we must consider; when we decide that we want to take the gradient of the parameters w.r.t. a specific action's return, the policy expectancy must be multiplied by the specific action's likelihood to determine the value it contributes to the policy. Thus, we end up with the gradient of the action's probability conditioned on the state and parameters. 

Thus, the general concept of loss to backpropogate in the REINFORCE algorithm is:


$\Large (G_t - \hat{\upsilon}) \frac{\nabla\pi(A_t|S_t, \theta)}{\pi(A_t|S_t, \theta)}$

This can be expressed as:

$\Large (G_t - \hat{\upsilon}) \nabla \ln{\pi(A_t|S_t, \theta)}$


The code below just worries about the loss and not the gradient, as PyTorch provides autograd differentiation behind the scenes.

In [None]:
def policy_loss(logits, action_chosen, state_util_difference):
    ce_loss = nn.CrossEntropyLoss()
    # NOTE: I think RL literature typically describes problems as gradient ascent
    # however I am defining it here as a loss function so we will multiply
    # the state's return by -1 and instead aim to minimize this function
    # i.e. we will do gradient descent
    # TODO: if working and I have time, refactor to gradient ascent
    action_scaled = torch.mul(ce_loss(logits, action_chosen), torch.mul(torch.unsqueeze(state_util_difference, dim=1), -1))
    return action_scaled

Now, the network for approximating state utililities.

In [None]:
class StateUtility(nn.Module):
    def __init__(self):
        super(StateUtility, self).__init__()
        self.layer1 = torch.nn.Linear(8, 4)
        self.relu1 = torch.nn.ReLU()
        self.layer2 = torch.nn.Linear(4, 2)
        self.relu2 = torch.nn.ReLU()
        self.layer3 = torch.nn.Linear(2, 1)
        # self.relu3 = torch.nn.ReLU()
        # self.layer4 = torch.nn.Linear(2, 1)

    def forward(self, x):
        x_weighted = self.layer1(x)
        h1 = self.relu1(x_weighted)
        h1_weighted = self.layer2(h1)
        h2 = self.relu2(h1_weighted)
        # h2_weighted = self.layer3(h2)
        # h3 = self.relu3(h2_weighted)
        state_utility = self.layer3(h2)
        return state_utility

state_util_model = StateUtility().to(device)
state_util_adam = torch.optim.Adam(params=state_util_model.parameters(), lr=1e-1)

For the state utilities network, we just use L1 loss with the gradients of W with respect to state utility.

$\Large (G_t - \hat{\upsilon}(S_t, W)) \nabla \hat{\upsilon}(S_t, W)$

Like above, the code below just worries about the loss and not the gradient, as PyTorch provides autograd differntiation.

In [None]:
def state_util_loss(calculated_state_value, episode_state_value):
    # the overall state value is the input, and the individual state value is our target
    l1_loss = nn.L1Loss(reduction='none')
    return l1_loss(calculated_state_value, episode_state_value)


Let's define our hyperparameters

Having limitied compute (and longer runtimes) led to me to reflect on gamma tuning. I initially noticed that having a higher discount factor (smaller gamma) improved values a lot for this Lunar Lander task. My first thought was that it's a reflection of the fact that this task has very well-defined rewards that are frequent and that reflect short-term actions.

That's definitely true, but as I reflected on it, it also occurred to me that using a higher discount factor is, in general, a tradeoff. With a higher discount factor you are hamstringing your ability to learn long-term dependencies, but you can learn action's values much faster. However with a very low discount factor you can actually still theoretically learn actions fine grained values and not conflate them, but it just takes a lot more training examples as the Monte Carlo nature of sampling will eventually lead to distinction (my models with higher gammas seemed to still be learning when they terminated). And with this lower factor you can also learn long-term dependencies fairly easy.

I interpreted this as meaning that I should leave gamma a little higher than it's best value.

In [None]:
# gamma = .4

gamma = .05

Let's load the Lunar Lander environment now

In [None]:
# torch.autograd.set_detect_anomaly(mode=True)

In [None]:
# # TODO: use a custom dataloader class and see if speed up

# env = gym.make(
#     "LunarLander-v2",
#     #render_mode="human",
#     enable_wind=False,
# )

# env_2 = gym.make(
#     "LunarLander-v2",
#     render_mode="human",
#     enable_wind=False,
    
# )

# num_of_actions = 4

# action_space_seed = np.random.seed(13)

# observation, info = env.reset(seed=13)

# episodes_total_rewards_sums = []
# # for debug of state-value funtion
# episode_total_state_err = []

# observations = []
# # NOTE: rewards[0] corresponds to the result of calc_reward(state_of(observations[0]), action_indices[0])
# # thus len(rewards) == len(action_indices) == len(observations) - 1
# # i.e. no reward for the first timestep, no action_index for the last timestep
# rewards = []
# action_indices = []
# action_logits_per_ep = []
# # for debug of state-value funtion
# state_preds = []
# state_err = []

# policy_adam.zero_grad()
# state_util_adam.zero_grad()

# #warmup, policy frozen
# for timestep in range(1300000):

#     if timestep==0 or timestep==900000:
#         print('debug entry')
    
#     # use policy gradient to get action probabilities; sample stochastically
#     action_logits = policy_model(torch.tensor(observation, device=device, dtype=torch.float32))
#     with torch.no_grad():
#         action_logits_per_ep.append(action_logits.detach().clone())
#         action_probs = torch.nn.functional.softmax(action_logits, dim=0)
#         action_sampling = torch.multinomial(action_probs, 1)
#         action = action_sampling.item()
#         action_indices.append(action)
    
#     observations.append(observation)
#     # get info from environment
#     observation, reward, terminated, truncated, info = env.step(action)
#     rewards.append(reward)
    
#     # end of episode
#     if terminated or truncated:
#         observations.append(observation)
#         ep_length = len(observations[:-1]) # Do not take the terminal state as we have no action in the terminal state
#         ep_total_rewards_sum = np.sum(np.array(rewards))
#         episodes_total_rewards_sums.append(ep_total_rewards_sum)
#         returns = np.zeros(len(observations) - 1)
#         for timestep in reversed(range(ep_length)):

#             # calculate state's actual return by looking at reward + future rewards
#             terminal = timestep == len(rewards) - 1
#             returns[timestep] = rewards[timestep] + (gamma * returns[timestep+1]) if not terminal else rewards[timestep]

#         with torch.no_grad():
#             actual_state_util = torch.zeros((len(returns), 1), device=device)
#             for i, actual_util in enumerate(returns):
#                 actual_state_util[i] = torch.tensor(returns[i], device=device)
#             # calculate baseline expected state value
#             input_state_util = torch.zeros((len(observations)-1, len(observation)), device=device)
#             for i, input_samples in enumerate(observations[:-1]):
#                 input_state_util[i] = torch.tensor(observations[i], device=device)
#         pred_state_util = state_util_model(input_state_util)
#         loss_state_utility = state_util_loss(pred_state_util, actual_state_util)
#         with torch.no_grad():
#             state_pred_err = np.abs(loss_state_utility.detach().clone().mean().item())
#             state_preds.append(pred_state_util.detach().clone())
#             state_err.append(state_pred_err)
        
#         with torch.no_grad():
#             state_util_differences = []
#             for timestep in range(ep_length):
#                 # make updates to policy (specific action) based on return
#                 # get the state's return minus the baseline (predicted state return)
#                 state_util_differences.append(actual_state_util.detach().clone()[timestep] - pred_state_util.detach().clone()[timestep])
        
#             # TODO: change to store rather than recomputation, but without autograd complaining about inplace operations
#             # e.g. putting the tensor in a list
#         # with torch.no_grad():
#         #     input_policy = torch.zeros((len(observations)-1, len(observation)), device=device)
#         #     for i, input_samples in enumerate(observations[:-1]):
#         #         input_policy[i] = torch.tensor(observations[i], device=device)
#         #     actions_chosen_tensor = torch.zeros((len(action_indices), num_of_actions), device=device)
#         #     for i, action_index in enumerate(action_indices):
#         #         actions_chosen_tensor[i][action_index] = 1
#         #     state_util_diffs_tensor = torch.tensor(state_util_differences, device=device)
#         # recomputed_policy = policy_model(input_state_util)
#         # loss_policy = policy_loss(recomputed_policy, actions_chosen_tensor, state_util_diffs_tensor)

#         episode_total_state_err.append(np.sum(np.array(state_err)))

#         # accumulate, avg, and add gradients to parameters for state value network
#         loss_state_utility.mean().backward()
#         state_util_adam.step()
#         state_util_adam.zero_grad()
    
#         # accumulate, avg, and add gradients to parameters for policy network
#         # loss_policy.mean().backward()
#         # policy_adam.step()
#         # policy_adam.zero_grad()

#         observation, info = env.reset()
#         observations, rewards, action_indices, action_logits_per_ep = [], [], [], []
#         state_err, state_preds = [], []






# print(f'The avg. state val prediction error on the first quarter of episodes was: {np.sum(episode_total_state_err[:len(episode_total_state_err)//4]) / (len(episode_total_state_err)/4)}')
# print(f'The avg. state val prediction error on the second quarter of episodes was: {np.sum(episode_total_state_err[len(episode_total_state_err)//4:2 * len(episode_total_state_err)//4]) / (len(episode_total_state_err)/4)}')
# print(f'The avg. state val prediction error on the third quarter of episodes was: {np.sum(episode_total_state_err[2 * len(episode_total_state_err)//4:3 *len(episode_total_state_err)//4]) / (len(episode_total_state_err)/4)}')
# print(f'The avg. state val prediction error on the fourth quarter of episodes was: {np.sum(episode_total_state_err[3 *len(episode_total_state_err)//4:len(episode_total_state_err)]) / (len(episode_total_state_err)/4)}')





# episodes_total_rewards_sums = []
# ep_total_rewards = []
# # for debug of state-value funtion
# episode_total_state_err = []
# action_logits_episodes = []
# observations, rewards, returns, action_indices, action_logits_per_ep = [], [], [], [], []
# state_err, state_preds = [], []
# observation, info = env.reset()

# # state_util_adam = torch.optim.Adam(params=state_util_model.parameters(), lr=1e-3, weight_decay=)

# # everything unfrozen
# for timestep in range(1000000):

#     if timestep==0 or timestep==900000:
#         print('debug entry')
    
#     # use policy gradient to get action probabilities; sample stochastically
#     action_logits = policy_model(torch.tensor(observation, device=device, dtype=torch.float32))
#     with torch.no_grad():
#         action_logits_per_ep.append(action_logits.detach().clone())
#         action_probs = torch.nn.functional.softmax(action_logits, dim=0)
#         action_sampling = torch.multinomial(action_probs, 1)
#         action = action_sampling.item()
#         action_indices.append(action)
    
#     observations.append(observation)
#     # get info from environment
#     observation, reward, terminated, truncated, info = env.step(action)
#     rewards.append(reward)
    
#     # end of episode
#     if terminated or truncated:
#         observations.append(observation)
#         ep_length = len(observations[:-1]) # Do not take the terminal state as we have no action in the terminal state
#         ep_total_rewards_sum = np.sum(np.array(rewards))
#         ep_total_rewards.append(rewards)
#         episodes_total_rewards_sums.append(ep_total_rewards_sum)
#         returns = np.zeros(len(observations) - 1)
#         for timestep in reversed(range(ep_length)):

#             # calculate state's actual return by looking at reward + future rewards
#             terminal = timestep == len(rewards) - 1
#             returns[timestep] = rewards[timestep] + (gamma * returns[timestep+1]) if not terminal else rewards[timestep]

#         with torch.no_grad():
#             actual_state_util = torch.zeros((len(returns), 1), device=device)
#             for i, actual_util in enumerate(returns):
#                 actual_state_util[i] = torch.tensor(returns[i], device=device)
#             # calculate baseline expected state value
#             input_state_util = torch.zeros((len(observations)-1, len(observation)), device=device)
#             for i, input_samples in enumerate(observations[:-1]):
#                 input_state_util[i] = torch.tensor(observations[i], device=device)
#         pred_state_util = state_util_model(input_state_util)
#         loss_state_utility = state_util_loss(pred_state_util, actual_state_util)
        
#         # some extra info helpful for debug
#         with torch.no_grad():
#             state_pred_err = np.abs(loss_state_utility.detach().clone().mean().item())
#             state_preds.append(pred_state_util.detach().clone())
#             state_err.append(state_pred_err)
#             state_util_differences = []
#             for timestep in range(ep_length):
#                 # make updates to policy (specific action) based on return
#                 # get the state's return minus the baseline (predicted state return)
#                 state_util_differences.append(actual_state_util.detach().clone()[timestep] - pred_state_util.detach().clone()[timestep])
        
#             # TODO: change to store rather than recomputation, but without autograd complaining about inplace operations
#             # e.g. putting the tensor in a list
#         with torch.no_grad():
#             input_policy = torch.zeros((len(observations)-1, len(observation)), device=device)
#             for i, input_samples in enumerate(observations[:-1]):
#                 input_policy[i] = torch.tensor(observations[i], device=device)
#             actions_chosen_tensor = torch.zeros((len(action_indices), num_of_actions), device=device)
#             for i, action_index in enumerate(action_indices):
#                 actions_chosen_tensor[i][action_index] = 1
#             state_util_diffs_tensor = torch.tensor(state_util_differences, device=device)
#         recomputed_policy = policy_model(input_state_util)
#         loss_policy = policy_loss(recomputed_policy, actions_chosen_tensor, state_util_diffs_tensor)

#         episode_total_state_err.append(np.sum(np.array(state_err)))

#         # accumulate, avg, and add gradients to parameters for state value network
#         loss_state_utility.mean().backward()
#         state_util_adam.step()
#         state_util_adam.zero_grad()
    
#         # accumulate, avg, and add gradients to parameters for policy network
#         loss_policy.mean().backward()
#         policy_adam.step()
#         policy_adam.zero_grad()

#         observation, info = env.reset()
#         action_logits_episodes.append(action_logits_per_ep)
#         observations, rewards, action_indices, action_logits_per_ep = [], [], [], []
#         state_err, state_preds = [], []

# # TODO: move these to a self-contained function

# print(f'The avg. state val prediction error on the first quarter of episodes was: {np.sum(episode_total_state_err[:len(episode_total_state_err)//4]) / (len(episode_total_state_err)/4)}')
# print(f'The avg. state val prediction error on the second quarter of episodes was: {np.sum(episode_total_state_err[len(episode_total_state_err)//4:2 * len(episode_total_state_err)//4]) / (len(episode_total_state_err)/4)}')
# print(f'The avg. state val prediction error on the third quarter of episodes was: {np.sum(episode_total_state_err[2 * len(episode_total_state_err)//4:3 *len(episode_total_state_err)//4]) / (len(episode_total_state_err)/4)}')
# print(f'The avg. state val prediction error on the fourth quarter of episodes was: {np.sum(episode_total_state_err[3 *len(episode_total_state_err)//4:len(episode_total_state_err)]) / (len(episode_total_state_err)/4)}')

# print(f'The avg. episode reward on the first quarter of episodes was: {np.sum(episodes_total_rewards_sums[:len(episodes_total_rewards_sums)//4]) / (len(episodes_total_rewards_sums)//4)}')
# print(f'The avg. episode reward on the second quarter of episodes was: {np.sum(episodes_total_rewards_sums[len(episodes_total_rewards_sums)//4:2 * len(episodes_total_rewards_sums)//4]) / (len(episodes_total_rewards_sums)/4)}')
# print(f'The avg. episode reward on the third quarter of episodes was: {np.sum(episodes_total_rewards_sums[2 * len(episodes_total_rewards_sums)//4:3 *len(episodes_total_rewards_sums)//4]) / (len(episodes_total_rewards_sums)/4)}')
# print(f'The avg. episode reward on the fourth quarter of episodes was: {np.sum(episodes_total_rewards_sums[3 *len(episodes_total_rewards_sums)//4:len(episodes_total_rewards_sums)]) / (len(episodes_total_rewards_sums)/4)}')


# observation, info = env_2.reset(seed=45)

# observations, rewards, returns, action_indices, action_logits_per_ep = [], [], [], [], []

# episodes_total_rewards = []

# # just to see what the model has learnt visually
# for timestep in range(1000000):

#     if timestep==0 or timestep==900000:
#         print('debug entry')
    
#     # use policy gradient to get action probabilities; sample stochastically
#     action_logits = policy_model(torch.tensor(observation, device=device, dtype=torch.float32))
#     with torch.no_grad():
#         action_logits_per_ep.append(action_logits.detach().clone())
#         action_probs = torch.nn.functional.softmax(action_logits, dim=0)
#         action_sampling = torch.multinomial(action_probs, 1)
#         action = action_sampling.item()
#         action_indices.append(action)
    
#     observations.append(observation)
#     # get info from environment
#     observation, reward, terminated, truncated, info = env_2.step(action)
#     rewards.append(reward)
    
#     # end of episode
#     if terminated or truncated:
#         observations.append(observation)
#         ep_length = len(observations[:-1]) # Do not take the terminal state as we have no action in the terminal state
#         ep_total_reward = np.sum(np.array(rewards))
#         episodes_total_rewards.append(ep_total_reward)
#         returns = np.zeros(len(observations) - 1)
#         for timestep in reversed(range(ep_length)):

#             # calculate state's actual return by looking at reward + future rewards
#             terminal = timestep == len(rewards) - 1
#             returns[timestep] = rewards[timestep] + (gamma * returns[timestep+1]) if not terminal else rewards[timestep]

#         with torch.no_grad():
#             actual_state_util = torch.zeros((len(returns), 1), device=device)
#             for i, actual_util in enumerate(returns):
#                 actual_state_util[i] = torch.tensor(returns[i], device=device)
#             # calculate baseline expected state value
#             input_state_util = torch.zeros((len(observations)-1, len(observation)), device=device)
#             for i, input_samples in enumerate(observations[:-1]):
#                 input_state_util[i] = torch.tensor(observations[i], device=device)
#         pred_state_util = state_util_model(input_state_util)
#         loss_state_utility = state_util_loss(pred_state_util, actual_state_util)
        
#         # some extra info helpful for debug
#         with torch.no_grad():
#             state_pred_err = np.abs(loss_state_utility.detach().clone().mean().item())
#             state_preds.append(pred_state_util.detach().clone())
#             state_err.append(state_pred_err)
#             state_util_differences = []
#             for timestep in range(ep_length):
#                 # make updates to policy (specific action) based on return
#                 # get the state's return minus the baseline (predicted state return)
#                 state_util_differences.append(actual_state_util.detach().clone()[timestep] - pred_state_util.detach().clone()[timestep])
        
#             # TODO: change to store rather than recomputation, but without autograd complaining about inplace operations
#             # e.g. putting the tensor in a list
#         with torch.no_grad():
#             input_policy = torch.zeros((len(observations)-1, len(observation)), device=device)
#             for i, input_samples in enumerate(observations[:-1]):
#                 input_policy[i] = torch.tensor(observations[i], device=device)
#             actions_chosen_tensor = torch.zeros((len(action_indices), num_of_actions), device=device)
#             for i, action_index in enumerate(action_indices):
#                 actions_chosen_tensor[i][action_index] = 1
#             state_util_diffs_tensor = torch.tensor(state_util_differences, device=device)
#         recomputed_policy = policy_model(input_state_util)
#         loss_policy = policy_loss(recomputed_policy, actions_chosen_tensor, state_util_diffs_tensor)

#         episode_total_state_err.append(np.sum(np.array(state_err)))

#         # accumulate, avg, and add gradients to parameters for state value network
#         loss_state_utility.sum().backward()
#         state_util_adam.step()
#         state_util_adam.zero_grad()
    
#         # accumulate, avg, and add gradients to parameters for policy network
#         loss_policy.sum().backward()
#         policy_adam.step()
#         policy_adam.zero_grad()

#         observation, info = env_2.reset()
#         observations, rewards, action_indices, action_logits_per_ep = [], [], [], []
#         state_err, state_preds = [], []


# env_2.close()

Non-batched below

In [None]:
# TODO: use a custom dataloader class and see if speed up

env = gym.make(
    "LunarLander-v2",
    #render_mode="human",
    enable_wind=False,
)

env_2 = gym.make(
    "LunarLander-v2",
    render_mode="human",
    enable_wind=False,
    
)

num_of_actions = 4

action_space_seed = np.random.seed(13)

observation, info = env.reset(seed=13)

episodes_total_rewards_sums = []
# for debug of state-value funtion
episode_total_state_err = []

observations = []
# NOTE: rewards[0] corresponds to the result of calc_reward(state_of(observations[0]), action_indices[0])
# thus len(rewards) == len(action_indices) == len(observations) - 1
# i.e. no reward for the first timestep, no action_index for the last timestep
rewards = []
action_indices = []
action_logits_per_ep = []
# for debug of state-value funtion
state_preds = []
state_err = []

policy_adam.zero_grad()
state_util_adam.zero_grad()

#warmup, policy frozen
for timestep in range(1300000):

    if timestep==0 or timestep==900000:
        print('debug entry')
    
    # use policy gradient to get action probabilities; sample stochastically
    action_logits = policy_model(torch.tensor(observation, device=device, dtype=torch.float32))
    with torch.no_grad():
        action_logits_per_ep.append(action_logits.detach().clone())
        action_probs = torch.nn.functional.softmax(action_logits, dim=0)
        action_sampling = torch.multinomial(action_probs, 1)
        action = action_sampling.item()
        action_indices.append(action)
    
    observations.append(observation)
    # get info from environment
    observation, reward, terminated, truncated, info = env.step(action)
    rewards.append(reward)
    
    # end of episode
    if terminated or truncated:
        observations.append(observation)
        ep_length = len(observations[:-1]) # Do not take the terminal state as we have no action in the terminal state
        ep_total_rewards_sum = np.sum(np.array(rewards))
        episodes_total_rewards_sums.append(ep_total_rewards_sum)
        returns = np.zeros(len(observations) - 1)
        for timestep in reversed(range(ep_length)):

            # calculate state's actual return by looking at reward + future rewards
            terminal = timestep == len(rewards) - 1
            returns[timestep] = rewards[timestep] + (gamma * returns[timestep+1]) if not terminal else rewards[timestep]

        with torch.no_grad():
            actual_state_util = torch.zeros((len(returns), 1), device=device)
            for i, actual_util in enumerate(returns):
                actual_state_util[i] = torch.tensor(returns[i], device=device)
            # calculate baseline expected state value
            input_state_util = torch.zeros((len(observations)-1, len(observation)), device=device)
            for i, input_samples in enumerate(observations[:-1]):
                input_state_util[i] = torch.tensor(observations[i], device=device)
        pred_state_util = state_util_model(input_state_util)
        loss_state_utility = state_util_loss(pred_state_util, actual_state_util)
        with torch.no_grad():
            state_pred_err = np.abs(loss_state_utility.detach().clone().mean().item())
            state_preds.append(pred_state_util.detach().clone())
            state_err.append(state_pred_err)
        
        with torch.no_grad():
            state_util_differences = []
            for timestep in range(ep_length):
                # make updates to policy (specific action) based on return
                # get the state's return minus the baseline (predicted state return)
                state_util_differences.append(actual_state_util.detach().clone()[timestep] - pred_state_util.detach().clone()[timestep])
        
            # TODO: change to store rather than recomputation, but without autograd complaining about inplace operations
            # e.g. putting the tensor in a list
        # with torch.no_grad():
        #     input_policy = torch.zeros((len(observations)-1, len(observation)), device=device)
        #     for i, input_samples in enumerate(observations[:-1]):
        #         input_policy[i] = torch.tensor(observations[i], device=device)
        #     actions_chosen_tensor = torch.zeros((len(action_indices), num_of_actions), device=device)
        #     for i, action_index in enumerate(action_indices):
        #         actions_chosen_tensor[i][action_index] = 1
        #     state_util_diffs_tensor = torch.tensor(state_util_differences, device=device)
        # recomputed_policy = policy_model(input_state_util)
        # loss_policy = policy_loss(recomputed_policy, actions_chosen_tensor, state_util_diffs_tensor)

        episode_total_state_err.append(np.sum(np.array(state_err)))

        # accumulate, avg, and add gradients to parameters for state value network
        loss_state_utility.mean().backward()
        state_util_adam.step()
        state_util_adam.zero_grad()
    
        # accumulate, avg, and add gradients to parameters for policy network
        # loss_policy.mean().backward()
        # policy_adam.step()
        # policy_adam.zero_grad()

        observation, info = env.reset()
        observations, rewards, action_indices, action_logits_per_ep = [], [], [], []
        state_err, state_preds = [], []






print(f'The avg. state val prediction error on the first quarter of episodes was: {np.sum(episode_total_state_err[:len(episode_total_state_err)//4]) / (len(episode_total_state_err)/4)}')
print(f'The avg. state val prediction error on the second quarter of episodes was: {np.sum(episode_total_state_err[len(episode_total_state_err)//4:2 * len(episode_total_state_err)//4]) / (len(episode_total_state_err)/4)}')
print(f'The avg. state val prediction error on the third quarter of episodes was: {np.sum(episode_total_state_err[2 * len(episode_total_state_err)//4:3 *len(episode_total_state_err)//4]) / (len(episode_total_state_err)/4)}')
print(f'The avg. state val prediction error on the fourth quarter of episodes was: {np.sum(episode_total_state_err[3 *len(episode_total_state_err)//4:len(episode_total_state_err)]) / (len(episode_total_state_err)/4)}')





episodes_total_rewards_sums = []
ep_total_rewards = []
# for debug of state-value funtion
episode_total_state_err = []
action_logits_episodes = []
observations, rewards, returns, action_indices, action_logits_per_ep = [], [], [], [], []
state_err, state_preds = [], []
observation, info = env.reset()

# state_util_adam = torch.optim.Adam(params=state_util_model.parameters(), lr=1e-3, weight_decay=)

# everything unfrozen
for timestep in range(1000000):

    if timestep==0 or timestep==900000:
        print('debug entry')
    
    # use policy gradient to get action probabilities; sample stochastically
    action_logits = policy_model(torch.tensor(observation, device=device, dtype=torch.float32))
    with torch.no_grad():
        action_logits_per_ep.append(action_logits.detach().clone())
        action_probs = torch.nn.functional.softmax(action_logits, dim=0)
        action_sampling = torch.multinomial(action_probs, 1)
        action = action_sampling.item()
        action_indices.append(action)
    
    observations.append(observation)
    # get info from environment
    observation, reward, terminated, truncated, info = env.step(action)
    rewards.append(reward)
    
    # end of episode
    if terminated or truncated:
        observations.append(observation)
        ep_length = len(observations[:-1]) # Do not take the terminal state as we have no action in the terminal state
        ep_total_rewards_sum = np.sum(np.array(rewards))
        ep_total_rewards.append(rewards)
        episodes_total_rewards_sums.append(ep_total_rewards_sum)
        returns = np.zeros(len(observations) - 1)
        for timestep in reversed(range(ep_length)):

            # calculate state's actual return by looking at reward + future rewards
            terminal = timestep == len(rewards) - 1
            returns[timestep] = rewards[timestep] + (gamma * returns[timestep+1]) if not terminal else rewards[timestep]

        # with torch.no_grad():
        #     actual_state_util = torch.zeros((len(returns), 1), device=device)
        for i, actual_util in enumerate(returns):
        #         actual_state_util[i] = torch.tensor(returns[i], device=device)
            actual_state_util = torch.tensor(returns[i], device=device)
            # calculate baseline expected state value
            # input_state_util = torch.zeros((len(observations)-1, len(observation)), device=device)
            # for i, input_samples in enumerate(observations[:-1]):
            #     input_state_util[i] = torch.tensor(observations[i], device=device)
            input_state_util = torch.tensor(observations[i], device=device)
            pred_state_util = state_util_model(input_state_util)
            loss_state_utility = state_util_loss(pred_state_util, actual_state_util)
        
        # some extra info helpful for debug
            with torch.no_grad():
        #     state_pred_err = np.abs(loss_state_utility.detach().clone().mean().item())
        #     state_preds.append(pred_state_util.detach().clone())
        #     state_err.append(state_pred_err)
                state_util_differences = []
        #     for timestep in range(ep_length):
        #         # make updates to policy (specific action) based on return
        #         # get the state's return minus the baseline (predicted state return)
                state_util_differences.append(actual_state_util.detach().clone() - pred_state_util.detach().clone())
        
            # TODO: change to store rather than recomputation, but without autograd complaining about inplace operations
            # e.g. putting the tensor in a list
            with torch.no_grad():
            # #input_policy = torch.zeros((len(observations)-1, len(observation)), device=device)
            #     for i, input_samples in enumerate(observations[:-1]):
                input_policy = torch.tensor(observations[i], device=device)
                actions_chosen_tensor = torch.zeros(4, device=device)
                actions_chosen_tensor[action_indices[timestep]] = 1
                state_util_diffs_tensor = torch.tensor(state_util_differences[timestep], device=device)
            recomputed_policy = policy_model(input_state_util)
            loss_policy = policy_loss(recomputed_policy, actions_chosen_tensor, state_util_diffs_tensor)

            loss_policy.mean().backward()
            policy_adam.step()
            policy_adam.zero_grad()

            loss_state_utility.mean().backward()
            state_util_adam.step()
            state_util_adam.zero_grad()



        episode_total_state_err.append(np.sum(np.array(state_err)))

    
        # # accumulate, avg, and add gradients to parameters for policy network
        # loss_policy.mean().backward()
        # policy_adam.step()
        # policy_adam.zero_grad()

        # accumulate, avg, and add gradients to parameters for state value network
        # loss_state_utility.mean().backward()
        # state_util_adam.step()
        # state_util_adam.zero_grad()

        observation, info = env.reset()
        action_logits_episodes.append(action_logits_per_ep)
        observations, rewards, action_indices, action_logits_per_ep = [], [], [], []
        state_err, state_preds = [], []

# TODO: move these to a self-contained function

print(f'The avg. state val prediction error on the first quarter of episodes was: {np.sum(episode_total_state_err[:len(episode_total_state_err)//4]) / (len(episode_total_state_err)/4)}')
print(f'The avg. state val prediction error on the second quarter of episodes was: {np.sum(episode_total_state_err[len(episode_total_state_err)//4:2 * len(episode_total_state_err)//4]) / (len(episode_total_state_err)/4)}')
print(f'The avg. state val prediction error on the third quarter of episodes was: {np.sum(episode_total_state_err[2 * len(episode_total_state_err)//4:3 *len(episode_total_state_err)//4]) / (len(episode_total_state_err)/4)}')
print(f'The avg. state val prediction error on the fourth quarter of episodes was: {np.sum(episode_total_state_err[3 *len(episode_total_state_err)//4:len(episode_total_state_err)]) / (len(episode_total_state_err)/4)}')

print(f'The avg. episode reward on the first quarter of episodes was: {np.sum(episodes_total_rewards_sums[:len(episodes_total_rewards_sums)//4]) / (len(episodes_total_rewards_sums)//4)}')
print(f'The avg. episode reward on the second quarter of episodes was: {np.sum(episodes_total_rewards_sums[len(episodes_total_rewards_sums)//4:2 * len(episodes_total_rewards_sums)//4]) / (len(episodes_total_rewards_sums)/4)}')
print(f'The avg. episode reward on the third quarter of episodes was: {np.sum(episodes_total_rewards_sums[2 * len(episodes_total_rewards_sums)//4:3 *len(episodes_total_rewards_sums)//4]) / (len(episodes_total_rewards_sums)/4)}')
print(f'The avg. episode reward on the fourth quarter of episodes was: {np.sum(episodes_total_rewards_sums[3 *len(episodes_total_rewards_sums)//4:len(episodes_total_rewards_sums)]) / (len(episodes_total_rewards_sums)/4)}')


observation, info = env_2.reset(seed=45)

observations, rewards, returns, action_indices, action_logits_per_ep = [], [], [], [], []

episodes_total_rewards = []

# just to see what the model has learnt visually
for timestep in range(1000000):

    if timestep==0 or timestep==900000:
        print('debug entry')
    
    # use policy gradient to get action probabilities; sample stochastically
    action_logits = policy_model(torch.tensor(observation, device=device, dtype=torch.float32))
    with torch.no_grad():
        action_logits_per_ep.append(action_logits.detach().clone())
        action_probs = torch.nn.functional.softmax(action_logits, dim=0)
        action_sampling = torch.multinomial(action_probs, 1)
        action = action_sampling.item()
        action_indices.append(action)
    
    observations.append(observation)
    # get info from environment
    observation, reward, terminated, truncated, info = env_2.step(action)
    rewards.append(reward)
    
    # end of episode
    if terminated or truncated:
        observations.append(observation)
        ep_length = len(observations[:-1]) # Do not take the terminal state as we have no action in the terminal state
        ep_total_reward = np.sum(np.array(rewards))
        episodes_total_rewards.append(ep_total_reward)
        returns = np.zeros(len(observations) - 1)
        for timestep in reversed(range(ep_length)):

            # calculate state's actual return by looking at reward + future rewards
            terminal = timestep == len(rewards) - 1
            returns[timestep] = rewards[timestep] + (gamma * returns[timestep+1]) if not terminal else rewards[timestep]

        with torch.no_grad():
            actual_state_util = torch.zeros((len(returns), 1), device=device)
            for i, actual_util in enumerate(returns):
                actual_state_util[i] = torch.tensor(returns[i], device=device)
            # calculate baseline expected state value
            input_state_util = torch.zeros((len(observations)-1, len(observation)), device=device)
            for i, input_samples in enumerate(observations[:-1]):
                input_state_util[i] = torch.tensor(observations[i], device=device)
        pred_state_util = state_util_model(input_state_util)
        loss_state_utility = state_util_loss(pred_state_util, actual_state_util)
        
        # some extra info helpful for debug
        with torch.no_grad():
            state_pred_err = np.abs(loss_state_utility.detach().clone().mean().item())
            state_preds.append(pred_state_util.detach().clone())
            state_err.append(state_pred_err)
            state_util_differences = []
            for timestep in range(ep_length):
                # make updates to policy (specific action) based on return
                # get the state's return minus the baseline (predicted state return)
                state_util_differences.append(actual_state_util.detach().clone()[timestep] - pred_state_util.detach().clone()[timestep])
        
            # TODO: change to store rather than recomputation, but without autograd complaining about inplace operations
            # e.g. putting the tensor in a list
        with torch.no_grad():
            input_policy = torch.zeros((len(observations)-1, len(observation)), device=device)
            for i, input_samples in enumerate(observations[:-1]):
                input_policy[i] = torch.tensor(observations[i], device=device)
            actions_chosen_tensor = torch.zeros((len(action_indices), num_of_actions), device=device)
            for i, action_index in enumerate(action_indices):
                actions_chosen_tensor[i][action_index] = 1
            state_util_diffs_tensor = torch.tensor(state_util_differences, device=device)
        recomputed_policy = policy_model(input_state_util)
        loss_policy = policy_loss(recomputed_policy, actions_chosen_tensor, state_util_diffs_tensor)

        episode_total_state_err.append(np.sum(np.array(state_err)))

        # accumulate, avg, and add gradients to parameters for state value network
        loss_state_utility.sum().backward()
        state_util_adam.step()
        state_util_adam.zero_grad()
    
        # accumulate, avg, and add gradients to parameters for policy network
        loss_policy.sum().backward()
        policy_adam.step()
        policy_adam.zero_grad()

        observation, info = env_2.reset()
        observations, rewards, action_indices, action_logits_per_ep = [], [], [], []
        state_err, state_preds = [], []


env_2.close()

In [None]:
print(torch.__version__)
print(gym.__version__)