In [10]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import plotly.graph_objs as go
import numpy as np
import seaborn as sns

def logit(x, c, eps=0):
    x2 = (x-c)/(1-c)
    logit = np.log((x2+eps)/(1-x2+eps))
    return np.log((x2+eps)/(1-x2+eps)), c

num_fam = 15
start = 0

In [11]:
scenarios = ['IFEval','BBH','MATH Lvl 5','GPQA','MUSR','MMLU-PRO','MMLU','ARC','HellaSwag','Winogrande','TruthfulQA','GSM8K']

In [14]:
lower_bounds = {}
for s in scenarios:
    lower_bounds[s]=pd.read_csv('data/lower_bounds.csv')[s].iloc[0]
    
for scenario in scenarios:
    print(f"************ {scenario} ************")
    
    data = pd.read_csv('data/data_v1.csv')
    data['logS'] = np.log(data['#Params (B)'])
    data['logT'] = np.log(data['Pretraining Data Size (T)'])
    data['logF'] = np.log(data['FLOPs (1E21)'])
    data = data.loc[~data.Instruct,['Model','Family','Instruct','logS','logT','logF']+[scenario]]
    data = data.dropna()
    selected_families = list(data.Family.value_counts().iloc[:num_fam].index)
    data = data.loc[[m in selected_families for m in data.Family]]

    #y,c = logit(np.array(data[scenario]), c=lower_bounds[scenario])
    y = np.array(data[scenario])
    logS = np.array(data['logS'])
    logT = np.array(data['logT'])
    logF = np.array(data['logF'])
    Fam = np.array(data['Family'])
    # Get unique categories in Fam and generate a color palette
    unique_fams = np.unique(Fam)
    palette = sns.color_palette("hsv", len(unique_fams)).as_hex()  # Generate colors and convert to hex
    fam_colors = {fam: palette[i] for i, fam in enumerate(unique_fams)}  # Map each category to a color
    
    # Create the scatter plot traces for each category in Fam
    traces = []
    for fam in unique_fams:
        fam_indices = np.where(Fam == fam)  # Find indices for this category
        trace = go.Scatter3d(
            x=logS[fam_indices],
            y=logT[fam_indices],
            z=y[fam_indices],
            mode='markers',
            marker=dict(
                size=8,
                color=fam_colors[fam],  # Use the color mapped to this category
                opacity=0.8
            ),
            name=fam  # Use the category name for the legend
        )
        traces.append(trace)
    
    # Set up the layout
    layout = go.Layout(
        scene=dict(
            xaxis=dict(title='logS'),
            yaxis=dict(title='logT'),
            zaxis=dict(title='y'),
        ),
        margin=dict(l=0, r=0, b=0, t=0),
    )
    
    # Create the figure
    fig = go.Figure(data=traces, layout=layout)
    
    # Show the figure
    fig.show()

************ IFEval ************


************ BBH ************


************ MATH Lvl 5 ************


************ GPQA ************


************ MUSR ************


************ MMLU-PRO ************


************ MMLU ************


************ ARC ************


************ HellaSwag ************


************ Winogrande ************


************ TruthfulQA ************


************ GSM8K ************
