# Legal Document Retrieval through Active Learning approaches: A comparative study

### Research Questions

1. How do different AL sampling strategies perform compared to baselines trained on all data available?
    - We will evaluate the performance of AL sampling strategies based on uncertainty and diversity against different baselines.
2. In the Legal Document Retrieval task, is there a value combination of query size, pool size, and number of initially labeled examples in the AL cycle that can consistently deliver satisfactory results compared to a model trained with all labeled data available?
    - We will evaluate experiment results using different query sizes, initial labeled samples and pool size.
3. Can different data sources from the same domain benefit the AL cycle?
    - We will evaluate using a data source in the AL cycle and the performance acquired on a test set of a different source.

| hyperparameter | description |values |
| --------------- | ------ | ------------|
| AL sampling strategy | TODO | LeastConfidence<br>LeastConfidenceDropout<br>MarginSampling<br>MarginSamplingDropout<br>KCenterGreedy<br>KMeansSampling<br>RandomSampling |
| Query size | TODO | 8, 16, 32, 64 |
| Size of initial labeled samples | TODO |  16, 32, 64, 128 |
| Pool size | TODO | 640, 1280, 2560, 5120 |
| Dataset | TODO | STJ local, IRIS STJ |

In [1]:
import pandas as pd
from IPython.display import display, HTML

In [2]:
data = pd.read_csv('../data/all_runs.csv', sep=";")
data['test_accuracy'] = data.apply(lambda x: round(x.test_accuracy,2), axis=1)
nsp_models_data = data[data.tag == 'only_nsp_v0']
nsp_models_data = nsp_models_data[(nsp_models_data.strategy != 'EntropySampling') & (nsp_models_data.strategy != 'EntropySamplingDropout')]
baselines_data = data[data.tag == 'only_baselines_v0']
print(len(nsp_models_data) + len(baselines_data))

  exec(code_obj, self.user_global_ns, self.user_ns)


21330


In [3]:
set(nsp_models_data.strategy.values)

{'KCenterGreedy',
 'KMeansSampling',
 'LeastConfidence',
 'LeastConfidenceDropout',
 'MarginSampling',
 'MarginSamplingDropout',
 'RandomSampling'}

In [4]:
set(nsp_models_data.columns.values)

{'_runtime',
 '_step',
 '_timestamp',
 'dataset',
 'initial_labeled_data',
 'model',
 'n_epochs',
 'n_init_labeled',
 'n_initial_unlabeled_pool',
 'n_query',
 'n_round',
 'name',
 'new_labeled_data',
 'predictions',
 'round',
 'samples',
 'seed',
 'strategy',
 'tag',
 'test_accuracy',
 'test_data',
 'total_labeled_data',
 'train_batch_size'}

In [5]:
# pd.options.display.max_rows = None
nsp_models_data[['model','strategy','dataset']].drop_duplicates()

Unnamed: 0,model,strategy,dataset
0,BERTikal,MarginSamplingDropout,LOCAL_STJ
6,BERT,MarginSamplingDropout,LOCAL_STJ
12,Legal_BERT_STF,MarginSamplingDropout,LOCAL_STJ
18,ITD_BERT,MarginSamplingDropout,LOCAL_STJ
96,BERTikal,LeastConfidenceDropout,LOCAL_STJ
...,...,...,...
14383,BERT,LeastConfidence,IRIS_STJ_LOCAL_STJ
14389,Legal_BERT_STF,RandomSampling,IRIS_STJ_LOCAL_STJ
14395,BERTikal,RandomSampling,IRIS_STJ_LOCAL_STJ
14401,ITD_BERT,RandomSampling,IRIS_STJ_LOCAL_STJ


In [6]:
pd.options.display.max_rows = 10

def filter_data(data, dataset, strategy, total_samples, initial_pool, query, train_batch):
    data1 = data[data.dataset == dataset]
    data1 = data1[data1.strategy == strategy]
    data1 = data1[data1.samples == total_samples]
    data1 = data1[data1.n_init_labeled == initial_pool]
    data1 = data1[data1.n_query == query]
    data1 = data1[data1.train_batch_size == train_batch]
    # data1 = data1[data1['round'] == round]
    return data1

In [7]:
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
sns.set(style="whitegrid")

