In [1]:
import sys
import logging
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import optuna
import print_n_log

from sklearn.metrics import classification_report, confusion_matrix
from torch.utils.data.dataloader import DataLoader
from definitions import *
from model_helper_functions import *
from dataset_helper_functions import *
from bi_lstm import BiLSTM
from bert_embedding_model import BertEmbeddingModel
from debates_dataset import DebatesDataset
from early_stopping import EarlyStopping
from optuna.trial import TrialState
from torchvision import transforms
# my transforms
from transforms import *
from scorer.task5 import evaluate_v2

In [2]:
data = {}
optim_path = os.path.join(EXP_DIR_PATH, 'bi-lstm', 'optimization')
# bi-lstm
# - no_feat
# - sent_feat
# - word_feat
training_path = os.path.join(EXP_DIR_PATH, 'bi-lstm', 'training')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_uw_ratio = 0
slf_dim = 0


In [3]:
# studies = [m for m in os.listdir(optim_path) if m.split('_')[-1] != 'params.pkl']
# models_directories = []
# for m in studies:
#     sp = m.split('_')[1:]
#     xx = '_'.join([s for s in sp if s not in {'pNone', 'df0.2', 'wf0.2.pkl'}])
#     if len(xx):
#         models_directories.append(xx)

# for c in models_directories:
#     try:
#         os.mkdir(os.path.join(training_path, c))
#     except Exception as e:
#         print(e.args)

Function for loading data.

In [4]:
def load_data():
    dev_path = os.path.join(PROC_DATA_DIR_PATH, 'dev')

    data_paths = {
        'dev': [
            os.path.join(dev_path, 'dev.tsv'),
        ],
        'test': [
            os.path.join(POLIT_DATA_DIR_PATH, 'test', 'test_combined.tsv'),
        ],
        'train': [
            os.path.join(POLIT_DATA_DIR_PATH, 'train', 'train_combined.tsv'),
        ],
        'val': [
            os.path.join(POLIT_DATA_DIR_PATH, 'val', 'val_combined.tsv'),
        ],
    }

    for dtype, dpaths in data_paths.items():
        try:
            data[dtype] = pd.read_csv(dpaths[0], sep='\t', index_col=False)

        except Exception as e:
            print(e.args)
            exit()
    
    if training_on_weak[0]:
        if training_on_weak[1] == 'balanced_original':
            data['train'], _ = weak_data_merge(merge_type=training_on_weak[1])
        elif training_on_weak[1] == 'weak_only':
            data['train'], data['val'] = weak_data_merge(merge_type=training_on_weak[1], weak_frac=training_on_weak[2])
        else:
            data['train'], data['val'] = weak_data_merge(merge_type=training_on_weak[1])

Datasets and DataLoaders, takes trial as input to be able to suggest values for variables.

In [5]:
def get_loaders(batch_size, transforms_params=None, stopwords_type=None):
    global train_uw_ratio, slf_dim

    transform_pipeline = None

    if transforms_params:
        transforms_map = {
            'sum': Sum,
            'onehot': OneHot,
            'none': NoTransform
        }
        cw_map = {
            'count_words': CountWords,
            'none': NoTransform
        }

        from_sel = transforms_params['from_selection']
  
        pos_feat = transforms_map[transforms_params['pos_feature_type']]
        pos_feat = pos_feat(
            'pos', from_selection=from_sel, stopwords=stopwords_type
        )

        tag_feat = transforms_map[transforms_params['tag_feature_type']]
        tag_feat = tag_feat(
            'tag', from_selection=from_sel, stopwords=stopwords_type
        )
        
        dep_feat = transforms_map[transforms_params['dep_feature_type']]
        dep_feat = dep_feat(
            'dep', from_selection=from_sel, stopwords=stopwords_type
        )

        cw_feat = cw_map[transforms_params['word_count_feature_type']]
        cw_feat = cw_feat()
        
        transform_pipeline = transforms.Compose([
            HandleStopwords(stopwords=stopwords_type),
            pos_feat,
            tag_feat,
            dep_feat,
            cw_feat,
            ToBinary(6),
            ToTensor()
        ])
    print(transform_pipeline)

    train = data['train']
    worthy_train = train[train['label'] == 1]
    train_uw_ratio = (len(train) - len(worthy_train)) / len(worthy_train)
    
    train_dd = DebatesDataset(data=data['train'], transform=transform_pipeline)
    val_dd = DebatesDataset(data=data['val'], transform=transform_pipeline)
    test_dd = DebatesDataset(data=data['test'], transform=transform_pipeline)
    
    if transforms_params:
        slf_dim = train_dd[0][-1].size()[0]

    train_loader = DataLoader(train_dd, batch_size=batch_size, shuffle=True, drop_last=True)
    val_loader = DataLoader(val_dd, batch_size=batch_size, shuffle=True, drop_last=True)
    test_loader = DataLoader(test_dd, batch_size=batch_size, shuffle=True, drop_last=True)

    return train_loader, val_loader, test_loader

