In [1]:
import os, warnings
warnings.simplefilter("ignore")

import torch
from transformers import *
from transformers.utils import logging as hf_logging
from fastai.text.all import *

from blurr.text.data.all import *
from blurr.text.modeling.all import *

In [2]:
seed=1

pd.options.display.max_rows = 20
pd.options.display.max_columns = 8
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'


hf_logging.set_verbosity_error()
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [3]:
def get_oversampled_dls(dls, seed=1, baseline_factor=1.0, noise_limit=0.15, label='speaker'):
    df_train = dls.train.items
    max_dialog = df_train[label].value_counts().max()

    dfs_oversampled = [df_train]
    for _, group in df_train.groupby('speaker'):
        rand_mult = random.uniform(baseline_factor-noise_limit, baseline_factor + noise_limit)
        sample_amt_to_max = max_dialog - len(group)
        
        sample_amt = max(int(sample_amt_to_max * rand_mult), len(group))
        
        dfs_oversampled.append(group.sample(sample_amt, replace=True, random_state=seed))
        
    dls.train.items = pd.concat(dfs_oversampled)
    return dls

In [4]:
data_path = Path('../data')
df = pd.read_csv(data_path/'train21_shuffled.csv')
df_test = pd.read_csv(data_path/'test21_shuffled.csv')
df

Unnamed: 0,season,episode,scene,line_text,speaker,deleted
0,7,24,8,[conducting interview] Your paper experience is very interesting. Do you think you could use that experience to inform decisions here?,Jim,False
1,9,10,27,I'm not gonna lie. Lye!,Dwight,False
2,9,9,27,Take a bowl and pass it down.,Dwight,False
3,3,15,25,It's a miracle.,Dwight,False
4,7,1,1,This is how you build a business. This is how you make it in this country.,Ryan,False
...,...,...,...,...,...,...
44370,3,5,41,I don't know. It felt far.,Pam,False
44371,3,10,43,"Excuse me [tries to take away meat with chopsticks] Ah, un guard. [Fights with chopsticks and laughs] Family style.",Michael,False
44372,5,23,19,I never went to Thailand.,Ryan,False
44373,3,4,1,"[breathless] All right. Okay. [goes back down pretend stairs, crawls on belly to the kitchen for the coffee]",Michael,False


In [5]:
n_labels = len(df['speaker'].unique())

21

In [6]:
pretrained_model_name = 'bert-base-uncased'

'bert-base-uncased'

In [7]:
%%time

noise_limits = [0.1]
batch_sizes = [8, 16]
lrs = [0.003, 0.002, 0.001]
seeds = [1, 23]

epochs=10

for noise in noise_limits:
    for bs in batch_sizes:
        for lr in lrs:
            for seed in seeds:
                learn = BlearnerForSequenceClassification.from_data(
                    df, 
                    pretrained_model_name, 
                    dl_kwargs={"bs": bs, "seed": seed},
                    learner_kwargs={"metrics": accuracy},
                    text_attr='line_text',
                    label_attr='speaker',
                    n_labels = n_labels,
                    dblock_splitter=RandomSplitter(valid_pct=0.1, seed=seed)
                )

                print(f'lr: {lr}, bs: {bs}, noise: {noise}, seed: {seed}')
                learn.dls = get_oversampled_dls(learn.dls, seed=seed, noise_limit=noise)
                test_dl = learn.dls.test_dl(df_test, with_labels='True', label_col='speaker')

                learn.fit_one_cycle(epochs, lr_max=lr)
                
                res=learn.validate(dl=test_dl)
                print(f'Validation results: [cost, accuracy]:{res}')

                learn.export(f'BERT_accuracy{"%.5f"%res[1]}_lr{lr}_bs{bs}_seed{seed}')

lr: 0.003, bs: 8, noise: 0.1, seed: 1


epoch,train_loss,valid_loss,accuracy,time
0,2.539172,2.459905,0.243858,01:58
1,2.465613,2.4247,0.266396,02:00
2,2.505112,2.421263,0.266622,01:58
3,2.466475,2.41978,0.263016,01:56
4,2.451088,2.404349,0.269551,01:57
5,2.375927,2.408864,0.258057,02:01
6,2.369609,2.349365,0.28961,01:57
7,2.325718,2.341902,0.293216,01:58
8,2.358683,2.328231,0.288258,02:02
9,2.212749,2.328589,0.289385,01:59


