In [None]:
import wandb
api = wandb.Api()

# Just load seaborn & set theme and the chart looks better:
! pip install seaborn -q
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter
import seaborn as sns
import numpy as np
import matplotlib.colors as cor
from collections import defaultdict 
import pandas as pd
sns.set_theme()

# set matplotlib dpi to 200 to make the images bigger
plt.rcParams['figure.dpi'] = 200
wandb_cache = {}

In [None]:
#_runs_pile_code = api.runs("seperability/seperability-pile-code")
_runs_pile_code_old = api.runs("seperability/seperability-pile-code")
_runs_pile_code_new = api.runs("seperability/new-method-compare")
_runs_pile_code = [ *_runs_pile_code_old, *_runs_pile_code_new ]
_runs_code_python = api.runs("seperability/seperability-code-python")
#_runs_attn = api.runs("seperability/pile-code-attn")
use_fmt_map = False

In [None]:
import re

def extract_name_number_unit(s):
    name = re.search(r'[a-zA-Z ]+', s)
    name = name.group() if name else ''
    
    number = re.search(r'[+-]?\d+(\.\d*)?|\.\d+', s)
    number = float(number.group()) if number else float('inf')
    
    unit = re.findall(r'(?<=[\d.])[a-zA-Z]+', s)
    unit = unit[0] if unit else ''
    unit_value = 0
    if unit.lower() == 'm':
        unit_value = 1
    elif unit.lower() == 'b':
        unit_value = 1000
        
    return (name, number * unit_value, s)


print( extract_name_number_unit('model-1.2M') )

In [None]:
metric_map = {"base": "Top1", "topk": "Top10", "skip": "Skip50-Top1", "topk_skip": "Skip50-Top10", "loss": "Loss", "perplexity": "Perplexity"}
color_map  = {"pile": "tab:orange", "code": "tab:blue", "python": "tab:green"}
dataset_map = {"pile": "Pile", "pile_codeless": "Pile", "python": "Python", "code": "Code"}
fmt_map = {
    "opt-125m": ":y",
    "opt-1.3b": "--y",
    "opt-6.7b": "-yo",
    "galactica-125m": ":C2",
    "galactica-1.3b": "--C2",
    "galactica-6.7b": "-C2s",
    "pythia-160m": ":C4",
    "pythia-1.4b": "--C4",
    "pythia-6.9b": "-C4^",
    "roberta-large": "-rx",
}
fmt_map_loss = {
    "opt": "-yo",
    "galactica": "-C2s",
    "pythia": "-C4^",
    "roberta": "-rx",
}
use_fmt_map = False

def df_append(df, item: dict):
    new_data = pd.DataFrame({ k:[v] for k,v in item.items() })
    df = pd.concat([ df, new_data ], ignore_index=True )
    return df

def is_loss_metric(metric_name):
    if metric_name == "perplexity":
        return True
    if metric_name[-4:] == "loss":
        return True
    return False

def normed(h, dataset, key):
    if key == "perplexity":
        arr = np.array([ float(x) for x in h[f"loss_data/{dataset}/loss"]])
        if np.isnan(arr[0]):
            return np.ones_like(arr)
        return np.exp(arr)/np.exp(arr)[0]
        #return np.exp(h[f"loss_data/{dataset}/loss"])/np.exp(h[f"loss_data/{dataset}/loss"][0])
    if is_loss_metric(key):
        arr = np.array([ float(x) for x in h[f"loss_data/{dataset}/{key}"]])
        if np.isnan(arr[0]):
            return np.ones_like(arr)
        return arr/arr[0]
        #bottom = h[f"loss_data/{dataset}/{key}"][0]
        #return h[f"loss_data/{dataset}/{key}"]/h[f"loss_data/{dataset}/{key}"][0]
    return h[f"accuracy/{dataset}/{key}"]/h[f"accuracy/{dataset}/{key}"][0]

def calculate_area(x, y):
    return np.trapz(y, x)

