In [None]:
from tensorboard.backend.event_processing import event_accumulator
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

def calculate_moving_average(values, window_size=50):
    """Calculate moving average of values with specified window size."""
    return pd.Series(values).rolling(window=window_size, min_periods=1).mean()

def visualize_rewards(events_file_path, ma_window=1):
    """
    Visualize only the rollout/ep_rew_mean metric from a TensorBoard events file.
    
    Args:
        events_file_path (str): Path to the events file
        ma_window (int): Window size for moving average calculation
    """
    # Load the events file
    ea = event_accumulator.EventAccumulator(
        events_file_path,
        size_guidance={
            event_accumulator.SCALARS: 0,
            event_accumulator.COMPRESSED_HISTOGRAMS: 0,
            event_accumulator.IMAGES: 0,
            event_accumulator.AUDIO: 0,
            event_accumulator.HISTOGRAMS: 0,
        }
    )
    ea.Reload()

    # Get reward data
    tag = "rollout/ep_rew_mean"
    if tag not in ea.Tags()['scalars']:
        print(f"No {tag} data found in the events file.")
        return None

    scalar_events = ea.Scalars(tag)
    steps = [event.step for event in scalar_events]
    values = [event.value for event in scalar_events]

    # Create figure
    fig, ax = plt.subplots(figsize=(12, 6))

    # Plot original values
    ax.plot(steps, values, 'b-', alpha=0.3, label='Raw rewards')

    # Calculate and plot moving average
    ma_values = calculate_moving_average(values, ma_window)
    ax.plot(steps, ma_values, 'r-', linewidth=2, 
            label=f'Moving average (window={ma_window})')

    # Customize plot
    ax.set_title('Episode Rewards Over Time', fontsize=14)
    ax.set_xlabel('Steps', fontsize=12)
    ax.set_ylabel('Average Episode Reward', fontsize=12)
    ax.grid(True, alpha=0.3)
    ax.legend(fontsize=10)

    plt.tight_layout()
    return fig

if __name__ == "__main__":
    events_file = "PATH-TO-YOUR-EVENTS-FILE"
    fig = visualize_rewards(events_file, ma_window=10)
    if fig is not None:
        plt.show()