In [2]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.categorical import Categorical
from torch.optim import Adam
torch.manual_seed(42)


import gymnasium as gym
import numpy as np



In [3]:
env = gym.make('CartPole-v1')

In [None]:
# for actor

# sizes is the size of the input layer, generated like so:   logits_net = mlp(sizes = [obs_dim] + hidden_sizes + [n_acts])
def mlp(sizes, activation = nn.Tanh, output_activation = nn.Identity):
    layers = []
    for i in range(len(sizes) - 1):
        act = activation if i < len(sizes) -2 else output_activation #everything but last layer has activation, outherwise output
        layers += [nn.Linear(sizes [i], sizes[i +1], act())]
    return nn.Sequential(*layers) # * unpcks the list


In [None]:
actor = 

In [None]:
critic = mlp([obs_dim, 32, 1], activation=nn.ReLU, output_activation=nn.Identity)

In [4]:
def train(env_name = env, hidden_sizes = [64], lr = 1e-2, epochs = 50, batch_size = 50, render = False):

    env = env_name

    obs_dim = env.observation_space.shape[0]  #+1 for the visited bit, change depending on wrapper
    # print("obs_dim:", obs_dim)
    n_acts = env.action_space.n
    # print("n_acts:", n_acts)

    



    #generate polucy network
    actor = mlp(sizes = [obs_dim] + hidden_sizes + [n_acts])


    critic = mlp([obs_dim, 32, 1], activation=nn.ReLU, output_activation=nn.Identity)

    #takes policy network and returns action distribution
    def get_policy(obs):
        logits = actor(obs)
        return Categorical(logits = logits)

    #samples actions from the action distrubution from the policy network
    def get_action(obs):
        return get_policy(obs).sample().item()
    

    # make loss function whose gradient, for the right data, is policy gradient
    def compute_loss(obs, act, weights):
        logp = get_policy(obs).log_prob(act)
        return -(logp * weights).mean()

      # make optimizer
    optimizer = Adam(actor.parameters(), lr=lr)

    def train_one_epoch():
        # make some empty lists for logging.
        batch_obs = []          # for observations
        batch_acts = []         # for actions
        batch_weights = []      # for R(tau) weighting in policy gradient
        batch_rets = []         # for measuring episode returns
        batch_lens = []         # for measuring episode lengths

        # reset episode-specific variables
        obs, info = env.reset()      # first obs comes from starting distribution
        # print("obs shape:", obs.shape, "obs type:", type(obs))
        terminated = False            # signal from environment that episode is over
        ep_rews = []            # list for rewards accrued throughout ep

        # render first episode of each epoch
        finished_rendering_this_epoch = False

        # collect experience by acting in the environment with current policy
        while True:

            # rendering
            if (not finished_rendering_this_epoch) and render:
                env.render()
            # print("obs shape:", obs.shape, "obs type:", type(obs))
            # save obs
            batch_obs.append(obs.copy())

            # act in the environment
            act = get_action(torch.as_tensor(obs, dtype=torch.float32))
            # obs, rew, done, _, _ = env.step(act)
            obs, rew, terminated, truncated, info = env.step(act)
            position = obs[-3:-1].astype(int)  # assuming [img_flat, visited_bit, x, y, dir]
            y, x = position[1], position[0]
            # print(f"reward: {rew}, reward_mask[y, x]: {env.unwrapped.reward_mask[y, x]}")
            
            

            # save action, reward
            batch_acts.append(act)
            ep_rews.append(rew)

            if terminated or truncated:
                # print(f"terminated: {terminated}, truncated: {truncated}")
                # if episode is over, record info about episode
                ep_ret, ep_len = sum(ep_rews), len(ep_rews)
                batch_rets.append(ep_ret)
                batch_lens.append(ep_len)

                # the weight for each logprob(a|s) is R(tau)
                batch_weights += [ep_ret] * ep_len     #why is this the way the setup is, this is where i want to add rewards
                # print(f"step count: {env.unwrapped.step_count}")
                # print(f"env length {ep_len}")
                # reset episode-specific variables
                obs, info  = env.reset()
                terminated = False
                ep_rews = []

                # won't render again this epoch
                finished_rendering_this_epoch = True

                # end experience loop if we have enough of it
                if len(batch_obs) > batch_size:
                    break
        # take a single policy gradient update step
        optimizer.zero_grad()
        batch_loss = compute_loss(obs=torch.as_tensor(batch_obs, dtype=torch.float32),
                                  act=torch.as_tensor(batch_acts, dtype=torch.int32),
                                  weights=torch.as_tensor(batch_weights, dtype=torch.float32)
                                  )
        batch_loss.backward()
        optimizer.step()
        return batch_loss, batch_rets, batch_lens

    # training loop
    for i in range(epochs):
        batch_loss, batch_rets, batch_lens = train_one_epoch()
        print('epoch: %3d \t loss: %.3f \t return: %.3f \t ep_len: %.3f'%
                (i, batch_loss, np.mean(batch_rets), np.mean(batch_lens)))
        
    return actor



