In [64]:
from fastai.text.all import *
from transformers import AutoModelForSequenceClassification, AutoTokenizer

from datasets import Dataset
from inspect import signature
from pathlib import Path
import gc

pd.options.display.max_rows = 20
pd.options.display.max_columns = 8
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'

In [65]:
DATA_ROOT = Path('../data')

df_train = pd.read_csv(DATA_ROOT/'train21_shuffled.csv')
df_test = pd.read_csv(DATA_ROOT/'test21_shuffled.csv')

df_train
del df_train['season']
del df_train['episode']
del df_train['scene']
del df_train['deleted']

del df_test['season']
del df_test['episode']
del df_test['scene']
del df_test['deleted']

# generate expected format of transformer class
ds_train = Dataset.from_pandas(df_train)
ds_test = Dataset.from_pandas(df_test)

ds_train[0]

{'line_text': '[conducting interview] Your paper experience is very interesting. Do you think you could use that experience to inform decisions here?',
 'speaker': 'Jim'}

In [66]:
split = 0.85
nt = int(len(ds_train)*0.85)
nv = len(ds_train)-nt

train_idx, valid_idx = L(range(nt)), L(range(nt, nt+nv))

In [67]:
ds_name = 'DunderAI'
model_name = "distilroberta-base"

max_len = 512
bs = 32
val_bs = bs*2

n_epoch = 4
lr = 2e-5
opt_func = Adam

<function fastai.optimizer.Adam(params: 'Tensor | Iterable', lr: 'float | slice', mom: 'float' = 0.9, sqr_mom: 'float' = 0.99, eps: 'float' = 1e-05, wd: 'Real' = 0.01, decouple_wd: 'bool' = True) -> 'Optimizer'>

In [68]:
dunder_metrics = [accuracy]
dunder_textfields = ["line_text", None]

['line_text', None]

In [69]:
lens = ds_train.map(lambda s: {'len': sum([len(s[i]) for i in dunder_textfields if i])},
                    remove_columns=ds_train.column_names, num_proc=2, keep_in_memory=True)

train_lens = lens.select(train_idx)['len']
valid_lens = lens.select(valid_idx)['len']

Map (num_proc=2):   0%|          | 0/44375 [00:00<?, ? examples/s]

