In [2]:
import tensorflow as tf
from tensorboard.backend.event_processing import event_accumulator
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import numpy as np

def create_multi_scalar_animation(logdir, output_file='training_metrics.gif'):
    # Load the event file
    ea = event_accumulator.EventAccumulator(logdir)
    ea.Reload()
    
    # Get all scalar tags
    scalar_tags = [
        'eval/mean_ep_length', 'eval/mean_reward',
        'train/approx_kl', 'train/clip_fraction',
        'train/clip_range', 'train/entropy_loss',
        'train/explained_variance', 'train/learning_rate',
        'train/loss', 'train/policy_gradient_loss',
        'train/value_loss'
    ]
    
    # Create figure with subplots in a grid
    fig = plt.figure(figsize=(20, 15))
    fig.patch.set_facecolor('#1a1a1a')
    plt.style.use('dark_background')
    
    # Calculate grid dimensions
    n_plots = len(scalar_tags)
    n_cols = 3
    n_rows = (n_plots + n_cols - 1) // n_cols
    
    # Store all data and lines
    all_data = {}
    lines = {}
    axes = {}
    
    # Create subplots and store data
    for idx, tag in enumerate(scalar_tags, 1):
        ax = plt.subplot(n_rows, n_cols, idx)
        axes[tag] = ax
        
        # Get data for this tag
        events = ea.Scalars(tag)
        steps = [event.step for event in events]
        values = [event.value for event in events]
        all_data[tag] = {'steps': steps, 'values': values}
        
        # Create line
        line, = ax.plot([], [], lw=2, color='#00ff99')
        lines[tag] = line
        
        # Style the subplot
        ax.set_facecolor('#262626')
        ax.grid(True, linestyle='--', alpha=0.2, color='#888888')
        ax.set_xlim(min(steps), max(steps))
        ax.set_ylim(min(values) - 0.1 * abs(min(values)), 
                   max(values) + 0.1 * abs(max(values)))
        
        # Format title from tag
        title = tag.replace('train/', '').replace('eval/', '').replace('_', ' ').title()
        ax.set_title(title, color='#00ff99', fontsize=12, pad=10)
        
        # Style spines and ticks
        for spine in ax.spines.values():
            spine.set_color('#404040')
        ax.tick_params(colors='#ffffff', labelsize=8)
        
        if idx > (n_rows - 1) * n_cols:  # Only bottom row
            ax.set_xlabel('Steps', color='#ffffff', fontsize=10)
        if idx % n_cols == 1:  # Only leftmost column
            ax.set_ylabel('Value', color='#ffffff', fontsize=10)
    
    plt.tight_layout()
    fig.suptitle('Training Metrics Overview', 
                 color='#00ff99', 
                 fontsize=16, 
                 y=0.95)
    
    def init():
        # Initialize all lines
        for line in lines.values():
            line.set_data([], [])
        return list(lines.values())
    
    def animate(frame):
        # Update all lines
        for tag, line in lines.items():
            steps = all_data[tag]['steps'][:frame]
            values = all_data[tag]['values'][:frame]
            line.set_data(steps, values)
        return list(lines.values())
    
    # Create animation
    anim = animation.FuncAnimation(
        fig, 
        animate, 
        init_func=init,
        frames=min(len(all_data[tag]['steps']) for tag in scalar_tags),
        interval=20,
        blit=True,
        repeat=True
    )
    
    # Save animation
    anim.save(
        output_file,
        writer='pillow',
        fps=60,
        dpi=150)
    plt.close()


run_name_to_plot = "RecurrentPPO_14"
logdir = f"/home/tommy/Downloads/ai-playground/tensorboard_logs/{run_name_to_plot}"
create_multi_scalar_animation(logdir)