def plot_groupedbar(data, x, y, hue, baseline_data, title_vars, alpha = 0.75):
    g = sns.barplot(x="total_labeled_data",y="test_accuracy",hue='model',data=data)
    g.set_title("Strategy = {} | Total Samples = {} | Initial Pool = {} | Query = {} | Batch Size = {}".format(*title_vars))
    g.set_xticklabels(g.get_xticklabels(), rotation=90)
    g.set(ylim=(-0.1, 1))
    for idx, row in baseline_data.iterrows():
        ls=['-','--','-.',':'][idx%4]
        # lw=10-8*idx/len(baseline_data)
        g.axhline(row['test_accuracy'], label=row['model'] +' -> '+ row['strategy'],
                 ls=ls, color=list(mcolors.CSS4_COLORS.keys())[idx], alpha=alpha)

    plt.legend()
    sns.move_legend(g, "upper left", bbox_to_anchor=(1, 1))
    plt.show()

In [8]:
# dataset = 'LOCAL_STJ'
# strategies = [
#  # 'EntropySampling',
#  # 'EntropySamplingDropout',
#  'KCenterGreedy',
#  'KMeansSampling',
#  'LeastConfidence',
#  'LeastConfidenceDropout',
#  'MarginSampling',
#  'MarginSamplingDropout',
#  'RandomSampling']

# total_samples = 5120
# initial_pool = 128
# query = 64
# batch_size = 32

# for strategy in strategies:
    
#     rs_local_stj = filter_data(nsp_models_data, dataset, strategy, total_samples, initial_pool, query, batch_size)
#     b_data = baselines_data[(baselines_data.dataset == dataset) & (baselines_data.samples == total_samples)]
#     plot_groupedbar(rs_local_stj, 'total_labeled_data', 'test_accuracy', 'model', b_data, [strategy, total_samples, initial_pool, query, batch_size])

In [9]:
from dash import Dash, dcc, html, Input, Output, clientside_callback
import dash_bootstrap_components as dbc 
from jupyter_dash import JupyterDash
import plotly.express as px
import plotly.graph_objects as go

app = JupyterDash(__name__, 
                  external_stylesheets=[dbc.themes.BOOTSTRAP],
                  external_scripts=[{'src': 'https://cdnjs.cloudflare.com/ajax/libs/html2canvas/1.4.0/html2canvas.min.js'}]
                 )

app.layout = dbc.Container([
    html.H1("Results Analysis Dashboard", className='mb-2', style={'textAlign':'center'}),

    dbc.Row([
        dbc.Col([
            html.Label('Dataset'),
            dcc.Dropdown(
                    id="dataset",
                    options=sorted(list(set(nsp_models_data['dataset'].values))),
                    value=sorted(list(set(nsp_models_data['dataset'].values)))[0],
                    clearable=False,)
            ], width=3),
        dbc.Col([
            html.Label('Pool size'),
            dcc.Dropdown(
                    id="samples",
                    options=sorted(list(set(nsp_models_data['samples'].values))),
                    value=sorted(list(set(nsp_models_data['samples'].values)))[0],
                    clearable=False,)  
        ], width=3)
        ]),
    dbc.Row([
        dbc.Col([
            html.Label('Initial labeled'),
            dcc.Dropdown(
                    id="n_init_labeled",
                    options=sorted(list(set(nsp_models_data['n_init_labeled'].values))),
                    value=sorted(list(set(nsp_models_data['n_init_labeled'].values)))[0],
                    clearable=False,)
            ], width=3),
        dbc.Col([
            html.Label('Query'),
            dcc.Dropdown(
                    id="n_query",
                    options=sorted(list(set(nsp_models_data['n_query'].values))),
                    value=sorted(list(set(nsp_models_data['n_query'].values)))[0],
                    clearable=False,)
            ], width=3),
        
        ]),
    
    dbc.Button(
        'Download as image',
        id='download-image'),
    
    dbc.Card(
            [
                dcc.Graph(id="rs"),
                dcc.Graph(id="ls"),
                dcc.Graph(id="ms"),
                # dcc.Graph(id="es"),
                dcc.Graph(id="km"),
                dcc.Graph(id="kc"),
                dcc.Graph(id="lsd"),
                dcc.Graph(id="msd"),
                # dcc.Graph(id="esd"),
            ],
            body=True,
            id='component-to-save'
        )
])

