In [17]:
import pandas as pd
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.express as px
import statsmodels.api as sm
from scipy.optimize import curve_fit
import re
import json
from typing import Dict, List, Tuple

In [57]:
data_files = {
              "Tuned112-2B": "../results/compiled_gemma2b-tuned112.csv",
              "Tuned112-7B": "../results/compiled_gemma7b-tuned112.csv",
              "Pretrained-2B": "../results/compiled_gemma2b-pretrained.csv",
              "Pretrained-7B": "../results/compiled_gemma7b-pretrained.csv",
            }

color_and_dash = {
                "Human": ('rgb(105,105,105)', 'solid'),
                "Tuned112-2B": ('royalblue', 'solid'),
                "Tuned112-7B": ('royalblue', 'solid'),
                "Pretrained-2B": ('skyblue', 'dashdot'),
                "Pretrained-7B": ('skyblue', 'dashdot')
                }

primitives_or = [
    'hg03', 'hg04', 'hg18', 'hg19', 'hg20', 'hg24', 'hg09', 'hg25', 'hg06'
    ]

## P(Human) v. P(Model) Correlation Graph

In [33]:
def make_r2_graph(title: str, mdf: pd.DataFrame, bins: int = 15, show_raw_scatter=False):
    """
    :param str title: Title of heatmap graph
    :param int bins: Number of bins for x-y values heatmap
    :param pd.DataFrame mdf: Accepts Pandas DataFrames as generated from `process_results.py`
    :returns:
        - figure
    """

    def func(x, b, c, d):
        return  b*(x**2) + c *x + d

    all_hscores = []
    all_mscores = []

    for i, row in mdf.iterrows():
        all_hscores += [float(x) for x in re.sub(",", " ", row['hyes'][1:-1]).split()]
        all_mscores += [float(x) for x in re.sub(",", " ", row['myes'][1:-1]).split()]

    all_mscores = np.array(all_mscores)
    all_hscores = np.array(all_hscores)
    
    x =  (np.array(list(range(0, bins+1)))) * 1/bins 
    x[0] += 0.00001
    mscores_binned = np.digitize(all_mscores, x)

    center_mean = []
    center_sd = []
    valid_xs = []

    plot_array_mean = np.ones(len(x)) * -100
    plot_array_sd = np.zeros(len(x))

    for i in np.array(list(range(0, bins+1))):
        if np.sum(mscores_binned == i) > 0:
            std = np.std(all_hscores[mscores_binned == i]) / np.sqrt(np.sum(mscores_binned == i))
            mean = np.mean(all_hscores[mscores_binned == i])

            valid_xs.append(i*1/bins)
            center_mean.append(mean)
            center_sd.append(std)
            plot_array_mean[i] = mean
            plot_array_sd[i] = std

    if show_raw_scatter:
        fig = px.scatter(x=all_mscores, y=all_hscores,
                    labels={'x': 'Model P(Yes)', 'y': "Human P(Yes)"} ,
        )
    else:
        fig = px.density_heatmap(x=all_mscores, y=all_hscores, 
                        labels={'x': 'Model P(Yes)', 'y': "Human P(Yes)"} ,
                        height=500, width=500, 
                        histnorm='percent', color_continuous_scale=px.colors.sequential.dense, 
                        nbinsx=bins, nbinsy=bins,
                        title=title)
        
    popt, pcov = curve_fit(func, valid_xs, center_mean)

    fig.add_trace(go.Scatter(x=[0, 1], y=[0,1], showlegend=False, mode='lines', marker_color='white'))
    
    fig.add_trace(go.Scatter(x=x, y=func(x, *popt), 
                    mode='lines', 
                    line_width=1,
                    showlegend=False,
                    marker_color='slategray'
                    ))

    fig.add_annotation(x=0.9, y=0.05,
            text=f"R<sup>2</sup>={str(np.corrcoef(all_mscores, all_hscores)[0,1] ** 2)[:4]}",
            showarrow=False)

    fig.add_trace(go.Scatter(x=x, y=plot_array_mean, 
                    mode='markers', 
                    showlegend=False,
                    marker_color='slategray'
                    ),)

    fig.add_trace(go.Scatter(x=x, y=plot_array_mean, 
                    error_y=dict(type='data', array=plot_array_sd, visible=True, thickness=1, color='lightslategray'),  
                    marker_color='rgba(255,255,255,0)',
                    mode='markers', 
                    showlegend=False,
                    ))

    fig.update_layout(width=500, height=500)
    fig.update_yaxes(range=[-0.01,1.01], dtick=0.1)
    fig.update_xaxes(range=[-0.01,1.01], dtick=0.1)
    fig.update_layout(template='ggplot2', coloraxis_showscale=False, font_family='times new roman', font_size=14)
    fig.update_layout(
        margin=dict(l=10, r=10, t=30, b=10),
    )

    return fig

