In [1]:
import sys
import logging
import os.path as p
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import optuna
import print_n_log

from sklearn.metrics import classification_report, confusion_matrix
from torch.utils.data.dataloader import DataLoader
from definitions import *
from model_helper_functions import *
from dataset_helper_functions import *
from sent_nn import SentNN
from sentence_transformers import SentenceTransformer
from debates_dataset import DebatesDataset
from early_stopping import EarlyStopping
from optuna.trial import TrialState
from torchvision import transforms
# my transforms
from transforms import *

In [2]:
data = {}
optim_path = os.path.join(EXP_DIR_PATH, 'sent-nn', 'optimization')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_uw_ratio = 0
dataset_frac = 0.1
worthy_frac = 0.2
slf_dim = 0
rs = 22

Function for loading data.

In [3]:
def load_data():
    dev_path = p.join(PROC_DATA_DIR_PATH, 'dev')

    data_paths = {
        'dev': [
            p.join(dev_path, 'dev.tsv'),
        ],
        'test': [
            p.join(POLIT_DATA_DIR_PATH, 'test', 'test_combined.tsv'),
        ],
        'train': [
            p.join(POLIT_DATA_DIR_PATH, 'train', 'train_combined.tsv'),
        ],
        'val': [
            p.join(POLIT_DATA_DIR_PATH, 'val', 'val_combined.tsv'),
        ],
    }

    for dtype, dpaths in data_paths.items():
        try:
            data[dtype] = pd.read_csv(dpaths[0], sep='\t', index_col=False)

        except Exception as e:
            print(e.args)
            exit()

Datasets and DataLoaders, takes trial as input to be able to suggest values for variables.

In [4]:
def get_loaders(trial, stopwords_type):
    global train_uw_ratio, slf_dim
    # dev_df, test_df, train_df, val_df = data.values()
    subsets = {}
    for k, df in data.items():

        n_subset = int(len(df)*dataset_frac)

        worthy_df = df.loc[df['label'] == 1]
        n_worthy = int(min(n_subset*worthy_frac, len(worthy_df)))
        worthy_df = worthy_df.sample(n=n_worthy, random_state=rs)

        unworthy_df = df.loc[df['label'] == 0].sample(
            n=n_subset-n_worthy,
            random_state=rs
        )
        if k == 'train':
            train_uw_ratio = len(unworthy_df) / len(worthy_df)
        # sample(frac=1.0) -> shuffle
        subsets[k] = worthy_df.append(unworthy_df).sample(frac=1.0, random_state=rs, ignore_index=True)

    # TODO: for sentence level feature optimization    
    transforms_map = {
        'sum': Sum,
        'onehot': OneHot,
        'none': NoTransform
    }
    transforms_options = list(transforms_map.keys())
    cw_map = {
        'count_words': CountWords,
        'none': NoTransform
    }
    cw_options = list(cw_map.keys())

    from_sel = trial.suggest_categorical('from_selection', [True, False])

    # trial.suggest_categorical returns one of the keys of transforms_map, which then return Transform or None
    # if transform --> initialize
    pos_feat = transforms_map[trial.suggest_categorical('pos_feature_type', transforms_options)]
    pos_feat = pos_feat(
        'pos', from_selection=from_sel, stopwords=stopwords_type
    )

    tag_feat = transforms_map[trial.suggest_categorical('tag_feature_type', transforms_options)]
    tag_feat = tag_feat(
        'tag', from_selection=from_sel, stopwords=stopwords_type
    )

    cw_feat = cw_map[trial.suggest_categorical('word_count_feature_type', cw_options)]
    cw_feat = cw_feat()
    
    transform_pipeline = transforms.Compose([
        HandleStopwords(stopwords=stopwords_type),
        pos_feat,
        tag_feat,
        cw_feat,
        ToBinary(6),
        ToTensor()
    ])
    # transform_pipeline = None
#     transform_pipeline = None

    train_dd = DebatesDataset(data=subsets['train'], transform=transform_pipeline)
    val_dd = DebatesDataset(data=subsets['val'], transform=transform_pipeline)
    test_dd = DebatesDataset(data=subsets['test'], transform=transform_pipeline)
    
    # slf_dim = train_dd[0][-1].size()[0]