@app.callback(
    Output("rs", "figure"),
    Output("ls", "figure"), 
    Output("ms", "figure"),
    # Output("es", "figure"),
    Output("km", "figure"),
    Output("kc", "figure"),
    Output("lsd", "figure"),
    Output("msd", "figure"),
    # Output("esd", "figure"),
    
    Input("dataset", "value"),
    Input("samples", "value"),
    Input("n_init_labeled", "value"),
    Input("n_query", "value"),
)
def update_bar_chart(dataset, samples, n_init_labeled, n_query):
    mask1 = nsp_models_data["dataset"] == dataset
    mask2 = nsp_models_data["samples"] == samples
    mask3 = nsp_models_data["n_init_labeled"] == n_init_labeled
    mask4 = nsp_models_data["n_query"] == n_query
    
    nsp_data = nsp_models_data[mask1 & mask2 & mask3 & mask4]
    baseline_data = baselines_data[(baselines_data.dataset == dataset) & (baselines_data.samples == samples)]
    
    output = []
    for strategy in ["RandomSampling", "LeastConfidence","MarginSampling","KMeansSampling","KCenterGreedy",
                "LeastConfidenceDropout","MarginSamplingDropout",]: #"EntropySamplingDropout", "EntropySampling",
        data = nsp_data[nsp_data.strategy == strategy]
        fig = px.bar(data, title=strategy,
                     x='total_labeled_data', y="test_accuracy", 
                     color="model", barmode="group")

        for idx, row in baseline_data.iterrows():
            fig.add_trace(go.Scatter(
                x=(n_init_labeled - (n_query/2),
                   data['total_labeled_data'].max() + (n_query/2)),
                y=(row['test_accuracy'],row['test_accuracy']),
                mode='lines',
                name=row['model'] +' -> '+ row['strategy'],
                yaxis='y1',
                opacity=0.75
            ))
        output.append(fig)
        
    return *output,

clientside_callback(
    """
    function(n_clicks){
        if(n_clicks > 0){
            html2canvas(document.getElementById("component-to-save"), {useCORS: true}).then(function (canvas) {
                var anchorTag = document.createElement("a");
                document.body.appendChild(anchorTag);
                anchorTag.download = "download.png";
                anchorTag.href = canvas.toDataURL();
                anchorTag.target = '_blank';
                anchorTag.click();
            });
        }
    }
    """,
    Output('download-image', 'n_clicks'),
    Input('download-image', 'n_clicks')
)

app.run_server(mode="inline", host='0.0.0.0',debug=True, port=8049)


Dash is running on http://0.0.0.0:8049/



### **Research Question 1**: How do different AL sampling strategies perform compared to baselines trained on all data available?

### **Research Question 2**: In the Legal Document Retrieval task, is there a value combination of query size, pool size, and number of initially labeled examples in the AL cycle that can consistently deliver satisfactory results compared to a model trained with all labeled data available?

#### 1.1 Evaluate AL methods in different pool sizes comparing with the respective raw and trained baselines

In [10]:
question1_baseline_data = baselines_data[baselines_data.dataset != 'IRIS_STJ_LOCAL_STJ']
question1_nsp_data = nsp_models_data[nsp_models_data.dataset != 'IRIS_STJ_LOCAL_STJ']
question1_results = pd.DataFrame()

In [11]:
for dataset_pool, group in question1_baseline_data[question1_baseline_data.strategy == 'TRAINED BASELINE'].groupby(["dataset","samples"]):
    best_in_group = group.nlargest(1, "test_accuracy", keep='all')
    question1_results = question1_results.append(best_in_group)

In [12]:
for dataset_pool, group in question1_baseline_data[question1_baseline_data.strategy == 'RAW BASELINE'].groupby(["dataset","samples"]):
    best_in_group = group.nlargest(1, "test_accuracy", keep='all')
    question1_results = question1_results.append(best_in_group)

In [13]:
for dataset_pool, group in question1_nsp_data.groupby(["dataset","samples"]):
    best_in_group = group.nlargest(1, "test_accuracy", keep='all')
    question1_results = question1_results.append(best_in_group)