Model setup + training loop

In [6]:
def train_model(params, features_params=None, model_checkpoint_path=None, is_finetuning=False, bam=''):
    global logf_path
    # this is here so that it can be accessed here and in get_loaders()
    stopwords_type = None
    if features_params:
        stopwords_type = features_params['stopwords_type'] if 'stopwords_type' in features_params else None

    train_loader, val_loader, test_loader = get_loaders(
        params['batch_size'],
        transforms_params=features_params,
        stopwords_type=stopwords_type
    )

    # best for given trial
    pooling_strategy = params['pooling_strategy']
    dropout = params['dropout']
    hidden_dim = params['hidden_dim']
    w_seq = params['with_sequential_layer']
    w_att = params['w_att']
    lr = params['learning_rate']
    opt_weight_decay = params['optimizer_weigth_decay']
    pos_weight = train_uw_ratio if params['pos_weight'] > 1.0 else 1.0
    
#     fnn_hidden_dim = params['fnn_hidden_dim']
#     fnn_n_layers = params['fnn_n_hidden_layers']
#     fnn_dropout = params['fnn_dropout']
    
#     pnn_hidden_dim = params['pnn_hidden_dim']
#     pnn_dropout = params['pnn_dropout']
    
    word_level_feat = features_params['word_level_feature_type'] if features_params else 'none'
#     word_level_feat = 'none'
   
    # TODO: test these as well
    # remove_stopwords = stopwords_type != 'wstop'
    # dep_feat = trial.suggest_categorical('word_level_dep_features', [True, False])
    # triplet_feat = trial.suggest_categorical('word_level_triplet_features', [True, False])
#     word_level_feat = trial.suggest_categorical('word_level_feature_type', ['dep', 'triplet'])
    embedding_model = BertEmbeddingModel(
        device=device,
        pooling_strat=pooling_strategy,
        scale=params['scale'],
        dep_features=word_level_feat == 'dep',
        triplet_features=word_level_feat == 'triplet',
#         remove_stopwords=stopwords_type == 'wostop'
    )

    model = BiLSTM(
        dropout=dropout,
        hidden_dim=hidden_dim,
        embedding_dim=embedding_model.dim,
        sent_level_feature_dim=slf_dim,
        device=device,
        w_seq=w_seq,
        w_att=w_att,
    ).to(device)    

    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=opt_weight_decay)

    criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight]).to(device))