class WandbRunData:
    def __init__(self, run_name):
        self.run_name = run_name.split("https://wandb.ai/")[-1]
        
        if self.run_name in wandb_cache:
            self.run = api.run(self.run_name)
        else:
            self.run = api.run(self.run_name)
            wandb_cache[run_name] = self.run
        
        self.history = self.run.history()
        self.h = self.history
        c = self.run.config
        
        self.model_name = c["model_size"].split("/")[-1].lower()
        
        max_frac = max( c["ff_frac"], c["attn_frac"] )
        self.frac = self.history["_step"] * max_frac
        
        self.cripple, self.focus = c["cripple"], c["focus"]
        self.cripple_label = dataset_map[self.cripple]
        self.focus_label   = dataset_map[self.focus]
        
        # Get "unique" names
        self.name_set = self.focus_label + " " + self.cripple_label
        self.name_set_model = self.name_set + " " + self.model_name
        
    def get_metric(self, metric, diff):
        if is_loss_metric(metric):
            return self.get_loss_metric(metric)
        if diff:
            return self.get_diag_metric(metric)
        return self.get_d_metric(metric)
    
    def get_loss_metric(self, metric):
        self.scale = 1
        focus_perf     = normed(self.h, self.focus,   metric) * self.scale
        cripple_perf   = normed(self.h, self.cripple, metric) * self.scale

        d_focus, d_cripple = focus_perf, cripple_perf
        area = calculate_area(focus_perf, cripple_perf)*2 / (self.scale**2)
        return focus_perf, cripple_perf, area
    
    def get_diag_metric(self, metric):
        self.scale = 100
        focus_perf     = normed(self.h, self.focus,   metric) * self.scale
        cripple_perf   = normed(self.h, self.cripple, metric) * self.scale
        
        d_cripple, d_focus = self.scale-cripple_perf, self.scale-focus_perf
        area = ( calculate_area(d_focus, d_cripple) - self.scale**2/2 )/(self.scale**2)
        return d_cripple, d_focus, area
     
    def get_d_metric(self, metric):
        self.scale = 100
        focus_perf     = normed(self.h, self.focus,   metric) * self.scale
        cripple_perf   = normed(self.h, self.cripple, metric) * self.scale
        
        d_cripple, d_focus = self.scale-cripple_perf, self.scale-focus_perf
        area = calculate_area(d_focus, d_cripple-d_focus)*2 / (self.scale**2)
        return d_cripple, d_focus, area
     
    def get_max_diff(self, metric, reversed=False):
        self.scale = 100
        focus_perf     = normed(self.h, self.focus,   metric) * self.scale
        cripple_perf   = normed(self.h, self.cripple, metric) * self.scale
        
        return self.calculate_max_diff(focus_perf, cripple_perf, reversed=reversed)
    
    def calculate_max_diff(self, x, y, reversed=False):
        x, y = np.array(x), np.array(y)
        diff = (y-x) if reversed else (x-y)
        return diff.max() 

def plot_frac_pruned(run_obj, metric):
    r = run_obj
    focus_perf     = normed(r.h, r.focus,   metric)
    cripple_perf   = normed(r.h, r.cripple, metric)
    metric_name = metric_map[metric]

    # Begin plotting
    plt.figure()
    plt.plot(r.frac, focus_perf,   label=r.focus_label, color="tab:orange")
    plt.plot(r.frac, cripple_perf, label=r.cripple_label, color="tab:blue")
    plt.fill_between(x, focus_perf, cripple_perf, color="tab:purple", alpha=0.2)
    
    # Add details
    plt.xlim(-0.01, 1)
    plt.ylim(-0.01, None)
    plt.xlabel("Fraction of Model Pruned")
    plt.ylabel("Fraction of Original Accuracy")
    plt.title(f"{metric_name} Accuracy ({r.model_name})")
    plt.legend()
    
    cripple_area = calculate_area(r.frac, cripple_perf)
    focus_area   = calculate_area(r.frac, focus_perf)
    area_ratio = (focus_area-cripple_area)/focus_area
    print(metric_name, "%.3f" % area_ratio)
    
    return area_ratio

In [None]:
def plot_metric(run_obj, metric, diff=True, label=None):
    if is_loss_metric(metric):
        return plot_loss_metric(run_obj, metric, diff, label)
    return plot_perf_metric(run_obj, metric, diff, label)
    
