In [None]:
%load_ext autoreload
%autoreload 2

from openweights import OpenWeights
from openweights.utils import flatten, compare
import pandas as pd
from dotenv import load_dotenv

load_dotenv()
client = OpenWeights()


In [None]:
parent = 'unsloth/Llama-3.3-70B-Instruct-bnb-4bit'

# jobs = client.jobs.find(model=parent, merge_before_push=False)
jobs = client.jobs.find(merge_before_push=False)
jobs = [job for job in jobs if job['status'] == 'completed']

df = pd.DataFrame([flatten(job) for job in jobs])
df

In [None]:
from matplotlib import pyplot as plt
df['params.model'] = df['params.model'].apply(lambda i: i.split('/')[-1])
fig = compare(df, x='params.learning_rate', y='outputs.eval_loss', subplot_rows=None)
plt.show()

In [12]:
import numpy as np
import matplotlib.pyplot as plt


def get_last_completed_run(job_id):
    runs = client.runs.list(job_id=job_id)
    runs = [run for run in runs if run['status'] == 'completed']
    return runs[-1]


def get_events_from_run(run):
    events = client.events.list(run_id=run['id'])
    return events


def get_events_from_job(job_id):
    run = get_last_completed_run(job_id)
    return get_events_from_run(run)


def plot_events(jobs, metric, x='step', color='learning_rate', columns='model'):
    models = sorted(set([job['params']['model'] for job in jobs]))
    n_models = len(models)
    
    # Calculate figure width based on number of columns
    width_per_col = 5  # width in inches per subplot
    fig_width = max(width_per_col * n_models, 8)  # minimum width of 8 inches
    
    # Create figure and axes with shared y-axis
    fig, axes = plt.subplots(1, n_models, figsize=(fig_width, 5), sharey=True)
    if n_models == 1:
        axes = [axes]  # Make axes iterable when there's only one subplot
    
    # Get unique color values and create color mapping
    c_values = sorted(set([job['params'][color] for job in jobs]))
    print('colors', c_values)
    n_colors = len(c_values)
    colors = plt.cm.viridis(np.linspace(0, 1, n_colors))
    colors = {c: colors[i] for i, c in enumerate(c_values)}
    
    # Store lines for legend
    legend_lines = []
    legend_labels = []
    
    # Plot each model
    for i, model in enumerate(models):
        # Get all jobs for this model sorted by color parameter
        jobs_model = [job for job in jobs if job['params']['model'] == model]
        jobs_model = sorted(jobs_model, key=lambda x: x['params'][color])
        
        for job in jobs_model:
            events = get_events_from_job(job['id'])
            df_events = pd.DataFrame([event['data'] for event in events])
            df_events = df_events.dropna(subset=[metric])
            
            line = axes[i].plot(df_events[x], df_events[metric], 
                              color=colors[job['params'][color]], 
                              alpha=0.8)[0]
            
            # Add to legend if the color is not already in the legend
            if f"{color}={job['params'][color]}" not in legend_labels:
                legend_lines.append(line)
                legend_labels.append(f"{color}={job['params'][color]}")
        
        # Set title and labels for each subplot
        axes[i].set_title(f"Model: {model}", pad=10)
        axes[i].set_xlabel(x.capitalize())
        axes[i].grid(True, linestyle='--', alpha=0.7)
        
        # Only add y-label to leftmost subplot
        if i == 0:
            axes[i].set_ylabel(metric.capitalize())
    
    # Sort legend lines and labels alphabetically by legend label
    sorted_legend = sorted(zip(legend_labels, legend_lines), key=lambda x: x[0])
    legend_labels, legend_lines = zip(*sorted_legend)
    # Add single legend to the figure
    fig.legend(legend_lines, legend_labels, 
              loc='center right', 
              bbox_to_anchor=(0.98, 0.5),
              title=color.replace('_', ' ').title())
    
    # Adjust layout to prevent overlapping
    plt.tight_layout()
    # Adjust right margin to accommodate legend
    plt.subplots_adjust(right=0.85)
    
    return fig, axes

# Usage example:
# fig, axes = plot_events(jobs, 'loss')
# plt.show()

In [None]:
for job in jobs:
    if not isinstance(job['params']['learning_rate'], str):
        job['params']['learning_rate'] = f"1e{np.log10(job['params']['learning_rate'])}"
        print(job['params']['learning_rate'])


fig, axes = plot_events(jobs, 'loss')
axes[0].set_ylim(0, 1.5)
plt.show()