#     criterion = nn.SmoothL1Loss(reduction='sum')
    if is_finetuning:
        load_checkpoint(model_checkpoint_path, model, optimizer, device, bam=bam)
    
    early_stopping = EarlyStopping(
        patience=10,
        path=model_checkpoint_path,
        verbose=False,
        trace_func=print_n_log.run('early_stopping', logf_path, 'DEBUG')
    )

    n_epochs = 30
    threshold = 0.5
    val_losses, train_losses, val_clf_reports, train_clf_reports = [], [], [], []

    for epoch in range(n_epochs):
        epoch_val_losses, epoch_train_losses = [], []

        model.train()
        y_pred, y_true = [], []
        for ids, sentences, labels, features in train_loader:
            labels = labels.float().to(device)
            features = features.to(device)
            
            embeddings, lengths = embedding_model(sentences)
            output = model(embeddings, lengths.cpu(), sent_level_features=features)
            loss = criterion(output, labels)
            
            loss.backward()
            epoch_train_losses.append(loss.item())
            
            pred = torch.sigmoid(output)
            pred = (pred > threshold).int()
            y_pred.extend(pred.tolist())
            y_true.extend(labels.tolist())

            optimizer.step()
            optimizer.zero_grad()
            
        cr = classification_report(y_true, y_pred, digits=6, output_dict=True, zero_division=0)
        train_clf_reports.append(cr)

        model.eval()
        y_pred, y_true = [], []
        with torch.no_grad():
            for val_ids, val_sentences, val_labels, val_features in val_loader:
                val_labels = val_labels.float().to(device)
                val_features = val_features.to(device)
                
                val_embeddings, val_lengths = embedding_model(val_sentences)
                pred = model(val_embeddings, val_lengths.cpu(), sent_level_features=val_features)
                val_loss = criterion(pred, val_labels)
                epoch_val_losses.append(val_loss.item())
                
                pred = torch.sigmoid(pred)
                
                pred = (pred > threshold).int()
                y_pred.extend(pred.tolist())
                y_true.extend(val_labels.tolist())
        
        val_losses.append(np.average(epoch_val_losses))
        train_losses.append(np.average(epoch_train_losses))
        avg_val_loss = np.average(val_losses)
        print(
            'epoch ==> ', epoch,
            ' | avg train loss ==> ', np.average(train_losses),
            ' | avg val loss ==> ', avg_val_loss
        )
        print(classification_report(y_true, y_pred, digits=6))
        cr = classification_report(y_true, y_pred, digits=6, output_dict=True, zero_division=0)
        val_clf_reports.append(cr)
        
        early_stopping(
            val_loss=avg_val_loss,
            model=model,
            optimizer=optimizer,
            train_losses=train_losses,
            val_losses=val_losses,
            train_clf_reports=train_clf_reports,
            val_clf_reports=val_clf_reports,
            acomp_metrics=('f1_p', cr['1.0']['f1-score'])
        )
        
        if early_stopping.early_stop:
            print('early stopping...')
            break

    # recall_p = early_stopping.acomp_metrics['recall_p'] if early_stopping.acomp_metrics else 0.0
    "Done."
#     return model, 

In [7]:
def evaluate_model(params, features_params=None, load_path=None, bam=''):
    stopwords_type = None
    if features_params is not None:
        stopwords_type = features_params['stopwords_type'] if 'stopwords_type' in features_params else None

    train_loader, val_loader, test_loader = get_loaders(
        params['batch_size'],
        transforms_params=features_params,
        stopwords_type=stopwords_type
    )

    # best for given trial
    pooling_strategy = params['pooling_strategy']
    dropout = params['dropout']
    hidden_dim = params['hidden_dim']
    w_seq = params['with_sequential_layer']
    w_att = params['w_att']
    lr = params['learning_rate']
    opt_weight_decay = params['optimizer_weigth_decay']
    pos_weight = train_uw_ratio if params['pos_weight'] > 1.0 else 1.0
    
#     fnn_hidden_dim = params['fnn_hidden_dim']
#     fnn_n_layers = params['fnn_n_hidden_layers']
#     fnn_dropout = params['fnn_dropout']
    
#     pnn_hidden_dim = params['pnn_hidden_dim']
#     pnn_dropout = params['pnn_dropout']
    
    word_level_feat = features_params['word_level_feature_type'] if features_params else 'none'
