In [139]:
import os, warnings

import torch
from transformers import *
from transformers.utils import logging as hf_logging
from fastai.text.all import *

from blurr.text.data.all import *
from blurr.text.modeling.all import *

seed=10

pd.options.display.max_rows = 20
pd.options.display.max_columns = 8
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'

In [140]:
warnings.simplefilter("ignore")
hf_logging.set_verbosity_error()

os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [141]:
def get_oversampled_dls(dls, seed=1, noise_limit=0.15, label='speaker'):
    df_train = dls.train.items
    max_dialog = df_train[label].value_counts().max()

    dfs_oversampled = [df_train]
    for _, group in df_train.groupby('speaker'):
        rand_mult = random.uniform(1.0-noise_limit, 1.0 + noise_limit)
        sample_amt_to_max = max_dialog - len(group)
        dfs_oversampled.append(group.sample(int(sample_amt_to_max * rand_mult), replace=True))
        
    dls.train.items = pd.concat(dfs_oversampled)
    return dls

In [142]:
data_path = Path('../data')
df = pd.read_csv(data_path/'train21_shuffled.csv')
df_test = pd.read_csv(data_path/'test21_shuffled.csv')

Unnamed: 0,season,episode,scene,line_text,speaker,deleted
0,3,16,29,"Don't hurt that bat, Creed! It's a living thing with feelings and a family!",Kelly,False
1,7,7,27,"I cancelled my plans to come to this thing, and they repay me with this?",Kevin,False
2,8,13,25,"Oh, yes. Oh, what a beautiful child. Prominent forehead, short arms, tiny nose. You will lead millions... [whispers] willingly, or as slaves.",Dwight,False
3,2,17,18,Brad Pitt. Also there will be no bonuses.,Dwight,False
4,5,4,32,"Okay, alright. Hey, you know what? I would appreciate it if people would stop storming off the stage.",Michael,False
...,...,...,...,...,...,...
7827,9,7,32,Yeah.,Pam,False
7828,4,3,25,Alright. Well fight it out amongst yourselves. I was thinking Pammy but boys night out is also good.,Michael,False
7829,7,8,19,Rachel.,Kelly,False
7830,8,21,19,[chuckling] Okay.,Andy,False


In [143]:
n_labels = len(df['speaker'].unique())

21

In [144]:
model_cls = AutoModelForSequenceClassification

pretrained_model_name = 'bert-base-uncased'

config = AutoConfig.from_pretrained(pretrained_model_name)
config.num_labels = n_labels

hf_arch, hf_config, hf_tokenizer, hf_model = get_hf_objects(
    pretrained_model_name,
    model_cls=model_cls,
    config=config
)

In [145]:
blocks = (
    TextBlock(hf_arch, hf_config, hf_tokenizer, hf_model),
    CategoryBlock
)
dblock = DataBlock(
    blocks=blocks,
    get_x=ColReader('line_text'),
    get_y=ColReader('speaker'),
    splitter=RandomSplitter(valid_pct=0.1, seed=seed)
)
dls = dblock.dataloaders(df, bs=8)
test_dl = dls.test_dl(df_test, with_labels='True', label_col='speaker')

<fastai.text.data.SortedDL at 0x7f61aeb6c730>

In [146]:
# dls = get_oversampled_dls(dls, seed=seed)

In [147]:
# dls.show_batch(dataloaders=dls, max_n=4)

In [149]:
learn = BlearnerForSequenceClassification.from_data(
    df, 
    pretrained_model_name, 
    dl_kwargs={"bs": 4},
    learner_kwargs={"metrics": accuracy},
    text_attr='line_text',
    label_attr='speaker',
    n_labels = n_labels,
    dblock_splitter=RandomSplitter(valid_pct=0.1, seed=seed)
)

learn.dls = get_oversampled_dls(learn.dls, seed=seed)
test_dl = learn.dls.test_dl(df_test, with_labels='True', label_col='speaker')


<fastai.text.data.SortedDL at 0x7f61aeb50bb0>

In [150]:
learn.fit_one_cycle(3, lr_max=1e-3)

epoch,train_loss,valid_loss,accuracy,time
0,2.471381,2.456425,0.25693,03:58
1,2.284889,2.400234,0.263692,03:58
2,2.284127,2.358051,0.279243,03:58


In [151]:
learn.validate(dl=test_dl)

(#2) [2.327181577682495,0.27885597944259644]

TODO: 

Oversample DF directly, or extra DL from learner and oversample that
Ensure learn.validate working as intended

Try to use my own dl/etc?

In [152]:
help(learn)

Help on BlearnerForSequenceClassification in module blurr.text.modeling.core object:

class BlearnerForSequenceClassification(Blearner)
 |  BlearnerForSequenceClassification(dls: fastai.data.core.DataLoaders, hf_model: transformers.modeling_utils.PreTrainedModel, base_model_cb: blurr.text.modeling.core.BaseModelCallback = <class 'blurr.text.modeling.core.BaseModelCallback'>, *, loss_func: 'callable | None' = None, opt_func: 'Optimizer | OptimWrapper' = <function Adam at 0x7f615e06c820>, lr: 'float | slice' = 0.001, splitter: 'callable' = <function trainable_params at 0x7f615e18cb80>, cbs: 'Callback | MutableSequence | None' = None, metrics: 'callable | MutableSequence | None' = None, path: 'str | Path | None' = None, model_dir: 'str | Path' = 'models', wd: 'float | int | None' = None, wd_bn_bias: 'bool' = False, train_bn: 'bool' = True, moms: 'tuple' = (0.95, 0.85, 0.95), default_cbs: 'bool' = True) -> fastai.learner.Learner
 |  
 |  # Cell
 |  
 |  Method resolution order:
 |      Ble

In [153]:
learn.dls.train.items

Unnamed: 0,season,episode,scene,line_text,speaker,deleted
11149,8,13,28,[to Cece] You want a giraffe?,Jim,False
7664,7,6,41,"Ah! [lets go, candy corn flies everywhere] That's enough.",Erin,False
24808,8,13,5,Ah! Angela had the baby!,Erin,False
5476,3,11,16,Hmm.,Pam,False
39675,6,1,23,[to Pam] Did you know a baby conceived out of wedlock is still a bastard?,Angela,False
...,...,...,...,...,...,...
33076,5,8,23,I have a reasonable right to privacy.,Toby,False
40040,3,2,15,What?,Toby,False
12010,2,13,30,He seems fine to me.,Toby,False
21105,6,2,18,"Well, you know, 'cause of the trains.",Toby,False


In [154]:
learn.export('BertTransformOversampled')