# Let's code a vanilla policy gradient algo

In [8]:
# Imports
import sys
sys.path.append('..')
from game import Game

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim



In [None]:
class vpg:
    def __init__(self, policy_dim, value_dim):
        self.policy = [0] * policy_dim
        self.env = Game()
        self.num_rollouts = 1
        self.gamma = 0.99
    def fit(self, iters):
        for k in range(iters):
            trajectories = self.rollout(self.num_rollouts)
            returns = self.compute_returns(trajectories)
            contributions = self.compute_contributions(trajectories, returns)

    def update_policy(self, trajectories, contributions) -> torch.Tensor:
        pass

    def compute_contributions(self, trajectories, returns) -> torch.Tensor:
        batch_size, trajectory_length = trajectories['rewards'].shape
        contributions = torch.zeros(batch_size, trajectory_length)

        for b in range(batch_size):
            for t in range(trajectory_length):
                a, s = trajectories['actions'][b, t], trajectories['states'][b, t]
                contributions[b, t] = torch.log(self.policy(s)[a]) * returns[b, t]
        return contributions


    def compute_returns(self, trajectories) -> torch.Tensor:
        rewards = trajectories['rewards']  # shape: (batch_size, trajectory_length)
        batch_size, trajectory_length = rewards.shape

        # Initialize the returns tensor
        returns = torch.zeros_like(rewards)

        # Last step in each trajectory is just the final reward
        returns[:, -1] = rewards[:, -1]

        # Work backwards through time for the entire batch in one go
        for t in reversed(range(trajectory_length - 1)):
            returns[:, t] = rewards[:, t] + self.gamma * returns[:, t + 1]

        return returns


    def rollout(self, num_rollouts) -> dict:
        # Initialize lists to store trajectories
        trajectories = {
            'states': [],
            'actions': [], 
            'rewards': []
        }
        
        # Track max length for padding
        max_length = 0
        
        # First pass: collect trajectories and find max length
        raw_trajectories = []
        for _ in range(num_rollouts):
            # Lists for this specific rollout
            states, actions, rewards = [], [], []
            state = self.env.reset()
            
            while not self.env.is_terminal():
                # Convert state tuple to tensor
                state_tensor = torch.tensor(state, dtype=torch.float32)
                states.append(state_tensor)
                
                # Sample action and store
                action = self.policy.sample(state_tensor)
                actions.append(action)
                
                # Take action in environment
                self.env.make_move(action)
                reward = self.env.get_reward()
                rewards.append(reward)
                
                # Update state
                state = self.env.copy_state()
            
            max_length = max(max_length, len(states))
            raw_trajectories.append((states, actions, rewards))
        
        # Second pass: pad and store trajectories
        for states, actions, rewards in raw_trajectories:
            # Calculate padding needed
            pad_length = max_length - len(states)
            
            # Pad states with zeros
            padded_states = states + [torch.zeros_like(states[0])] * pad_length
            padded_states = torch.stack(padded_states)
            
            # Pad actions and rewards with zeros
            padded_actions = actions + [0] * pad_length
            padded_rewards = rewards + [0] * pad_length
            
            # Store padded trajectories
            trajectories['states'].append(padded_states)
            trajectories['actions'].append(torch.tensor(padded_actions))
            trajectories['rewards'].append(torch.tensor(padded_rewards))
        
        # Convert lists to tensors before returning
        trajectories['states'] = torch.stack(trajectories['states'])
        trajectories['actions'] = torch.stack(trajectories['actions'])
        trajectories['rewards'] = torch.stack(trajectories['rewards'])
            
        return trajectories