Validation results: [cost, accuracy]:[2.31400728225708, 0.29468846321105957]
lr: 0.003, bs: 8, noise: 0.1, seed: 23


epoch,train_loss,valid_loss,accuracy,time
0,2.41942,2.463366,0.233266,02:00
1,2.452494,2.468272,0.216362,02:01
2,2.476475,2.42781,0.249493,02:01
3,2.523357,2.454985,0.244084,02:03
4,2.439986,2.403972,0.260762,01:59
5,2.365987,2.383095,0.264142,02:00
6,2.350943,2.372041,0.26865,02:00
7,2.346364,2.351495,0.276764,01:59
8,2.356135,2.348322,0.27789,01:57
9,2.289366,2.34882,0.276313,01:58


Validation results: [cost, accuracy]:[2.3313820362091064, 0.2830694615840912]
lr: 0.002, bs: 8, noise: 0.1, seed: 1


epoch,train_loss,valid_loss,accuracy,time
0,2.492631,2.462291,0.249493,01:56
1,2.481206,2.436526,0.25648,01:57
2,2.408717,2.414394,0.259184,01:57
3,2.478541,2.398536,0.267748,01:56
4,2.364449,2.365284,0.2833,01:58
5,2.333753,2.34329,0.291413,01:57
6,2.376557,2.34819,0.286229,01:57
7,2.293641,2.333533,0.290737,02:01
8,2.375442,2.322415,0.296371,01:59
9,2.24147,2.32532,0.296146,02:02


Validation results: [cost, accuracy]:[2.3092079162597656, 0.2893258333206177]
lr: 0.002, bs: 8, noise: 0.1, seed: 23


epoch,train_loss,valid_loss,accuracy,time
0,2.541925,2.436087,0.236646,01:56
1,2.457917,2.429093,0.244084,01:57
2,2.452958,2.397052,0.260987,01:58
3,2.39999,2.388446,0.268875,01:58
4,2.401576,2.350796,0.268875,01:57
5,2.372958,2.359916,0.26865,02:00
6,2.351152,2.348769,0.275411,01:57
7,2.337976,2.324759,0.286455,01:58
8,2.269686,2.31956,0.284652,01:59
9,2.252819,2.317336,0.286455,02:00


Validation results: [cost, accuracy]:[2.302870512008667, 0.2959652841091156]
lr: 0.001, bs: 8, noise: 0.1, seed: 1


epoch,train_loss,valid_loss,accuracy,time
0,2.444618,2.479882,0.225603,01:56
1,2.46876,2.419769,0.263241,01:58
2,2.371024,2.380533,0.277665,01:56
3,2.353849,2.364193,0.2833,01:58
4,2.441352,2.363805,0.279468,01:58
5,2.317453,2.345564,0.282623,01:57
6,2.269853,2.316377,0.290737,01:56
7,2.27337,2.323289,0.295245,01:58
8,2.288617,2.310364,0.298174,01:58
9,2.242059,2.312616,0.296597,01:57


Validation results: [cost, accuracy]:[2.2984445095062256, 0.2950715124607086]
lr: 0.001, bs: 8, noise: 0.1, seed: 23


epoch,train_loss,valid_loss,accuracy,time
0,2.485796,2.474926,0.219067,01:56
1,2.448951,2.427449,0.251747,02:03
2,2.378997,2.392994,0.260311,01:59
3,2.392958,2.379833,0.264593,01:57
4,2.397234,2.329035,0.273608,01:58
5,2.286541,2.333094,0.282849,01:59
6,2.216,2.305332,0.282623,01:59
7,2.305545,2.298683,0.284652,01:56
8,2.24156,2.301047,0.290737,01:59
9,2.271419,2.300756,0.288934,01:58


Validation results: [cost, accuracy]:[2.3018124103546143, 0.2939223647117615]
lr: 0.003, bs: 16, noise: 0.1, seed: 1


