In [1]:
from collections import defaultdict
import math
import random
import numpy as np

from delphi.eval.vis_per_token_model import visualize_per_token_category


random.seed(0)

# generate mock data
# model_names = ['llama2-100k', 'llama2-200k', 'llama2-1m', 'llama2-10m', "0"]
model_names = list(range(500))
categories = ['nouns', 'verbs', 'prepositions', 'adjectives']
entries = [200, 100, 150, 300, 100]*100
performance_data = defaultdict()
for i, model in enumerate(model_names):
    performance_data[model] = defaultdict()
    for cat in categories:
        x = [math.log2(random.random()) for _ in range(entries[i])]
        means = np.mean(x)
        err_low = means - np.percentile(x, 25)
        err_hi = np.percentile(x, 75) - means
        performance_data[model][cat] = (-means, err_low, err_hi)


visualize_per_token_category(performance_data, log_scale=True, checkpoint_mode=True)

FigureWidget({
    'data': [{'line': {'width': 0},
              'marker': {'color': '#444'},
              'mode': 'lines',
              'name': 'Upper Bound',
              'showlegend': False,
              'type': 'scatter',
              'uid': '4a40445f-0502-42f2-878b-9a9f03d66717',
              'x': [0, 1, 2, ..., 497, 498, 499],
              'y': array([2.34006592, 2.41241021, 2.57781922, ..., 2.56474203, 2.59573629,
                          2.43304471])},
             {'fill': 'tonexty',
              'fillcolor': 'rgba(68, 68, 68, 0.3)',
              'line': {'width': 0},
              'marker': {'color': '#444'},
              'mode': 'lines',
              'name': 'Lower Bound',
              'showlegend': False,
              'type': 'scatter',
              'uid': 'a08f2c43-1513-4042-bd1e-f1f01e28a0ef',
              'x': [0, 1, 2, ..., 497, 498, 499],
              'y': array([0.93626447, 0.9302987 , 0.99836227, ..., 0.95607835, 0.76146911,
                         

In [2]:
visualize_per_token_category(performance_data, log_scale=True, checkpoint_mode=True, line_metric="Median", line_color='Orange' , shade_color="wheat")

FigureWidget({
    'data': [{'line': {'width': 0},
              'marker': {'color': 'wheat'},
              'mode': 'lines',
              'name': 'Upper Bound',
              'showlegend': False,
              'type': 'scatter',
              'uid': '274f7f1b-21af-41bb-8c00-6fa385439bff',
              'x': [0, 1, 2, ..., 497, 498, 499],
              'y': array([2.34006592, 2.41241021, 2.57781922, ..., 2.56474203, 2.59573629,
                          2.43304471])},
             {'fill': 'tonexty',
              'fillcolor': 'wheat',
              'line': {'width': 0},
              'marker': {'color': 'wheat'},
              'mode': 'lines',
              'name': 'Lower Bound',
              'showlegend': False,
              'type': 'scatter',
              'uid': '050ad540-2443-452c-8f2c-e4d218640318',
              'x': [0, 1, 2, ..., 497, 498, 499],
              'y': array([0.93626447, 0.9302987 , 0.99836227, ..., 0.95607835, 0.76146911,
                          0.81709211])}