In [None]:
from pathlib import Path
import pandas as pd
import numpy as np
import re

pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 150)

import matplotlib.pyplot as plt

import plotly.graph_objects as go
from plotly.subplots import make_subplots

# %matplotlib inline
PROJECT_ROOT = Path.cwd().parent
data_path = PROJECT_ROOT.joinpath('data')
print(PROJECT_ROOT)

In [None]:
data = pd.read_csv(data_path.joinpath('cross_task.csv'))
data['len_bucket']= (data['prompt_tokens']//5*5)
medians = data.groupby(['prompt_id','prompt_name','prompt_task']).median().reset_index()


In [None]:

task_colors = {'CLASSIFICATION': '#1f77b4',
 'COMPLETION': '#ff7f0e',
 'ENTAILMENT': '#2ca02c',
 'MCQ': '#d62728',
 'QA': '#9467bd',
 'SENTIMENT': '#8c564b',
 'SUMMARIZATION': '#e377c2'}

GROUP_NAME_MAPPING = {
    "anli[dev_r1]"                            : "ANLI R1",
    "anli[dev_r2]"                            : "ANLI R2",
    "anli[dev_r3]"                            : "ANLI R3",
    "CommitmentBank[validation]"              : "CB",
    "AQuA[validation]"                        : "AQuA",
    "WordsinContext[validation]"              : "WiC",
    "RecognizingTextualEntailment[validation]": "RTE",
    "craigslist_bargains[validation]"         : "Craigslist",
}
valid_metrics = ['f1', 'em', 'accuracy']

# subplot_count = 

figure_object = {}

inverse_map_names = {v:k for k,v in GROUP_NAME_MAPPING.items()}
subplot_titles = sorted(list(inverse_map_names))

fig = make_subplots(
        1,
        2,
        subplot_titles=["Accuracy","F1"],
    )
for i,k in enumerate(["Accuracy","F1"]):
    
    f1_chart = go.Scatter(
        x=medians['prompt_tokens'].tolist(),
        y=medians[f'{k.lower()}_rank'].tolist(),
        mode='markers',
        showlegend=False
    )
    fig.append_trace(f1_chart, row=1, col=i+1)

#     acc_chart = go.Scatter(
#         x=medians['prompt_tokens'].tolist(),
#         y=medians['accuracy_rank'].tolist(),
#         mode='markers',
#         showlegend=False

#     )
#     fig.append_trace(acc_chart, row=1, col=1)
    fig.update_layout(
    #                 title=title+" - "+met_name,
        title_x=0.5,
        font=dict(size=15),
        template="plotly_white",
        legend_orientation='h',
        legend=dict(xanchor="center", x=0.5, bgcolor="rgba(0,0,0,0)"),
        # yaxis=dict(range=[0,100]),
        width=1000,
        height=600,
    )
fig['layout']['xaxis']['title']='Length (# of Tokens)'
fig['layout']['xaxis2']['title']='Length (# of Tokens)'
fig['layout']['yaxis']['autorange'] = "reversed"
fig['layout']['yaxis2']['autorange'] = "reversed"
fig['layout']['yaxis']['title']='Median Rank'
fig['layout']['yaxis2']['title']='Median Rank'
fig.update_yaxes(range=[0,100])# hide all the xticks
fig.write_image(re.sub(' ', '-', f'scatter_ranks_graphs.png'))
fig.show()


In [None]:
task_colors = {'CLASSIFICATION': '#1f77b4',
 'COMPLETION': '#ff7f0e',
 'ENTAILMENT': '#2ca02c',
 'MCQ': '#d62728',
 'QA': '#9467bd',
 'SENTIMENT': '#8c564b',
 'SUMMARIZATION': '#e377c2'}

GROUP_NAME_MAPPING = {
    "anli[dev_r1]"                            : "ANLI R1",
    "anli[dev_r2]"                            : "ANLI R2",
    "anli[dev_r3]"                            : "ANLI R3",
    "CommitmentBank[validation]"              : "CB",
    "AQuA[validation]"                        : "AQuA",
    "WordsinContext[validation]"              : "WiC",
    "RecognizingTextualEntailment[validation]": "RTE",
    "craigslist_bargains[validation]"         : "Craigslist",
}
valid_metrics = ['f1', 'em', 'accuracy']

# subplot_count = 

figure_object = {}

inverse_map_names = {v:k for k,v in GROUP_NAME_MAPPING.items()}
subplot_titles = sorted(list(inverse_map_names))

fig = make_subplots(
        1,
        2,
        subplot_titles=["Accuracy","F1"],
    )
for i,k in enumerate(["Accuracy","F1"]):
    
#     to_exclude = (<5)
    f1_chart =  go.Box(
        x=medians['len_bucket'].tolist(),
        y=medians[f'{k.lower()}_rank'].tolist(),
        showlegend= False,
        name=k,
        boxpoints=False,
#         marker_color=marker_color[str(v)]
    )
    fig.append_trace(f1_chart, row=1, col=i+1)
    
    counts = medians.groupby(['len_bucket']).describe()[f'{k.lower()}_rank']['count'].to_dict()
    tick_text= []
    ticks =[]
    for i in range(0,int(np.max(medians['len_bucket']))+5,5):
        ticks.append(i)
        tick_text.append(f"{i} (p={int(counts.get(float(i),0))})")
    fig.update_layout(
        title_x=0.5,
        font=dict(size=15),
        template="plotly_white",
        legend_orientation='h',
        legend=dict(xanchor="center", x=0.5, bgcolor="rgba(0,0,0,0)"),
        # yaxis=dict(range=[0,100]),
        width=1200,
        height=600,
        xaxis=dict(
            tickmode = 'array',
            tickvals = ticks,
            ticktext=tick_text
        ),
        xaxis2=dict(
            tickmode = 'array',
            tickvals = ticks,
            ticktext=tick_text
        )
    )
fig['layout']['xaxis']['title']='Length (# of Tokens, p=# of Datapoints)'
fig['layout']['xaxis2']['title']='Length (# of Tokens, p=# of Datapoints)'
fig['layout']['yaxis']['autorange'] = "reversed"
fig['layout']['yaxis2']['autorange'] = "reversed"
fig['layout']['yaxis']['title']='Median Rank (Inverse Scale)'
fig['layout']['yaxis2']['title']='Median Rank (Inverse Scale)'
fig.update_yaxes(range=[0,100])# hide all the xticks
fig.write_image(re.sub(' ', '-', f'ranks_graphs.png'))
fig.show()


In [None]:
medians.groupby(['len_bucket']).describe()[f'f1_rank']['count'].to_dict()

In [None]:
data['rum_name']=='CTNoText'