#     batch_size = trial.suggest_categorical('batch_size', [16, 32, 64, 128])
    batch_size = 16
    train_loader = DataLoader(train_dd, batch_size=batch_size, shuffle=True, drop_last=True)
    val_loader = DataLoader(val_dd, batch_size=batch_size, shuffle=True, drop_last=True)
    test_loader = DataLoader(test_dd, batch_size=batch_size, shuffle=True, drop_last=True)

    return train_loader, val_loader, test_loader

Model setup + training loop

In [5]:
def objective(trial):
    global logf_path
    # this is here so that it can be accessed here and in get_loaders()
    stopwords_type = trial.suggest_categorical('stopwords_type', ['wstop', 'wostop'])
    # unused is test_loader    
    train_loader, val_loader, _ = get_loaders(trial, stopwords_type)
    
    # hyperparams opt
#     dropout = trial.suggest_float('dropout', 0.0, 0.5, step=0.01)
#     hidden_dim = trial.suggest_categorical('hidden_dim', [128, 256, 512])
#     w_seq = trial.suggest_categorical('with_sequential_layer', [True, False])
#     lr = trial.suggest_float('learning_rate', 1e-5, 1e-1, log=True)
#     opt_weight_decay = trial.suggest_float('optimizer_weigth_decay', 1e-6, 0.1, log=True)
#     pos_weight = trial.suggest_categorical('pos_weight', [1.0, train_uw_ratio])
#     emb_model_name = trial.suggest_categorical(
#         'embedding_model_name',
#         ['all-mpnet-base-v2', 'all-MiniLM-L6-v2', 'multi-qa-mpnet-base-dot-v1']
#     )
    
    # temp_best
    dropout = 0.09
    hidden_dim = 128
    w_seq = False
    lr = 0.03698629814522988
    opt_weight_decay = 0.02355650972967366
    pos_weight = train_uw_ratio
    emb_model_name = 'all-mpnet-base-v2'

    
    emb_size_map = {
        'all-mpnet-base-v2': 768,
        'all-MiniLM-L6-v2': 384,
        'multi-qa-mpnet-base-dot-v1': 768
    }
    # emb_model_name = 'all-MiniLM-L6-v2'
    embedding_model = SentenceTransformer(emb_model_name, device=device, cache_folder=SBERT_MODEL_PATH)
    
    model = SentNN(
        embeddings_dim=emb_size_map[emb_model_name],
        sentence_level_feature_dim=slf_dim,
        dropout=dropout,
        w_seq=w_seq
    ).to(device)    

    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=opt_weight_decay)

    criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight]).to(device))

    n_epochs = 16
    threshold = 0.5
    early_stopping = EarlyStopping(
        patience=5,
        path=None,
        verbose=False,
        trace_func=print_n_log.run('early_stopping', logf_path, 'DEBUG')
    )
    # training
    for epoch in range(n_epochs):
        losses, val_losses = [], []
        model.train()
        for ids, sentences, labels, features in train_loader:
            labels = labels.float().to(device)

            optimizer.zero_grad()

            embeddings = embedding_model.encode(sentences, convert_to_tensor=True)
            output = model(embeddings)
            loss = criterion(output, labels)
            losses.append(loss.item())
            loss.backward()
            optimizer.step()

        model.eval()
        y_pred, y_true = [], []
        with torch.no_grad():
            for val_ids, val_sentences, val_labels, val_features in val_loader:
                val_labels = val_labels.float().to(device)

                val_embeddings = embedding_model.encode(val_sentences, convert_to_tensor=True)
                pred = model(val_embeddings)
                loss = criterion(pred, val_labels)
                val_losses.append(loss.item())

                pred = torch.sigmoid(pred)

                pred = (pred > threshold).int()
                y_pred.extend(pred.tolist())
                y_true.extend(val_labels.tolist())

        # print(f'train_loss: {np.average(losses)} | val_loss: {np.average(val_losses)}')
        cr = classification_report(y_true, y_pred, output_dict=True, digits=6, zero_division=0)
        # recall_p = cr['1.0']['recall']
        # print(recall_p)
        val_loss = np.average(val_losses)
        early_stopping(val_loss, model, acomp_metrics={'recall_p': cr['1.0']['recall']})
        
        if early_stopping.early_stop:
            break
        # trial.report(recall_p, epoch)

        # # Handle pruning based on the intermediate value.
        # if trial.should_prune():
        #     raise optuna.exceptions.TrialPruned()
    recall_p = early_stopping.acomp_metrics['recall_p'] if early_stopping.acomp_metrics else 0.0
    "Done."
    return recall_p