[43,
 22,
 31,
 19,
 21,
 34,
 66,
 26,
 20,
 31,
 34,
 70,
 22,
 63,
 23,
 14,
 26,
 39,
 53,
 106,
 17,
 108,
 24,
 29,
 343,
 49,
 64,
 21,
 19,
 27,
 61,
 59,
 4,
 28,
 29,
 259,
 43,
 76,
 59,
 176,
 136,
 48,
 35,
 65,
 147,
 43,
 64,
 14,
 55,
 657,
 23,
 46,
 247,
 11,
 14,
 42,
 20,
 11,
 12,
 5,
 119,
 16,
 22,
 17,
 45,
 77,
 4,
 8,
 30,
 19,
 8,
 23,
 154,
 75,
 22,
 4,
 31,
 74,
 144,
 4,
 18,
 5,
 4,
 48,
 18,
 31,
 73,
 166,
 7,
 30,
 58,
 98,
 5,
 39,
 54,
 3,
 203,
 31,
 15,
 113,
 61,
 130,
 207,
 10,
 33,
 13,
 19,
 6,
 38,
 58,
 124,
 12,
 25,
 8,
 45,
 128,
 144,
 20,
 13,
 160,
 56,
 65,
 58,
 69,
 123,
 7,
 33,
 5,
 6,
 97,
 72,
 29,
 6,
 19,
 5,
 144,
 5,
 39,
 23,
 25,
 151,
 52,
 205,
 88,
 62,
 7,
 67,
 115,
 91,
 4,
 57,
 16,
 40,
 217,
 19,
 72,
 25,
 3,
 5,
 6,
 35,
 82,
 39,
 10,
 8,
 34,
 20,
 56,
 118,
 51,
 31,
 220,
 46,
 6,
 275,
 23,
 16,
 7,
 19,
 5,
 114,
 265,
 58,
 70,
 202,
 9,
 92,
 39,
 107,
 24,
 46,
 61,
 184,
 27,
 28,
 12,
 8,
 11,
 27,
 

In [79]:
class TextGetter(ItemTransform):
    def __init__(self, s1='line_text', s2=None):
        self.s1, self.s2 = s1, s2
    def encodes(self, sample):
        if self.s2 is None: return sample[self.s1]
        else: return sample[self.s1], sample[self.s2]

In [80]:
class TransTensorText(TensorBase): pass

@typedispatch
def show_batch(x:TransTensorText, y, samples, ctxs=None, max_n=10, trunc_at=150, **kwargs):
    if ctxs is None: ctxs = get_empty_df(min(len(samples), max_n))
    if isinstance(samples[0][0], tuple):
        samples = L((*s[0], *s[1:]) for s in samples)
        if trunc_at is not None: samples = L((s[0].truncate(trunc_at), s[1].truncate(trunc_at), *s[2:]) for s in samples)
    if trunc_at is not None: samples = L((s[0].truncate(trunc_at),*s[1:]) for s in samples)
    ctxs = show_batch[object](x, y, samples, max_n=max_n, ctxs=ctxs, **kwargs)
    display_df(pd.DataFrame(ctxs))

In [81]:
def find_first(t, e):
    for i, v in enumerate(t):
        if v == e: return i
        
def split_by_sep(t, sep_tok_id):
    idx = find_first(t, sep_tok_id)
    return t[:idx], t[idx:]

In [82]:
class TokBatchTransform(Transform):
    """
    Tokenizes texts in batches using pretrained HuggingFace tokenizer.
    The first element in a batch can be single string or 2-tuple of strings.
    If `with_labels=True` the "labels" are added to the output dictionary.
    """
    def __init__(self, pretrained_model_name=None, tokenizer_cls=AutoTokenizer, 
                 config=None, tokenizer=None, with_labels=False,
                 padding=True, truncation=True, max_length=None, **kwargs):
        if tokenizer is None:
            tokenizer = tokenizer_cls.from_pretrained(pretrained_model_name, config=config)
        self.tokenizer = tokenizer
        self.kwargs = kwargs
        self._two_texts = False
        store_attr()
    
    def encodes(self, batch):
        # batch is a list of tuples of ({text or (text1, text2)}, {targets...})
        if is_listy(batch[0][0]): # 1st element is tuple
            self._two_texts = True
            texts = ([s[0][0] for s in batch], [s[0][1] for s in batch])
        elif is_listy(batch[0]): 
            texts = ([s[0] for s in batch],)

        inps = self.tokenizer(*texts,
                              add_special_tokens=True,
                              padding=self.padding,
                              truncation=self.truncation,
                              max_length=self.max_length,
                              return_tensors='pt',
                              **self.kwargs)
        # inps are batched, collate targets into batches too
        labels = default_collate([s[1:] for s in batch])
        if self.with_labels:
            inps['labels'] = labels[0]
            res = (inps, )
        else:
            res = (inps, ) + tuple(labels)
        return res
    
    def decodes(self, x:TransTensorText):
        if self._two_texts:
            x1, x2 = split_by_sep(x, self.tokenizer.sep_token_id)
            return (TitledStr(self.tokenizer.decode(x1.cpu(), skip_special_tokens=True)),
                    TitledStr(self.tokenizer.decode(x2.cpu(), skip_special_tokens=True)))
        return TitledStr(self.tokenizer.decode(x.cpu(), skip_special_tokens=True))

In [83]:
class Undict(Transform):
    def decodes(self, x:dict):
        if 'input_ids' in x: res = TransTensorText(x['input_ids'])
        return res

In [89]:
dls_kwargs = {
    'before_batch': TokBatchTransform(pretrained_model_name=model_name, max_length=max_len),
    'create_batch': fa_convert
}
text_block = TransformBlock(dl_type=SortedDL, dls_kwargs=dls_kwargs, batch_tfms=Undict(), )

dblock = DataBlock(blocks = [text_block, CategoryBlock()],
                   get_x=TextGetter(*dunder_textfields),
                   get_y=ItemGetter('speaker'),
                   splitter=IndexSplitter(valid_idx))

<fastai.data.block.DataBlock at 0x7f0de03b0280>

In [90]:
%%time
dl_kwargs=[{'res':train_lens}, {'val_res':valid_lens}]
dls = dblock.dataloaders(ds_train, bs=bs, val_bs=val_bs, dl_kwargs=dl_kwargs)

CPU times: user 2.38 s, sys: 23.2 ms, total: 2.41 s
Wall time: 2.41 s


In [91]:
dir(dls)

['__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattr__',
 '__getattribute__',
 '__getitem__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__len__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__setstate__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_add_tfms',
 '_after_item',
 '_call',
 '_call1',
 '_component_attr_filter',
 '_component_attr_filter',
 '_component_attr_filter',
 '_dbunch_type',
 '_dbunch_type',
 '_decode_batch',
 '_default',
 '_default',
 '_default',
 '_device',
 '_dir',
 '_dir',
 '_dir',
 '_dl_type',
 '_dl_type',
 '_do_call',
 '_docs',
 '_docs',
 '_get',
 '_is_showable',
 '_methods',
 '_n_inp',
 '_new',
 '_new',
 '_noop_methods',
 '_one_pass',
 '_pre_show_batch',
 '_repr_pretty_',
 '_retain',
 '_retain_dl',
 '_set',
 '_types',
 'add',
 'add_na',
 'add_tfms',
 'after_batch',
 'after_item',


In [92]:
dls.show_batch()

RecursionError: maximum recursion depth exceeded while calling a Python object