In [14]:
q1_tmp = question1_results[['dataset', 'samples', 'model', 'strategy', 'n_init_labeled','n_query', 'round','total_labeled_data','test_accuracy']]
for _, group in q1_tmp.groupby(['dataset', 'samples']):
    print(_)
    group['labeled_data_reduction'] = group.apply(lambda x: "{:0.2f}%".format(((x.samples - x.total_labeled_data)/x.samples)*100), axis=1)
    group['data_reduction'] = group.apply(lambda x: round(((x.samples - x.total_labeled_data)/x.samples)*100,2), axis=1)
    sorted_group = group.sort_values(['test_accuracy','data_reduction'], ascending=False)
    sorted_group = sorted_group.fillna(0).astype({'samples':'int','n_init_labeled':'int', 'n_query':'int', 'round':'int', 'total_labeled_data':'int' })
    df = sorted_group.drop(columns=['dataset', 'samples', 'data_reduction'])
    display(df)
    print(df.to_latex(index=False))
    print("#"*100)
#     group['graph_title'] = group.apply(lambda x: "M: {} - AL: {} - TD: {} - IL: {} - Q: {}".format(
#                                        x.model, x.strategy, x.total_labeled_data, x.n_init_labeled, x.n_query), 
#                                        axis=1)

#     nsp = sorted_group[(sorted_group.strategy != 'TRAINED BASELINE') & (sorted_group.strategy != 'RAW BASELINE')]
#     baselines = sorted_group[(sorted_group.strategy == 'TRAINED BASELINE') & (sorted_group.strategy == 'RAW BASELINE')]
#     g = sns.barplot(x="graph_title",y="test_accuracy",data=nsp)
#     g.set_xticklabels(g.get_xticklabels(), rotation=90)
#     g.set(ylim=(-0.1, 1))
#     for idx, row in baselines.iterrows():
#         ls=['-','--','-.',':'][idx%4]
#         # lw=10-8*idx/len(baseline_data)
#         g.axhline(row['test_accuracy'], label=row['model'] +' -> '+ row['strategy'],
#                  ls=ls, color=list(mcolors.CSS4_COLORS.keys())[idx], alpha=alpha)

#     plt.legend()
#     sns.move_legend(g, "upper left", bbox_to_anchor=(1, 1))
#     plt.show()


('IRIS_STJ', 640.0)


Unnamed: 0,model,strategy,n_init_labeled,n_query,round,total_labeled_data,test_accuracy,labeled_data_reduction
13558,BERTikal,KCenterGreedy,32,8,3,56,0.66,91.25%
13965,BERTikal,RandomSampling,16,32,2,80,0.66,87.50%
13270,BERTikal,KCenterGreedy,32,32,3,128,0.66,80.00%
12261,BERTikal,KCenterGreedy,128,16,2,160,0.66,75.00%
12045,BERTikal,MarginSampling,128,64,2,256,0.66,60.00%
12069,BERTikal,LeastConfidence,128,64,2,256,0.66,60.00%
12093,BERTikal,RandomSampling,128,64,2,256,0.66,60.00%
3348,SBERT_Legal_BERTimbau,TRAINED BASELINE,0,0,0,640,0.51,0.00%
3347,SBERT_Paraphrase_Multilingual,RAW BASELINE,0,0,0,0,0.48,100.00%


\begin{tabular}{llrrrrrl}
\toprule
                        model &         strategy &  n\_init\_labeled &  n\_query &  round &  total\_labeled\_data &  test\_accuracy & labeled\_data\_reduction \\
\midrule
                     BERTikal &    KCenterGreedy &              32 &        8 &      3 &                  56 &           0.66 &                 91.25\% \\
                     BERTikal &   RandomSampling &              16 &       32 &      2 &                  80 &           0.66 &                 87.50\% \\
                     BERTikal &    KCenterGreedy &              32 &       32 &      3 &                 128 &           0.66 &                 80.00\% \\
                     BERTikal &    KCenterGreedy &             128 &       16 &      2 &                 160 &           0.66 &                 75.00\% \\
                     BERTikal &   MarginSampling &             128 &       64 &      2 &                 256 &           0.66 &                 60.00\% \\
                   

