# Let's code a vanilla policy gradient algo

In [8]:
# Imports
import sys
sys.path.append('..')
from game import Game

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim



In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class vpg_network(nn.Module):
    def __init__(self, input_size, hidden_layers, layer_size, output_size):
        super(vpg_network, self).__init__()
        layers = []
        
        layers.append(nn.Linear(input_size, layer_size))
        layers.append(nn.ReLU())
        
        for _ in range(hidden_layers - 1):
            layers.append(nn.Linear(layer_size, layer_size))
            layers.append(nn.ReLU())
        
        layers.append(nn.Linear(layer_size, output_size))
        
        self.model = nn.Sequential(*layers)
        
    def forward(self, x):
        return self.model(x)


class vpg:
    def __init__(self, input_size, output_size, hidden_layers = 1, hidden_size = 64):
        self.policy = vpg_network(input_size, hidden_layers, hidden_size, output_size)
        self.env = Game()
        self.num_rollouts = 1
        self.gamma = 0.99
    def fit(self, iters):
        for k in range(iters):
            trajectories = self.rollout(self.num_rollouts)
            returns = self.compute_returns(trajectories)
            policy_gradient = self.compute_policy_gradient(trajectories, returns)

            # ...

    def compute_policy_gradient(self, trajectories: dict, returns: torch.Tensor) -> torch.Tensor:
        '''
        Trajectories dict: Dictionary containing padded trajectories with keys:
                'states': Tensor of shape (num_rollouts, max_length, state_dim)
                'actions': Tensor of shape (num_rollouts, max_length)
                'rewards': Tensor of shape (num_rollouts, max_length)
        Returns dict: Tensor of shape (num_rollouts, max_length)
        '''
        total_loss = 0
        num_rollouts = trajectories['actions'].shape[0]
        for rollout in range(num_rollouts):
            states = trajectories['states'][rollout]
            actions = trajectories['actions'][rollout]
            curr_returns = returns[rollout]
            # Get the log probabilities of all actions for every timestamp
            log_probs = torch.log(F.softmax(self.policy(states), dim=-1))
            # Only take the probabilbilities of taken actions
            selected_log_probs = log_probs[range(len(actions)), actions]

            # Compute the contribution 
            loss = - (selected_log_probs * curr_returns).mean()
            total_loss += loss

        total_loss /= num_rollouts
        
        self.policy.zero_grad()
        total_loss.backward()




            

        
    def update_policy(self, trajectories, contributions) -> torch.Tensor:
        pass

    def compute_returns(self, trajectories) -> torch.Tensor:
        rewards = trajectories['rewards']  # shape: (batch_size, trajectory_length)
        batch_size, trajectory_length = rewards.shape

        # Initialize the returns tensor
        returns = torch.zeros_like(rewards)

        # Last step in each trajectory is just the final reward
        returns[:, -1] = rewards[:, -1]

        # Work backwards through time for the entire batch in one go
        for t in reversed(range(trajectory_length - 1)):
            returns[:, t] = rewards[:, t] + self.gamma * returns[:, t + 1]

        return returns


    def rollout(self, num_rollouts) -> dict:
        """Collect trajectories by rolling out the policy in the environment.

        Args:
            num_rollouts (int): Number of trajectories to collect

        Returns:
            dict: Dictionary containing padded trajectories with keys:
                'states': Tensor of shape (num_rollouts, max_length, state_dim)
                'actions': Tensor of shape (num_rollouts, max_length)
                'rewards': Tensor of shape (num_rollouts, max_length)
        """
        # Initialize lists to store trajectories
        trajectories = {
            'states': [],
            'actions': [], 
            'rewards': []
        }
        
        # Track max length for padding
        max_length = 0
        
        # First pass: collect trajectories and find max length
        raw_trajectories = []
        for _ in range(num_rollouts):
            # Lists for this specific rollout
            states, actions, rewards = [], [], []
            state = self.env.reset()
            
            while not self.env.is_terminal():
                # Convert state tuple to tensor
                state_tensor = torch.tensor(state, dtype=torch.float32)
                states.append(state_tensor)
                
                # Sample action and store
                action = self.policy.sample(state_tensor)
                actions.append(action)
                
                # Take action in environment
                self.env.make_move(action)
                reward = self.env.get_reward()
                rewards.append(reward)
                
                # Update state
                state = self.env.copy_state()
            
            max_length = max(max_length, len(states))
            raw_trajectories.append((states, actions, rewards))
        
        # Second pass: pad and store trajectories
        for states, actions, rewards in raw_trajectories:
            # Calculate padding needed
            pad_length = max_length - len(states)
            
            # Pad states with zeros
            padded_states = states + [torch.zeros_like(states[0])] * pad_length
            padded_states = torch.stack(padded_states)
            
            # Pad actions and rewards with zeros
            padded_actions = actions + [0] * pad_length
            padded_rewards = rewards + [0] * pad_length
            
            # Store padded trajectories
            trajectories['states'].append(padded_states)
            trajectories['actions'].append(torch.tensor(padded_actions))
            trajectories['rewards'].append(torch.tensor(padded_rewards))
        
        # Convert lists to tensors before returning
        trajectories['states'] = torch.stack(trajectories['states'])
        trajectories['actions'] = torch.stack(trajectories['actions'])
        trajectories['rewards'] = torch.stack(trajectories['rewards'])
            
        return trajectories