In [None]:
import numpy as np

### 3.1. Task Specific Metrics

In [2]:

class TaskMetrics:
    def __init__(self, env, model, num_episodes=100, success_threshold=200):
        """
        Initialize the metrics calculation class.
        
        Parameters:
        - env: Gym environment.
        - model: Trained RL model.
        - num_episodes: Number of episodes to evaluate the model on.
        - success_threshold: Reward threshold to consider an episode successful.
        """
        self.env = env
        self.model = model
        self.num_episodes = num_episodes
        self.success_threshold = success_threshold
        self.results = []

    def evaluate(self):
        """
        Run the model in the environment for a set number of episodes
        and collect metrics.
        """
        total_rewards = []
        success_count = 0
        total_distance = 0

        for episode in range(self.num_episodes):
            obs = self.env.reset()
            episode_reward = 0
            episode_distance = 0
            done = False

            while not done:
                # Predict the action from the model
                action, _ = self.model.predict(obs, deterministic=True)
                obs, reward, done, info = self.env.step(action)
                
                # Accumulate metrics
                episode_reward += reward
                episode_distance += info.get('distance', 0)  # Assuming 'distance' in info

            total_rewards.append(episode_reward)
            total_distance += episode_distance
            if episode_reward >= self.success_threshold:
                success_count += 1

        # Store results
        self.results = {
            "mean_reward": np.mean(total_rewards),
            "success_rate": success_count / self.num_episodes,
            "total_distance": total_distance / self.num_episodes,
            "reward_variance": np.var(total_rewards),
        }

    def get_results(self):
        """
        Return the computed metrics.
        """
        if not self.results:
            raise ValueError("Run `evaluate` method before fetching results.")
        return self.results