Unnamed: 0,model,strategy,n_init_labeled,n_query,round,total_labeled_data,test_accuracy,labeled_data_reduction
11373,BERTikal,RandomSampling,32,8,2,48,0.66,96.25%
1350,BERTikal,MarginSamplingDropout,64,64,0,64,0.66,95.00%
1374,BERTikal,LeastConfidenceDropout,64,64,0,64,0.66,95.00%
1422,BERTikal,MarginSamplingDropout,64,32,0,64,0.66,95.00%
1446,BERTikal,LeastConfidenceDropout,64,32,0,64,0.66,95.00%
...,...,...,...,...,...,...,...,...
10365,BERTikal,RandomSampling,64,64,2,192,0.66,85.00%
9838,BERTikal,KMeansSampling,128,32,3,224,0.66,82.50%
9815,BERTikal,KCenterGreedy,128,32,4,256,0.66,80.00%
3341,SBERT_Legal_BERTimbau,RAW BASELINE,0,0,0,0,0.62,100.00%


\begin{tabular}{llrrrrrl}
\toprule
                        model &               strategy &  n\_init\_labeled &  n\_query &  round &  total\_labeled\_data &  test\_accuracy & labeled\_data\_reduction \\
\midrule
                     BERTikal &         RandomSampling &              32 &        8 &      2 &                  48 &           0.66 &                 96.25\% \\
                     BERTikal &  MarginSamplingDropout &              64 &       64 &      0 &                  64 &           0.66 &                 95.00\% \\
                     BERTikal & LeastConfidenceDropout &              64 &       64 &      0 &                  64 &           0.66 &                 95.00\% \\
                     BERTikal &  MarginSamplingDropout &              64 &       32 &      0 &                  64 &           0.66 &                 95.00\% \\
                     BERTikal & LeastConfidenceDropout &              64 &       32 &      0 &                  64 &           0.66 &           

Unnamed: 0,model,strategy,n_init_labeled,n_query,round,total_labeled_data,test_accuracy,labeled_data_reduction
3395,SBERT_Paraphrase_Multilingual,RAW BASELINE,0,0,0,0,0.69,100.00%
26876,BERTikal,KCenterGreedy,16,16,1,32,0.69,95.00%
26468,BERTikal,KMeansSampling,32,8,1,40,0.69,93.75%
27047,BERTikal,KMeansSampling,16,8,4,48,0.69,92.50%
26440,Legal_BERT_STF,KCenterGreedy,32,8,3,56,0.69,91.25%
...,...,...,...,...,...,...,...,...
26700,BERT,LeastConfidence,16,64,5,336,0.69,47.50%
24876,BERT,KCenterGreedy,128,64,5,448,0.69,30.00%
24900,BERT,KMeansSampling,128,64,5,448,0.69,30.00%
24996,BERT,RandomSampling,128,64,5,448,0.69,30.00%


\begin{tabular}{llrrrrrl}
\toprule
                        model &         strategy &  n\_init\_labeled &  n\_query &  round &  total\_labeled\_data &  test\_accuracy & labeled\_data\_reduction \\
\midrule
SBERT\_Paraphrase\_Multilingual &     RAW BASELINE &               0 &        0 &      0 &                   0 &           0.69 &                100.00\% \\
                     BERTikal &    KCenterGreedy &              16 &       16 &      1 &                  32 &           0.69 &                 95.00\% \\
                     BERTikal &   KMeansSampling &              32 &        8 &      1 &                  40 &           0.69 &                 93.75\% \\
                     BERTikal &   KMeansSampling &              16 &        8 &      4 &                  48 &           0.69 &                 92.50\% \\
               Legal\_BERT\_STF &    KCenterGreedy &              32 &        8 &      3 &                  56 &           0.69 &                 91.25\% \\
               

Unnamed: 0,model,strategy,n_init_labeled,n_query,round,total_labeled_data,test_accuracy,labeled_data_reduction
23143,BERT,KCenterGreedy,64,64,0,64,0.72,95.00%
23167,BERT,KMeansSampling,64,64,0,64,0.72,95.00%
23215,BERT,MarginSampling,64,64,0,64,0.72,95.00%
23239,BERT,LeastConfidence,64,64,0,64,0.72,95.00%
23263,BERT,RandomSampling,64,64,0,64,0.72,95.00%
...,...,...,...,...,...,...,...,...
23695,BERT,RandomSampling,64,8,0,64,0.72,95.00%
3388,SBERT_Legal_BERTimbau,TRAINED BASELINE,0,0,0,1280,0.69,0.00%
3390,SBERT_Local_BERTibaum,TRAINED BASELINE,0,0,0,1280,0.69,0.00%
3387,SBERT_Paraphrase_Multilingual,RAW BASELINE,0,0,0,0,0.62,100.00%