def plot_perf_metric(run_obj, metric, diff=True, label=None): 
    r = run_obj
    d_cripple, d_focus, area = r.get_metric(metric, diff)
    if label is None:
        label = r.model_name
   
    # Plot differences in ability 
    #fmt = fmt_map[r.model_name] if use_fmt_map else []
    fmt = fmt_map[r.model_name] if r.model_name in fmt_map else None
    
    metric_name, scale = metric_map[metric], r.scale
    plt.figure(r.name_set+metric_name)
    #plt.title(f"{metric_name} Accuracy, {r.cripple_label} Cripple {r.focus_label} Focus")
    #plt.xlim(-scale*0.01, scale*1.01)
    #plt.ylim(-0.01, 1.01)
    plt.xlabel(f"Drop in {r.focus_label} {metric_name} Accuracy (∆%)")
    
    if diff:
        plt.ylabel(f"Differential Drop in {r.cripple_label} Performance (∆%)")
        if fmt:
            plt.plot(d_focus, d_cripple-d_focus, fmt, label=label)
        else:
            plt.plot(d_focus, d_cripple-d_focus, label=label)
        plt.plot([0, scale], [0,     0], color="darkgray", linestyle="--", alpha=0.2)
        plt.plot([0, scale], [scale, 0], color="darkgray", linestyle="--", alpha=0.2)
    else:
        #plt.ylabel(f"Drop in {r.cripple_label} Performance (%)")
        plt.ylabel(f"Drop in {r.cripple_label} {metric_name} Accuracy (∆%)")
        if fmt:
            plt.plot(d_focus, d_cripple, fmt, label=label, markersize=3)
        else:
            plt.plot(d_focus, d_cripple, label=label)
        plt.plot([0, scale], [scale, scale], color="darkgray", linestyle="--", alpha=0.2)
        plt.plot([0, scale], [0,     scale], color="darkgray", linestyle="--", alpha=0.2)
        
    plt.legend()
    
    return area

def plot_loss_metric(run_obj, metric, diff=False, label=None):
    r = run_obj
    print(r.model_name)
    d_cripple, d_focus, area = r.get_metric(metric, diff)
    scale = min([ max([*d_cripple, *d_focus, 1]), 100 ])
    x_scale = min([ max([*d_cripple, *d_focus, 1]), 130 ])
    y_scale = min([ max([*d_cripple, *d_focus, 1]), 130 ])
    if label is None:
        label = r.model_name
   
    # Plot differences in ability 
    metric_name = metric_map[metric]
    fig = plt.figure(r.name_set+metric_name)
    plt.title(f"{metric_name} Accuracy, {r.cripple_label} Cripple {r.focus_label} Focus")
   
    #fmt = fmt_map[r.model_name] if use_fmt_map else None
    fmt = fmt_map[r.model_name] if r.model_name in fmt_map else None
    #if r.model_name not in ["opt-6.7b", "galactica-6.7b", "pythia-6.7b", "roberta-large"]:
    #    return 0
    
    # Custom formatter function for y-axis
    def times_formatter(x, pos):
        return f'{x:.0f}x'    
    plt.xlabel(f"Increase in {r.cripple_label} {metric_name}")
    plt.ylabel(f"Increase in {r.focus_label} {metric_name}")
    if fmt is not None:
        plt.loglog(d_cripple, d_focus, fmt, label=label, base=2)
    else:
        plt.loglog(d_cripple, d_focus, label=label, base=2)
    if scale == 100:
        plt.loglog([1,100], [1,100], base=2, color="darkgray", linestyle="--", alpha=0.2)
        plt.xlim([0.6, x_scale*1.01])
        plt.ylim([0.6, y_scale*1.01])
    ax = fig.gca()
    ax.xaxis.set_major_formatter(FuncFormatter(times_formatter))
    ax.yaxis.set_major_formatter(FuncFormatter(times_formatter))
    
    plt.legend()
    
    return area

In [None]:
def plot_best_metrics(run_names, diff=True, best_run_obj=None):
    #metrics_to_plot = ["base", "topk", "skip", "topk_skip", "loss", "perplexity"] # metric_map.keys()
    metrics_to_plot = ["base", "perplexity"] # metric_map.keys()
    main_metric = "base"
  
    # GET RUN DATA
    if best_run_obj is None: 
        best_run_obj = defaultdict(list)
        df_all = pd.DataFrame()
        for run_name in run_names:
            run_obj  = WandbRunData(run_name)
            run_type = run_obj.name_set_model
            model_name = run_obj.run.config["model_size"].split("/")[-1]
            #run_summary, run_areas = {"name": run_obj.run.name}, {}
            run_summary, run_areas = {"Model": model_name}, {}
            for metric in metrics_to_plot:
                _, _, area = run_obj.get_metric(metric, diff)
                run_areas[metric] = area
                if not is_loss_metric(metric):
                    run_summary[metric_map[metric]] = area
                #run_summary["url"] = run_obj.run.url
            df_all = df_append( df_all, run_summary )
            best_run_obj[run_type].append( (run_areas[main_metric], run_obj, run_summary) )

    sort_key = lambda k: extract_name_number_unit(k[0]) 
   
    # PLOT RUNS BY BEST METRIC 
    df = pd.DataFrame() 
    for run_type, run_list in sorted(best_run_obj.items(), key=sort_key):
        #print( f"{run_type}, {run_list}")
        _, run_obj, run_summary = \
            run_list[ np.argmax([ a for (a,_b,_c) in run_list]) ]
        for metric in metrics_to_plot:
            area = plot_metric(run_obj, metric, diff)
            #print(run_obj.model_name, "%.3f" % area)
        df = df_append( df, run_summary )
    
    print(df.to_string(float_format=lambda x:"%.3f"%x))
    plt.show()
    
    return best_run_obj
    #print(df_all.to_string(float_format=lambda x:"%.3f"%x))
   
