In [8]:
import os, warnings
warnings.simplefilter("ignore")

import torch
from transformers import *
from transformers.utils import logging as hf_logging
from fastai.text.all import *

from blurr.text.data.all import *
from blurr.text.modeling.all import *

In [9]:
seed=1

pd.options.display.max_rows = 20
pd.options.display.max_columns = 8
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'


hf_logging.set_verbosity_error()
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [10]:
def get_df_by_hyperparam(fp, oversample=True, seed=1, baseline_factor=1.0, noise_limit=0.15, label='speaker'):
    df_train = pd.read_csv(fp)
    
    if not oversample: return df_train
    
    max_dialog = df_train[label].value_counts().max()

    dfs_oversampled = [df_train]
    for _, group in df_train.groupby('speaker'):
        rand_mult = random.uniform(baseline_factor-noise_limit, baseline_factor + noise_limit)
        sample_amt_to_max = max_dialog - len(group)
        
        sample_amt = max(int(sample_amt_to_max * rand_mult), len(group))
        
        dfs_oversampled.append(group.sample(sample_amt, replace=True, random_state=seed))
        
    return pd.concat(dfs_oversampled)

In [11]:
data_path = Path('../data')

# Will be loaded dynamically depending upon hyperparam
train_path = data_path/'train21_shuffled.csv'

# Always test against same hold out set
df_test = pd.read_csv(data_path/'test21_shuffled.csv')

Unnamed: 0,season,episode,scene,line_text,speaker,deleted
0,3,16,29,"Don't hurt that bat, Creed! It's a living thing with feelings and a family!",Kelly,False
1,7,7,27,"I cancelled my plans to come to this thing, and they repay me with this?",Kevin,False
2,8,13,25,"Oh, yes. Oh, what a beautiful child. Prominent forehead, short arms, tiny nose. You will lead millions... [whispers] willingly, or as slaves.",Dwight,False
3,2,17,18,Brad Pitt. Also there will be no bonuses.,Dwight,False
4,5,4,32,"Okay, alright. Hey, you know what? I would appreciate it if people would stop storming off the stage.",Michael,False
...,...,...,...,...,...,...
7827,9,7,32,Yeah.,Pam,False
7828,4,3,25,Alright. Well fight it out amongst yourselves. I was thinking Pammy but boys night out is also good.,Michael,False
7829,7,8,19,Rachel.,Kelly,False
7830,8,21,19,[chuckling] Okay.,Andy,False


In [12]:
n_labels = len(df_test['speaker'].unique())

21

In [13]:
pretrained_model_name = 'bert-base-uncased'

'bert-base-uncased'

In [15]:
%%time

noise_limits = [0.1]
batch_sizes = [8, 16]
lrs = [0.003, 0.002, 0.001]
seeds = [1, 23]
sample_types = ['normal', 'oversample']


for noise in noise_limits:
    for bs in batch_sizes:
        for lr in lrs:
            for seed in seeds:
                for sample_type in sample_types:
                    
                    oversampling = True if sample_type == 'oversample' else False
                    epochs= 16 if oversampling else 10
                    
                    learn = BlearnerForSequenceClassification.from_data(
                        get_df_by_hyperparam(train_path, oversample=oversampling, seed=seed, noise_limit=noise), 
                        pretrained_model_name, 
                        dl_kwargs={"bs": bs, "seed": seed},
                        learner_kwargs={"metrics": accuracy},
                        text_attr='line_text',
                        label_attr='speaker',
                        n_labels = n_labels,
                        dblock_splitter=RandomSplitter(valid_pct=0.1, seed=seed)
                    )

                    print(f'lr: {lr}, bs: {bs}, noise: {noise}, seed: {seed}, oversample: {oversampling}')
                    test_dl = learn.dls.test_dl(df_test, with_labels='True', label_col='speaker')

                    learn.fit_one_cycle(epochs, lr_max=lr)

                    res=learn.validate(dl=test_dl)
                    print(f'Validation results: [cost, accuracy]:{res}')

                    learn.export(f'BERT_accuracy{"%.5f"%res[1]}_oversample:{oversampling}_lr{lr}_bs{bs}_seed{seed}')