\begin{tabular}{llrrrrrl}
\toprule
                        model &         strategy &  n\_init\_labeled &  n\_query &  round &  total\_labeled\_data &  test\_accuracy & labeled\_data\_reduction \\
\midrule
                         BERT &    KCenterGreedy &              64 &       64 &      0 &                  64 &           0.72 &                 95.00\% \\
                         BERT &   KMeansSampling &              64 &       64 &      0 &                  64 &           0.72 &                 95.00\% \\
                         BERT &   MarginSampling &              64 &       64 &      0 &                  64 &           0.72 &                 95.00\% \\
                         BERT &  LeastConfidence &              64 &       64 &      0 &                  64 &           0.72 &                 95.00\% \\
                         BERT &   RandomSampling &              64 &       64 &      0 &                  64 &           0.72 &                 95.00\% \\
                   

Unnamed: 0,model,strategy,n_init_labeled,n_query,round,total_labeled_data,test_accuracy,labeled_data_reduction
488,BERT,MarginSampling,128,8,2,144,0.75,94.38%
704,BERT,LeastConfidence,128,8,2,144,0.75,94.38%
3380,SBERT_Legal_BERTimbau,TRAINED BASELINE,0,0,0,2560,0.72,0.00%
3381,SBERT_Legal_BERTimbau,RAW BASELINE,0,0,0,0,0.66,100.00%


\begin{tabular}{llrrrrrl}
\toprule
                model &         strategy &  n\_init\_labeled &  n\_query &  round &  total\_labeled\_data &  test\_accuracy & labeled\_data\_reduction \\
\midrule
                 BERT &   MarginSampling &             128 &        8 &      2 &                 144 &           0.75 &                 94.38\% \\
                 BERT &  LeastConfidence &             128 &        8 &      2 &                 144 &           0.75 &                 94.38\% \\
SBERT\_Legal\_BERTimbau & TRAINED BASELINE &               0 &        0 &      0 &                2560 &           0.72 &                  0.00\% \\
SBERT\_Legal\_BERTimbau &     RAW BASELINE &               0 &        0 &      0 &                   0 &           0.66 &                100.00\% \\
\bottomrule
\end{tabular}

####################################################################################################
('LOCAL_STJ', 5120.0)


Unnamed: 0,model,strategy,n_init_labeled,n_query,round,total_labeled_data,test_accuracy,labeled_data_reduction
20792,BERT,KMeansSampling,16,32,1,48,0.75,99.06%
20624,BERT,KCenterGreedy,16,64,1,80,0.75,98.44%
19593,BERT,RandomSampling,64,64,2,192,0.75,96.25%
3372,SBERT_Legal_BERTimbau,TRAINED BASELINE,0,0,0,5120,0.69,0.00%
3371,SBERT_Paraphrase_Multilingual,RAW BASELINE,0,0,0,0,0.62,100.00%


\begin{tabular}{llrrrrrl}
\toprule
                        model &         strategy &  n\_init\_labeled &  n\_query &  round &  total\_labeled\_data &  test\_accuracy & labeled\_data\_reduction \\
\midrule
                         BERT &   KMeansSampling &              16 &       32 &      1 &                  48 &           0.75 &                 99.06\% \\
                         BERT &    KCenterGreedy &              16 &       64 &      1 &                  80 &           0.75 &                 98.44\% \\
                         BERT &   RandomSampling &              64 &       64 &      2 &                 192 &           0.75 &                 96.25\% \\
        SBERT\_Legal\_BERTimbau & TRAINED BASELINE &               0 &        0 &      0 &                5120 &           0.69 &                  0.00\% \\
SBERT\_Paraphrase\_Multilingual &     RAW BASELINE &               0 &        0 &      0 &                   0 &           0.62 &                100.00\% \\
\bottomrule
\en

### **Research Question 3**: Can different data sources from the same domain benefit the AL cycle?

In [29]:
question3_baseline_data = baselines_data[baselines_data.dataset == 'IRIS_STJ_LOCAL_STJ']
question3_nsp_data = nsp_models_data[nsp_models_data.dataset == 'IRIS_STJ_LOCAL_STJ']

