In [2]:
import altair as alt
import os
import pandas as pd
import re
from transformers import T5Tokenizer

In [3]:
DIR = os.path.abspath('')
ANALYSIS_DIR = os.path.join(os.path.dirname(DIR), 'data', 'rouge_analysis')
with open(os.path.join(ANALYSIS_DIR, 'articles'), 'r') as f:
    articles = f.readlines()
ext_df, abs_df, hybrid_df = (
    pd.read_csv(
        os.path.join(ANALYSIS_DIR, f'{model}_rouge.csv'),
        index_col=0,
    )
    for model in ('extractive', 'abstractive', 'hybrid')
)

In [4]:
tokenizer = T5Tokenizer.from_pretrained('t5-base')

In [10]:
%%capture
data = pd.DataFrame({
    'doc_id': range(len(articles)),
    #'char_len': map(len, articles),
    #'word_len': (len(re.split('\s+', x)) for x in articles),
    'token_len': (len(tokenizer.encode(x, verbose=False)) for x in articles),
})

In [12]:
n_cuts = 500
data['token_bin'] = pd.qcut(
    data['token_len'],
    q=n_cuts,
    labels=range(n_cuts),
)

In [13]:
model_dfs = (
    ('extractive', ext_df),
    ('abstractive', abs_df),
    ('hybrid', hybrid_df),
)
for label, df in model_dfs:
    df['label'] = label

In [14]:
full_data = pd.concat(
    (data.merge(df, on='doc_id', how='left') for _, df in model_dfs),
    axis=0,
)

In [15]:
alt.data_transformers.disable_max_rows()

DataTransformerRegistry.enable('default')

In [16]:
alt.Chart(data).mark_bar().encode(
    alt.X('token_len:Q', bin=alt.Bin(maxbins=n_cuts)),
    y='count()',
)

In [19]:
rouge_cols = [
    'rouge_1_f_score', 'rouge_2_f_score', 'rouge_l_f_score',
    'rouge_1_recall', 'rouge_2_recall', 'rouge_l_recall',
]
binned_data = full_data.groupby(
    ['token_bin', 'label']
)[['token_len'] + rouge_cols].mean().reset_index()

In [29]:
# make a single row
def make_hcc(row_of_charts):
    hconcat = [chart for chart in row_of_charts]
    hcc = alt.HConcatChart(hconcat=hconcat)
    return hcc

# take an array of charts and produce a facet grid
def facet_wrap(charts, charts_per_row):
    rows_of_charts = [
        charts[i:i+charts_per_row] 
        for i in range(0, len(charts), charts_per_row)]        
    vconcat = [make_hcc(r) for r in rows_of_charts]    
    vcc = alt.VConcatChart(vconcat=vconcat)\
      .configure_axisX(grid=True)\
      .configure_axisY(grid=True)
    return vcc

# assemble the facet grid
compound_chart = facet_wrap(
    [
        alt.Chart(binned_data).mark_line().encode(
            x='token_len', y=col, color='label',
        ) for col in rouge_cols
    ],
    charts_per_row=3,
)
compound_chart