#     word_level_feat = 'none'
    embedding_model = BertEmbeddingModel(
        device=device,
        pooling_strat=pooling_strategy,
        scale=params['scale'],
        dep_features=word_level_feat == 'dep',
        triplet_features=word_level_feat == 'triplet',
#         remove_stopwords=stopwords_type == 'wostop'
    )
        
    model = BiLSTM(
        dropout=dropout,
        hidden_dim=hidden_dim,
        embedding_dim=embedding_model.dim,
        sent_level_feature_dim=slf_dim,
        device=device,
        w_seq=w_seq,
        w_att=w_att,
    ).to(device)
    
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=opt_weight_decay)
    criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight]).to(device))
#     criterion = nn.SmoothL1Loss(reduction='sum')
    
    load_checkpoint(load_path, model, optimizer, device, bam=bam)
    
    threshold = 0.5
    y_pred = []
    y_true = []
    scores = []
    ids = []

    model.eval()
    with torch.no_grad():
        for test_ids, test_sentences, test_labels, test_features in test_loader:           
            test_labels = test_labels.float().to(device)
            test_features = test_features.to(device)

            embeddings, lengths = embedding_model(test_sentences)
            output = torch.sigmoid(model(embeddings, lengths.cpu(), sent_level_features=test_features))
            
            ids.extend(test_ids.tolist())
            scores.extend(output.tolist())
            output = (output > threshold).int()
            y_pred.extend(output.tolist())
            y_true.extend(test_labels.tolist())


    predictions = list(zip(ids, scores))
    predictions = sorted(
        predictions, 
        key=lambda x: x[0]
    )
    _, _, avg_precision, rr, num_relevant = evaluate_v2(predictions)
    print('Avg. precision: ', avg_precision)
    print('Classification Report:')
    print(classification_report(y_true, y_pred, digits=4)) #

In [8]:
# load_data()

In [9]:
# study_path = os.path.join(optim_path, 'bi-lstm_wAtt_sTPE_pNone_df0.2_wf0.2.pkl')
# params_path = os.path.join(optim_path, 'bi-lstm_featOptim_wAtt_sTPE_pNone_df0.2_wf0.2_params.pkl')
# model_checkpoint_path = os.path.join(training_path, 'wAtt_sTPE')
model_optim_path = os.path.join(optim_path, 'bi-lstm_NO_FEAT_mF1_wAtt_sTPE_pNone_df0.2_wf0.03')
# feature_params_path = f'{model_optim_path}_featureParams.pkl'
# params_path = f'{model_optim_path}_params.pkl'
study_path = f'{model_optim_path}.pkl'


# for now ignore features 
# studies = [s for s in studies if '_params' not in s and 'featOptim' not in s]
is_training = True
is_finetuning = True
training_on_weak = (True, 'weak_only', 0.75)
load_data()
# bam = ''
bam = 'best_f1_p_'
# for study_name in studies[:1]:
# TODO: check whether all of these exist
study = torch.load(os.path.join(optim_path, study_path))
params = study.best_params
# params['w_att'] = False
params['with_sequential_layer'] = False
print(params)
# features_params = torch.load(feature_params_path)
features_params = None
# features_params = { 'stopwords_type': 'wostop', 'from_selection': True, 'dep_feature_type': 'onehot', 'word_count_feature_type': 'count_words' }
# features_params = {'stopwords_type': 'wstop', 'word_count_feature_type': 'count_words', 'word_level_feature_type': 'dep', 'from_selection': False}
print(features_params)

# add this manually
# features_params['word_level_feature_type'] = 'dep'
# features_params['stopwords_type'] = 'wstop'
# del features_params['dep_feature_type']
# TODO: este raz spustit word --> { 'stopwords_type': 'wstop', 'word_count_feature_type': 'count_words', 'word_level_feature_type': 'dep' }
checkpoint_dir = 'no_feat_no_seq_pre_train_075'
#TODO: tu este chyba w_feat_sent
# checkpoint_dir = [d for d in models_directories if d in study_name][0]
# v2 params: {'stopwords_type': 'wstop', 'from_selection': False, 'dep_feature_type': 'sum', 'word_count_feature_type': 'count_words'}
study_log_name = 'bi-lstm_noFeat_noFeatNoSeqPreTrain075'
# params = study.best_params
logf_path = os.path.join(LOG_DIR_PATH, f'training_{study_log_name}.log')

