## TRPO+GAE Walker2D Continuous Control 

TRPO [Schulman, John, et al. "Trust region policy optimization." International conference on machine learning. 2015.](http://proceedings.mlr.press/v37/schulman15.pdf)

GAE [Schulman, Abbeel, et al. "High-Dimensional Continuous Control Using Generalized Advantage Estimation"](https://arxiv.org/pdf/1506.02438)




### TRPO Trusted Region Policy Optimization 

To improve training stability, we should avoid parameter updates that change the policy too much at one step. Trust region policy optimization (TRPO) carries out this idea by enforcing a KL divergence constraint on the size of policy update at each iteration.

Consider the case when we are doing off-policy RL, the policy $\beta$ used for collecting trajectories on rollout workers is different from the policy $\pi$ to optimize for. The objective function in an off-policy model measures the total advantage over the state visitation distribution and actions, while the mismatch between the training data distribution and the true policy state distribution is compensated by importance sampling estimator:

\begin{aligned}
J(\theta)
&= \sum_{s \in \mathcal{S}} \rho^{\pi_{\theta_\text{old}}} \sum_{a \in \mathcal{A}} \big( \pi_\theta(a \vert s) \hat{A}_{\theta_\text{old}}(s, a) \big) & \\
&= \sum_{s \in \mathcal{S}} \rho^{\pi_{\theta_\text{old}}} \sum_{a \in \mathcal{A}} \big( \beta(a \vert s) \frac{\pi_\theta(a \vert s)}{\beta(a \vert s)} \hat{A}_{\theta_\text{old}}(s, a) \big) & \scriptstyle{\text{; Importance sampling}} \\
&= \mathbb{E}_{s \sim \rho^{\pi_{\theta_\text{old}}}, a \sim \beta} \big[ \frac{\pi_\theta(a \vert s)}{\beta(a \vert s)} \hat{A}_{\theta_\text{old}}(s, a) \big] &
\end{aligned}


where $\theta_\text{old}$ is the policy parameters before the update and thus known to us; $\rho^{\pi_{\theta_\text{old}}}$ is defined in the same way as above; $\beta(a \vert s)$ is the behavior policy for collecting trajectories. Noted that we use an estimated advantage $\hat{A}(.)$ rather than the true advantage function $A(.)$ because the true rewards are usually unknown.

When training on policy, theoretically the policy for collecting data is same as the policy that we want to optimize. However, when rollout workers and optimizers are running in parallel asynchronously, the behavior policy can get stale. TRPO considers this subtle difference: It labels the behavior policy as $\pi_{\theta_\text{old}}(a \vert s)$ and thus the objective function becomes:

$$ J(\theta) = \mathbb{E}_{s \sim \rho^{\pi_{\theta_\text{old}}}, a \sim \pi_{\theta_\text{old}}} \big[ \frac{\pi_\theta(a \vert s)}{\pi_{\theta_\text{old}}(a \vert s)} \hat{A}_{\theta_\text{old}}(s, a) \big] $$

TRPO aims to maximize the objective function $J(\theta)$ subject to, trust region constraint which enforces the distance between old and new policies measured by KL-divergence to be small enough, within a parameter $\delta$:

$$\mathbb{E}_{s \sim \rho^{\pi_{\theta_\text{old}}}} [D_\text{KL}(\pi_{\theta_\text{old}}(.\vert s) \| \pi_\theta(.\vert s)] \leq \delta$$

In this way, the old and new policies would not diverge too much when this hard constraint is met. While still, TRPO can guarantee a monotonic improvement over policy iteration.

### ENV

In [1]:
import gymnasium as gym
env = gym.make('Walker2d-v5')

observation, info = env.reset()
print(env.action_space)
print(env.observation_space)
print(env.observation_space.shape[-1])
print(env.action_space.shape[0])

Box(-1.0, 1.0, (6,), float32)
Box(-inf, inf, (17,), float64)
17
6


### Define Model

We define separate policy and value models, to be more close to original research papers, although it can also be implemented in the way of ActorCritic single model 2 separate heads.

In [44]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.nn.init as init
from torch.distributions import Normal

# Prioritize device: CUDA > MPS > CPU
if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
    print("CUDA is available. Using CUDA.")
elif torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
    print("MPS backend is available. Using MPS.")
else:
    DEVICE = torch.device("cpu")
    print("Neither CUDA nor MPS is available. Using CPU.")

# lets see first on cpu
DEVICE = torch.device("cpu")

class Policy(nn.Module):

    def __init__(self, env, fc_units=128, activation=nn.ReLU()):
        """Initialize parameters and build model."""
        super(Policy, self).__init__()

        # Check if it's a vector environment or regular environment
        if hasattr(env, 'single_observation_space'):
            obs_dim = env.single_observation_space.shape[0]
            action_dim = env.single_action_space.shape[0]
        else:
            obs_dim = env.observation_space.shape[0]
            action_dim = env.action_space.shape[0]
        

        # Actor head
        self.fc_policy = nn.Sequential(
            nn.Linear(obs_dim, fc_units),
            activation,
            nn.Linear(fc_units, fc_units),
            activation
        )

        # Actor Head presented in form of mean modeled by linear layer and log(std) to use Gaussian Distribution
        self.fc_mean = nn.Linear(fc_units, action_dim) # fully connected 
        self.log_std = nn.Parameter(torch.zeros(action_dim)) # standard initialization std = exp(log_std) = exp(0) = 1      

        self.log_std_min = -20  # exp(-20) ≈ 2e-9, effectively zero
        self.log_std_max = 2 # exp(2) ≈ 7.4, reasonable upper bound

        for module in self.modules():
             if isinstance(module, nn.Linear):
                if module == self.fc_mean: 
                    init.xavier_normal_(module.weight, gain=1e-2) 
                    init.zeros_(module.bias)
                else:  # Hidden layers
                    if isinstance(activation, nn.Tanh):
                        init.xavier_normal_(module.weight, gain=1.0)
                    elif isinstance(activation, nn.ReLU):
                        init.kaiming_normal_(module.weight, mode='fan_in', nonlinearity='relu')
                    init.zeros_(module.bias)

    def forward(self, state, deterministic=False):
        """Forward method implementation."""

        mean = self.fc_mean(self.fc_policy(state))
         
        action_log_std = self.log_std.expand_as(mean) # we add expand_as to reflect passing multiple states same time
        action_log_std = torch.clamp(action_log_std, self.log_std_min, self.log_std_max)
        std = action_log_std.exp() # Convert log-std to std
        
        distribution = Normal(mean, std)
        if deterministic:
            action = mean  
        else:
            action = distribution.sample() # When to Use rsample: Use rsample if gradients need to flow through sampled actions, such as in:
                                    # SAC or DDPG: Off-policy RL algorithms.
                                    # VAEs or auxiliary losses involving actions
                
        # we summing up here last dim, for log probs and entropies, as we have multiple actions executed same time
        # (1,6) in case of cheetah we like to turn in (1,1), if it was (1,1) it just stays this way 
        log_prob = distribution.log_prob(action).sum(-1, keepdim=True)  
         # entropy = distribution.entropy().sum(-1, keepdim=True)


        return action, log_prob, mean, std
    
    def evaluate_actions(self, state, action):
        action_mean = self.fc_mean(self.fc_policy(state))
        
        action_log_std = self.log_std.expand_as(action_mean) # we add expand_as to reflect passing multiple states same time
        action_log_std = torch.clamp(action_log_std, self.log_std_min, self.log_std_max)
        action_std = action_log_std.exp() # Convert log-std to std
        distribution = Normal(action_mean, action_std)
       
        log_prob = distribution.log_prob(action).sum(-1, keepdim=True)  
        entropy = distribution.entropy().sum(-1, keepdim=True)

        return log_prob, entropy, action_mean, action_std


class Value(nn.Module):

    def __init__(self, state_size, fc_units=128):
        """Initialize parameters and build model."""
        super(Value, self).__init__()

        #Value head
        self.fc_value = nn.Sequential(
            nn.Linear(state_size, fc_units),
            nn.ReLU(),
            nn.Linear(fc_units, fc_units),
            nn.ReLU()
        )

        # Value Linear Out
        self.fc_value_out = nn.Linear(fc_units, 1)

    
    def forward(self, state):
        """Forward method implementation.""" 
        value = self.fc_value(state)
        value = self.fc_value_out(value)

        return value

MPS backend is available. Using MPS.


#### Model Utils

In [31]:
def compute_flattened_gradients(output, parameters, retain_graph=False, create_graph=False):
    """Compute the gradients of the output w.r.t. the parameters and flatten them into a single vector.
    Args:
        output: Output tensor to compute gradients.
        parameters: List of parameters to compute gradients.
        retain_graph: Whether to retain the computation graph. If False, the graph is freed after computing gradients.
        create_graph: Whether to create the computation graph. If True, the graph is retained for further computations.
    """
    if create_graph:
        retain_graph = True

    g = torch.autograd.grad(output, parameters, retain_graph=retain_graph, create_graph=create_graph)
    g = torch.cat([t.view(-1) for t in g])
    return g

## Rollout Storage



#### **Generalized Advantage Estimation (GAE)**
To reduce variance further while balancing bias, **Generalized Advantage Estimation (GAE)** is often used. The advantage function is computed as a weighted sum of temporal difference (TD) residuals over multiple steps, controlled by a parameter $\lambda \in [0, 1]$:

$$ A(s_t, a_t) = \sum_{l=0}^\infty (\gamma \lambda)^l \delta_{t+l}, $$

where the TD residual $\delta_t$ is given by:

$$ \delta_t = r_t + \gamma V(s_{t+1}; w) - V(s_t; w). $$

This introduces a bias-variance tradeoff: lower $\lambda$ relies more on immediate TD errors (lower variance), while higher $\lambda$ uses longer-horizon rewards (lower bias).

\begin{align*}
\text{If }  \lambda=0,  \; \hat{A}_t^{\text{GAE}(\gamma,0)} & := \delta_t = r_t + \gamma V(s_{t+1}) - V(s_t) \\
\text{If }  \lambda=1,  \; \hat{A}_t^{\text{GAE}(\gamma,1)} & :=  \sum_{l=0}^{\infty} \gamma^{l} \delta_{t+l} = \sum_{l=0}^{\infty} \gamma^{l} r_{t+l} - V(s_t)
\end{align*}

---


In [32]:
import torch 
import numpy as np 
from collections.abc import Generator


class RolloutStorage:
    """
    Rollout buffer used in on-policy algorithms like A2C/PPO.
    It corresponds to ``buffer_size`` transitions collected
    using the current policy or n-steps as we like to call it in TD.
    This experience will be discarded after the policy update.
    In order to use PPO objective, we also store the current value of each state
    and the log probability of each taken action.

    It is only involved in policy and value function training but not action selection.
    """
    def __init__(self, 
                 obs_shape: any,
                 action_dim: int = 1,
                 num_steps: int = 1, 
                 n_envs: int = 1,
                 device: torch.device = torch.device("cpu")):
        
        self.num_steps = num_steps
        self.n_envs = n_envs

        self.obs_shape = obs_shape
        self.action_dim = action_dim
        self.device = device

        self.last_obs = None # To store the last observation between rollouts

        # setup all the data   
        self.reset()  
    
    def reset(self) -> None:
        """
        Call reset, whenever we starting to collect next n-step rollout of data. 
        """
        self.obs = torch.zeros(self.num_steps, self.n_envs, *self.obs_shape, dtype=torch.float32, device=self.device)
        self.rewards = torch.zeros(self.num_steps, self.n_envs, 1,  dtype=torch.float32, device=self.device)
        self.values = torch.zeros(self.num_steps, self.n_envs,  1,  dtype=torch.float32, device=self.device)
        self.log_probs = torch.zeros(self.num_steps, self.n_envs, 1, dtype=torch.float32, device=self.device)
        self.actions = torch.zeros(self.num_steps, self.n_envs, self.action_dim, dtype=torch.float32, device=self.device)
        self.masks = torch.ones(self.num_steps, self.n_envs, 1, dtype=torch.int8, device=self.device)
        self.truncates = torch.zeros(self.num_steps, self.n_envs, 1, dtype=torch.bool, device=self.device)
        self.advantages = torch.zeros(self.num_steps, self.n_envs, 1, dtype=torch.float32, device=self.device)
        self.returns = torch.zeros(self.num_steps, self.n_envs, 1, dtype=torch.float32, device=self.device)
        self.action_means = torch.zeros(self.num_steps, self.n_envs, self.action_dim, dtype=torch.float32, device=self.device)
        self.actions_stds = torch.zeros(self.num_steps, self.n_envs, self.action_dim, dtype=torch.float32, device=self.device)

        self.step = 0
        self.generator_ready = False
    
    def add(
        self,
        obs: torch.Tensor,
        actions: torch.Tensor,
        log_probs: torch.Tensor,
        values: torch.Tensor,
        rewards: torch.Tensor,
        masks: torch.Tensor,
        truncates: torch.Tensor,
        means: torch.Tensor,
        stds: torch.Tensor
    ) -> None:
        """
        :param obs: Observations
        :param action: Actions
        :param log_probs: log probability of the action
            following the current policy.
        :param values: estimated value of the current state
            following the current policy.
        :param entropies: entropy calculated for the current step
        :param rewards: rewards
        :param masks: indicate env is still active (terminated or truncated)
        :param truncated: indicate env is truncated, needed to calculated Advantages correctly
        """
        self.obs[self.step].copy_(obs)
        self.actions[self.step].copy_(actions)
        self.log_probs[self.step].copy_(log_probs) 
        self.values[self.step].copy_(values)
        self.rewards[self.step].copy_(rewards)
        self.masks[self.step].copy_(masks)
        self.truncates[self.step].copy_(truncates)
        self.action_means[self.step].copy_(means)
        self.actions_stds[self.step].copy_(stds)

        self.step = (self.step + 1) % self.num_steps # hopefully thios % is actually not needed here

    def compute_returns_and_advantages(
            self,
            last_values: torch.Tensor,
            gamma: float = 0.99,
            gae_lambda: float = 1.0,
            normalize: bool = True) -> None:
        """
        Post-processing step: compute the advantages A using TD(n) error method, to use in the gradient calculation in future
            - TD(1) or A_1 is one-step estimate with bootstrapping delta_t = (r_{t+1} + gamma * v(s_{t+1}) - v(s_t))
            ....
            - TD(n) or A_n is n-step estimate with bootstrapping SUM_{l=0}^{n}(gamma^{l}*delta_{t+l})
               (r_{t+1} + gamma*r_{t+2} + gamma^2*r_{t+3} + .....+ gamma^(n+1)*v(s_{t+n+1}) - v(s_t))
        
        We using Generalized Advantage Estimation, in this case advantage calculated next way:

            - A_t^gae(gamma,lambda) = SUM_{l=0}^{\infty}( (gamma*lambda)^{l} * \delta_{t+l})

        :param last_values: state values estimation for the last step (one for each env)
        :param gamma: discount to be used in reward estimation
        :param use_gae:  use Generalized Advantage Estimation 
        :param gae_lambada: factor for trade off of bias vs variance for GAE
        """
        gae = 0
        for step in reversed(range(self.num_steps)):
            if step == self.num_steps - 1:
                next_values = last_values.detach()
            else:
                next_values = self.values[step+1].detach()

            # Handle truncated episodes by incorporating next state value
            # https://github.com/DLR-RM/stable-baselines3/issues/633#issuecomment-961870101
            adjusted_rewards = self.rewards[step].clone()  # Start with original rewards
            adjusted_rewards[self.truncates[step]] += gamma * next_values[self.truncates[step]]
            
            delta = adjusted_rewards + gamma * self.masks[step] * next_values - self.values[step].detach() #td_error
            gae = delta + gamma * gae_lambda * self.masks[step] * gae
            self.advantages[step] = gae.detach()

        #R_t = A_t{GAE} + V(s_t) 
        self.returns = self.advantages + self.values.detach()    

        # Normalize advantages to reduce skewness and improve convergence
        if normalize:
            self.advantages = (self.advantages - self.advantages.mean()) / (self.advantages.std() + 1e-8)

    def get_mini_batch(self, batch_size):
        
        indices = np.random.permutation(self.num_steps * self.n_envs)
        # Prepare the data
        if not self.generator_ready:
            self.obs_ = self._swap_and_flatten(self.obs)
            self.returns_ = self._swap_and_flatten(self.returns)
            self.generator_ready = True

        # Return everything, don't create minibatches
        if batch_size is None:
            batch_size = self.num_steps * self.n_envs

        start_idx = 0
        while start_idx < self.num_steps * self.n_envs:
            batch_indices = indices[start_idx:start_idx + batch_size]
            yield self.obs_[batch_indices], self.returns_[batch_indices]
            start_idx += batch_size

    def _swap_and_flatten(self, tensor):
        """
        Swap the first two axes and flatten the tensor.
        """
        shape = tensor.shape  # e.g., (num_steps, n_envs, feature_dim)
        return tensor.swapaxes(0, 1).reshape(-1, *shape[2:])  # Flatten into (num_steps * n_envs, feature_dim)

  """


## TRPO Agent

The formula for the Kullback-Leibler (KL) divergence between two Gaussian distributions comes from the general definition of KL divergence. Let's derive it step-by-step, including all mathematical details.

---

### **What is KL Divergence?**

For two probability distributions $P(x)$ and $Q(x)$, the KL divergence is defined as:


$$ D_{\text{KL}}(P \,||\, Q) = \int P(x) \log \frac{P(x)}{Q(x)} \, dx $$


It measures how different the distribution $Q$ is from $P$. Intuitively:
- $P$ is the "true" distribution.
- $Q$ is the "approximating" distribution.
- $D_{\text{KL}}$ quantifies the inefficiency of representing $P$ with $Q$.

---

### **KL Divergence for Gaussian Distributions**

Let $ P(x) \sim \mathcal{N}(\mu_0, \sigma_0^2) $ and $ Q(x) \sim \mathcal{N}(\mu_1, \sigma_1^2) $. The probability density functions (PDFs) of these Gaussian distributions are:


$$P(x) = \frac{1}{\sqrt{2 \pi \sigma_0^2}} \exp\left(-\frac{(x - \mu_0)^2}{2 \sigma_0^2}\right) $$


$$Q(x) = \frac{1}{\sqrt{2 \pi \sigma_1^2}} \exp\left(-\frac{(x - \mu_1)^2}{2 \sigma_1^2}\right)$$

The KL divergence becomes:


$$D_{\text{KL}}(P \,||\, Q) = \int P(x) \log \frac{P(x)}{Q(x)} \, dx$$

---

### **Connection to the Code**

The code directly implements this formula:
```python
kl = log_std1 - log_std0 + (std0.pow(2) + (mean0 - mean1).pow(2)) / (2.0 * std1.pow(2)) - 0.5
```

This computation is done element-wise across all action dimensions, and the final KL divergence is summed over dimensions and averaged across states.

---

The **Fisher Vector Product (FVP)** is a key step in **TRPO** to approximate the curvature of the policy's objective function. It involves computing the product of the Fisher Information Matrix (FIM) with a vector \( v \), which is essential for solving the constrained optimization problem in TRPO.

Here’s a step-by-step breakdown of the function and the corresponding mathematical formulas.

---

### **What is the Fisher Information Matrix (FIM)?**

The FIM is defined as:


$$F = \mathbb{E} \left[ \nabla_\theta \log \pi(a \mid s; \theta) \nabla_\theta \log \pi(a \mid s; \theta)^T \right]$$

Where:
- $\theta$: Policy parameters.
- $\pi(a \mid s; \theta)$: The policy’s probability distribution for action \( a \) given state \( s \).

The FIM is a positive semi-definite matrix that captures the second-order curvature of the KL divergence. In TRPO, it is used to constrain the policy update via the trust region.

---

### **Why Do We Need the Fisher Vector Product?**

Directly computing the FIM is computationally expensive because:
1. It involves storing a large matrix (size: $ |\theta| \times |\theta| $, where $ |\theta| $ is the number of parameters).
2. Inverting the FIM is even more expensive.

Instead, we use the **Fisher Vector Product (FVP)**:

$$Fv = \left[ \nabla_\theta^2 D_{\text{KL}}(\pi_{\text{old}} \,||\, \pi_{\text{new}}) \right] v$$

This computes the effect of multiplying the FIM $F$ with a vector $v$ without explicitly constructing $F$. It leverages automatic differentiation to efficiently compute $Fv$.



### **Mathematical Summary**

The function `Fvp` computes the Fisher Vector Product:


$$Fv = \left( \nabla_\theta^2 D_{\text{KL}}(\pi_{\text{old}} \,||\, \pi_{\text{new}}) \right) v + \lambda v$$

Steps:
1. Compute the gradient $ g = \nabla_\theta D_{\text{KL}} $.
2. Compute the dot product $ g^T v $.
3. Take the gradient of $ g^T v  to approximate the Hessian-vector product $ H v $.
4. Add the damping term $ \lambda v $ for stability.

---

In [35]:
from collections import deque
from functools import partial
import os 
import time

from helpers.utils import Logger


class TRPOAgent:
    """
    Trust Regoion Policy Optimization (TRPO)
    :param env(gym.Env): The environment to learn from / openAI Gym environment - vector environment
    :param n_steps: The number of steps to run for each environment per update
        (i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel)
    :param gamma: Discount factor
    :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator.
        Equivalent to classic advantage when set to 1.
    :param ent_coef: Entropy coefficient for the loss calculation
    :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator.
         Equivalent to classic advantage when set to 1.
    :param cg_max_steps: maximum number of steps in the Conjugate Gradient algorithm
         for computing the Hessian vector product
    :param cg_damping: damping in the Hessian vector product computation
    :param target_kl: Target Kullback-Leibler divergence between updates.
        Should be small for stability. Values like 0.01, 0.05.
    :param device: Device (cpu, cuda, ...) on which the code should be run.
    """
    def __init__(self,
                 env: gym.vector.VectorEnv,
                 n_steps: int = 5,
                 gamma: float = 0.99,
                 gae_lambda: float = 1.0,
                 ent_coef: float = 0.01,
                 min_ent_coef: float = 0.001,
                 lr: float = 1e-3,
                 target_kl: float = 0.01,
                 cg_maxsteps: int = 15,
                 cg_damping: float = 0.001,
                 n_critic_updates: int = 20,
                 min_best_rewards: float = 2000.0,
                 line_search_step_decay: float = 0.5,
                 fc_units: int = 128,
                 activation: nn.Module = nn.ReLU(),
                 normalize_advantage: bool = True, 
                 device: torch.device = torch.device("cpu")):
        
        self.env = env
        self.n_steps = n_steps
        self.n_envs = env.num_envs

        self.gamma = gamma
        self.gae_lambda = gae_lambda

        self.entropy_coef = ent_coef
        self.min_entropy_coef = min_ent_coef

        self.target_kl = target_kl
        self.cg_damping = cg_damping
        self.cg_maxsteps = cg_maxsteps
        self.n_critic_updates = n_critic_updates
        self.normalize_advantage = normalize_advantage
        self.line_search_step_decay = line_search_step_decay

        self.device = device

        self.action_size = self.env.action_space.shape[-1]
        self.state_size = self.env.observation_space.shape[-1]

        # networks
        self.policy_net = Policy(env, fc_units, activation=activation).to(device) 
        self.value_net = Value(self.state_size, fc_units).to(device)
        self.value_optimizer = torch.optim.Adam(self.value_net.parameters(), lr=lr)

        #logger
        run_name = (
            f"lr{lr}"
            f"_nenvs{env.num_envs}"
            f"_nsteps{n_steps}"
            f"_gamma{gamma}"
            f"_gae{gae_lambda}"
            f"_ent{ent_coef}"
            f"_kl{target_kl}"
            f"_cg{cg_maxsteps}"
            f"_damping{cg_damping}"
            f"_activation{activation.__class__.__name__}"
            f"_{int(time.time())}"
        )
        run_name = "".join(run_name)
        self.logger = Logger(run_name=run_name, env=env.envs[0].spec.id, algo="TRPO")
        self.logger.add_run_command()

        self.episode_num = 1
        self.best_reward = float('-inf') 
        self.min_best_rewards = min_best_rewards

        # rollout storage
        self.rollout_storage =  RolloutStorage(
            (self.state_size, ), 
            self.action_size, 
            num_steps=n_steps, 
            n_envs=self.n_envs, 
            device=self.device)

    def collect_rollouts(self):
        """
        Collect experiences using the current policy and fill a ``RolloutBuffer``.
        The term rollout here refers to the model-free notion and should not
        be used with the concept of rollout used in model-based RL or planning.
        """

        last_obs = self.rollout_storage.last_obs 
        self.rollout_storage.reset()

        for _ in range(self.n_steps):
            with torch.no_grad():
                actions, log_probs, means, stds =  self.policy_net(last_obs)
                values = self.value_net(last_obs)

            #clamp selected actions to be in correct format
            # we use ClipAciton wrapper in the env, so we don't need to do this
            # selected_actions = actions.cpu().numpy().clip(
            #        self.env.action_space.low, 
            #        self.env.action_space.high)
            
            # Take actions in env and look the results
            obs, rewards, terminates, truncates, infos = self.env.step(actions.cpu().numpy())

            # Updating global number of steps agent done, while learning
            self.num_timesteps += self.env.num_envs

            dones = (terminates | truncates)
            masks = torch.tensor(1 - dones, dtype=torch.float32).unsqueeze(-1)
            rewards = torch.tensor(rewards, dtype=torch.float32).unsqueeze(-1)
            truncates = torch.tensor(truncates, dtype=torch.bool).unsqueeze(-1)

            if "final_info" in infos:
                source_info = infos["final_info"]
                self._extract_episode_data(source_info.get('episode'), source_info.get('_episode'))

            self.rollout_storage.add(last_obs, 
                            actions, 
                            log_probs,
                            values,
                            rewards,
                            masks,
                            truncates,
                            means,
                            stds)
            
            last_obs = torch.tensor(obs, dtype=torch.float32, device=self.device)

        # compute values for the last timestamp
        with torch.no_grad():
            last_values = self.value_net(last_obs)

        self.rollout_storage.compute_returns_and_advantages(
            last_values, 
            self.gamma,
            self.gae_lambda,
            normalize = self.normalize_advantage)
        self.rollout_storage.last_obs = last_obs

    def optimize_value_function_adam(self): 
        """
        Optimize the value network using Adams optimizer.
        """  
        targets = self.rollout_storage.returns
        states = self.rollout_storage.obs

        for _ in range(self.n_critic_updates):
            for states, targets in self.rollout_storage.get_mini_batch(128):
                values_pred = self.value_net(states)
                value_loss = F.mse_loss(values_pred, targets)
                self.value_optimizer.zero_grad()
                value_loss.backward()
                # nn.utils.clip_grad_norm_(self.value_net.parameters(), 0.5)
                self.value_optimizer.step()

    def optimize_policy(self):

        # 1. Compute initial policy metrics 
        policy_objective, kl_div = self._compute_policy_metrics()

        # 2. Prepare for optimization
        policy_params = list(self.policy_net.parameters())

        # Compute policy objective gradients
        objective_grad= compute_flattened_gradients(policy_objective, 
                                                                 policy_params, 
                                                                 retain_graph=True)
        
        # Compute KL divergence gradient (needed for Hessian-vector product)
        kl_grad = compute_flattened_gradients(kl_div, 
                                              policy_params, 
                                              create_graph=True)

        # Hessian-vector dot product function used in the conjugate gradient step
        hessian_vector_product_fn = partial(self._hessian_vector_product,
                                            policy_params,
                                            kl_grad)       
        
        # 3. Compute search direction using conjugate gradient
        search_direction = self._conjugate_gradients(
            hessian_vector_product_fn, 
            objective_grad, 
            self.cg_maxsteps)

        # 4. Determine maximal step size based on KL constraint
        quadratic_term = search_direction @ hessian_vector_product_fn(search_direction)
        if quadratic_term <= 0:
            print("Invalid quadratic term:", quadratic_term)
            raise KeyError
        
        max_step_size = torch.sqrt(2 * self.target_kl/quadratic_term)  # type: ignore[assignment, arg-type]
        full_step = max_step_size * search_direction

        # 5. Line search to find acceptable step size        
        step_size = 1.0
        line_search_step_decay = self.line_search_step_decay
        is_line_search_success = False
        original_parameters = [param.detach().clone() for param in policy_params]
        
        with torch.no_grad():
            # Try increasingly smaller step sizes until constraints are satisfied
            for i in range(10):
                # Applying the scaled step direction
                n = 0
                for param, original_param in zip(policy_params, original_parameters):
                    n_params = param.numel()
                    param.data = (
                        original_param 
                        + step_size
                        * full_step[n: n+n_params].view(param.shape)
                    )
                    n += n_params

                # Evaluate new policy
                new_policy_objective, new_kl = self._compute_policy_metrics()
                improvement = new_policy_objective - policy_objective

                # Check if step satisfies constraints:
                # 1. KL divergence within bounds
                # 2. Policy objective improved
                if (new_kl <= self.target_kl) and (improvement > 0):
                    is_line_search_success = True
                    break

                # Reducing step size if line-search wasn't successful
                step_size *= line_search_step_decay

            if not is_line_search_success:
                # If the line-search wasn't successful we revert to the original parameters
                for param, original_param in zip(policy_params, original_parameters):
                    param.data = original_param.data.clone()
                
                return policy_objective.item()
        
        return new_policy_objective.item()
    
    def _compute_policy_metrics(self):
        """
        Compute metrics used for policy optimization:
        - Surrogate objective
        - KL divergence
        """
        rollout = self.rollout_storage
        states, actions = rollout.obs, rollout.actions
        advantages, fixed_log_prob = rollout.advantages, rollout.log_probs
        mean_0, std_0 = rollout.action_means, rollout.actions_stds

        # Call to reevaluate actions using updated policy, but same trajectory
        log_prob, entropy, mean_1, std_1 = self.policy_net.evaluate_actions(states, actions)

        # Compute ratio for surrogate function 
        ratio = torch.exp(log_prob - fixed_log_prob)
        surrogate_objective = (advantages * ratio + self.entropy_coef * entropy).mean() 

        # Compute KL Divergance using mean and std values 
        # Formula: KL = log(std_1/std_0) + (var_0 + (mean_0 - mean_1)^2) / (2*std_1^2) - 0.5
        variance_ratio = std_0.pow(2) + (mean_0 - mean_1).pow(2)
        kl_div = torch.log(std_1 / std_0) + variance_ratio / (2.0 * std_1.pow(2)) - 0.5
        kl_divergance = kl_div.sum(-1, keepdim=True).mean()

        return surrogate_objective, kl_divergance

    def _hessian_vector_product(self, 
                                params: list[nn.Parameter], 
                                grad_kl: torch.Tensor, 
                                vector: torch.Tensor):
        """
        Compute the Hessian-vector product of the KL divergence with respect to the model parameters.
        This function first computes the dot product between the flattened KL divergence gradient (grad_kl)
        and the input vector. Then, it computes the gradient of this dot product with respect to the parameters,
        effectively yielding the Hessian-vector product. An extra damping term (self.cg_damping * vector) is
        added to the result for numerical stability during the conjugate gradient optimization process.
        
        Mathematically:
            Let g(θ) = ∇₍θ₎ D_KL(θ).
            Then, the Hessian-vector product H(θ)v is given by:
            H(θ)v = ∇₍θ₎ (g(θ)ᵀ v)
            The final result is:
            H(θ)v + damping * v
            Used in the Conjugate Gradient optimization process:
            
        
        :param params: List[torch.nn.Parameter]
            The list of model parameters with respect to which the gradients of the KL divergence are computed.
        :param grad_kl: torch.Tensor
            The flattened gradient of the KL divergence between the old and new policy.
        :param vector: torch.Tensor
            The vector to be multiplied with the Hessian.
        :param retain_graph: bool, optional
            Whether to retain the computational graph after computing the gradient. This allows further backward operations.
        
        :return: torch.Tensor
            The Hessian-vector product with an added damping term.
        """
        jacobian_vector_product = compute_flattened_gradients(grad_kl @ vector, params, retain_graph=True)
        return jacobian_vector_product + self.cg_damping * vector

    # Conjugate Gradient Solver
    def _conjugate_gradients(self, matrix_vector_dot_fn, b, max_iter=10, residual_tol=1e-10):
        """
        Solve Ax = b using Conjugate Gradient method.
        Finds an approximate solution to a set of linear equations Ax = b
        """
        # The vector is not initialized at 0 because of the instability issues when the gradient becomes small.
        # A small random gaussian noise is used for the initialization.
        x =  1e-4 * torch.rand_like(b)  # Initial guess for the solution (x = 0).
        # r = b - Ax (initial residual)
        residual = b - matrix_vector_dot_fn(x) 
        # Initial squared norm of residual (r^T r)
        residual_squared_norm = torch.matmul(residual, residual) # Initial residual norm (squared).

        if residual_squared_norm < residual_tol:
            # If the gradient becomes extremely small
            # The denominator in alpha will become zero
            # Leading to a division by zero
            return x

        p = residual.clone() # Search direction (start with p = r).

        for i in range(max_iter):
            # Compute A@p - matrix-vector product with current search direction
            A_dot_p = matrix_vector_dot_fn(p)
            
            # Compute optimal step size: α = (r^T r) / (p^T A p)
            # This minimizes the quadratic form along direction p
            alpha = residual_squared_norm / p.dot(A_dot_p)
            # Update solution: x = x + α*p
            x += alpha * p

            if i == max_iter - 1:
                return x
            
            # Update residual: r = r - α*A*p
            residual -= alpha * A_dot_p
            new_residual_squared_norm = torch.matmul(residual, residual)
            # Check convergence criterion: ||r||^2 < tolerance
            if new_residual_squared_norm < residual_tol:
                return x

            # Compute β = (r_new^T r_new)/(r_old^T r_old) [Fletcher-Reeves formula]
            beta = new_residual_squared_norm / residual_squared_norm
            residual_squared_norm = new_residual_squared_norm
            # Update search direction: p = r + β*p
            p = residual + beta * p

        # Note: this return statement is only used when max_iter=0
        return x

    def learn(self):
        """
        Do some learning
        """
        # self.optimize_value_function_original()
        self.optimize_value_function_adam()
        loss = self.optimize_policy()
        
        return loss

    def train(self, total_timesteps: int = 1e5, log_interval: int = 100):
        """Train the agent."""

        obs, _ = self.env.reset()
        self.rollout_storage.last_obs = torch.tensor(obs, dtype=torch.float32, device=self.device)

        self.num_timesteps = 0
        self.episode_rewards = deque(maxlen=20)
        self.episode_lengths = deque(maxlen=20)
        # Calculate the updates
        update = 1
        
        while self.num_timesteps < total_timesteps:
            self.collect_rollouts()
            policy_loss = self.learn()

             # Display training infos
            if update % log_interval == 0 and len(self.episode_rewards) > 1:
                mean_rewards =  np.mean(self.episode_rewards)
                min_rewards = np.min(self.episode_rewards)

                print(
                    "Updates {}, num timesteps {}/{} Mean Length {:.1f}\n"
                    "Policy {:.4f}\n"
                    "Last {} training episodes: mean/median reward {:.1f}/{:.1f}, min/max reward {:.1f}/{:.1f}"
                    .format(
                        update, self.num_timesteps, total_timesteps, np.mean(self.episode_lengths),
                        policy_loss,
                        len(self.episode_rewards), mean_rewards,
                        np.median(self.episode_rewards), min_rewards,
                        np.max(self.episode_rewards)
                    )
                )
                
                if update % (10 * log_interval) == 0:
                    print('Saving checkpoint...')
                    self.save_model(f'torch.model')

            update += 1
        
        print('Saving last...')
        self.save_model(f'torch.model')

            

    def _extract_episode_data(self, episode_data, episode_flags):
        """
        Extract data for environments where '_episode' is True and append it to 
        self.episode_info_buffer deque
            {'r': -21.0, '_r': True, 'l': 944, '_l': True, 't': 5.089006, '_t': True}- example
            r - cumulative reward
            l - episode length
            t - elapsed time since beginning of episode

        :param episode_data: dict, data from environments
        :param episode_flags: np.ndarray, boolean array indicating done environments
        """
        done_envs = np.where(episode_flags)[0]  # Get indices of done environments
    
        for env_index in done_envs:
            env_specific_data = {}
            for key, value in episode_data.items():
                if isinstance(value, np.ndarray):  # Ensure it's an array
                    env_specific_data[key] = value[env_index]
            reward = env_specific_data['r']
            length = env_specific_data['l']
            self.logger.add_scalar('reward', reward, self.episode_num)
            self.logger.add_scalar('length', length, self.episode_num)        
            self.episode_rewards.append(reward)
            self.episode_lengths.append(length)
            self.episode_num += 1
        
        # Save the best model
        mean_rewards = np.mean(self.episode_rewards)
        if  mean_rewards >= self.min_best_rewards and mean_rewards > self.best_reward:
            self.best_reward = mean_rewards
            self.save_model(f'torch-best.model')
            print(f"Saved best model with min rewards {mean_rewards:.2f}")
    
    def save_model(self, filename: str):
        """Save the current model state dictionaries and normalization info to a file in logger.dir_name."""
        model_path = os.path.join(self.logger.dir_name, filename)
        # Gather normalization statistics from each environment
        all_obs_rms = [env.obs_rms for env in self.env.envs]
        
        checkpoint = {
            'actor': self.policy_net.state_dict(),
            'critic': self.value_net.state_dict(),
            'obs_rms': all_obs_rms,
        }
        torch.save(checkpoint, model_path)
    

Setup envs, helps collect the normalize observation.

In [51]:
from helpers.envs import make_sync_vec, AutoresetMode


num_envs = 2
env_id = 'Walker2d-v5'

envs = make_sync_vec(env_id, 
                    num_envs=num_envs, 
                    wrappers=(gym.wrappers.RecordEpisodeStatistics, 
                              gym.wrappers.ClipAction,
                              gym.wrappers.NormalizeObservation,),
                    autoreset_mode=AutoresetMode.SAME_STEP)

agent = TRPOAgent(envs, 
                n_steps=1024, 
                gae_lambda=0.97, 
                gamma=0.99, 
                ent_coef=0.00, 
                fc_units=128,
                activation=nn.ReLU(),
                # TRPO specific
                cg_maxsteps=15,
                cg_damping=0.1,
                target_kl=0.01,
                line_search_step_decay=0.8,

                # Valie function parameters
                n_critic_updates=15,
                lr=1e-3,
                # Others 
                device=DEVICE, 
                normalize_advantage=True,
                min_best_rewards=2000.0)

agent.train(total_timesteps=1e6, log_interval=5) 

Updates 5, num timesteps 10240/1000000.0 Mean Length 27.0
Policy 0.0648
Last 20 training episodes: mean/median reward 6.7/3.0, min/max reward -10.1/41.2
Updates 10, num timesteps 20480/1000000.0 Mean Length 53.1
Policy 0.0713
Last 20 training episodes: mean/median reward 34.4/15.1, min/max reward -6.7/202.4
Updates 15, num timesteps 30720/1000000.0 Mean Length 95.4
Policy 0.0662
Last 20 training episodes: mean/median reward 102.8/40.6, min/max reward -5.2/361.1
Updates 20, num timesteps 40960/1000000.0 Mean Length 136.8
Policy 0.0730
Last 20 training episodes: mean/median reward 156.2/83.1, min/max reward 24.7/453.6
Updates 25, num timesteps 51200/1000000.0 Mean Length 164.4
Policy 0.0706
Last 20 training episodes: mean/median reward 220.2/230.2, min/max reward 67.5/367.6
Updates 30, num timesteps 61440/1000000.0 Mean Length 217.8
Policy 0.0745
Last 20 training episodes: mean/median reward 278.2/276.1, min/max reward 46.9/581.9
Updates 35, num timesteps 71680/1000000.0 Mean Length 160.

## Evaluate, make video 

In [48]:
import os
import imageio
import numpy as np
from IPython.display import Video, display, HTML

def record_video(env, policy, out_directory, out_name, fps=60, min_reward=4000):
    """
    Generate a replay video of the agent and display it in the notebook.
    :param env: Environment to record.
    :param policy: Policy used to determine actions.
    :param out_directory: Path to save the video.
    :param fps: Frames per second.
    """
    images = []
    done = False
    obs, info = env.reset()
    img = env.render()

    times = 0
    total_reward = 0
    length = 0
    while times != 3:
        # Preprocess the observation, set input to network to be difference
        state = torch.tensor(obs, dtype=torch.float32)

        # calculate actions and values
        action, _, _, _  = policy(state, deterministic=True)
        selected_actions = action.detach().numpy().clip(
                    env.action_space.low, 
                    env.action_space.high)

        obs, reward, terminated, truncated, _ = env.step(selected_actions)
        total_reward += reward
        img = env.render()
        images.append(img)
        length += 1
        if terminated or truncated:
            obs, info = env.reset() 
            print(total_reward)
            if total_reward < min_reward:
                images = images[:-length]
            else:
                times += 1
            total_reward = 0
            length = 0
            
    
    # Save the video
    video_path = os.path.join(out_directory, out_name)
    imageio.mimsave(video_path, [np.array(img) for img in images], fps=fps)
    
    # Display the video in Jupyter notebook
    display(Video(video_path, embed=True, width=640, height=480))

def eval_policy(env, policy, num_episodes = 10):
    # Store rewards for each episode
    episode_rewards = []

    # Evaluation loop
    for episode in range(num_episodes):
        obs, _ = env.reset()
        done = False
        total_reward = 0  # Keep track of total reward in the episode
        while not done:
            state = torch.tensor(obs, dtype=torch.float32)
            # Select the action using the trained model
            action, _, _, _  = policy(state, deterministic=True)
            # print(action)
            selected_actions = action.detach().numpy()
  
            # Step the environment
            obs, reward, terminated, truncated, info = env.step(selected_actions)
            total_reward += reward
            done = terminated or truncated
        
        episode_rewards.append(total_reward)  # Store the total reward for the episode
        print(f"Episode {episode + 1}: Total Reward = {total_reward}")

    # Calculate mean and standard deviation
    mean_reward = np.mean(episode_rewards)
    std_reward = np.std(episode_rewards)

    print(f"Mean Reward: {mean_reward}, Standard Deviation: {std_reward}")

In [None]:
from gymnasium.wrappers import NormalizeObservation

env = gym.make('Walker2d-v5', render_mode='rgb_array')
env = NormalizeObservation(env,)
env.update_running_mean = False

# Load the saved normalization statistics
checkpoint = torch.load('./runs/Walker2d-v5/TRPO/lr0.001_nenvs2_nsteps1024_gamma0.99_gae0.97_ent0.0_kl0.01_cg15_damping0.1_activationReLU_1740834184/torch-best.model',  weights_only=False)
# checkpoint = torch.load('./runs/Walker2d-v5/TRPO/lr0.001_gamma0.99_gae0.97_ent0.0_kl0.01_cg15_damping0.1_1740757321/torch-best.model',  weights_only=False)
all_obs_rms = checkpoint['obs_rms']

# Average mean and variance across all environments
mean = sum(obs_rms.mean for obs_rms in all_obs_rms) / len(all_obs_rms)
var = sum(obs_rms.var for obs_rms in all_obs_rms) / len(all_obs_rms)
# print(mean, var)
# Assign the averaged statistics to the evaluation environment
env.obs_rms.mean = mean #all_obs_rms[0].mean
env.obs_rms.var = var # all_obs_rms[0].var

eval_model = Policy(env)
eval_model.load_state_dict(checkpoint['actor'])
# Ensure the model is in evaluation mode
eval_model.eval()

eval_policy(env, eval_model, num_episodes=10)

Episode 1: Total Reward = 4638.5414179870995
Episode 2: Total Reward = 4502.5305828268565
Episode 3: Total Reward = 4586.404502305798
Episode 4: Total Reward = 4592.828366768639
Episode 5: Total Reward = 4502.145291478862
Episode 6: Total Reward = 4640.295511547828
Episode 7: Total Reward = 4526.345484514717
Episode 8: Total Reward = 4597.64644321398
Episode 9: Total Reward = 4602.105123397844
Episode 10: Total Reward = 4580.638349752285
Mean Reward: 4576.948107379391, Standard Deviation: 47.87191732923976


: 

In [49]:
record_video(env, eval_model, './videos', 'trpo_walker2d-v5.mp4',min_reward=4000)

4714.948839794251
4725.3723102399035
4716.47982012971


<video width="640" height="480" controls>
  <source src="../assets/videos/trpo_walker2d-v5.mp4" type="video/mp4">
</video>

### Offtopic

In [6]:
def get_flat_params_from(model):
    """
    Extract all parameters from the model as a single flattened tensor.
    """
    flat_params = torch.cat([param.data.view(-1) for param in model.parameters()])
    if flat_params.numel() == 0:
        raise ValueError("Model has no parameters to flatten.")
    return flat_params


test_value = Value(17, 64)
x =test_value(torch.randn(1,17))
x.backward()
print([param.shape for param in test_value.parameters()])
print(get_flat_params_from(test_value).shape)
a = 0
for param in test_value.parameters():
    b = int(np.prod(list(param.size())))
    print(b)
    a += b
    print(param.grad.shape)
print(a)

[torch.Size([64, 17]), torch.Size([64]), torch.Size([64, 64]), torch.Size([64]), torch.Size([1, 64]), torch.Size([1])]
torch.Size([5377])
1088
torch.Size([64, 17])
64
torch.Size([64])
4096
torch.Size([64, 64])
64
torch.Size([64])
64
torch.Size([1, 64])
1
torch.Size([1])
5377



The **Fisher Vector Product (FVP)** is a key step in **TRPO** to approximate the curvature of the policy's objective function. It involves computing the product of the Fisher Information Matrix (FIM) with a vector $ v $, which is essential for solving the constrained optimization problem in TRPO.

Here’s a step-by-step breakdown of the function and the corresponding mathematical formulas.

---

### **What is the Fisher Information Matrix (FIM)?**

The FIM is defined as:

$$
F = \mathbb{E} \left[ \nabla_\theta \log \pi(a \mid s; \theta) \nabla_\theta \log \pi(a \mid s; \theta)^T \right]
$$

Where:
- $ \theta $: Policy parameters.
- $ \pi(a \mid s; \theta) $: The policy’s probability distribution for action $ a $ given state $ s $.

The FIM is a positive semi-definite matrix that captures the second-order curvature of the KL divergence. In TRPO, it is used to constrain the policy update via the trust region.

---

### **Why Do We Need the Fisher Vector Product?**

Directly computing the FIM is computationally expensive because:
1. It involves storing a large matrix (size: $ |\theta| \times |\theta| $, where $ |\theta| $ is the number of parameters).
2. Inverting the FIM is even more expensive.

Instead, we use the **Fisher Vector Product (FVP)**:
$$
Fv = \left[ \nabla_\theta^2 D_{\text{KL}}(\pi_{\text{old}} \,||\, \pi_{\text{new}}) \right] v
$$

This computes the effect of multiplying the FIM $ F $ with a vector $ v $ without explicitly constructing $ F $. It leverages automatic differentiation to efficiently compute $ Fv $.

---

### **Step-by-Step Explanation of the Code**

```python
def Fvp(v):
    kl = kl_fn()
```

#### **Step 1: Compute the KL Divergence**

$$
D_{\text{KL}}(\pi_{\text{old}} \,||\, \pi_{\text{new}})
$$

This is the KL divergence between the old policy ($ \pi_{\text{old}} $) and the new policy ($ \pi_{\text{new}} $), evaluated for a batch of states. It serves as a measure of how much the new policy deviates from the old policy.

---

```python
grads = torch.autograd.grad(kl, policy_net.parameters(), create_graph=True)
flat_grad_kl = torch.cat([grad.view(-1) for grad in grads])
```

#### **Step 2: Compute the Gradient of the KL Divergence**

The gradient of the KL divergence with respect to the policy parameters $ \theta $ is:

$$
g = \nabla_\theta D_{\text{KL}}(\pi_{\text{old}} \,||\, \pi_{\text{new}})
$$

This produces a vector $ g $, where each element corresponds to the partial derivative of the KL divergence with respect to a parameter in $ \theta $.

---

```python
kl_v = (flat_grad_kl * v).sum()
```

#### **Step 3: Compute the Dot Product with $ v $**

The Fisher Vector Product involves the second derivative of the KL divergence (Hessian). To approximate this, we compute:

$$
h_v = \nabla_\theta g \cdot v = \nabla_\theta \left( g^T v \right)
$$

Here:
- $ g^T v $: The dot product of the gradient $ g $ and the vector $ v $.

This step avoids explicitly computing the full Hessian matrix by leveraging automatic differentiation.

---

```python
grads = torch.autograd.grad(kl_v, policy_net.parameters())
flat_grad_grad_kl = torch.cat([grad.contiguous().view(-1) for grad in grads])
```

#### **Step 4: Compute the Second Derivative (Hessian-Vector Product)**

The gradient of $ g^T v $ with respect to $ \theta $ is:

$$
H v = \nabla_\theta^2 D_{\text{KL}}(\pi_{\text{old}} \,||\, \pi_{\text{new}}) \cdot v
$$

This is equivalent to multiplying the Hessian of the KL divergence by the vector $ v $. The result is a vector of the same size as $ v $.

---

```python
return flat_grad_grad_kl + damping * v
```

#### **Step 5: Add Damping**

In practice, the FIM can be ill-conditioned, leading to numerical instability. To address this, a small multiple of the vector $ v $ is added to the result:

$$
Fv = H v + \lambda v
$$

Where:
- $ \lambda $ (damping): A small positive scalar that stabilizes the computation.

This corresponds to using a **regularized Fisher Information Matrix**:
$$
F_{\text{reg}} = F + \lambda I
$$

The damping term ensures that $ F_{\text{reg}} $ is invertible and prevents extreme updates.

---

### **Mathematical Summary**

The function `Fvp` computes the Fisher Vector Product:

$$
Fv = \left( \nabla_\theta^2 D_{\text{KL}}(\pi_{\text{old}} \,||\, \pi_{\text{new}}) \right) v + \lambda v
$$

Steps:
1. Compute the gradient $ g = \nabla_\theta D_{\text{KL}} $.
2. Compute the dot product $ g^T v $.
3. Take the gradient of $ g^T v $ to approximate the Hessian-vector product $ H v $.
4. Add the damping term $ \lambda v $ for stability.

---

### **Why Is This Important for TRPO?**

- The Fisher Vector Product is used to solve the **constrained optimization problem** in TRPO:
  $$
  \max_\theta \mathbb{E}[A(s, a)] \quad \text{subject to} \quad D_{\text{KL}}(\pi_{\text{old}} \,||\, \pi_{\text{new}}) \leq \delta
  $$
- By approximating $ Fv $, TRPO avoids directly computing or storing the full Fisher Information Matrix, making the algorithm computationally efficient.

This technique is essential for scaling TRPO to high-dimensional policies with many parameters. Let me know if you'd like more clarifications or examples! 🚀