In [None]:
!pip install seaborn

In [None]:
import glob
import json
import os
import pandas as pd
import re
import seaborn as sns

from collections import OrderedDict
from glom import glom

class tasks:
    default = ["lambada", "piqa", "hellaswag", "winogrande", "mathqa", "pubmedqa"]
    ppl = ["wikitext"]
    blimp = [
        "blimp_adjunct_island", "blimp_anaphor_gender_agreement", "blimp_anaphor_number_agreement",
        "blimp_animate_subject_passive", "blimp_animate_subject_trans", "blimp_causative", "blimp_complex_NP__island",
        "blimp_coordinate_structure_constraint_complex_left_branch", "blimp_coordinate_structure_constraint_object_extraction",
        "blimp_determiner_noun_agreement_1", "blimp_determiner_noun_agreement_2", "blimp_determiner_noun_agreement_irregular_1",
        "blimp_determiner_noun_agreement_irregular_2", "blimp_determiner_noun_agreement_with_adj_1",
        "blimp_determiner_noun_agreement_with_adj_2", "blimp_determiner_noun_agreement_with_adj_irregular_1",
        "blimp_determiner_noun_agreement_with_adj_irregular_2", "blimp_distractor_agreement_relational_noun",
        "blimp_distractor_agreement_relative_clause", "blimp_drop_argument", "blimp_ellipsis_n_bar_1",
        "blimp_ellipsis_n_bar_2", "blimp_existential_there_object_raising", "blimp_existential_there_quantifiers_1",
        "blimp_existential_there_quantifiers_2", "blimp_existential_there_subject_raising", "blimp_expletive_it_object_raising",
        "blimp_inchoative", "blimp_intransitive", "blimp_irregular_past_participle_adjectives",
        "blimp_irregular_past_participle_verbs", "blimp_irregular_plural_subject_verb_agreement_1",
        "blimp_irregular_plural_subject_verb_agreement_2", "blimp_left_branch_island_echo_question",
        "blimp_left_branch_island_simple_question", "blimp_matrix_question_npi_licensor_present", "blimp_npi_present_1",
        "blimp_npi_present_2", "blimp_only_npi_licensor_present", "blimp_only_npi_scope", "blimp_passive_1", "blimp_passive_2",
        "blimp_principle_A_c_command", "blimp_principle_A_case_1", "blimp_principle_A_case_2", "blimp_principle_A_domain_1",
        "blimp_principle_A_domain_2", "blimp_principle_A_domain_3", "blimp_principle_A_reconstruction",
        "blimp_regular_plural_subject_verb_agreement_1", "blimp_regular_plural_subject_verb_agreement_2",
        "blimp_sentential_negation_npi_licensor_present", "blimp_sentential_negation_npi_scope", "blimp_sentential_subject_island",
        "blimp_superlative_quantifiers_1", "blimp_superlative_quantifiers_2", "blimp_tough_vs_raising_1",
        "blimp_tough_vs_raising_2", "blimp_transitive", "blimp_wh_island", "blimp_wh_questions_object_gap",
        "blimp_wh_questions_subject_gap", "blimp_wh_questions_subject_gap_long_distance", "blimp_wh_vs_that_no_gap",
        "blimp_wh_vs_that_no_gap_long_distance", "blimp_wh_vs_that_with_gap", "blimp_wh_vs_that_with_gap_long_distance"
    ]

class models:
    all = ['dense_small', 'dense_medium', 'dense_large']

class results_dir:
    rgx_toplevel = re.compile("^(.*)_eval_results_([0-9-]+).json$")
    rgx_plain = re.compile("^[^.]+$")
    rgx_step = re.compile("(^[^.]+)\.global_step(\d+)$")
    rgx_depth = re.compile("^([^.]+)\.([a-z_]*)\.(\d+)$")
    
    model_dtype = pd.CategoricalDtype(
        models.all, 
        ordered=True
    )
    depth_method_dtype = pd.CategoricalDtype(
        ["extra_linear", "final_linear", "logit_lens"]
    )
    
    def __init__(self, path):
        self.path = path
    
    @staticmethod
    def rchop(s, suffix):
        if suffix and s.endswith(suffix):
            return s[:-len(suffix)]
        return s

    @staticmethod
    def canonical_model_name(name):
        return results_dir.rchop(name, "_checkpoints")

    @classmethod
    def parse_results_file_name(cls, file_name):
        res = OrderedDict()
        m = cls.rgx_toplevel.match(file_name)
        if m is None:
            print("WARNING: cannot parse results file '{}'".format(file_name))
            return None

        run_id = m[1]
        res['path'] = file_name
        res['timestamp'] = m[2]

        m_step = cls.rgx_step.match(run_id)
        if m_step is not None:
            res['model'] = cls.canonical_model_name(m_step[1])
            res['step'] = int(m_step[2])
            return res

        m_depth = cls.rgx_depth.match(run_id)
        if m_depth is not None:
            res['model'] = cls.canonical_model_name(m_depth[1])
            res['layer'] = int(m_depth[3])
            res['depth_method'] = m_depth[2]
            return res

        m_plain = cls.rgx_plain.match(run_id)
        if m_plain is not None:
            res['model'] = cls.canonical_model_name(m_plain[0])
            return res

        print("WARNING: cannot parse run id '{}'".format(run_id))
        return None

    def as_df(self):
        dict_list = []
        for file_path in glob.glob(os.path.join(self.path, "*.json")):
            # Parse the file name and add
            meta = self.parse_results_file_name(os.path.basename(file_path))
            if meta is None:
                continue

            # Read the json file into a data frame

            with open(file_path) as f:
                try:
                    result_json = json.load(f)
                except:
                    print("WARNING: cannot load file '{}'".format(file_path))
                    continue

            records = []
            for task in result_json.keys():
                for metric in result_json[task]:
                    record = dict()
                    record["task"] = task
                    record["metric"] = metric
                    record["value"] = result_json[task][metric]

                    meta_pos = 0
                    for meta_key in meta:
                        record[meta_key] = meta[meta_key]
                    dict_list.append(record)

        res = pd.json_normalize(dict_list)

        if 'lens' not in res:
            res['lens'] = None

        def set_col_type(col_name, t):
            if col_name in res:
                res[col_name] = res[col_name].astype(t)
        set_col_type('path', str)
        set_col_type('model', self.model_dtype)
        set_col_type('step', pd.Int64Dtype())
        set_col_type('layer', pd.Int64Dtype())
        set_col_type('depth_method', self.depth_method_dtype)
        set_col_type('task', str)
        set_col_type('metric', str)
        set_col_type('value', float)

        return res