if features_params is not None:
    for ft in ['pos', 'tag', 'dep', 'word_count', 'word_level']:
        feature_type = f'{ft}_feature_type'
        if feature_type not in features_params:
            features_params[feature_type] = 'none'
            
for p in ['w_att', 'scale', 'with_sequential_layer']:
    if p not in params:
        params[p] = False

if isinstance(params['pos_weight'], bool):
    params['pos_weight'] = 2.0 if params['pos_weight'] else 1.0

checkpoint_path = os.path.join(training_path, checkpoint_dir)
if is_training:
    train_model(params, features_params, checkpoint_path)
    if is_finetuning:
        training_on_weak = (False, 'weak_only')
        load_data()
        tr_bam = 'best_f1_p_' if os.path.isfile(os.path.join(checkpoint_path, 'best_f1_p_checkpoint.pt')) else ''
        train_model(params, features_params, checkpoint_path, is_finetuning, tr_bam)
else:
    evaluate_model(params, features_params, checkpoint_path, bam=bam)

{'batch_size': 16, 'pooling_strategy': 'last_four', 'dropout': 0.46, 'hidden_dim': 256, 'with_sequential_layer': False, 'learning_rate': 7.534940239689385e-05, 'optimizer_weigth_decay': 6.948459981308478e-06, 'pos_weight': 1.0}
None
None


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


epoch ==>  0  | avg train loss ==>  0.19477999946268942  | avg val loss ==>  0.1760047251370684
              precision    recall  f1-score   support

         0.0   0.960813  0.939332  0.949951     19681
         1.0   0.782077  0.850367  0.814794      5039

    accuracy                       0.921197     24720
   macro avg   0.871445  0.894850  0.882372     24720
weighted avg   0.924379  0.921197  0.922400     24720

epoch ==>  1  | avg train loss ==>  0.17363981063257158  | avg val loss ==>  0.1673522035594318
              precision    recall  f1-score   support

         0.0   0.951119  0.962861  0.956954     19683
         1.0   0.847518  0.806631  0.826569      5037

    accuracy                       0.931028     24720
   macro avg   0.899318  0.884746  0.891762     24720
weighted avg   0.930009  0.931028  0.930387     24720

epoch ==>  2  | avg train loss ==>  0.1563355647024466  | avg val loss ==>  0.16677611665194902
              precision    recall  f1-score   support

   

INFO : Early stopping with best value: 0.16677611665194902 and acompanying metrics: ('f1_p', 0.8285543608124254)


epoch ==>  12  | avg train loss ==>  0.05827413506577582  | avg val loss ==>  0.266263726064708
              precision    recall  f1-score   support

         0.0   0.946271  0.970788  0.958373     19684
         1.0   0.872956  0.784551  0.826396      5036

    accuracy                       0.932848     24720
   macro avg   0.909614  0.877670  0.892385     24720
weighted avg   0.931335  0.932848  0.931486     24720

early stopping...
None


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Model loaded from <== /home/jovyan/sharedstorage/s12b3v/dp/dt/exp/bi-lstm/training/no_feat_no_seq_pre_train_075/best_f1_p_checkpoint.pt
epoch ==>  0  | avg train loss ==>  0.12187815043916836  | avg val loss ==>  0.10187296854493058
              precision    recall  f1-score   support

         0.0   0.974016  0.999736  0.986708      3787
         1.0   0.000000  0.000000  0.000000       101

    accuracy                       0.973765      3888
   macro avg   0.487008  0.499868  0.493354      3888
weighted avg   0.948714  0.973765  0.961076      3888