In [6]:
load_data()
# print('final recall: ', objective(None))

In [None]:
# # optuna.logging.get_logger("optuna").addHandler(logging.StreamHandler(sys.stdout))
feature_search_space = {
    'stopwords_type': ['wstop', 'wostop'],
    'from_selection': [True, False],
    'pos_feature_type': ['sum', 'onehot', 'none'],
    'tag_feature_type': ['sum', 'onehot', 'none'],
    'word_count_feature_type': ['count_words', 'none'],
#     'word_level_feature_type': ['dep', 'triplet']
}
params = {
    'batch_size': 16,
    'dropout': 0.09,
    'hidden_dim': 128,
    'with_sequential_layer': False,
    'learning_rate': 0.03698629814522988,
    'optimizer_weigth_decay': 0.02355650972967366,
    'pos_weight': 4.0042918454935625,
    'embedding_model_name': 'all-mpnet-base-v2'
}
study = optuna.create_study(
    study_name=f'sent_nn_featOptim_sGrid_pNone_df{dataset_frac}_wf{worthy_frac}',
#     sampler=optuna.samplers.TPESampler(),
    sampler=optuna.samplers.GridSampler(feature_search_space),
    # pruner=optuna.pruners.MedianPruner(),
    direction='maximize'
)
logf_path = p.join(LOG_DIR_PATH, f'{study.study_name}.log')
study.optimize(objective, n_trials=150)

study_path = os.path.join(optim_path, f'{study.study_name}.pkl')
torch.save(study, study_path)
torch.save(params, f'{os.path.join(optim_path, study.study_name)}_params.pkl')

