In [None]:
import json

import pandas as pd
from sqlalchemy import create_engine
from sqlalchemy.orm import Query

import src
from src.bert.gridsearch.models import Model
from src.bert.gridsearch.models import Result

In [None]:
DBPATH = src.PATH / "tmp/gridsearch.db"
DB = f"sqlite:///{DBPATH}"

In [None]:
engine = create_engine(DB)

In [None]:
query = (
    Query(Model)
    .join(Result)
    .filter(Model.name == "xlm-roberta")
    .with_entities(
        Model.id,
        Model.lr,
        Model.batch_size,
        Model.kfold,
        Model.weight_decay,
        Result.epoch,
        Result.score,
        Result.best_threshold,
        Model.clip,
    )
)
with engine.connect() as con:
    df = pd.read_sql_query(query.statement, con)
    df.best_threshold = df.best_threshold.apply(eval)
    thresh = pd.DataFrame(df.best_threshold.tolist())
    thresh.columns = [f"thresh_{col}" for col in thresh.columns]
    df = pd.concat([df.drop("best_threshold", axis=1), thresh], axis=1)

In [None]:
pd.set_option("display.max_rows", 512)

In [None]:
HYPERPARAMS = ["lr", "batch_size", "weight_decay", "clip"]

best_runs_per_fold = df.groupby(HYPERPARAMS + ["kfold"]).max("score").reset_index()

In [None]:
best_runs_per_fold.groupby(HYPERPARAMS).mean().reset_index().sort_values("score", ascending=False)

Unnamed: 0,lr,batch_size,weight_decay,clip,kfold,id,epoch,score,thresh_0,thresh_1,thresh_2,thresh_3
17,9e-06,8,0.05,5.0,3.0,438.0,6.8,0.701201,0.680575,0.566126,0.689746,0.66525
27,1e-05,8,0.05,0.5,3.0,460.0,6.8,0.699631,0.70883,0.639195,0.660093,0.62446
34,1e-05,16,0.05,1.0,3.0,467.0,7.8,0.699546,0.662954,0.649891,0.717136,0.657541
20,9e-06,16,0.01,5.0,3.0,441.0,7.6,0.699185,0.661501,0.720497,0.735761,0.73534
18,9e-06,16,0.01,0.5,3.0,439.0,7.6,0.699009,0.58371,0.865043,0.689722,0.665074
19,9e-06,16,0.01,1.0,3.0,440.0,7.6,0.698697,0.597272,0.659476,0.726083,0.683667
15,9e-06,8,0.05,0.5,3.0,436.0,6.8,0.698525,0.736853,0.684373,0.735482,0.798517
35,1e-05,16,0.05,5.0,3.0,468.0,7.4,0.698519,0.638093,0.673913,0.605158,0.736257
23,9e-06,16,0.05,5.0,3.0,444.0,7.2,0.698407,0.523441,0.64151,0.591099,0.465405
31,1e-05,16,0.01,1.0,3.0,464.0,7.4,0.698147,0.716151,0.739032,0.712745,0.603661


In [None]:
# Best thresholds

df.loc[df.id.isin([366, 402, 438, 474, 510])].groupby("epoch").mean()

Unnamed: 0_level_0,id,lr,batch_size,kfold,weight_decay,score,clip,thresh_0,thresh_1,thresh_2,thresh_3
epoch,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
1,438.0,9e-06,8.0,3.0,0.05,0.430129,5.0,0.284438,0.320607,0.274925,0.289761
2,438.0,9e-06,8.0,3.0,0.05,0.586985,5.0,0.364797,0.320191,0.346379,0.445821
3,438.0,9e-06,8.0,3.0,0.05,0.666381,5.0,0.332504,0.295839,0.318649,0.280451
4,438.0,9e-06,8.0,3.0,0.05,0.666311,5.0,0.336603,0.239379,0.476275,0.243902
5,438.0,9e-06,8.0,3.0,0.05,0.675052,5.0,0.318503,0.251811,0.36006,0.349562
6,438.0,9e-06,8.0,3.0,0.05,0.689628,5.0,0.565325,0.339535,0.408757,0.232659
7,438.0,9e-06,8.0,3.0,0.05,0.699741,5.0,0.413472,0.265016,0.657122,0.384051