lr: 0.003, bs: 8, noise: 0.1, seed: 1, oversample: False


epoch,train_loss,valid_loss,accuracy,time
0,2.449038,2.515831,0.226054,02:02
1,2.497755,2.449003,0.253324,01:58
2,2.41924,2.418982,0.268199,02:00
3,2.464818,2.417589,0.267298,02:02
4,2.378691,2.44113,0.269777,01:59
5,2.429611,2.41201,0.267072,02:01
6,2.345301,2.353127,0.287356,01:57
7,2.354997,2.336153,0.287131,01:58
8,2.248125,2.331542,0.287131,01:57
9,2.290641,2.330811,0.291864,02:00


Validation results: [cost, accuracy]:[2.3200953006744385, 0.29341164231300354]
lr: 0.003, bs: 8, noise: 0.1, seed: 1, oversample: True


epoch,train_loss,valid_loss,accuracy,time
0,2.806329,2.720455,0.180566,09:58
1,2.73522,2.63901,0.210273,10:08
2,2.724908,2.687462,0.195075,10:06
3,2.802568,2.782328,0.148577,10:13
4,2.831016,2.691442,0.195419,10:13
5,2.760704,2.715102,0.186378,10:14
6,2.753052,2.656336,0.200758,10:08
7,2.739508,2.599397,0.219744,10:11
8,2.665731,2.536348,0.228312,10:09
9,2.596526,2.467674,0.262496,10:10


Validation results: [cost, accuracy]:[2.7640573978424072, 0.1435137838125229]
lr: 0.003, bs: 8, noise: 0.1, seed: 23, oversample: False


epoch,train_loss,valid_loss,accuracy,time
0,2.465505,2.494908,0.241154,01:56
1,2.500488,2.428458,0.248141,01:59
2,2.448771,2.522764,0.212756,01:59
3,2.459241,2.423947,0.243633,01:59
4,2.410018,2.436042,0.244084,01:59
5,2.361312,2.388,0.269101,01:58
6,2.343733,2.378347,0.259635,01:59
7,2.396978,2.353674,0.272707,01:58
8,2.385136,2.334311,0.28037,01:59
9,2.30203,2.335915,0.285553,01:58


Validation results: [cost, accuracy]:[2.3148446083068848, 0.2890704870223999]
lr: 0.003, bs: 8, noise: 0.1, seed: 23, oversample: True


epoch,train_loss,valid_loss,accuracy,time
0,2.784455,2.712144,0.18455,09:56
1,2.679914,2.644136,0.20763,10:03
2,2.750516,2.633702,0.22036,10:11
3,2.766599,2.701273,0.187148,10:07
4,2.84666,2.743538,0.181519,10:06
5,2.781931,2.657959,0.210358,10:13
6,2.737207,2.632018,0.216333,10:08
7,2.712538,2.595847,0.223305,10:13
8,2.59966,2.525462,0.241751,10:12
9,2.605052,2.467556,0.265264,10:13


Validation results: [cost, accuracy]:[2.753291606903076, 0.14836567640304565]
lr: 0.002, bs: 8, noise: 0.1, seed: 1, oversample: False


epoch,train_loss,valid_loss,accuracy,time
0,2.458441,2.447896,0.259635,01:56
1,2.381426,2.426101,0.266171,01:57
2,2.419939,2.415318,0.263016,02:00
3,2.367568,2.420491,0.256254,01:59
4,2.365332,2.345868,0.2833,01:59
5,2.319314,2.353816,0.281271,02:03
6,2.302929,2.32723,0.293892,01:58
7,2.289307,2.31907,0.299527,01:58
8,2.227603,2.312952,0.298625,01:57
9,2.156249,2.315805,0.302231,02:00


Validation results: [cost, accuracy]:[2.303863763809204, 0.28855976462364197]
lr: 0.002, bs: 8, noise: 0.1, seed: 1, oversample: True


epoch,train_loss,valid_loss,accuracy,time
0,2.739271,2.734983,0.173874,10:10
1,2.659794,2.571462,0.232625,10:22
2,2.615218,2.618708,0.221155,10:16
3,2.682757,2.550074,0.2387,10:39


KeyboardInterrupt: 