In [30]:
question3_results = pd.DataFrame()

for dataset_pool, group in question3_baseline_data[question3_baseline_data.strategy == 'TRAINED BASELINE'].groupby(["dataset","samples"]):
    best_in_group = group.nlargest(1, "test_accuracy", keep='all')
    question3_results = question3_results.append(best_in_group)

for dataset_pool, group in question3_baseline_data[question3_baseline_data.strategy == 'RAW BASELINE'].groupby(["dataset","samples"]):
    best_in_group = group.nlargest(1, "test_accuracy", keep='all')
    question3_results = question3_results.append(best_in_group)

for dataset_pool, group in question3_nsp_data.groupby(["dataset","samples"]):
    best_in_group = group.nlargest(1, "test_accuracy", keep='all')
    question3_results = question3_results.append(best_in_group)

In [36]:
q3_tmp = question3_results[['dataset', 'samples', 'model', 'strategy', 'n_init_labeled','n_query', 'round','total_labeled_data','test_accuracy']]
for _, group in q3_tmp.groupby(['dataset', 'samples']):
    print(_)
    group['labeled_data_reduction'] = group.apply(lambda x: "{:0.2f}%".format(((x.samples - x.total_labeled_data)/x.samples)*100), axis=1)
    group['data_reduction'] = group.apply(lambda x: round(((x.samples - x.total_labeled_data)/x.samples)*100,2), axis=1)
    sorted_group = group.sort_values(['test_accuracy','data_reduction'], ascending=False)
    sorted_group = sorted_group.fillna(0).astype({'samples':'int','n_init_labeled':'int', 'n_query':'int', 'round':'int', 'total_labeled_data':'int' })
    df = sorted_group.drop(columns=['dataset', 'samples', 'data_reduction'])
    display(df)
    print(df.to_latex(index=False))
    print("#"*100)
 

('IRIS_STJ_LOCAL_STJ', 640.0)


Unnamed: 0,model,strategy,n_init_labeled,n_query,round,total_labeled_data,test_accuracy,labeled_data_reduction
3363,SBERT_Paraphrase_Multilingual,RAW BASELINE,0,0,0,0,0.72,100.00%
3366,SBERT_Local_BERTibaum,TRAINED BASELINE,0,0,0,640,0.72,0.00%
17374,BERTikal,MarginSampling,64,32,3,160,0.69,75.00%
17398,BERTikal,LeastConfidence,64,32,3,160,0.69,75.00%


\begin{tabular}{llrrrrrl}
\toprule
                        model &         strategy &  n\_init\_labeled &  n\_query &  round &  total\_labeled\_data &  test\_accuracy & labeled\_data\_reduction \\
\midrule
SBERT\_Paraphrase\_Multilingual &     RAW BASELINE &               0 &        0 &      0 &                   0 &           0.72 &                100.00\% \\
        SBERT\_Local\_BERTibaum & TRAINED BASELINE &               0 &        0 &      0 &                 640 &           0.72 &                  0.00\% \\
                     BERTikal &   MarginSampling &              64 &       32 &      3 &                 160 &           0.69 &                 75.00\% \\
                     BERTikal &  LeastConfidence &              64 &       32 &      3 &                 160 &           0.69 &                 75.00\% \\
\bottomrule
\end{tabular}

####################################################################################################
('IRIS_STJ_LOCAL_STJ', 1280.0)


Unnamed: 0,model,strategy,n_init_labeled,n_query,round,total_labeled_data,test_accuracy,labeled_data_reduction
15647,BERTikal,MarginSampling,32,32,4,160,0.69,87.50%
15671,BERTikal,LeastConfidence,32,32,4,160,0.69,87.50%
3358,SBERT_Local_BERTibaum,TRAINED BASELINE,0,0,0,1280,0.62,0.00%
3359,SBERT_Local_BERTibaum,RAW BASELINE,0,0,0,0,0.56,100.00%


\begin{tabular}{llrrrrrl}
\toprule
                model &         strategy &  n\_init\_labeled &  n\_query &  round &  total\_labeled\_data &  test\_accuracy & labeled\_data\_reduction \\
\midrule
             BERTikal &   MarginSampling &              32 &       32 &      4 &                 160 &           0.69 &                 87.50\% \\
             BERTikal &  LeastConfidence &              32 &       32 &      4 &                 160 &           0.69 &                 87.50\% \\
