# Analyzing Runs

In [124]:
import os
import json
import plotly.graph_objects as go
import plotly.express as px
import numpy as np
import pandas as pd

# MODEL_DIR = "models/good_sweep_1"
MODEL_DIR = "models/good_sweep_1_fixed_256"
MAIN_COLORS = px.colors.qualitative.Vivid
TRANSLUCENT_COLORS = [pastel_color.replace("rgb", "rgba").replace(")", ", 0.5)") for pastel_color in MAIN_COLORS]
DARK_COLORS = px.colors.qualitative.Dark24

mu_data = {}
mu_dirs = [dir.path for dir in os.scandir(MODEL_DIR) 
    if dir.is_dir() and 
    dir.path.split('/')[-1].split('-')[1][14:] == "mu"]

print(MAIN_COLORS)

['rgb(229, 134, 6)', 'rgb(93, 105, 177)', 'rgb(82, 188, 163)', 'rgb(153, 201, 69)', 'rgb(204, 97, 176)', 'rgb(36, 121, 108)', 'rgb(218, 165, 27)', 'rgb(47, 138, 196)', 'rgb(118, 78, 159)', 'rgb(237, 100, 90)', 'rgb(165, 170, 153)']


In [125]:
fig = px.colors.qualitative.swatches()
fig.show()

In [126]:
print(len(mu_dirs))

for mu_dir in mu_dirs:
    try:
        with open(os.path.join(mu_dir, "plot_data.json"), "r") as f:
            name = mu_dir.split('/')[-1]
            full_json = json.load(f)
            mu_data[name] = {key: {
                "d_m_bd": full_json[key]["d_m_bd"],
                "d_m_bc": full_json[key]["d_m_bc"],
                "d_m_js": full_json[key]["d_m_js"],
                "d_m_t": full_json[key]["d_m_t"],
                "d_bd_t": full_json[key]["d_bd_t"],
                "d_bc_t": full_json[key]["d_bc_t"],
                "d_js_t": full_json[key]["d_js_t"],
                "c_m_bd": full_json[key]["c_m_bd"],
                "c_m_bc": full_json[key]["c_m_bc"],
                "c_m_js": full_json[key]["c_m_js"],
                "c_m_t": full_json[key]["c_m_t"],
                "c_bd_t": full_json[key]["c_bd_t"],
                "c_bc_t": full_json[key]["c_bc_t"],
                "c_js_t": full_json[key]["c_js_t"]
            } for key in full_json.keys()}
    except:
        print(f"Could not find plot_data.json in {mu_dir}")

print(len(mu_data))
print(mu_data.keys())

576
Could not find plot_data.json in models/good_sweep_1_fixed_256/transformer-plotting_type=mu-family=bert-n_positions=64-pretraining_sigma_tasks=1-pretraining_mu_tasks=8-default_sigma=1-family=bert-hid_dim=4-mlp=0-layer=8-head=8-tokenizer=learnable
Could not find plot_data.json in models/good_sweep_1_fixed_256/transformer-plotting_type=mu-family=bert-n_positions=64-pretraining_sigma_tasks=1-pretraining_mu_tasks=4-default_sigma=1-family=bert-hid_dim=4-mlp=0-layer=8-head=8-tokenizer=learnable
Could not find plot_data.json in models/good_sweep_1_fixed_256/transformer-plotting_type=mu-family=bert-n_positions=64-pretraining_sigma_tasks=1-pretraining_mu_tasks=1-default_sigma=1-family=bert-hid_dim=4-mlp=0-layer=1-head=8-tokenizer=learnable
Could not find plot_data.json in models/good_sweep_1_fixed_256/transformer-plotting_type=mu-family=bert-n_positions=64-pretraining_sigma_tasks=1-pretraining_mu_tasks=32-default_sigma=1-family=bert-hid_dim=4-mlp=0-layer=8-head=8-tokenizer=learnable
Could n

In [127]:
data_dict = {"pretraining_mu_tasks": [], "hid_dim": [], "layer": [], "head": [], "default_sigma": [], "converged": [], "converged_epoch": [], "data": []}