epoch ==>  1  | avg train loss ==>  0.10886922065648971  | avg val loss ==>  0.10243802377379235
              precision    recall  f1-score   support

         0.0   0.974253  0.999472  0.986701      3786
         1.0   0.500000  0.019608  0.037736       102

    accuracy                       0.973765      3888
   macro avg   0.737127  0.509540  0.512219      3888
weighted avg   0.961812  0.973765  0.961806      3888

epoch ==>  2  | 

INFO : Early stopping with best value: 0.10187296854493058 and acompanying metrics: ('f1_p', 0.0)


epoch ==>  10  | avg train loss ==>  0.03748836167160428  | avg val loss ==>  0.15310698551627178
              precision    recall  f1-score   support

         0.0   0.976841  0.991548  0.984139      3786
         1.0   0.288889  0.127451  0.176871       102

    accuracy                       0.968879      3888
   macro avg   0.632865  0.559499  0.580505      3888
weighted avg   0.958793  0.968879  0.962961      3888

early stopping...


### results
- ***no_feat:***
    - **last checkpoint:** 
    ```
        Avg. precision:  0.05039179434847546
        Classification Report:
                      precision    recall  f1-score   support

                 0.0     0.9948    0.5435    0.7029      6328
                 1.0     0.0392    0.8676    0.0751       136

            accuracy                         0.5503      6464
           macro avg     0.5170    0.7056    0.3890      6464
        weighted avg     0.9747    0.5503    0.6897      6464
    ```
    - **best f1 checkpoint:**
    ```
        Avg. precision:  0.05696354972124724
        Classification Report:
                      precision    recall  f1-score   support

                 0.0     0.9986    0.3434    0.5111      6328
                 1.0     0.0310    0.9779    0.0601       136

            accuracy                         0.3567      6464
           macro avg     0.5148    0.6607    0.2856      6464
        weighted avg     0.9783    0.3567    0.5016      6464
    ```
    - **last checkpoint - weak_simple:**
    ```
        Avg. precision:  0.05155608847730918
        Classification Report:
                      precision    recall  f1-score   support

                 0.0     0.9974    0.4850    0.6526      6328
                 1.0     0.0378    0.9412    0.0727       136

            accuracy                         0.4946      6464
           macro avg     0.5176    0.7131    0.3626      6464
        weighted avg     0.9772    0.4946    0.6404      6464
    ```
    - **best f1 checkpoint - weak_simple:**
    ```
        Avg. precision:  0.052806869371306464
        Classification Report:
                      precision    recall  f1-score   support

                 0.0     0.9970    0.5240    0.6870      6328
                 1.0     0.0402    0.9265    0.0770       136

            accuracy                         0.5325      6464
           macro avg     0.5186    0.7252    0.3820      6464
        weighted avg     0.9769    0.5325    0.6741      6464
    ```
    - **last checkpoint - weak_balanced_result:**
    ```
        Avg. precision:  0.039851885803194445
        Classification Report:
                      precision    recall  f1-score   support

                 0.0     0.9823    0.8938    0.9360      6328
                 1.0     0.0482    0.2500    0.0808       136

            accuracy                         0.8803      6464
           macro avg     0.5152    0.5719    0.5084      6464
        weighted avg     0.9626    0.8803    0.9180      6464
    ```
    - **best f1 checkpoint - weak_balanced_result:**
    ```
        Avg. precision:  0.036280559917958234
        Classification Report:
                      precision    recall  f1-score   support

                 0.0     0.9835    0.8210    0.8949      6328
                 1.0     0.0415    0.3603    0.0744       136

            accuracy                         0.8113      6464
           macro avg     0.5125    0.5906    0.4846      6464
        weighted avg     0.9637    0.8113    0.8777      6464
    ```
    - **last checkpoint - weak_balanced_original:**
    ```
        Avg. precision:  0.04849566628376427
        Classification Report:
                      precision    recall  f1-score   support

                 0.0     0.9793    0.9960    0.9876      6328
                 1.0     0.1071    0.0221    0.0366       136

            accuracy                         0.9756      6464
           macro avg     0.5432    0.5091    0.5121      6464
        weighted avg     0.9610    0.9756    0.9676      6464
    ```
    - **best f1 checkpoint - weak_balanced_original:**
    ```
        --
    ```
    - **last checkpoint - no_att:**
    ```
        Avg. precision:  0.09279971384085937
        Classification Report:
                      precision    recall  f1-score   support

                 0.0     0.9801    0.9954    0.9877      6328
                 1.0     0.2162    0.0588    0.0925       136

            accuracy                         0.9757      6464
           macro avg     0.5982    0.5271    0.5401      6464
        weighted avg     0.9640    0.9757    0.9689      6464
    ```
    - **best f1 checkpoint - no_att:**
    ```
        Avg. precision:  0.09880078753917501
        Classification Report:
                      precision    recall  f1-score   support

                 0.0     0.9809    0.9910    0.9859      6328
                 1.0     0.1972    0.1029    0.1353       136

            accuracy                         0.9723      6464
           macro avg     0.5890    0.5470    0.5606      6464
        weighted avg     0.9644    0.9723    0.9680      6464
    ```
    - **last checkpoint - no_att_no_seq:**
    ```
        Avg. precision:  0.13303394554800413
        Classification Report:
                      precision    recall  f1-score   support

                 0.0     0.9793    0.9997    0.9894      6329
                 1.0     0.3333    0.0074    0.0145       135

            accuracy                         0.9790      6464
           macro avg     0.6563    0.5035    0.5019      6464
        weighted avg     0.9658    0.9790    0.9690      6464
    ```
    - **last checkpoint - no_feat_no_seq_pre_train_0.25:**
    ```
        Avg. precision:  0.11418210276223052
        Classification Report:
                      precision    recall  f1-score   support

                 0.0     0.9792    0.9975    0.9883      6328
                 1.0     0.1111    0.0147    0.0260       136

            accuracy                         0.9768      6464
           macro avg     0.5452    0.5061    0.5071      6464
        weighted avg     0.9609    0.9768    0.9680      6464
    ```
    - **best f1 checkpoint - no_feat_no_seq_pre_train_0.25:**
    ```
        Avg. precision:  0.11207229707910493
        Classification Report:
                      precision    recall  f1-score   support

                 0.0     0.9799    0.9957    0.9878      6329
                 1.0     0.1818    0.0444    0.0714       135

            accuracy                         0.9759      6464
           macro avg     0.5809    0.5201    0.5296      6464
        weighted avg     0.9633    0.9759    0.9686      6464
    ```
    - **last checkpoint - no_feat_no_seq_pre_train_0.5**
    ```
        Avg. precision:  0.11531658383033407
        Classification Report:
                      precision    recall  f1-score   support

                 0.0     0.9789    0.9987    0.9887      6328
                 1.0     0.0000    0.0000    0.0000       136

            accuracy                         0.9777      6464
           macro avg     0.4895    0.4994    0.4944      6464
        weighted avg     0.9583    0.9777    0.9679      6464
    ```
    - **best f1 checkpoint - no_feat_no_seq_pre_train_0.5**
    ```
        Avg. precision:  0.11337885214277875
        Classification Report:
                      precision    recall  f1-score   support

                 0.0     0.9792    0.9975    0.9883      6328
                 1.0     0.1111    0.0147    0.0260       136

            accuracy                         0.9768      6464
           macro avg     0.5452    0.5061    0.5071      6464
        weighted avg     0.9609    0.9768    0.9680      6464
    ```
    