[32m[I 2022-03-29 20:58:39,193][0m A new study created in memory with name: sent_nn_featOptim_sGrid_pNone_df0.1_wf0.2[0m
[nltk_data] Downloading package punkt to /home/jovyan/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[32m[I 2022-03-29 21:03:08,857][0m Trial 0 finished with value: 0.72 and parameters: {'stopwords_type': 'wstop', 'from_selection': False, 'pos_feature_type': 'none', 'tag_feature_type': 'none', 'word_count_feature_type': 'none'}. Best is trial 0 with value: 0.72.[0m


4


[nltk_data] Downloading package punkt to /home/jovyan/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[32m[I 2022-03-29 21:05:38,961][0m Trial 1 finished with value: 0.8571428571428571 and parameters: {'stopwords_type': 'wstop', 'from_selection': True, 'pos_feature_type': 'onehot', 'tag_feature_type': 'sum', 'word_count_feature_type': 'none'}. Best is trial 1 with value: 0.8571428571428571.[0m


19
49


[nltk_data] Downloading package punkt to /home/jovyan/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[32m[I 2022-03-29 21:07:51,880][0m Trial 2 finished with value: 0.6578947368421053 and parameters: {'stopwords_type': 'wstop', 'from_selection': False, 'pos_feature_type': 'onehot', 'tag_feature_type': 'onehot', 'word_count_feature_type': 'none'}. Best is trial 1 with value: 0.8571428571428571.[0m


2


[nltk_data] Downloading package punkt to /home/jovyan/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[32m[I 2022-03-29 21:09:59,076][0m Trial 3 finished with value: 0.9333333333333333 and parameters: {'stopwords_type': 'wostop', 'from_selection': True, 'pos_feature_type': 'onehot', 'tag_feature_type': 'sum', 'word_count_feature_type': 'none'}. Best is trial 3 with value: 0.9333333333333333.[0m


6


[nltk_data] Downloading package punkt to /home/jovyan/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[32m[I 2022-03-29 21:14:06,833][0m Trial 4 finished with value: 0.6533333333333333 and parameters: {'stopwords_type': 'wstop', 'from_selection': True, 'pos_feature_type': 'sum', 'tag_feature_type': 'onehot', 'word_count_feature_type': 'none'}. Best is trial 3 with value: 0.9333333333333333.[0m


49


[nltk_data] Downloading package punkt to /home/jovyan/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[32m[I 2022-03-29 21:19:06,292][0m Trial 5 finished with value: 0.7866666666666666 and parameters: {'stopwords_type': 'wstop', 'from_selection': False, 'pos_feature_type': 'sum', 'tag_feature_type': 'onehot', 'word_count_feature_type': 'none'}. Best is trial 3 with value: 0.9333333333333333.[0m
[nltk_data] Downloading package punkt to /home/jovyan/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[32m[I 2022-03-29 21:22:07,254][0m Trial 6 finished with value: 0.7792207792207793 and parameters: {'stopwords_type': 'wostop', 'from_selection': True, 'pos_feature_type': 'sum', 'tag_feature_type': 'none', 'word_count_feature_type': 'count_words'}. Best is trial 3 with value: 0.9333333333333333.[0m


6


[nltk_data] Downloading package punkt to /home/jovyan/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[32m[I 2022-03-29 21:24:58,874][0m Trial 7 finished with value: 0.8533333333333334 and parameters: {'stopwords_type': 'wstop', 'from_selection': True, 'pos_feature_type': 'none', 'tag_feature_type': 'onehot', 'word_count_feature_type': 'count_words'}. Best is trial 3 with value: 0.9333333333333333.[0m
[nltk_data] Downloading package punkt to /home/jovyan/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[32m[I 2022-03-29 21:28:26,009][0m Trial 8 finished with value: 0.6710526315789473 and parameters: {'stopwords_type': 'wstop', 'from_selection': True, 'pos_feature_type': 'none', 'tag_feature_type': 'none', 'word_count_feature_type': 'count_words'}. Best is trial 3 with value: 0.9333333333333333.[0m
[nltk_data] Downloading package punkt to /home/jovyan/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[32m[I 2022-03-29 21:31:54,509][0m

2
2


[nltk_data] Downloading package punkt to /home/jovyan/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[32m[I 2022-03-29 21:35:01,747][0m Trial 10 finished with value: 0.6493506493506493 and parameters: {'stopwords_type': 'wostop', 'from_selection': True, 'pos_feature_type': 'onehot', 'tag_feature_type': 'onehot', 'word_count_feature_type': 'none'}. Best is trial 3 with value: 0.9333333333333333.[0m


19
49


[nltk_data] Downloading package punkt to /home/jovyan/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[32m[I 2022-03-29 21:37:39,179][0m Trial 11 finished with value: 0.7368421052631579 and parameters: {'stopwords_type': 'wostop', 'from_selection': False, 'pos_feature_type': 'onehot', 'tag_feature_type': 'onehot', 'word_count_feature_type': 'count_words'}. Best is trial 3 with value: 0.9333333333333333.[0m
[nltk_data] Downloading package punkt to /home/jovyan/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[32m[I 2022-03-29 21:40:46,607][0m Trial 12 finished with value: 0.7733333333333333 and parameters: {'stopwords_type': 'wostop', 'from_selection': False, 'pos_feature_type': 'sum', 'tag_feature_type': 'sum', 'word_count_feature_type': 'count_words'}. Best is trial 3 with value: 0.9333333333333333.[0m
[nltk_data] Downloading package punkt to /home/jovyan/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[32m[I 2022-03-29 21:43:30,84

6


[nltk_data] Downloading package punkt to /home/jovyan/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[32m[I 2022-03-29 21:45:36,921][0m Trial 14 finished with value: 0.5866666666666667 and parameters: {'stopwords_type': 'wstop', 'from_selection': True, 'pos_feature_type': 'none', 'tag_feature_type': 'onehot', 'word_count_feature_type': 'none'}. Best is trial 3 with value: 0.9333333333333333.[0m


49


[nltk_data] Downloading package punkt to /home/jovyan/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[32m[I 2022-03-29 21:48:30,688][0m Trial 15 finished with value: 0.7105263157894737 and parameters: {'stopwords_type': 'wostop', 'from_selection': False, 'pos_feature_type': 'sum', 'tag_feature_type': 'onehot', 'word_count_feature_type': 'none'}. Best is trial 3 with value: 0.9333333333333333.[0m


19


[nltk_data] Downloading package punkt to /home/jovyan/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[32m[I 2022-03-29 21:50:45,994][0m Trial 16 finished with value: 0.922077922077922 and parameters: {'stopwords_type': 'wostop', 'from_selection': False, 'pos_feature_type': 'onehot', 'tag_feature_type': 'none', 'word_count_feature_type': 'count_words'}. Best is trial 3 with value: 0.9333333333333333.[0m


4


[nltk_data] Downloading package punkt to /home/jovyan/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[32m[I 2022-03-29 21:55:47,091][0m Trial 17 finished with value: 0.7066666666666667 and parameters: {'stopwords_type': 'wstop', 'from_selection': True, 'pos_feature_type': 'onehot', 'tag_feature_type': 'sum', 'word_count_feature_type': 'count_words'}. Best is trial 3 with value: 0.9333333333333333.[0m
[nltk_data] Downloading package punkt to /home/jovyan/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[32m[I 2022-03-29 21:58:23,531][0m Trial 18 finished with value: 0.8 and parameters: {'stopwords_type': 'wostop', 'from_selection': True, 'pos_feature_type': 'sum', 'tag_feature_type': 'sum', 'word_count_feature_type': 'none'}. Best is trial 3 with value: 0.9333333333333333.[0m


19


[nltk_data] Downloading package punkt to /home/jovyan/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[32m[I 2022-03-29 22:00:29,772][0m Trial 19 finished with value: 0.7105263157894737 and parameters: {'stopwords_type': 'wstop', 'from_selection': False, 'pos_feature_type': 'onehot', 'tag_feature_type': 'none', 'word_count_feature_type': 'count_words'}. Best is trial 3 with value: 0.9333333333333333.[0m


4


[nltk_data] Downloading package punkt to /home/jovyan/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[32m[I 2022-03-29 22:03:40,315][0m Trial 20 finished with value: 0.8133333333333334 and parameters: {'stopwords_type': 'wstop', 'from_selection': True, 'pos_feature_type': 'onehot', 'tag_feature_type': 'none', 'word_count_feature_type': 'none'}. Best is trial 3 with value: 0.9333333333333333.[0m


4
6


[nltk_data] Downloading package punkt to /home/jovyan/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[32m[I 2022-03-29 22:08:34,414][0m Trial 21 finished with value: 0.8 and parameters: {'stopwords_type': 'wstop', 'from_selection': True, 'pos_feature_type': 'onehot', 'tag_feature_type': 'onehot', 'word_count_feature_type': 'count_words'}. Best is trial 3 with value: 0.9333333333333333.[0m


19


[nltk_data] Downloading package punkt to /home/jovyan/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[32m[I 2022-03-29 22:12:31,949][0m Trial 22 finished with value: 0.8421052631578947 and parameters: {'stopwords_type': 'wstop', 'from_selection': False, 'pos_feature_type': 'onehot', 'tag_feature_type': 'none', 'word_count_feature_type': 'none'}. Best is trial 3 with value: 0.9333333333333333.[0m
[nltk_data] Downloading package punkt to /home/jovyan/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[32m[I 2022-03-29 22:15:29,732][0m Trial 23 finished with value: 0.6973684210526315 and parameters: {'stopwords_type': 'wostop', 'from_selection': True, 'pos_feature_type': 'none', 'tag_feature_type': 'sum', 'word_count_feature_type': 'count_words'}. Best is trial 3 with value: 0.9333333333333333.[0m
[nltk_data] Downloading package punkt to /home/jovyan/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[32m[I 2022-03-29 22:19:42,294][0m Tri

In [None]:
loaded_study = torch.load(study_path)

print(loaded_study.best_trial.params)
print(loaded_study.best_trial)