for model in mu_data:
    model_params = {model_arg.split('=')[0]: model_arg.split('=')[1] for model_arg in model.split('-')[1:]}
    keys_to_show = ("pretraining_mu_tasks", "hid_dim", "layer", "head", "default_sigma")
    model_params = {key: model_params[key] for key in model_params if key in keys_to_show}

    data_dict["pretraining_mu_tasks"].append(int(model_params["pretraining_mu_tasks"]))
    data_dict["hid_dim"].append(int(model_params["hid_dim"]))
    data_dict["layer"].append(int(model_params["layer"]))
    data_dict["head"].append(int(model_params["head"]))
    data_dict["default_sigma"].append(float(model_params["default_sigma"]))
    data_dict["converged"].append(False if len(mu_data[model].keys()) == 100000//50 + 1 else True)
    data_dict["converged_epoch"].append(max([int(key) for key in mu_data[model].keys()]))
    data_dict["data"].append(mu_data[model])

df = pd.DataFrame(data_dict)
df

Unnamed: 0,pretraining_mu_tasks,hid_dim,layer,head,default_sigma,converged,converged_epoch,data
0,2,256,8,4,1.0,True,2310,"{'0': {'d_m_bd': 0.8136696991633173, 'd_m_bc':..."
1,1,256,2,4,1.0,True,2040,"{'0': {'d_m_bd': 0.754486548269042, 'd_m_bc': ..."
2,32,4,8,1,1.0,False,100000,"{'0': {'d_m_bd': 1.2458030804020015, 'd_m_bc':..."
3,1,16,1,1,1.0,True,2100,"{'0': {'d_m_bd': 0.8799767237827255, 'd_m_bc':..."
4,2,16,8,8,1.0,True,2400,"{'0': {'d_m_bd': 1.2398886814151278, 'd_m_bc':..."
...,...,...,...,...,...,...,...,...
535,128,16,1,4,1.0,False,100000,"{'0': {'d_m_bd': 1.3970999650172546, 'd_m_bc':..."
536,2,64,2,1,1.0,True,2430,"{'0': {'d_m_bd': 2.802920816694452, 'd_m_bc': ..."
537,8,256,4,2,1.0,True,9900,"{'0': {'d_m_bd': 0.6682861388834096, 'd_m_bc':..."
538,64,64,2,8,1.0,False,100000,"{'0': {'d_m_bd': 1.171240290254922, 'd_m_bc': ..."


In [128]:
df_no_data = df.drop(columns=["data"])
# sort by pretraining_mu_tasks, then hid_dim, then layer, then head, then default_sigma
df_no_data = df_no_data.sort_values(by=["pretraining_mu_tasks", "hid_dim", "layer", "head", "default_sigma"])

df_no_data.to_csv("mu_data.csv", index=False)

## Models that didn't converge

In [129]:
# Save all the entries of mu_data that do not have 100000 epochs
models_that_failed_to_converge = []
for mu in mu_data.keys():
    if len(mu_data[mu].keys()) == 100000//50 + 1:
        models_that_failed_to_converge.append(mu)

print(f"Out of the {len(mu_data.keys())} trained models, {len(models_that_failed_to_converge)} failed to converge.")
for model in models_that_failed_to_converge:
    # model_name = model.split('-')[0]
    model_params = {model_arg.split('=')[0]: model_arg.split('=')[1] for model_arg in model.split('-')[1:]}
    keys_to_show = ("pretraining_mu_tasks", "hid_dim", "layer", "head", "default_sigma")
    model_params = {key: model_params[key] for key in model_params if key in keys_to_show}
    
    if model_params["default_sigma"] == "1":
        continue
    print(model_params)

Out of the 540 trained models, 186 failed to converge.


## Figure 2 Raventos

In [130]:
def plot_figure(plot_data=[], title="", width=600, height=425, scale=1):
    fig = go.Figure()
    
    for data in plot_data:
        fig.add_trace(go.Scatter(
            x=[i+1 for i in range(len(data["data"]))], 
            y=data["data"], 
            name=data["name"], 
            mode="lines+markers",
            marker={'symbol': data["marker"], 'color': TRANSLUCENT_COLORS[data["color_index"]], 'size': 10 * scale, 'line': {'width': 1.5 * scale, 'color': MAIN_COLORS[data["color_index"]]}}, 
            line={'color': MAIN_COLORS[data["color_index"]], 'width': 1.5 * scale}
        ))

        if max(data["data"]) > 6:
            fig.update_layout(
                yaxis_range=[-0.1, 6],
            )

    fig.update_layout(
        width=width*scale, 
        height=height*scale, 
        font=dict(
            # family="Courier New, monospace",
            size=20 * scale,
            # color="RebeccaPurple"
        ),
        # yaxis_range=[-0.1, 2],
        xaxis_title="Number of mu vectors",
        yaxis_title="Normalized Risk",
        # title=title,
        margin=dict(l=0, r=10, t=10, b=0),
        legend=dict(x=0.98, y=0.98, bordercolor="Black", borderwidth=1)
    )
    
    fig.update_layout(
        plot_bgcolor='white'
    )
    fig.update_xaxes(
        mirror=True,
        ticks='outside',
        showline=True,
        linecolor='black',
        gridcolor='lightgrey',
        ticktext=[str(2**i) for i in range(len(plot_data[0]["data"]))],
        tickvals=[i+1 for i in range(len(plot_data[0]["data"]))],
        tickmode="array"
    )
    fig.update_yaxes(
        mirror=True,
        ticks='outside',
        showline=True,
        linecolor='black',
        gridcolor='lightgrey'
    )

    return fig

In [131]:
def plot_figure_two(hid_dim=16, layer=2, head=2, default_sigma=0.2):
    df_figure_2_models = df.loc[(df["hid_dim"] == hid_dim) & (df["layer"] == layer) & (df["head"] == head) & (df["default_sigma"] == default_sigma)]
    df_figure_2_models = df_figure_2_models.sort_values(by=["pretraining_mu_tasks"])

    get_estimator_at_epoch = lambda df_subset, estimator : df_subset.apply(lambda row : row["data"][str(row["converged_epoch"])][estimator], axis=1).tolist()

    disc_bayes_disc = get_estimator_at_epoch(df_figure_2_models, "d_bd_t")
    disc_bayes_cont = get_estimator_at_epoch(df_figure_2_models, "d_bc_t")
    disc_james_stein = get_estimator_at_epoch(df_figure_2_models, "d_js_t")
    disc_model = get_estimator_at_epoch(df_figure_2_models, "d_m_t")

    disc_fig = plot_figure(
        plot_data = [
            {"name": "Bayes Discrete", "data": disc_bayes_disc, "color_index": 0, "marker": "circle"},
            {"name": "Bayes Continuous", "data": disc_bayes_cont, "color_index": 1, "marker": "circle"},
            {"name": "James Stein", "data": disc_james_stein, "color_index": 2, "marker": "circle"},
            {"name": "Transformer", "data": disc_model, "color_index": 3, "marker": "triangle-up"}
        ],
        title=f"Discrete Prior, m={hid_dim}, L={layer}, h={head}, sigma2={default_sigma}"
    )

    disc_fig.show()

    cont_bayes_disc = get_estimator_at_epoch(df_figure_2_models, "c_bd_t")
    cont_bayes_cont = get_estimator_at_epoch(df_figure_2_models, "c_bc_t")
    cont_james_stein = get_estimator_at_epoch(df_figure_2_models, "c_js_t")
    cont_model = get_estimator_at_epoch(df_figure_2_models, "c_m_t")

    cont_fig = plot_figure(
        plot_data = [
            {"name": "Bayes Discrete", "data": cont_bayes_disc, "color_index": 0, "marker": "circle"},
            {"name": "Bayes Continuous", "data": cont_bayes_cont, "color_index": 1, "marker": "circle"},
            {"name": "James Stein", "data": cont_james_stein, "color_index": 2, "marker": "circle"},
            {"name": "Transformer", "data": cont_model, "color_index": 3, "marker": "triangle-up"}
        ],
        title=f"Continuous Prior, m={hid_dim}, L={layer}, h={head}, sigma2={default_sigma}"
    )

    cont_fig.show()

In [132]:
# plot_figure_two(hid_dim=16, layer=1, head=1, default_sigma=0.2)
# plot_figure_two(hid_dim=16, layer=1, head=2, default_sigma=0.2)
# plot_figure_two(hid_dim=16, layer=2, head=1, default_sigma=0.2)
# plot_figure_two(hid_dim=16, layer=2, head=2, default_sigma=0.2)
plot_figure_two(hid_dim=4, layer=2, head=2, default_sigma=1)
plot_figure_two(hid_dim=16, layer=2, head=2, default_sigma=1)
plot_figure_two(hid_dim=64, layer=2, head=2, default_sigma=1)
plot_figure_two(hid_dim=256, layer=2, head=2, default_sigma=1)


In [133]:
plot_figure_two(hid_dim=4, layer=1, head=1, default_sigma=1)
plot_figure_two(hid_dim=16, layer=1, head=1, default_sigma=1)
plot_figure_two(hid_dim=64, layer=1, head=1, default_sigma=1)
plot_figure_two(hid_dim=256, layer=1, head=1, default_sigma=1)

In [134]:
def plot_figure_paper(plot_data=[], title="", width=600, height=425, scale=1):
    fig = go.Figure()
    
    for data in plot_data:
        if data["marker"] == "None":
            fig.add_trace(go.Scatter(
                x=[i+1 for i in range(len(data["data"]))], 
                y=data["data"], 
                name=data["name"], 
                mode="lines",
                line={'color': DARK_COLORS[data["color_index"]] if isinstance(data["color_index"], int) else data["color_index"], 'width': 5 * scale, "dash": 'dot'}
            ))
        else:
            fig.add_trace(go.Scatter(
                x=[i+1 for i in range(len(data["data"]))], 
                y=data["data"], 
                name=data["name"], 
                mode="lines+markers",
                marker={'symbol': data["marker"], 'color': 'rgba(0, 0, 0, 0)', 'size': 10 * scale, 'line': {'width': 1.5 * scale, 'color': DARK_COLORS[data["color_index"]]}}, 
                line={'color': DARK_COLORS[data["color_index"]], 'width': 1.5 * scale}
            ))

        if max(data["data"]) > 6:
            fig.update_layout(
                yaxis_range=[-0.1, 6],
            )

    fig.update_layout(
        width=width*scale, 
        height=height*scale, 
        font=dict(
            # family="Courier New, monospace",
            size=20 * scale,
            # color="RebeccaPurple"
        ),
        # yaxis_range=[-0.1, 2],
        xaxis_title="Number of mu vectors",
        yaxis_title="Normalized Risk",
        # title=title,
        margin=dict(l=0, r=10, t=10, b=0),
        legend=dict(x=0.98, y=0.98, bordercolor="Black", borderwidth=1)
    )
    
    fig.update_layout(
        plot_bgcolor='white'
    )
    fig.update_xaxes(
        mirror=True,
        ticks='outside',
        showline=True,
        linecolor='black',
        gridcolor='lightgrey',
        ticktext=[str(2**i) for i in range(len(plot_data[0]["data"]))],
        tickvals=[i+1 for i in range(len(plot_data[0]["data"]))],
        tickmode="array"
    )
    fig.update_yaxes(
        mirror=True,
        ticks='outside',
        showline=True,
        linecolor='black',
        gridcolor='lightgrey'
    )

    return fig

In [135]:
def plot_figure_two_hid_dims(layer=2, head=2, default_sigma=0.2):
    disc_model = []
    cont_model = []
    for hid_dim in [4, 16, 64, 256]:
        df_figure_2_models = df.loc[(df["hid_dim"] == hid_dim) & (df["layer"] == layer) & (df["head"] == head) & (df["default_sigma"] == default_sigma)]
        df_figure_2_models = df_figure_2_models.sort_values(by=["pretraining_mu_tasks"])

        get_estimator_at_epoch = lambda df_subset, estimator : df_subset.apply(lambda row : row["data"][str(row["converged_epoch"])][estimator], axis=1).tolist()

        disc_model.append({"hid_dim": hid_dim, "data": get_estimator_at_epoch(df_figure_2_models, "d_m_t")})
        cont_model.append({"hid_dim": hid_dim, "data": get_estimator_at_epoch(df_figure_2_models, "c_m_t")})
    
    disc_bayes_disc = get_estimator_at_epoch(df_figure_2_models, "d_bd_t")
    disc_bayes_cont = get_estimator_at_epoch(df_figure_2_models, "d_bc_t")
    disc_james_stein = get_estimator_at_epoch(df_figure_2_models, "d_js_t")
    
    cont_bayes_disc = get_estimator_at_epoch(df_figure_2_models, "c_bd_t")
    cont_bayes_cont = get_estimator_at_epoch(df_figure_2_models, "c_bc_t")
    cont_james_stein = get_estimator_at_epoch(df_figure_2_models, "c_js_t")

    disc_fig = plot_figure_paper(
        plot_data = [
            {"name": "BD", "data": disc_bayes_disc, "color_index": "black", "marker": "None"},
            {"name": "BC", "data": disc_bayes_cont, "color_index": "red", "marker": "None"},
            {"name": "JS", "data": disc_james_stein, "color_index": "blue", "marker": "None"},
        ] + [{"name": f"TF, m={data['hid_dim']}", "data": data["data"], "color_index": 8 + i, "marker": "circle"} 
             for i, data in enumerate(disc_model)],
        title=f"Discrete Prior, L={layer}, h={head}, sigma2={default_sigma}"
    )

    # disc_fig.show()

    cont_fig = plot_figure_paper(
        plot_data = [
            {"name": "BD", "data": cont_bayes_disc, "color_index": "black", "marker": "None"},
            {"name": "BC", "data": cont_bayes_cont, "color_index": "red", "marker": "None"},
            {"name": "JS", "data": cont_james_stein, "color_index": "blue", "marker": "None"},
        ] + [{"name": f"TF, m={data['hid_dim']}", "data": data["data"], "color_index": 8 + i, "marker": "circle"} 
             for i, data in enumerate(cont_model)],
        title=f"Continuous Prior, L={layer}, h={head}, sigma2={default_sigma}"
    )

    # cont_fig.show()

    return disc_fig, cont_fig

In [136]:
def plot_figure_two_layers(hid_dim=2, head=2, default_sigma=0.2):
    disc_model = []
    cont_model = []
    for layer in [1, 2, 4, 8]:
        df_figure_2_models = df.loc[(df["hid_dim"] == hid_dim) & (df["layer"] == layer) & (df["head"] == head) & (df["default_sigma"] == default_sigma)]
        df_figure_2_models = df_figure_2_models.sort_values(by=["pretraining_mu_tasks"])

        get_estimator_at_epoch = lambda df_subset, estimator : df_subset.apply(lambda row : row["data"][str(row["converged_epoch"])][estimator], axis=1).tolist()

        disc_model.append({"layer": layer, "data": get_estimator_at_epoch(df_figure_2_models, "d_m_t")})
        cont_model.append({"layer": layer, "data": get_estimator_at_epoch(df_figure_2_models, "c_m_t")})
    
    disc_bayes_disc = get_estimator_at_epoch(df_figure_2_models, "d_bd_t")
    disc_bayes_cont = get_estimator_at_epoch(df_figure_2_models, "d_bc_t")
    disc_james_stein = get_estimator_at_epoch(df_figure_2_models, "d_js_t")
    
    cont_bayes_disc = get_estimator_at_epoch(df_figure_2_models, "c_bd_t")
    cont_bayes_cont = get_estimator_at_epoch(df_figure_2_models, "c_bc_t")
    cont_james_stein = get_estimator_at_epoch(df_figure_2_models, "c_js_t")

    disc_fig = plot_figure_paper(
        plot_data = [
            {"name": "BD", "data": disc_bayes_disc, "color_index": "black", "marker": "None"},
            {"name": "BC", "data": disc_bayes_cont, "color_index": "red", "marker": "None"},
            {"name": "JS", "data": disc_james_stein, "color_index": "blue", "marker": "None"},
        ] + [{"name": f"TF, L={data['layer']}", "data": data["data"], "color_index": 8 + i, "marker": "circle"} 
             for i, data in enumerate(disc_model)],
        title=f"Discrete Prior, m={hid_dim}, h={head}, sigma2={default_sigma}"
    )

    cont_fig = plot_figure_paper(
        plot_data = [
            {"name": "BD", "data": cont_bayes_disc, "color_index": "black", "marker": "None"},
            {"name": "BC", "data": cont_bayes_cont, "color_index": "red", "marker": "None"},
            {"name": "JS", "data": cont_james_stein, "color_index": "blue", "marker": "None"},
        ] + [{"name": f"TF, L={data['layer']}", "data": data["data"], "color_index": 8 + i, "marker": "circle"} 
             for i, data in enumerate(cont_model)],
        title=f"Continuous Prior, m={hid_dim}, h={head}, sigma2={default_sigma}"
    )

    return disc_fig, cont_fig

In [137]:
def plot_figure_two_heads(hid_dim=16, layer=2, default_sigma=0.2):
    disc_model = []
    cont_model = []
    for head in [1, 2, 4, 8]:
        df_figure_2_models = df.loc[(df["hid_dim"] == hid_dim) & (df["layer"] == layer) & (df["head"] == head) & (df["default_sigma"] == default_sigma)]
        df_figure_2_models = df_figure_2_models.sort_values(by=["pretraining_mu_tasks"])

        get_estimator_at_epoch = lambda df_subset, estimator : df_subset.apply(lambda row : row["data"][str(row["converged_epoch"])][estimator], axis=1).tolist()

        disc_model.append({"head": head, "data": get_estimator_at_epoch(df_figure_2_models, "d_m_t")})
        cont_model.append({"head": head, "data": get_estimator_at_epoch(df_figure_2_models, "c_m_t")})
    
    disc_bayes_disc = get_estimator_at_epoch(df_figure_2_models, "d_bd_t")
    disc_bayes_cont = get_estimator_at_epoch(df_figure_2_models, "d_bc_t")
    disc_james_stein = get_estimator_at_epoch(df_figure_2_models, "d_js_t")
    
    cont_bayes_disc = get_estimator_at_epoch(df_figure_2_models, "c_bd_t")
    cont_bayes_cont = get_estimator_at_epoch(df_figure_2_models, "c_bc_t")
    cont_james_stein = get_estimator_at_epoch(df_figure_2_models, "c_js_t")

    disc_fig = plot_figure_paper(
        plot_data = [
            {"name": "BD", "data": disc_bayes_disc, "color_index": "black", "marker": "None"},
            {"name": "BC", "data": disc_bayes_cont, "color_index": "red", "marker": "None"},
            {"name": "JS", "data": disc_james_stein, "color_index": "blue", "marker": "None"},
        ] + [{"name": f"TF, H={data['head']}", "data": data["data"], "color_index": 8 + i, "marker": "circle"} 
             for i, data in enumerate(disc_model)],
        title=f"Discrete Prior, m={hid_dim}, L={layer}, sigma2={default_sigma}"
    )

    # disc_fig.show()

    cont_fig = plot_figure_paper(
        plot_data = [
            {"name": "BD", "data": cont_bayes_disc, "color_index": "black", "marker": "None"},
            {"name": "BC", "data": cont_bayes_cont, "color_index": "red", "marker": "None"},
            {"name": "JS", "data": cont_james_stein, "color_index": "blue", "marker": "None"},
        ] + [{"name": f"TF, H={data['head']}", "data": data["data"], "color_index": 8 + i, "marker": "circle"} 
             for i, data in enumerate(cont_model)],
        title=f"Continuous Prior, m={hid_dim}, L={layer}, sigma2={default_sigma}"
    )

    # cont_fig.show()

    return disc_fig, cont_fig

In [139]:
for hid_dim, head, layer, sigma in [(16, 4, 4, 1)]:
    disc_fig, cont_fig = plot_figure_two_hid_dims(layer=layer, head=head, default_sigma=sigma)
    disc_fig.show()
    cont_fig.show()
    disc_fig, cont_fig = plot_figure_two_layers(hid_dim=hid_dim, head=head, default_sigma=sigma)
    disc_fig.show()
    cont_fig.show()
    disc_fig, cont_fig = plot_figure_two_heads(hid_dim=hid_dim, layer=layer, default_sigma=sigma)
    disc_fig.show()
    cont_fig.show()
    # disc_fig.write_image(f"plots/trueCompDiscPriorL{layer}H{head}S{str(sigma).replace('.', '')}.pdf")
    # cont_fig.write_image(f"plots/trueCompContPriorL{layer}H{head}S{str(sigma).replace('.', '')}.pdf")

In [13]:
def plot_figure_two_b_hid_dims(comparison="", layer=2, head=2, default_sigma=1):
    disc_model = []
    cont_model = []
    for hid_dim in [4, 16, 64, 256]:
        df_figure_2_models = df.loc[(df["hid_dim"] == hid_dim) & (df["layer"] == layer) & (df["head"] == head) & (df["default_sigma"] == default_sigma)]
        df_figure_2_models = df_figure_2_models.sort_values(by=["pretraining_mu_tasks"])

        get_estimator_at_epoch = lambda df_subset, estimator : df_subset.apply(lambda row : row["data"][str(row["converged_epoch"])][estimator], axis=1).tolist()

        disc_model.append({"hid_dim": hid_dim, "data": get_estimator_at_epoch(df_figure_2_models, f"d_{comparison}")})
        cont_model.append({"hid_dim": hid_dim, "data": get_estimator_at_epoch(df_figure_2_models, f"c_{comparison}")})

    disc_fig = plot_figure(
        [{"name": f"m={data['hid_dim']}", "data": data["data"], "color_index": i, "marker": "circle"} 
             for i, data in enumerate(disc_model)],
        title=f"Discrete Prior {comparison}, L={layer}, h={head}, sigma2={default_sigma}"
    )
    # disc_fig.show()

    cont_fig = plot_figure(
        [{"name": f"m={data['hid_dim']}", "data": data["data"], "color_index": i, "marker": "circle"} 
             for i, data in enumerate(cont_model)],
        title=f"Continuous Prior {comparison}, L={layer}, h={head}, sigma2={default_sigma}"
    )
    # cont_fig.show()

    return disc_fig, cont_fig

In [14]:
for head, layer, sigma in [(2, 2, 1)]:
    disc_fig, cont_fig = plot_figure_two_b_hid_dims(comparison="m_bd", layer=layer, head=head, default_sigma=sigma)
    disc_fig.write_image(f"plots/bayesDiscCompDiscPriorL{layer}H{head}S{str(sigma).replace('.', '')}.pdf")
    cont_fig.write_image(f"plots/bayesDiscCompContPriorL{layer}H{head}S{str(sigma).replace('.', '')}.pdf")
    disc_fig, cont_fig = plot_figure_two_b_hid_dims(comparison="m_bc", layer=layer, head=head, default_sigma=sigma)
    disc_fig.write_image(f"plots/bayesContCompDiscPriorL{layer}H{head}S{str(sigma).replace('.', '')}.pdf")
    cont_fig.write_image(f"plots/bayesContCompContPriorL{layer}H{head}S{str(sigma).replace('.', '')}.pdf")