SBERT\_Local\_BERTibaum & TRAINED BASELINE &               0 &        0 &      0 &                1280 &           0.62 &                  0.00\% \\
SBERT\_Local\_BERTibaum &     RAW BASELINE &               0 &        0 &      0 &                   0 &           0.56 &                100.00\% \\
\bottomrule
\end{tabular}

####################################################################################################


In [32]:
# find missing runs
# for dataset, dataset_grouped in question1_nsp_data.groupby('dataset'):
#     for samples, group in dataset_grouped.groupby('samples'):
#         print("Dataset: {} - Pool: {}".format(dataset, samples))
#         pivoted_group = group.pivot_table(index=['model','n_init_labeled','n_query', 'round'], columns='strategy', values='test_accuracy')
#         nan_rows = pivoted_group.isna().any(axis=1)
#         df_nan_rows = pivoted_group[nan_rows]
#         for column in df_nan_rows.columns[df_nan_rows.isna().any()].tolist():
#             print(column)
#             setup_nan = list(set(df_nan_rows[df_nan_rows[[column]].isna().any(axis=1)].index.tolist()))
#             result = {}
#             for model, il, q, r in setup_nan:
#                 key = "IL: {} - Q: {}".format(il,q)
#                 result[key] = [model] if key not in result else list(set(result[key] + [model]))
#             for k, v in result.items():
#                 print(k)
#                 print(v)
#             print('-------------------------------------------------------------')
#         print('#############################################')

In [33]:
# for dataset, dataset_grouped in question1_nsp_data.groupby('dataset'):
#     for samples, group in dataset_grouped.groupby('samples'):
#         print("Dataset: {} - Pool: {}".format(dataset, samples))
#         pivoted_group = group.pivot_table(index=['model','n_init_labeled','n_query', 'round'], columns='strategy', values='test_accuracy')
#         display(pivoted_group)        
#         result = autorank(pivoted_group, alpha=0.05, verbose=False)
#         # display(result)
#         create_report(result)
#         plot_stats(result)
#         plt.show()
#         print('#############################################')

In [34]:
def statistical_test_by_dataset_samples(index_columns, target_column):
    for dataset, dataset_grouped in question1_nsp_data.groupby('dataset'):
        for samples, group in dataset_grouped.groupby('samples'):
            print("Dataset: {} - Pool: {}".format(dataset, samples))
            pivoted_group = group.pivot_table(index=index_columns, columns='strategy', values=target_column)
            display(pivoted_group)        
            result = autorank(pivoted_group, alpha=0.05, verbose=False)
            # display(result)
            create_report(result)
            plot_stats(result)
            plt.show()
            print('#############################################')

In [35]:
statistical_test_by_dataset_samples(['model','n_init_labeled','n_query', 'round'], 'test_accuracy')

Dataset: IRIS_STJ - Pool: 640.0


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,strategy,KCenterGreedy,KMeansSampling,LeastConfidence,LeastConfidenceDropout,MarginSampling,MarginSamplingDropout,RandomSampling
model,n_init_labeled,n_query,round,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
BERT,16.0,8.0,0.0,0.65,0.65,0.65,0.64,0.65,0.64,0.65
BERT,16.0,8.0,1.0,0.64,0.65,0.65,0.64,0.65,0.64,0.65
BERT,16.0,8.0,2.0,0.63,0.64,0.65,0.64,0.65,0.64,0.65
BERT,16.0,8.0,3.0,0.63,0.57,0.65,0.64,0.65,0.64,0.65
BERT,16.0,8.0,4.0,0.61,0.62,0.65,0.64,0.65,0.64,0.65
...,...,...,...,...,...,...,...,...,...,...
Legal_BERT_STF,128.0,64.0,1.0,0.65,0.65,0.65,0.65,0.65,0.65,0.62
Legal_BERT_STF,128.0,64.0,2.0,0.65,0.65,0.62,0.65,0.62,0.65,0.65
Legal_BERT_STF,128.0,64.0,3.0,0.40,0.49,0.48,0.65,0.48,0.65,0.40
Legal_BERT_STF,128.0,64.0,4.0,0.41,0.39,0.35,0.65,0.35,0.65,0.56


NameError: name 'autorank' is not defined