- ***sent_feat:***
    - **last checkpoint:**
    ```
        Avg. precision:  0.02627628114627844
        Classification Report:
                      precision    recall  f1-score   support

                 0.0     0.9833    0.3347    0.4994      6329
                 1.0     0.0230    0.7333    0.0445       135

            accuracy                         0.3430      6464
           macro avg     0.5031    0.5340    0.2719      6464
        weighted avg     0.9632    0.3430    0.4899      6464
    ```
    - **best f1 checkpoint:**
    ```
        Avg. precision:  0.04701524372943559
        Classification Report:
                      precision    recall  f1-score   support

                 0.0     0.9905    0.3310    0.4962      6329
                 1.0     0.0264    0.8519    0.0513       135

            accuracy                         0.3419      6464
           macro avg     0.5085    0.5914    0.2738      6464
        weighted avg     0.9704    0.3419    0.4869      6464
    ```
    - **last checkpoint - no_att_no_seq:**
    ```
        Avg. precision:  0.1471588030651484
        Classification Report:
                      precision    recall  f1-score   support

                 0.0     0.9790    1.0000    0.9894      6328
                 1.0     0.0000    0.0000    0.0000       136

            accuracy                         0.9790      6464
           macro avg     0.4895    0.5000    0.4947      6464
        weighted avg     0.9584    0.9790    0.9686      6464
    ```
    - **best f1 checkpoint - no_att_no_seq:**
    ```
        Avg. precision:  0.14129833427615582
        Classification Report:
                      precision    recall  f1-score   support

                 0.0     0.9798    0.9987    0.9892      6329
                 1.0     0.3846    0.0370    0.0676       135

            accuracy                         0.9787      6464
           macro avg     0.6822    0.5179    0.5284      6464
        weighted avg     0.9674    0.9787    0.9700      6464
    ```
    