def plot_all_metrics(run_name_label, diff=False, run_list=None): 
    #print("model_name", metric_map.values())
    #metrics_to_plot = ["base", "topk", "skip", "topk_skip", "loss", "perplexity"] # metric_map.keys()
    metrics_to_plot = ["base", "perplexity"] # metric_map.keys()
   
    if run_list is None: 
        run_list = []
        df_all = pd.DataFrame()
        for (run_name, run_label) in run_name_label:
            run_obj  = WandbRunData(run_name)
            run_type = run_obj.name_set_model
            model_name = run_obj.run.config["model_size"].split("/")[-1]
            #run_summary, run_areas = {"name": run_obj.run.name}, {}
            run_summary = {"Model": model_name, "label": run_label}
            for metric in metrics_to_plot:
                score = run_obj.get_max_diff(metric)
                if not is_loss_metric(metric):
                    run_summary["score"+metric_map[metric]] = score
                #run_summary["url"] = run_obj.run.url
            df_all = df_append( df_all, run_summary )
            run_list.append( (run_obj, run_summary) )
    
    df = pd.DataFrame()
    for (run_obj, run_summary) in run_list:
        for metric in metrics_to_plot:
            area = plot_metric(run_obj, metric, diff=False, label=run_summary["label"])
            #print(run_obj.model_name, "%.3f" % area)
        df = df_append( df, run_summary )
    
    print(df.to_string(float_format=lambda x:"%.3f"%x))
    plt.show()
    
    return run_list
    #print(df_all.to_string(float_format=lambda x:"%.3f"%x))   
 
def filter_crashed_runs(runs, run_limit=None):
    runs_filtered = []
    for run in runs:
        if run_limit and len(runs_filtered) > run_limit:
            break
        if run.state == "crashed":
            continue
        if not "_step" in run.summary or run.summary["_step"] < 1:
            continue
        if run.config["ff_frac"] <= 0:
            continue
        runs_filtered.append(run.url)
    return runs_filtered

### Bias and Iterative vs Non-Iterative

In [None]:
# Plot specific runs for appendix
runs_bias_offset = [
   ("https://wandb.ai/seperability/pile-code-attn/runs/9v07rssi", "no bias offset"),
   ("https://wandb.ai/seperability/pile-code-attn/runs/ihe734z1", "bias mean offset"),
]
runs_iter = [
   #("https://wandb.ai/seperability/method-compare/runs/v8crnxp8", "iterative"),
   #("https://wandb.ai/seperability/method-compare/runs/91w9ssrz", "single step"),
   ("https://wandb.ai/seperability/method-compare/runs/5svab3oy", "single step"),
   ("https://wandb.ai/seperability/method-compare/runs/20tvinil", "iterative"),
]
runs_value_preout = [
   #("https://wandb.ai/seperability/method-compare/runs/12n1hg1e", "pre-out"),
   ("https://wandb.ai/seperability/method-compare/runs/e2b2a7cv", "value"),
   ("https://wandb.ai/seperability/pile-code-attn/runs/c4ypwu6k", "pre-out")
]
runs_both = [
   ("https://wandb.ai/seperability/method-compare/runs/peepo846", "2% FF 0.5% Attn"),
   ("https://wandb.ai/seperability/method-compare/runs/20tvinil", "2% FF Only")
]
plot_all_metrics(runs_bias_offset)
plot_all_metrics(runs_iter)
plot_all_metrics(runs_value_preout)
plot_all_metrics(runs_both)

## Normal Plots

### Plot A B C

In [None]:
runs_filtered_new = filter_crashed_runs(_runs_pile_code_new, run_limit=None)
new_run_list = plot_best_metrics(runs_filtered_new, False) 

In [None]:
runs_filtered = filter_crashed_runs(_runs_pile_code, run_limit=None)
print(runs_filtered)
plot_best_metrics(runs_filtered, False, run_list) 

In [None]:
runs_filtered = filter_crashed_runs(_runs_code_python, run_limit=None)
plot_best_metrics(runs_filtered, False) 