# **Trust Region Policy Optimization**

Policy gradient is based on gradient ascent, which is to update parameters by the first-order derivation. However, if the surface has a high curvature, we will make a relative large update to parameters, that is not stable.

TRPO limits the changes on the parameter within one iteration to stablize the policy update process, and make sure that the update is toward a better policy.

## **Minorize-Maximization Algorithm**

The Minorize-Maximization (MM) algorithm gives us the theoretical guarantees that the updates always result in improving the expected rewards. A simple one line explanation of this algorithm is that it iteratively maximizes a simpler lower bound function (lower bound with respect to the actual reward function), approximating the reward function locally.

<img src="img/mm_algorithm.png" width="800">  

The discounted reward can be expressed as,

$$\eta(\pi)=E_{\tau\sim\pi}[R(\tau)]=E_{\tau\sim\pi_{\theta}}[\sum_{t=1}^{\infty}\gamma^{t}r(s_t)]$$

As metioned in [reinforcement_learning_concept](https://github.com/kueiwen/reinforcement-learning/blob/main/reinforcement_learning_concept.ipynb), there are two function from Bellman equation: value function $V_{\pi}$ and action value function $Q_{\pi}$, here we introduce another function named advantage function $A_{\pi}$:

$$A_{\pi}(s,a)=Q_{\pi}(s,a)-V_{\pi}(s)$$

Advantage function calcuated the addtional action value if current state is $s$.

Then we can use advantage function to get how another policy $\tilde{\pi}$ is better than current policy $\pi$

$$\eta(\tilde{\pi})=\eta(\pi)+E_{\tau\sim\tilde{\pi}}[\sum_{t=1}^{\infty}\gamma^{t}A_{\tilde{\pi}}(s_t,a_t)]=\eta(\pi)+\sum_{s}\rho_{\tilde{\pi}}(s)\sum_{a}\tilde{\pi}(a|s)A_{\pi}(s,a)$$


where $\rho_{\pi}(s)$ is the probability distribution of states under policy $\pi$.



This equation implies that any policy update $\pi\to\tilde{\pi}$ that has a non-negative expected advantage at *every* state $s$, which mean that $\sum_{a}\tilde{\pi}(a|s)A_{\pi}(s,a)\geq0$ is guaranteed to increase the policy performance $\eta$ or keep the same.

However, it is hard to to use deterministic policy $\tilde{\pi}(s)=\text{arg}\max_{a}A(s,a)$ for all states to get a better policy that at least one state with positive advantage and others are zero, which is due to the difference between $\rho_{\tilde{\pi}}$ and $\rho_{\pi}$.

Therefore, here comes anther equation for local approximation of $\eta$,

$$L_{\pi}(\tilde{\pi})=\eta(\pi)+\sum_{s}\rho_{\tilde{\pi}}(s)\sum_{a}\tilde{\pi}(a|s)A_{\pi}(s,a)$$

The $L_{\pi}$ use the state probability disrtibution under $\pi$ rather than $\tilde{\pi}$.

And for any policy parameter $\theta_{0}$,

$$L_{\pi_{\theta_{0}}}(\pi_{\theta})=\eta(\pi_{\theta_{0}})$$

$$\nabla_{\theta}L_{\pi_{\theta_{0}}}(\pi_{\theta})|_{\theta=\theta_{0}}=\nabla_{\theta}\eta(\pi_{\theta})|_{\theta=\theta_{0}}$$

which implies that a small step $\pi_{\theta_0}\to\tilde{\pi}$ also contributes to $\eta$, but does not give us any guidance on how big of a step to take.

An policy updating schema is introduced to overcome this issue, called conservative policy iteration, which can provide explicit lower bounds on the improvement of $\eta$.

$$\pi_{\text{new}}(a|s)=(1-\alpha)\pi_{\text{old}}(a|s)+\alpha\pi^{\prime}(a|s)$$

where $\pi_{\text{old}}$ is current policy, and let $\pi^{\prime}=\arg\max_{\pi^{\prime}}L_{\text{old}}(\pi^{\prime})$.

And the lower bounds of improvement,

$$\eta(\pi_{\text{new}})\geq L_{\pi_{\text{old}}}(\pi_{\text{new}})-\frac{2\epsilon\gamma}{(1-\gamma)^2}\alpha^2$$

$$\text{where }\epsilon=\max_{s}|E_{a\sim\pi^{\prime}(a|s)}[A_{\pi}(s,a)]|$$

We can make this more applicable to pratical problem by replacing $\alpha$ with the distance between $\pi$ and $\tilde{\pi}$, we use total variance divergence $D_{\text{VT}}(p||q)=\frac{1}{2}\sum_{i}|p_i-q_i|$ for discrete probability distribution $p$ and $q$. So the distance between $\pi$ and $\tilde{\pi}$ will be,

$$D_{\text{VT}}^{\max}(\pi,\tilde{\pi})=\max_{a}D_{\text{VT}}(\pi(.|s)||\tilde{\pi}(.|s))$$

So if $\alpha=D_{\text{VT}}^{\max}(\pi,\tilde{\pi})$


$$\eta(\pi_{\text{new}})\geq L_{\pi_{\text{old}}}(\pi_{\text{new}})-\frac{4\epsilon\gamma}{(1-\gamma)^2}\alpha^2$$

$$\text{where }\epsilon=\max_{s,a}|A_{\pi}(s,a)|$$

Because $D_{\text{TV}}(p||q)^2\leq D_{\text{KL}}(p||q)$, the bound can be,

$$\eta(\pi_{\text{new}})\geq L_{\pi_{\text{old}}}(\pi_{\text{new}})-CD_{\text{KL}}^{\max}(\pi,\tilde{\pi})$$

$$\text{where }C=\frac{4\epsilon\gamma}{(1-\gamma)^2}\alpha^2$$

KL here represent KL divergence, which is to calculate the similarity of two distribution,  


$$D_{\text{KL}}(p,q)=\sum_{i}p_i\log(\frac{p_i}{q_i})$$

By conservative policy iteration, we can get a sequence of monotonically improved policies $\eta(\pi_0)\leq\eta(\pi_1)\leq\eta(\pi_2)\leq...$.   

Based on Minorize-Maximization (MM) algorithm, let surrogate function $M_i(\pi)=L_{\pi_{i}}(\pi)-CD_{\text{KL}}^{\max}(\pi_i,\pi)$, then,


$$\eta(\pi_{i+1})\geq M_i(\pi_{i+1})$$

$$\eta(\pi_{i})=M_i(\pi_{i})$$

$$\text{therefore  }\eta(\pi_{i+1})-\eta(\pi_{i})\geq M_i(\pi_{i+1})-M_i(\pi_{i})$$

By maximizing $M_i$ at each iteration, we can gaurantee that $\eta$ is non-decreasing.

***policy iteration algorithm guaranteeing non-decreasing expected return $\eta$***

---
Initialize $\pi_{0}$

**repeat**  
    $\quad$ Computing all advantage values $A_{\pi_{i}}(s,a)$      
    $\quad$ Solve the constrained optimization problem        
    $\quad\quad \pi_{i+1}=\arg\max_{\pi}[L_{\pi_{i}}(\pi)-CD_{\text{KL}}^{\max}(\pi_i,\pi)]$      
            $\quad\quad$ where $C=4\epsilon\gamma/(1-\gamma)^2$       
            $\quad\quad\quad$ and $L_{\pi_{i}}(\pi)=\eta(\pi_{i})+\sum_{s}\rho_{\pi{i}}(s)\sum_a \pi(a|s)A_{\pi_{i}}(s,a)$     
**until convergence** 

---


## **Optimization policy parameter**

From above derivation, we know that TRPO use a constraint on KL divergence to get a small step update, which is relatively robust.

To simplify the function, we use parameter $\theta$ to represent the policy $\pi_{\theta}$

$$\max_{\theta}[L_{\theta_{\text{old}}}(\theta)-CD_{\text{KL}}^{\max}(\theta_{\text{old}},\theta)]$$

However, if we use penalty coefficient $C$ the step size will be very small, so we cahnge to use a constraint on the KL divergence between the new policy and old policy,


$$\max_{\theta}L_{\theta_{\text{old}}}(\theta)$$

$$\text{subject to  }\overline{D}_{\text{KL}}^{\rho_{\theta_{\text{old}}}}(\theta_{\text{old}},\theta)\leq\delta$$

where $\overline{D}_{\text{KL}}^{\rho}(\theta_{1},\theta_{2}):=E_{s\sim\rho}D_{\text{KL}}(\pi_{\theta_{1}}(.|s),\pi_{\theta_{2}}(.|s))$

Maximizing $L_{\theta_{\text{old}}}(\theta)$ can be considered as maximizing $\sum_{s}\rho_{\theta_{\text{old}}}(s)\sum_{a}\pi_{\theta}(a|s)A_{\theta_{\text{old}}}(s,a)$. If we replace advantage function $A$ as $Q$, and use $q$ denote the sample distribution,

$$\max_{\theta}E_{s\sim\rho_{\theta_{\text{old}}},a\sim q}[\frac{\pi_{\theta}(a|s)}{q(a|s)}Q_{\theta_{\text{old}}}(s,a)]$$

$$\text{subject to  }E_{s\sim\rho_{\theta_{\text{old}}}}[D_{\text{KL}}(\pi_{\theta_{\text{old}}}(.|s)||\pi_{\theta}(.|s))]\leq\delta$$

There are two sampling schemes to estimate Q value: **single path** and **vine**.
<img src="img/trpo_path.png" width="600">  

##### **Single Path**

Single path is typically used in policy gradient, and is based on sampling one trajectory.

1. Collect a sequence of statesby smpling $s_0\sim\rho_0$.

2. Simulate the policy $\pi_{\theta_{\text{old}}}$ for some number of timesteps to gernerate a trajectory $s_0,a_0,s_1,a_1,...,s_{T-1},a_{T-1},s_T$.

3. Compute $Q_{\theta_{\text{old}}}$ at each state-action pair $(s_t,a_t)$ by taking discounted sum of future rewards along the trajectory.

##### **Vine**

Vine involves constructing a rollout set and then performing multiple actions from each state in the rollout set.

1. Collect a sequence of statesby smpling $s_0\sim\rho_0$.

2. Simulate the policy $\pi_{\theta}$ to gernerate a number of trajectories.

3. Choose a subset of $N$ states along the trajectories, called rollout set. ($s_1,s_2,...,s_{N}$)

4. For each state $s_n$ in rollout set, sample $K$ actions according to $a_{n,k}\sim q(.|s_n)$.

5. For each action $a_{n,k}$ sampled at each state $s_n$, estimate $\hat{Q}_{\theta_{i}}(s_n,a_{n,k})$ by performing a rollout starting with state $s_n$ and acttion $a_{n,k}$.
    * By using the same random number sequence for the noise in each of the $K$ rollouts, i.e., *common random numbers*, the variance of the Q-value differences between rollouts can be largely reduced.

In finite action spaces, we can generate a rollout for every possible action froma given state. The contribution to $L_{\theta_{\text{old}}}$ from a single state $s_n$ is,

$$L_n(\theta)=\sum_{k=1}^{K}\pi_{\theta}(a_k|s_n)\hat{Q}(s_n,a_{n,k})$$

In large or continuous action spaces, we can construct an estimator of the surrogate objective using importance sampling. The self-normalized estimator of $L_{\theta_{\text{old}}}$ obtained at a single state $s_n$ is,

$$L_n(\theta)=\frac{\sum_{k=1}^{K}\frac{\pi_{\theta}(a_{n,k}|s_n)}{\pi_{\theta_{\text{old}}}(a_{n,k}|s_n)}\hat{Q}(s_n,a_{n,k})}{\sum_{k=1}^{K}\frac{\pi_{\theta}(a_{n,k}|s_n)}{\pi_{\theta_{\text{old}}}(a_{n,k}|s_n)}}$$

To compare the two schemes, vine gives much better estimation of advantage value, but need to generate multiple trjectories from each state in the rollout set, which limits this algorithm to setting where the system can be reset to an arbitrary state. The single path, instead, does not need state resets and can be directly implemented.

## **TRPO Algorithm Steps**

1. Use the single path or vine procedures to collect a set of state-action pairs along wirh Monte Carlos estimates of the Q-values.

2. By averaging over sampling, construct the estimated objective ans constraint.

$$\max_{\theta}L_{\theta_{\text{old}}}(\theta)$$

$$\text{subject to  }\overline{D}_{\text{KL}}^{\rho_{\theta_{\text{old}}}}(\theta_{\text{old}},\theta)\leq\delta$$


3. Solve this optimization problem by conjugate gradient algorithm followed by a line search to update the policy's parameter vector $\theta$.
    * 

#### **Conjugate Gradient**

Conjugate gradient is used to solved linear equation or to optimize quadratic function.

The following linear equation and quadratirc optimization are equvalent.

Linear equation
$$Ax=b$$

quadratirc optimization
$$\max_x \frac{1}{2}x^{T}Ax-b^Tx$$
$$\text{subject to }Ax-b=0$$

Conjugate gradient is much more efficient than gradient ascent, because conjugate gradient method is a line search method but for every move, it would not undo part of the moves done previously.

<img src="img/conjugate_gradient.png" width="600">  

***conjugate gradient***

---
$r_0:=b-Ax$   
$p_0:=r_0$   
$k:=0$   

**repeat**  
    $\quad \alpha_k:=\frac{r_k^Tr_k}{p_k^TAp_k}\to$ how far to move in direction $p$          
    $\quad x_{k+1}:=x_k+\alpha_kp_k\to$ the next point     
    $\quad r_{k+1}:=r_k-\alpha_kAp_k\to$ remaining error from the optimal point      
    $\quad$ if $r_{k+1}$ is suffucuently small, then exit loop    
    $\quad \beta_k:=\frac{r_{k+1}^Tr_{k+1}}{r_k^tr_k}\to$ the new direction, A-orthogonal    
    $\quad p_{k+1}:=r_{k+1}+\beta_{k}p_{k}\to$ nex direction to go   
    $\quad k:=k+1$    
**end** 

---

For TRPO, the function of $x$ here, which is $A$ in above pseudo code, is fisher vector product.

#### **line search**

Since he theoritical TRPO optimization is not easy to achieve, TRPO make some approximations to get the result. With Taylor expand e objective and constraint o lead order around $\theta$,

$$L_{\theta_{\text{old}}}(\theta)\approx g^T(\theta-\theta_{\text{old}})$$

$$\overline{D}_{\text{KL}}^{\rho_{\theta_{\text{old}}}}(\theta_{\text{old}},\theta)\sim\frac{1}{2}(\theta-\theta_{\text{old}})^2H(\theta-\theta_{\text{old}})$$

Here we get an aproximate optimization problem,

$$\max_\theta g^T(\theta-\theta_{\text{old}})$$
$$\text{subject to  }\frac{1}{2}(\theta-\theta_{\text{old}})^2H(\theta-\theta_{\text{old}})\leq\delta$$

Based on Lagrangian duality, we can update parameter $\theta$ by,

$$\theta_{k+1}\leftarrow\theta_k+\alpha^j\sqrt{\frac{2\delta}{g^{T}H^{-1}g}}H^{-1}g$$

where $\alpha\in(0,1)$ is backtracking coefficient, and $j$ is the smallest non-negative integer such that $\pi_{\theta_{k+1}}$ satisfied the KL constraint and produces a positive surrogate advantage. If no $\alpha^j$, the algorithm is to calculate Natural Policy Gradient. However, due to the error from Tylor expansion, this may not satisfy the KL constraint, or actually improve the surrogate advange, so the $\alpha^j$ is a modification for backtracking line search rule.

***line search for TRPO***

---
Compute proposed policy step $\Delta_k=\sqrt{\frac{2\delta}{\hat{g}_{k}^{T}\hat{H}^{-1}\hat{g}_{k}}}\hat{H}^{-1}\hat{g}_{k}$

**repeat for k in 0,1,...,L**  
    $\quad$ Compute proposed update $\theta=\theta_k+\alpha^j\Delta_k$       
    $\quad$ **if** $L_{\theta_k}\geq 0$ and $\overline{D}_{KL}(\theta,\theta_k)\leq\delta$        
    $\quad\quad$ accept the update and set $\theta_{k+1}=\theta_k+\alpha^j\Delta_k$   
    $\quad\quad$ break      
**end** 

---

***TRPO***

---
**Input:** initial policy parameters $\theta_0$, value function parameter $\phi_0$     
**Hyperparameters:** KL-divergence limit $\delta$, backtracking coefficient $\alpha$, maximum numbe rof backtrackign steps $K$

**repeat for k in 0,1,...**  
    $\quad$ Collect set to trajectories $\cal{D}_k$ on policy $\pi_k=\pi(\theta_k)$       
    $\quad$ Compute rewards-to-go $\hat{R}_t$     
    $\quad$ Estimate advantages $\hat{A}_t^{\pi_k}$ using any advantage estimation algorithm based on current value function $V_{\phi_k}$       
    $\quad$ Estimate policy gradient $\hat{g}_k=\frac{1}{|\cal{D}_k|}\sum_{r\in\cal{D}_k}\sum_{t=0}^{T}\nabla_{\theta}\log\pi_{\theta}(a_t|s_t)|_{\theta_{k}}\hat{A}_t$        
    $\quad$ Use conjugate gradient with $n_{cg}$ iterations to obtain $x_k\approx\hat{H}_{k}^{-1}\hat{g}_{k}$      
    $\quad$ Estimate proposed step $\Delta_k\sim\sqrt{\frac{2\delta}{x_{k}^{T}\hat{H}_{k}x_{k}}}x_{k}$     
    $\quad$ Perform backtracking line search with exponential decay to obtain final update $\theta_{k+1}\leftarrow\theta_{k}+\alpha^{j}\Delta_{k}$    
    $\quad$ Fit value function by regression on mean-squared error, $\phi_{k+1}=\arg\max_{\phi}\frac{1}{|\cal{D}_k|T}\sum_{r\in\cal{D}_k}\sum_{t=0}^{T}[V_{\phi}(s_t)-\hat{R}_t]^2$       
**end** 

---

#### Policy and value network

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import scipy.optimize # For L-BFGS
import numpy as np
from typing import Tuple, Callable, Dict, Any

# Precompute constant
LOG_2_PI = np.log(2 * np.pi)

class Policy(nn.Module):
    """
    A Gaussian policy network for continuous action spaces.

    Args:
        input_dim: Dimension of the state space.
        hidden_dim: Dimension of the hidden layers.
        output_dim: Dimension of the action space.
    """
    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
        super(Policy, self).__init__()
        self.inputLayer = nn.Linear(input_dim, hidden_dim)
        self.hiddenLayer = nn.Linear(hidden_dim, hidden_dim)
        self.outputLayer = nn.Linear(hidden_dim, output_dim)

        self.outputLayer.weight.data.uniform_(-0.003, 0.003)
        self.outputLayer.bias.data.uniform_(-0.003, 0.003)

        # Learnable log standard deviation by nn.Parameter
        self.log_std = nn.Parameter(torch.zeros(1, output_dim))
        # Clamping log_std can improve stability
        self.log_std_min = -20
        self.log_std_max = 2


    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Forward pass to get action distribution parameters.

        Args:
            state: Input state tensor.

        Returns:
            A tuple containing:
            - action_mean: Mean of the action distribution.
            - action_log_std: Log standard deviation of the action distribution.
            - action_std: Standard deviation of the action distribution.
        """
        x = torch.tanh(self.inputLayer(x))
        x = torch.tanh(self.hiddenLayer(x))
        action_mean = self.outputLayer(x)

        # Clamp log_std for stability
        self.log_std.data.clamp_(self.log_std_min, self.log_std_max)

        action_log_std = self.log_std.expand_as(action_mean)
        action_std = torch.exp(action_log_std)
        return action_mean, action_log_std, action_std
    
    def get_log_probability_density(self, states: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
        """
        Calculate the log probability density of actions under the policy.

        Args:
            states: State tensor.
            actions: Action tensor.

        Returns:
            Log probability density for each state-action pair.
        """
        action_mean, action_log_std, action_std = self.forward(states)
        var = torch.exp(action_log_std).pow(2)
        log_prob_per_dim = -0.5 * (((actions - action_mean) / action_std)**2) \
                           - action_log_std \
                           - 0.5 * LOG_2_PI
        return log_prob_per_dim.sum(dim=1, keepdim=True)
    
    def get_KL_divergence(self, states: torch.Tensor, actions: torch.Tensor, old_log_prob: torch.Tensor) -> torch.Tensor:
        """
        Estimate the KL divergence D_KL(old_policy || current_policy) using samples.
        Assumes 'old_log_prob' contains log probabilities from the sampling policy.

        Args:
            states: State tensor.
            actions: Action tensor (sampled from the old policy).
            old_log_prob: Log probability of the actions under the old policy.

        Returns:
            Mean KL divergence estimate.
        """
        current_log_prob = self.get_log_probability_density(states, actions)
        kl_div = old_log_prob - current_log_prob
        return kl_div.mean()
    
    def get_action(self, state: torch.Tensor, deterministic: bool = False) -> torch.Tensor:
        """
        Sample or get the mean action from the policy.

        Args:
            state: Input state tensor (should be preprocessed, e.g., unsqueezed).
            deterministic: If True, return the mean action. Otherwise, sample.

        Returns:
            Action tensor.
        """
        with torch.no_grad(): # No need to track gradients for action selection
            action_mean, _, action_std = self.forward(state)
            if deterministic:
                return action_mean
            else:
                normal = torch.distributions.normal.Normal(action_mean, action_std)
                return normal.sample()
    

class Value(nn.Module):
    """
    A simple MLP value function network.

    Args:
        input_dim: Dimension of the state space.
        hidden_dim: Dimension of the hidden layers.
    """
    def __init__(self, input_dim: int, hidden_dim: int):
        super(Value, self).__init__()
        self.hidden1 = nn.Linear(input_dim, hidden_dim)
        self.hidden2 = nn.Linear(hidden_dim, hidden_dim)
        self.value_head = nn.Linear(hidden_dim, 1)

        self.value_head.weight.data.uniform_(-0.003, 0.003)
        self.value_head.bias.data.uniform_(-0.003, 0.003)


    def forward(self, x):
        x = torch.FloatTensor(x).unsqueeze(0)
        x = torch.tanh(self.hidden1(x))
        x = torch.tanh(self.hidden2(x))
        value = self.value_head(x)
        return value
    

#### Util function

In [2]:
def normal_log_density(actions: torch.Tensor, means: torch.Tensor, log_stds: torch.Tensor) -> torch.Tensor:
    """Calculates log probability density for a Gaussian distribution."""
    # Assuming policy_net.get_log_probability_density implements this correctly
    # This is just a placeholder signature
    # Example implementation (matches previous Policy review):
    stds = torch.exp(log_stds)
    log_prob_per_dim = -0.5 * (((actions - means) / stds)**2) \
                       - log_stds \
                       - 0.5 * np.log(2 * np.pi)
    return log_prob_per_dim.sum(dim=1, keepdim=True)

def set_flat_params_to(model: nn.Module, flat_params: torch.Tensor):
    """Sets model parameters from a flat tensor."""
    offset = 0
    for param in model.parameters():
        numel = param.numel()
        # Slice the flat_params and reshape it to the correct parameter shape
        param.data.copy_(flat_params[offset:offset + numel].view_as(param.data))
        offset += numel

def get_flat_params_from(model: nn.Module) -> torch.Tensor:
    """Flattens model parameters into a single tensor."""
    return torch.cat([p.detach().view(-1) for p in model.parameters()])

def get_flat_grad_from(model, grad_grad=False):
    '''Get first or second grad of param of model'''
    grad = []
    for param in model.parameters():
        if grad_grad:
            grad.append(param.grad.grad.view(-1))
        else:
            grad.append(param.grad.view(-1))
    return torch.cat(grad)

#### TRPO function

In [3]:
def conjugate_gradient(fvp_callable: Callable, b: torch.Tensor, n_steps: int, residual_tol: float = 1e-10) -> torch.Tensor:
    """
    Solves the linear system A*x = b using the conjugate gradient method,
    where A is implicitly defined by the Fisher-vector product function fvp_callable.

    Args:
        fvp_callable: A function that takes a vector v and returns A*v (the FVP).
        b: The right-hand side vector of the system A*x = b.
        nsteps: Maximum number of iterations.
        residual_tol: Tolerance for convergence.

    Returns:
        The solution vector x.
    """
    x = torch.zeros_like(b)
    r = b.clone() # Initial residual: r = b - A*x = b (since x=0)
    p = r.clone() # Initial search direction
    rdotr = torch.dot(r, r)

    for _ in range(n_steps):
        Ap = fvp_callable(p) # Calculate A*p (Fisher-vector product)
        alpha = rdotr / (torch.dot(p, Ap) + 1e-8) # Add epsilon for stability
        x += alpha * p
        r -= alpha * Ap
        new_rdotr = torch.dot(r, r)
        if new_rdotr < residual_tol:
            break
        beta = new_rdotr / (rdotr + 1e-8) # Add epsilon for stability
        p = r + beta * p
        rdotr = new_rdotr

    return x

def line_search(model: nn.Module,
                compute_objective: Callable, # Function to compute surrogate loss
                compute_constraint: Callable, # Function to compute KL divergence
                current_params: torch.Tensor,
                search_direction: torch.Tensor,
                expected_improvement: float,
                max_backtracks: int = 10,
                accept_ratio: float = 0.1,
                max_kl_constraint: float = 0.01) -> Tuple[bool, torch.Tensor]:
    """
    Performs backtracking line search to find parameters that improve the objective
    while satisfying the KL constraint.

    Args:
        model: The policy model.
        compute_objective: Function that takes flat params and returns scalar objective value.
        compute_constraint: Function that takes flat params and returns scalar constraint value (KL).
        current_params: Flat parameters before the update.
        search_direction: The proposed step direction (scaled).
        expected_improvement: Expected improvement from the quadratic approximation.
        max_backtracks: Maximum number of backtracking steps.
        accept_ratio: Minimum ratio of actual_improvement / expected_improvement.
        max_kl_constraint: The KL divergence limit.

    Returns:
        A tuple (success, new_params). 'success' is True if a valid step is found.
    """
    current_objective = compute_objective(current_params)

    for step_frac in (0.5**np.arange(max_backtracks)):
        new_params = current_params + step_frac * search_direction
        set_flat_params_to(model, new_params) # Temporarily update model
        
        # Evaluate objective and constraint with new parameters
        new_objective = compute_objective(new_params) # Needs scalar loss
        kl_divergence = compute_constraint(new_params) # Needs scalar KL
        
        actual_improvement = new_objective - current_objective
        expected_step_improvement = expected_improvement * step_frac
        ratio = actual_improvement / (expected_step_improvement + 1e-8) # Add epsilon

        # Check if improvement is sufficient and KL constraint is met
        if ratio > accept_ratio and actual_improvement > 0 and kl_divergence <= max_kl_constraint:
            # print(f"Line search success: step_frac={step_frac:.4f}, KL={kl_divergence:.4f}, Impr={actual_improvement:.4f}")
            return True, new_params # Keep the temporary update

    # If no step is found, revert to original parameters
    set_flat_params_to(model, current_params)
    # print("Line search failed.")
    return False, current_params

def trpo_step(model: Policy, # Use the specific Policy type hint if available
              value_net: Value, # Often needed for advantage calculation (outside this func)
              states: torch.Tensor,
              actions: torch.Tensor,
              advantages: torch.Tensor,
              old_log_prob: torch.Tensor, # Log probs from the policy *before* the update
              max_kl: float = 1e-2,
              damping: float = 1e-2):
    """
    Performs a single TRPO policy update step.

    Args:
        model: The policy network to be updated.
        value_net: The value network (potentially needed for helpers, though not directly used here).
        states: Batch of states.
        actions: Batch of actions corresponding to states.
        advantages: Batch of advantages corresponding to state-action pairs.
        old_log_prob: Log probabilities of the actions under the policy *before* this update.
        max_kl: Maximum KL divergence constraint.
        damping: Damping factor for the Fisher-vector product calculation.

    Returns:
        Tuple: (loss_gradient, kl_divergence_before_update) - for logging/debugging.
               The primary result is the update to the 'model' parameters.
    """
    # Ensure inputs are tensors
    states = torch.as_tensor(states, dtype=torch.float32)
    actions = torch.as_tensor(actions, dtype=torch.float32)
    advantages = torch.as_tensor(advantages, dtype=torch.float32)
    old_log_prob = torch.as_tensor(old_log_prob, dtype=torch.float32)

    # --- 1. Calculate Surrogate Loss and Gradient ---
    # Define functions needed for line search *before* calculating the initial gradient
    # These functions take flat parameters as input
    def compute_surrogate_loss(flat_params: torch.Tensor) -> torch.Tensor:
        """Computes the scalar surrogate loss for given parameters."""
        # Temporarily set model parameters
        original_params = get_flat_params_from(model)
        set_flat_params_to(model, flat_params)

        # Calculate loss (ensure no gradient tracking needed here for line search eval)
        with torch.no_grad():
            log_prob = model.get_log_probability_density(states, actions)
            # Importance sampling ratio: pi_new(a|s) / pi_old(a|s)
            ratio = torch.exp(log_prob - old_log_prob)
            surr_loss = - (advantages.squeeze() * ratio).mean() # Negative for minimization

        # Restore original parameters
        set_flat_params_to(model, original_params)
        return surr_loss
    
    def compute_kl_divergence(flat_params: torch.Tensor) -> torch.Tensor:
        """Computes the scalar KL divergence for given parameters."""
        # Temporarily set model parameters
        original_params = get_flat_params_from(model)
        set_flat_params_to(model, flat_params)

        # Calculate KL (ensure no gradient tracking needed here for line search eval)
        # Assumes model.get_KL_divergence computes KL(old || new) using old_log_prob
        # If get_KL_divergence needs the 'old' policy explicitly, adjust accordingly.
        with torch.no_grad():
            # Recompute current log_prob with new params
            current_log_prob = model.get_log_probability_density(states, actions)
            # KL divergence D_KL(pi_old || pi_new) = E_{a~pi_old} [log pi_old(a|s) - log pi_new(a|s)]
            kl = (old_log_prob - current_log_prob).mean()

        # Restore original parameters
        set_flat_params_to(model, original_params)
        return kl
    
    # Calculate the initial loss and gradient using current model parameters
    log_prob = model.get_log_probability_density(states, actions)
    ratio = torch.exp(log_prob - old_log_prob)
    action_loss = - (advantages.squeeze() * ratio) # Surrogate objective term per sample
    loss = action_loss.mean() # Average surrogate loss

    # Calculate the gradient of the loss w.r.t. current model parameters
    model.zero_grad() # Zero gradients before backward pass
    loss.backward() # Compute gradients
    loss_grad_list = [param.grad.clone() for param in model.parameters() if param.grad is not None]
    loss_grad = torch.cat([grad.view(-1) for grad in loss_grad_list]).detach() # Flatten and detach

    # --- 2. Fisher-Vector Product Calculation ---
    # (Defined as an inner function for cleaner scope, using variables from trpo_step)
    
    def fisher_vector_product(v: torch.Tensor) -> torch.Tensor:
        """
        Computes the Fisher-vector product (FVP) F*v or (H + damping*I)*v.

        Uses the Hessian of the KL divergence as an approximation for the Fisher
        Information Matrix (FIM). The computation H*v is done efficiently via
        two backward passes without explicitly constructing H.

        Args:
            v: The vector to multiply with the FIM (or damped Hessian). Should be
                a flattened tensor with the same number of elements as model parameters.

        Returns:
            The Fisher-vector product (H + damping*I)*v, detached from the graph.
        """
        model.zero_grad() # Ensure gradients are zeroed before computation

        # 1. Calculate KL divergence KL(old || current)
        # Ensure get_KL_divergence calculates the mean KL: E[log_old - log_new]
        # Recompute current log prob *with gradients enabled* for FVP
        current_log_prob_fvp = model.get_log_probability_density(states, actions)
        kl = (old_log_prob - current_log_prob_fvp).mean() # KL(pi_old || pi_new)

        # 2. Calculate first gradient: grad_kl = d(kl_div) / d(theta)
        # retain_graph=True and create_graph=True is essential for the Hessian-vector product calculation
        # They tell PyTorch to keep the computational graph that produced these gradients (g) intact, 
        # because we need to differentiate through this gradient calculation in a later step.
        grads = torch.autograd.grad(kl, model.parameters(), retain_graph=True, create_graph=True)
        flat_grad_kl = torch.cat([grad.view(-1) for grad in grads])

        # Ensure v is on the correct device and dtype, and detached
        v_device = flat_grad_kl.device
        v_dtype = flat_grad_kl.dtype
        v = v.to(device=v_device, dtype=v_dtype).detach()

        # 3. Calculate dot product: gv = sum(grad_kl * v)
        grad_kl_dot_v = torch.dot(flat_grad_kl, v)

        # 4. Calculate second gradient (Hessian-vector product):
        # Hv = d(grad_kl_dot_v) / d(theta) = d(dot(g, v)) / d(theta) = d(dot(d(kl)/d(theta), v)) / d(theta)
        # H is the Hessian of the KL divergence (H = d^2(kl) / d(theta)^2)
        # No need for create_graph=True or retain_graph=True here usually
        hessian_vector_prod_tuple = torch.autograd.grad(grad_kl_dot_v, model.parameters())
        # Flatten the result H*v
        flat_hessian_vector_prod = torch.cat([
            grad.contiguous().view(-1) for grad in hessian_vector_prod_tuple
        ])

        # 5. Apply damping and detach: FVP = H*v + damping * v
        # Damping helps stabilize optimization algorithms like Natural Gradient Descent or TRPO 
        # that use this calculation, preventing issues if H is ill-conditioned
        fvp = flat_hessian_vector_prod.detach() + damping * v
        # v is already detached from the earlier step

        return fvp
    
    # --- 3. Conjugate Gradient ---
    # Solve F * step_dir = -loss_grad (approximately F^{-1} * -loss_grad)
    step_dir = conjugate_gradient(fisher_vector_product, -loss_grad, 10)
    
    # --- 4. Calculate Proposed Step Size ---
    # Calculate s^T * F * s (where s = step_dir)
    # This is needed to scale the step to meet the KL constraint
    fvp_step_dir = fisher_vector_product(step_dir)
    shs = 0.5 * torch.dot(step_dir, fvp_step_dir) # 0.5 * s^T * F * s
    # Add epsilon to prevent division by zero or instability near zero
    lagrange_multiplier = torch.sqrt(shs / (max_kl + 1e-8))
    fullstep = step_dir / (lagrange_multiplier + 1e-8) # Scaled step: s / lm

    # Calculate expected improvement: g^T * s (where g = -loss_grad, s = fullstep)
    # Note: Original `loss_grad` is d(loss)/d(theta), so we use -loss_grad for improvement direction
    neggdotstepdir = torch.dot(-loss_grad, step_dir)
    expected_improvement = torch.dot(-loss_grad, fullstep) # Should be positive if step_dir is descent direction

    # --- 5. Line Search ---
    prev_params = get_flat_params_from(model)
    initial_kl = compute_kl_divergence(prev_params) # For logging

    success, new_params = line_search(
        model,
        compute_surrogate_loss, # Pass the function that computes scalar loss
        compute_kl_divergence,  # Pass the function that computes scalar KL
        prev_params,
        fullstep,
        expected_improvement,
        max_kl_constraint=max_kl
    )

    # --- 6. Update Policy ---
    set_flat_params_to(model, new_params) # Update model with parameters found by line search

    # Return initial gradient and KL for logging purposes
    return loss_grad, initial_kl

In [4]:
import gym

env = gym.make("Pendulum-v0")
num_inputs = env.observation_space.shape[0]
num_actions = env.action_space.shape[0]

In [None]:
hidden_dim = 64
gamma = 0.99 # Discount factor
tau = 0.95 # GAE lambda parameter
l2_reg = 1e-3 # L2 regularization strength for value net
max_kl: float = 0.01 # Max KL constraint for TRPO
damping: float = 0.1 # Damping for FVP in TRPO
policy_net = Policy(num_inputs, hidden_dim, num_actions)
value_net = Value(num_inputs, hidden_dim)



def update_policy(batch: Dict[str, Any]):
    """
    Updates the policy and value networks using GAE and TRPO.

    Args:
        batch: A dictionary containing 'states', 'actions', 'rewards', 'mask'.
               Assumes data corresponds to a single trajectory or episode.
    """
    # --- 1. Data Preparation ---
    # Consider device placement (.to(device)) if using GPU
    # Using torch.as_tensor is generally safer than torch.FloatTensor
    # Avoid squeeze(0) unless batch structure guarantees dim 0 is size 1.
    # Assume batch contains data for N steps: [N, state_dim], [N, action_dim], etc.
    try:
        states = torch.as_tensor(batch["states"], dtype=torch.float32)
        actions = torch.as_tensor(batch["actions"], dtype=torch.float32)
        rewards = torch.as_tensor(batch["rewards"], dtype=torch.float32)
        # Ensure masks are treated as floats for multiplication
        masks = torch.as_tensor(batch["mask"], dtype=torch.float32)
    except KeyError as e:
        print(f"Error: Batch dictionary missing key: {e}")
        return
    except Exception as e:
        print(f"Error processing batch tensors: {e}")
        return
    
    # Validate shapes - assuming [N, dim] format after potential loading squeeze
    if states.dim() == 1: states = states.unsqueeze(0) # Handle single step case
    if actions.dim() == 1: actions = actions.unsqueeze(0)
    if rewards.dim() == 1: rewards = rewards.unsqueeze(0)
    if masks.dim() == 1: masks = masks.unsqueeze(0)

    # Ensure rewards and masks have a trailing dimension for broadcasting if needed
    if rewards.dim() == 1: rewards = rewards.unsqueeze(-1) # Shape [N, 1]
    if masks.dim() == 1: masks = masks.unsqueeze(-1)     # Shape [N, 1]
    if actions.dim() == 1: actions = actions.unsqueeze(-1) # Shape [N, 1] if action_dim is 1

    
    # --- 2. Value Function Estimation ---
    with torch.no_grad(): # No gradients needed for calculating targets
        # Typo corrected: squeeeze -> squeeze
        # Use squeeze(-1) if value_net outputs [N, 1], or just ensure output is [N]
        values = value_net(states).squeeze(0) # Assuming output [N, 1] -> [N]


    # --- 3. GAE and Returns Calculation ---
    num_steps = rewards.size(0)
    returns = torch.zeros_like(rewards)     # Use zeros_like for correct shape/device/dtype
    deltas = torch.zeros_like(rewards)
    advantages = torch.zeros_like(rewards)

    prev_return = 0.0
    prev_value = 0.0
    prev_advantage = 0.0
    for i in reversed(range(num_steps)):
    # Ensure rewards[i], masks[i], values[i] are scalars or broadcastable
        # Using .item() might be safer if shapes are guaranteed [1], but indexing should work for [N]
        current_reward = rewards[i]
        current_mask = masks[i]
        current_value = values[i] # From value_net(states) calculated earlier

        # Calculate return G(t) = r_t + gamma * G(t+1) * mask
        returns[i] = current_reward + gamma * prev_return * current_mask

        # Calculate TD error (delta) = r_t + gamma * V(s_{t+1}) * mask - V(s_t)
        # Note: prev_value holds V(s_{t+1}) from the previous iteration
        deltas[i] = current_reward + gamma * prev_value * current_mask - current_value

        # Calculate GAE advantage A(t) = delta_t + gamma * tau * A(t+1) * mask
        advantages[i] = deltas[i] + gamma * tau * prev_advantage * current_mask

        # Update values for the next iteration (t-1)
        # Use detach() instead of .data if accessing tensors that might have history
        prev_return = returns[i].item() # Use .item() for scalar python number
        prev_value = current_value.item() # V(s_t) becomes V(s_{t+1}) for next step
        prev_advantage = advantages[i].item()

    # --- 4. Value Function Update (using SciPy L-BFGS) ---
    # Define the loss+gradient function required by fmin_l_bfgs_b
    def get_value_loss_and_grad(flat_params_numpy):
        # Convert numpy array back to tensor
        flat_params = torch.tensor(flat_params_numpy, dtype=torch.float32) # Match model dtype
        set_flat_params_to(value_net, flat_params)

        # Zero gradients using standard method
        value_net.zero_grad()

        # Forward pass - ensure states is correctly shaped
        values_pred = value_net(states) # Output shape likely [N, 1]

        # Calculate MSE loss - ensure targets (returns) have compatible shape [N, 1]
        value_loss = nn.functional.mse_loss(values_pred, returns) # Use torch MSE

        # Add L2 regularization
        l2_penalty = 0.0
        for param in value_net.parameters():
            l2_penalty += param.pow(2).sum()
        value_loss += l2_penalty * l2_reg * 0.5 # Common to scale L2 by 0.5

        # Backward pass to compute gradients
        value_loss.backward()

        # Get flat gradient and convert back to numpy double
        flat_grad = get_flat_grad_from(value_net)
        # Return loss and gradient as numpy doubles
        return value_loss.item(), flat_grad.cpu().numpy().astype(np.float64)

    # Get initial parameters as numpy double
    initial_params_numpy = get_flat_params_from(value_net).cpu().numpy().astype(np.float64)


    # Optimize using L-BFGS-B
    try:
        optimal_params_numpy, _, _ = scipy.optimize.fmin_l_bfgs_b(
            get_value_loss_and_grad, initial_params_numpy, maxiter=25 # Adjust maxiter as needed
        )
        # Update the value network with the optimized parameters
        set_flat_params_to(value_net, torch.tensor(optimal_params_numpy, dtype=torch.float32))
    except Exception as e:
        print(f"L-BFGS optimization failed: {e}")
        # Decide how to handle failure (e.g., skip update, use old params)


    # --- 5. Advantage Normalization ---
    # Ensure advantages tensor has requires_grad=False before in-place ops if needed elsewhere
    advantages = advantages.detach() # Make sure it's detached before normalization
    adv_mean = advantages.mean()
    adv_std = advantages.std() + 1e-8 # Add epsilon for numerical stability
    normalized_advantages = (advantages - adv_mean) / adv_std

    # --- 6. Prepare for TRPO Step ---
    # Calculate log probabilities using the policy *before* the update
    with torch.no_grad(): # Ensure no gradients are computed here
        action_means, action_log_stds, _ = policy_net(states)
        # Use the assumed log density function or the model's method
        # Ensure actions tensor has the correct shape expected by the function
        old_log_prob = normal_log_density(actions, action_means, action_log_stds)
        # No need for .data.clone(), detach() is sufficient if needed,
        # but with torch.no_grad(), it's already detached.

    
    # --- 7. Perform TRPO Step ---
    # Ensure the trpo_step function signature matches the arguments provided
    # Pass necessary arguments like max_kl and damping
    loss_grad, kl_divergence = trpo_step(
        model=policy_net,
        value_net=value_net, # Pass if needed by helpers within trpo_step
        states=states,
        actions=actions,
        advantages=normalized_advantages, # Use normalized advantages
        old_log_prob=old_log_prob,
        max_kl=max_kl,     # Pass the constraint value
        damping=damping    # Pass the damping value
    )

    
n_episodes = 1000 # Example total episodes
batch_size = 4000  # Target number of steps per policy update batch
max_episode_steps = 1000 # Max steps per episode
log_interval = 100 # How often to print logs

# --- Main Training Loop ---
total_steps_processed = 0 # Keep track of total steps across all updates

rewards = []

for i_episode in range(n_episodes):
    # Data storage for the current batch (will collect multiple episodes)
    batch_states = []
    batch_actions = []
    batch_rewards = []
    batch_masks = [] # Represents (1 - done)

    steps_in_batch = 0
    episodes_in_batch = 0
    total_reward_in_batch = 0.0

    # Collect experience until batch_size is reached
    while steps_in_batch < batch_size:
        state = env.reset()
        # Ensure state is in the format expected by policy_net (e.g., numpy array)
        # If policy_net expects a tensor, convert here:
        # state_tensor = torch.from_numpy(state).float().unsqueeze(0)

        episode_reward = 0.0
        episode_steps = 0

        # Temporary storage for the current episode's trajectory
        episode_states = []
        episode_actions = []
        episode_rewards = []
        episode_masks = []

        for t in range(max_episode_steps):
            # 1. Get Action
            # Ensure state format matches policy_net.get_action input requirement
            # Assuming get_action returns a tensor
            state_tensor = torch.from_numpy(state.reshape(-1)).float().unsqueeze(0)
            action_tensor = policy_net.get_action(state_tensor)
            # Convert action to numpy for the environment step if needed
            action_numpy = action_tensor.detach().cpu().numpy() # Adjust based on env requirements

            # 2. Step Environment
            # Ensure env.step returns consistent types (usually numpy for state/reward)
            next_state, reward, done, _ = env.step(action_numpy)

            # 3. Store Transition Data (using consistent types, e.g., numpy for states)
            episode_states.append(state.reshape(-1)) # Store original state (numpy)
            episode_actions.append(action_tensor) # Store action tensor
            episode_rewards.append(reward) # Store reward (float/numpy)
            episode_masks.append(1.0 - float(done)) # Store mask (float)

            state = next_state # Update state for next iteration
            episode_reward += reward
            episode_steps += 1

            if done:
                break

        # End of episode: Append episode data to the main batch lists
        batch_states.extend(episode_states)
        batch_actions.extend(episode_actions)
        batch_rewards.extend(episode_rewards)
        batch_masks.extend(episode_masks)

        # Update batch counters
        steps_in_batch += episode_steps
        total_reward_in_batch += episode_reward
        episodes_in_batch += 1

        # Store the reward of the *last completed* episode for logging
        last_episode_reward = episode_reward

    # --- Batch Finalization and Policy Update ---
    # Calculate average reward per episode in this batch
    avg_reward_per_episode = total_reward_in_batch / episodes_in_batch if episodes_in_batch > 0 else 0.0
    total_steps_processed += steps_in_batch

    # Prepare batch dictionary for update_policy
    # Convert lists of data points into single tensors
    # Ensure correct dtypes and device placement (.to(device)) if using GPU
    update_batch = {
        "states": torch.tensor(np.asarray(batch_states), dtype=torch.float32),
        "actions": torch.stack(batch_actions), # Stack list of action tensors
        "rewards": torch.tensor(batch_rewards, dtype=torch.float32).unsqueeze(1), # Add dim for [N, 1]
        "mask": torch.tensor(batch_masks, dtype=torch.float32).unsqueeze(1)      # Add dim for [N, 1]
        # Note: "next_states" is often not needed directly by GAE/TRPO update,
        # but if it were, you'd collect and tensorize it similarly.
    }

    # Call the policy update function
    update_policy(update_batch) # Pass the correctly formatted batch
    rewards.append(avg_reward_per_episode[0])
    # --- Logging ---
    if i_episode % log_interval == 0:
        print(f'Episode {i_episode}\tSteps Collected: {steps_in_batch}\t'
              f'Last Ep Reward: {last_episode_reward[0]:.2f}\t'
              f'Avg Batch Ep Reward: {avg_reward_per_episode[0]:.2f}')


  "rewards": torch.tensor(batch_rewards, dtype=torch.float32).unsqueeze(1), # Add dim for [N, 1]
  value_loss = nn.functional.mse_loss(values_pred, returns) # Use torch MSE


Episode 0	Steps Collected: 4000	Last Ep Reward: -1211.06	Avg Batch Ep Reward: -1324.67