In [48]:
plot_item = [
            # 'Tuned112-2B',
            'Tuned112-7B',
            # 'Pretrained-7B',
            # 'Pretrained-2B'
            ][0]
make_r2_graph(plot_item, pd.read_csv(data_files[plot_item]))

## Trajectory Graph

In [49]:
def plot_trajectory_graphs(concepts: list[str], 
                            sources: Dict[str, Tuple[str, str]], 
                            color_dash_dict,
                            end_early=1000,
                            num_cols=3,
                            height_per_row=150,
                            width_per_col=400):

    hdf = pd.read_csv("../results/compiled_humans.csv")                 
    with open('../data/labels_to_readable.json', 'r') as f:
        labels_to_readable = json.load(f)

    num_concepts = len(concepts)
    num_rows = num_concepts//num_cols + 1 if num_concepts % num_cols != 0 else num_concepts//num_cols

    fig = make_subplots(rows=num_rows, cols=num_cols, 
                    subplot_titles=[labels_to_readable[x] for x in concepts], 
                    # subplot_titles = concepts,
                    shared_yaxes=True)


    for i, concept in enumerate(concepts):
        pearsons = {}
        row = (i // num_cols) + 1
        col = (i % num_cols) + 1

        hyes, hno = np.array(hdf[hdf['concepts'] == concept]['hyes']).astype(int), np.array(hdf[hdf['concepts'] == concept]['hno']).astype(int)
        htotals = hyes + hno
        hyes, hno = hyes / htotals, hno / htotals

        answer_idx = (np.array(hdf[hdf['concepts'] == concept]['answers'])).astype(int)
        yes_rate = np.mean(answer_idx)
        hscore = np.vstack((hno, hyes))[answer_idx, np.arange(len(hyes))]

        fig.add_trace(go.Scatter(y=hscore, 
                                    x=np.arange(len(hscore)), 
                                    name='Human', 
                                    line_color=color_dash_dict['Human'][0], 
                                    showlegend=i==0, legendgroup="Human L2"),
                        row=row, col=col)

        fig.add_hline(y=(yes_rate ** 2) + ((1-yes_rate) ** 2), line_color='black', 
                       line_width=2, 
                        line_dash="dot", 
                        row=row, col=col)

        exmax = len(hscore)

        for name, fp in sources.items():
            
            mdf = pd.read_csv(fp)
            mscore = np.array([float(x) for x in re.sub(",", " ", mdf[mdf['concept'] == concept]['mscores'].iloc[0][1:-1]).split()][:end_early])
            myes = np.array([float(x) for x in re.sub(",", " ", mdf[mdf['concept'] == concept]['myes'].iloc[0][1:-1]).split()][:end_early])
            
            fig.add_trace(go.Scatter(y=mscore, 
                                x=np.arange(len(mscore)), 
                                name=name, 
                                line_color=color_dash_dict[name][0], 
                                line_dash=color_dash_dict[name][1],
                                showlegend=i==0, 
                                legendgroup=name),
                        row=row, col=col)

            corrmax = min(len(hyes), len(myes))
            exmax = min(exmax, len(mscore))
            pearson = np.corrcoef(hyes[:corrmax], myes[:corrmax])[0, 1]
            pearsons[name] = pearson ** 2

        annotation = "R²: " + " | ".join([f"{k}={str(v)[:4]}" for k, v in pearsons.items()])
        fig.add_annotation(x=exmax, y=0, text=annotation, xref=f"x{i+1}", yref=f"y{i+1}", xanchor='right',
                                showarrow=False)

        # fig.layout[f'xaxis{"" if i==0 else i}'] = {'range': (0, exmax)}

    fig.update_annotations(font_size=14, font_family='times new roman')
    fig.update_yaxes(showticklabels=True, dtick=0.25, range=[-0.1, 1.1])
    fig.update_layout(template='ggplot2', width=width_per_col*num_cols, height=height_per_row*num_rows)
    fig.update_layout(legend=dict(orientation="h", entrywidth=120, yanchor="top", y=-0.08,
    xanchor="right",
    x=1
    ))
    fig.update_layout(
        margin=dict(l=70, r=50, t=40, b=60),
    )

    return fig

In [50]:
def make_rule_threes(filepath: str, metric=lambda x: x['pyes_corr'] ** 2):
    df = pd.read_csv(filepath)
    df['sort'] = df.apply(metric, axis=1)
    sorted_concepts = df.sort_values('sort', ascending=False)['concept']
    half = len(df) // 2
    return sorted_concepts.iloc[[0, 1, 2, half-1, half, half+1, -3, -2, -1]]

In [58]:
# concepts = make_rule_threes(data['TunFOL'][0])
plot_trajectory_graphs(concepts=primitives_or,
                        sources={k: data_files[k] for k in 
                            # ['Pretrained-2B', 'Tuned112-2B']
                            # ['Pretrained-7B', 'Tuned112-7B']
                            ['Tuned112-2B', 'Tuned112-7B']
                        },
                        color_dash_dict=color_and_dash)