- ***word_feat:***
    - **last checkpoint:**
    ```
        Avg. precision:  0.05579242951469682
        Classification Report:
                      precision    recall  f1-score   support

                 0.0     0.9881    0.4718    0.6386      6329
                 1.0     0.0288    0.7333    0.0554       135

            accuracy                         0.4773      6464
           macro avg     0.5084    0.6026    0.3470      6464
        weighted avg     0.9681    0.4773    0.6265      6464
    ```
    - **best f1 checkpoint:**
    ```
        Avg. precision:  0.06383728548256118
        Classification Report:
                      precision    recall  f1-score   support

                 0.0     0.9869    0.6533    0.7862      6329
                 1.0     0.0352    0.5926    0.0664       135

            accuracy                         0.6521      6464
           macro avg     0.5110    0.6230    0.4263      6464
        weighted avg     0.9670    0.6521    0.7712      6464
    ```

In [10]:
# TODO: this is kept here just in case
# optuna.logging.get_logger("optuna").addHandler(logging.StreamHandler(sys.stdout))

# needed for GridSampler
# search_space = {
#     'batch_size': [16, 32, 64],
#     'pooling_strategy': ['last_four', 'last_four_sum', 'second_last'],
# #     'should_scale_emb': [False, True],
#     'dropout': [i/100 for i in range(0, 51, 5)],
#     'hidden_dim': [128, 256, 512],
#     'optimizer_weigth_decay': [i/10000 for i in range(11)],
#     'learning_rate': round_to_first_non_zero([i/100000 for i in range_inc(0, 100000, 1, 10)]),
#     'pos_weight': [1.0, train_uw_ratio]
# }
# feature_search_space = {
#     'stopwords_type': ['wstop', 'wostop'],
#     'from_selection': [True, False],
#     'pos_feature_type': ['sum', 'onehot', 'none'],
#     'tag_feature_type': ['sum', 'onehot', 'none'],
#     'word_count_feature_type': ['count_words', 'none'],
# #     'word_level_feature_type': ['dep', 'triplet']
# }
# # print(search_space)
# params = {
#     'batch_size': 32,
#     'pooling_strategy': 'second_last',
#     'dropout': 0.39,
#     'hidden_dim': 256,
#     'w_seq': True,
#     'lr': 0.004118121,
#     'opt_weight_decay': 0.024460049,
#     'pos_weight': train_uw_ratio,
# }