class plot:
    @staticmethod
    def display_name(axis, value):
        if axis == "depth_method":
            if value == "extra_linear":
                return "tuned extra projection"
            if value == "final_linear":
                return "tuned unembedding"
            elif value == "logit_lens":
                return "logit lens"
        if axis == "model":
            return value.replace('_', ' ')
        return value

    @staticmethod
    def metric_display_name(metric):
        if metric == "acc":
            return "accuracy"
        if metric == "bits_per_byte":
            return "bits per byte"
        return metric

    @staticmethod
    def by_axis2(df, metric, axes, x_axis):
        display_metric = plot.metric_display_name(metric)
        
        expr = df[x_axis].notnull() & (r["metric"] == metric)
        lists = 0
        hue = None
        title = None
        
        for axis, constraint in axes:
            if type(constraint) is not list:
                expr = expr & (r[axis] == constraint)
                
                constraint_display = plot.display_name(axis, constraint)
                if title is None:
                    title = constraint_display
                else:
                    title = title + ", " + constraint_display
            else:
                expr = expr & (r[axis].isin(constraint))
                hue = axis
                lists = lists + 1
        if lists > 1:
            raise ValueError("At most one of the axes should be a list")

        df = df[expr].rename(columns={"value": display_metric})
        sns.lineplot(data=df, x=x_axis, y=display_metric, hue=hue, marker='o').set_title(title)
        return df

    @staticmethod
    def by_step(df, metric, model, task):
        return plot.by_axis2(df, metric, [("model", model), ("task", task)], "step")

    @staticmethod
    def by_layer(df, metric, model, task, depth_method=["logit_lens", "extra_linear", "final_linear"]):
        return plot.by_axis2(df, metric, [("model", model), ("task", task), ("depth_method", depth_method)], "layer")

    @staticmethod
    def by_scale(df, metric, tasks):
        display_metric = plot.metric_display_name(metric)

        df = df[(df["lens"].isnull() & df["step"].isnull() & df["layer"].isnull() & r["task"].isin(tasks)) & (r["metric"] == metric)]
        df = df.rename(columns={"value": display_metric})
        sns.lineplot(data=df, x="model", y=display_metric, hue="task", marker='o')
        return df

In [None]:
r = results_dir("/mnt/ssd-1/igor/gpt-neox/results").as_df()

In [None]:
sns.set(rc = {'figure.figsize':(15,8)})

In [None]:
d = plot.by_layer(r, "acc", "dense_small", tasks.default, "extra_linear")

In [None]:
d = plot.by_layer(r, "bits_per_byte", "dense_small", tasks.ppl, "extra_linear")

In [None]:
d = plot.by_layer(r, "bits_per_byte", models.all, "wikitext", "extra_linear")

In [None]:
d = plot.by_layer(r, "acc", models.all, "lambada", "extra_linear")

In [None]:
d = plot.by_layer(r, "acc", "dense_small", "hellaswag")

In [None]:
d = plot.by_layer(r, "bits_per_byte", "dense_small", "wikitext")

In [None]:
d = plot.by_layer(r, "acc", "dense_small", "piqa")

In [None]:
d = plot.by_layer(r, "acc", "dense_small", "lambada")

In [None]:
d = plot.by_layer(r, "acc", "dense_small", "mathqa")

In [None]:
d = plot.by_step(r, "acc", "dense_medium", tasks.default)

In [None]:
d = plot.by_step(r, "bits_per_byte", "dense_medium", "wikitext")

In [None]:
d = plot.by_step(r, "bits_per_byte", models.all, "wikitext")

In [None]:
d = plot.by_step(r, "acc", models.all, "lambada")

In [None]:
d = plot.by_scale(r, "acc", tasks.default)

In [None]:
d = plot.by_scale(r, "bits_per_byte", tasks.ppl)