epoch,train_loss,valid_loss,accuracy,time
0,2.476907,2.455433,0.250394,01:08
1,2.444071,2.446665,0.253324,01:07
2,2.409045,2.418907,0.261889,01:07
3,2.427057,2.387456,0.275637,01:07
4,2.409184,2.371911,0.281271,01:07
5,2.301491,2.390669,0.273834,01:08
6,2.347175,2.345754,0.282849,01:07
7,2.320655,2.327004,0.292991,01:07
8,2.240168,2.319178,0.295245,01:08
9,2.278965,2.317724,0.29547,01:08


Validation results: [cost, accuracy]:[2.3043105602264404, 0.2927732467651367]
lr: 0.003, bs: 16, noise: 0.1, seed: 23


epoch,train_loss,valid_loss,accuracy,time
0,2.505589,2.449225,0.244535,01:06
1,2.462316,2.475521,0.243858,01:06
2,2.424091,2.393294,0.263016,01:07
3,2.355465,2.408838,0.248817,01:08
4,2.443497,2.383975,0.272707,01:06
5,2.33005,2.367282,0.270453,01:07
6,2.288141,2.336361,0.275186,01:07
7,2.303878,2.340541,0.279017,01:07
8,2.254934,2.307923,0.288483,01:07
9,2.156837,2.306529,0.290061,01:06


Validation results: [cost, accuracy]:[2.304107189178467, 0.2913687527179718]
lr: 0.002, bs: 16, noise: 0.1, seed: 1


epoch,train_loss,valid_loss,accuracy,time
0,2.497804,2.468195,0.2389,01:06
1,2.47645,2.42096,0.271805,01:07
2,2.413869,2.423901,0.249944,01:09
3,2.326588,2.369741,0.28082,01:06
4,2.337647,2.367922,0.284652,01:08
5,2.278206,2.342206,0.282173,01:09
6,2.357115,2.344881,0.284877,01:07
7,2.324603,2.315871,0.293667,01:08
8,2.235004,2.318541,0.294343,01:08
9,2.260113,2.317029,0.295921,01:08


Validation results: [cost, accuracy]:[2.3065810203552246, 0.29341164231300354]
lr: 0.002, bs: 16, noise: 0.1, seed: 23


epoch,train_loss,valid_loss,accuracy,time
0,2.520587,2.461661,0.232364,01:06
1,2.446217,2.459481,0.240478,01:07
2,2.366411,2.394435,0.25986,01:07
3,2.390589,2.409674,0.256254,01:07
4,2.336251,2.366139,0.272481,01:07
5,2.334701,2.331068,0.27789,01:07
6,2.296978,2.311776,0.284201,01:08
7,2.243833,2.313174,0.284877,01:06
8,2.243411,2.304312,0.292765,01:08
9,2.200046,2.304912,0.294118,01:08


Validation results: [cost, accuracy]:[2.301335096359253, 0.2951991856098175]
lr: 0.001, bs: 16, noise: 0.1, seed: 1


epoch,train_loss,valid_loss,accuracy,time
0,2.494608,2.484764,0.248817,01:06
1,2.437532,2.446423,0.25355,01:07
2,2.422692,2.392617,0.267748,01:06
3,2.392332,2.365801,0.275862,01:08
4,2.330584,2.341162,0.292765,01:07
5,2.271879,2.330481,0.288709,01:07
6,2.252472,2.330214,0.290061,01:07
7,2.288999,2.313688,0.296371,01:08
8,2.215034,2.312549,0.299301,01:08
9,2.239938,2.315486,0.297949,01:08


Validation results: [cost, accuracy]:[2.304358720779419, 0.2914964258670807]
lr: 0.001, bs: 16, noise: 0.1, seed: 23


epoch,train_loss,valid_loss,accuracy,time
0,2.449923,2.482248,0.221997,01:06
1,2.44508,2.429351,0.25941,01:06
2,2.386391,2.394469,0.266171,01:07
3,2.381663,2.378314,0.267523,01:07
4,2.342631,2.343393,0.275186,01:08
5,2.293617,2.33842,0.278792,01:08
6,2.304154,2.344872,0.274961,01:08
7,2.253644,2.320167,0.282398,01:08
8,2.232118,2.319192,0.288483,01:08
9,2.205819,2.320465,0.288032,01:07


Validation results: [cost, accuracy]:[2.3020877838134766, 0.29009193181991577]
CPU times: user 3h 3min 13s, sys: 5min 45s, total: 3h 8min 59s
Wall time: 3h 12min 59s
