In [3]:
import json
import plotly
import plotly.graph_objs as go

from typing import List, Tuple

In [4]:
import os

dirname = "logs/rnd_training"

def get_data(filepath:str, metric:str)-> Tuple[List[int], List[float]]:
    steps = []
    values = []
    with open(filepath, 'r') as f:
        for line in f:
            data = json.loads(line)
            if metric in data:
                steps.append(data['step'])
                values.append(data[metric])
    return steps, values

def plot(traces:List[str],  metric:str, labels:List[str]|None=None, min_steps:int=0, max_steps:int|None=None):
    if labels is None:
        labels = traces
    assert len(traces) == len(labels)

    data = {}
    for trace, label in zip(traces, labels):
        filepath = os.path.join(dirname, trace, "metrics.jsonl")
        steps, values = get_data(filepath, metric)
        if max_steps is not None:
            filtered = [(s,v) for s,v in zip(steps, values) if (max_steps is None or s <= max_steps) and s >= min_steps]
            steps, values = zip(*filtered) if filtered else ([], [])
        data[label] = (steps, values)
    
    fig = go.Figure()
    for label, (steps, values) in data.items():
        fig.add_trace(go.Scatter(x=steps, y=values, mode='lines+markers', name=label))
    fig.update_layout(title=f'{metric} over Update Steps', xaxis_title='Steps', yaxis_title=metric)
    fig.show()






### First, total rewards and deepmath reward vs baseline

In [7]:
traces = ["OSS-vanilla", "oss-20B-both-pos-0.25-normalised", "oss-20B-both-pos-0.4", "oss-20B-deep-both-0.25", "oss-20B-deep-both-0.4"]
labels = ["Vanilla (No curiosity)", "α=0.25,d=2", "α=0.4,d=2", "α=0.25,d=3", "α=0.4,d=3"]


In [8]:
plot(traces, metric="env/deepmath/correct", labels=labels)

### Entropy for this setup

In [9]:
plot(traces, metric="optim/entropy", labels=labels )

### Total tokens and rewards for Llama

In [10]:
llama_traces = ["Llama-3B-vanilla", "Llama-3B-both-pos-0.25-normalised"]
llama_labels = ["Vanilla (No curiosity)", "α=0.25,d=2"]
plot(llama_traces, metric="env/deepmath/correct", labels=llama_labels)

In [11]:
plot(llama_traces, metric="env/deepmath/ac_tokens_per_turn", labels=llama_labels)