In [1]:
from fastai.text.all import *
from transformers import AutoModelForSequenceClassification, AutoTokenizer

from datasets import Dataset, concatenate_datasets
from inspect import signature
from pathlib import Path
import gc

seed=2
pd.options.display.max_rows = 20
pd.options.display.max_columns = 8
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'

In [2]:
DATA_ROOT = Path('../data')

df_train = pd.read_csv(DATA_ROOT/'train21_shuffled.csv')
df_test = pd.read_csv(DATA_ROOT/'test21_shuffled.csv')

df_train
del df_train['season']
del df_train['episode']
del df_train['scene']
del df_train['deleted']

del df_test['season']
del df_test['episode']
del df_test['scene']
del df_test['deleted']

df_train = df_train[:1000]

# generate expected format of transformer class
ds_train = Dataset.from_pandas(df_train)
ds_test = Dataset.from_pandas(df_test)



Dataset({
    features: ['line_text', 'speaker'],
    num_rows: 7832
})

In [3]:
ds_train = ds_train.train_test_split(train_size=0.8, seed=seed)
nt, nv = len(ds_train['train']), len(ds_train['test'])
train_idx, valid_idx = L(range(nt)), L(range(nt, nt+nv))
ds_train = concatenate_datasets([ds_train['train'], ds_train['test']])
ds_train[0]

# split = 0.85
# nt = int(len(ds_train)*0.85)
# nv = len(ds_train)-nt

# train_idx, valid_idx = L(range(nt)), L(range(nt, nt+nv))

{'line_text': "And I'd blow your mind.", 'speaker': 'Darryl'}

In [4]:
ds_name = 'DunderAI'
model_name = "distilroberta-base"

max_len = 512
bs = 32
val_bs = bs*2

n_epoch = 4
lr = 2e-5
opt_func = Adam

<function fastai.optimizer.Adam(params: 'Tensor | Iterable', lr: 'float | slice', mom: 'float' = 0.9, sqr_mom: 'float' = 0.99, eps: 'float' = 1e-05, wd: 'Real' = 0.01, decouple_wd: 'bool' = True) -> 'Optimizer'>

In [5]:
dunder_metrics = [accuracy]
dunder_textfields = ["line_text"]

['line_text']

In [6]:
lens = ds_train.map(lambda s: {'len': sum([len(s[i]) for i in dunder_textfields if i])},
                    remove_columns=ds_train.column_names, num_proc=2, keep_in_memory=True)
train_lens = lens.select(train_idx)['len']
valid_lens = lens.select(valid_idx)['len']

Map (num_proc=2):   0%|          | 0/1000 [00:00<?, ? examples/s]

[43,
 153,
 144,
 35,
 293,
 23,
 41,
 180,
 25,
 42,
 184,
 38,
 151,
 25,
 10,
 17,
 26,
 54,
 8,
 13,
 95,
 5,
 140,
 7,
 15,
 56,
 43,
 81,
 22,
 36,
 148,
 39,
 14,
 37,
 27,
 21,
 143,
 22,
 165,
 232,
 49,
 127,
 53,
 43,
 10,
 8,
 3,
 66,
 15,
 134,
 28,
 22,
 116,
 6,
 34,
 49,
 17,
 10,
 11,
 22,
 55,
 20,
 113,
 137,
 193,
 9,
 61,
 125,
 5,
 17,
 25,
 3,
 3,
 35,
 19,
 67,
 36,
 89,
 48,
 66,
 10,
 305,
 37,
 27,
 4,
 66,
 144,
 24,
 33,
 134,
 128,
 24,
 15,
 22,
 13,
 59,
 17,
 93,
 9,
 19,
 29,
 45,
 3,
 149,
 80,
 49,
 34,
 5,
 34,
 73,
 7,
 146,
 45,
 27,
 316,
 3,
 77,
 24,
 181,
 28,
 32,
 40,
 74,
 74,
 43,
 37,
 4,
 4,
 20,
 28,
 128,
 17,
 66,
 61,
 44,
 3,
 102,
 95,
 16,
 111,
 245,
 52,
 56,
 54,
 4,
 111,
 7,
 81,
 56,
 32,
 110,
 43,
 13,
 65,
 18,
 32,
 19,
 6,
 148,
 206,
 33,
 10,
 46,
 11,
 8,
 24,
 78,
 5,
 85,
 29,
 4,
 12,
 47,
 5,
 6,
 198,
 23,
 17,
 14,
 23,
 48,
 221,
 78,
 142,
 53,
 76,
 39,
 41,
 19,
 54,
 79,
 24,
 30,
 57,
 98,
 35,
 82,
 89,


In [7]:
class TextGetter(ItemTransform):
    def __init__(self, s1='line_text'):
        self.s1 = s1
        print('TextGetter extracting from field: ' + self.s1)
    def encodes(self, sample):
        print('Encoding. Sample: ' + str(sample[self.s1]))
        return sample[self.s1]

In [8]:
class TransTensorText(TensorBase): pass

@typedispatch
def show_batch(x:TransTensorText, y, samples, ctxs=None, max_n=10, trunc_at=150, **kwargs):
    print('Showing batch')
    if ctxs is None:
        ctxs = get_empty_df(min(len(samples), max_n))
        print('ctxs is none')
        print('ctxs is none')
        print('ctxs is none')
        
    if trunc_at is not None: samples = L((s[0].truncate(trunc_at)) for s in samples)
    ctxs = show_batch[object](x, y, samples, max_n=max_n, ctxs=ctxs, **kwargs)
    display_df(pd.DataFrame(ctxs))

In [9]:
def find_first(t, e):
    for i, v in enumerate(t):
        if v == e: return i
        
def split_by_sep(t, sep_tok_id):
    idx = find_first(t, sep_tok_id)
    return t[:idx], t[idx:]

In [15]:
class TokBatchTransform(Transform):
    """
    Tokenizes texts in batches using pretrained HuggingFace tokenizer.
    The first element in a batch can be single string or 2-tuple of strings.
    If `with_labels=True` the "labels" are added to the output dictionary.
    """
    def __init__(self, pretrained_model_name=None, tokenizer_cls=AutoTokenizer, 
                 config=None, tokenizer=None, with_labels=False,
                 padding=True, truncation=True, max_length=None, **kwargs):
        if tokenizer is None:
            tokenizer = tokenizer_cls.from_pretrained(pretrained_model_name, config=config)
        self.tokenizer = tokenizer
        self.kwargs = kwargs
        self._two_texts = False
        store_attr()
    
    def encodes(self, batch):
        # batch is a list of tuples of ({text or (text1, text2)}, {targets...})
        texts = [s[0] for s in batch]

        inps = self.tokenizer(texts,
                              add_special_tokens=True,
                              padding=self.padding,
                              truncation=self.truncation,
                              max_length=self.max_length,
                              return_tensors='pt',
                              **self.kwargs)
        
        print(inps)
        print(len(inps))
        # inps are batched, collate targets into batches too
        labels = default_collate([s[1:] for s in batch])
        if self.with_labels:
            inps['labels'] = labels[0]
            res = inps
            print(inps['labels'])
        else:
            print('no label')
            res = inps + labels
        return res
    
    def decodes(self, x:TransTensorText):
        if self._two_texts:
            x1, x2 = split_by_sep(x, self.tokenizer.sep_token_id)
            return (TitledStr(self.tokenizer.decode(x1.cpu(), skip_special_tokens=True)),
                    TitledStr(self.tokenizer.decode(x2.cpu(), skip_special_tokens=True)))
        return TitledStr(self.tokenizer.decode(x.cpu(), skip_special_tokens=True))

In [16]:
class Undict(Transform):
    def decodes(self, x:dict):
        print('Decoding dictionary')
        if 'input_ids' in x:
            res = TransTensorText(x['input_ids'])
            print('Created TransTensorText object')
        
        return res

In [17]:
# dls_kwargs = {
#     'before_batch': TokBatchTransform(pretrained_model_name=model_name, max_length=max_len, with_labels=True),
#     'create_batch': fa_convert
# }
dls_kwargs = {
    'before_batch': TokBatchTransform(pretrained_model_name=model_name, max_length=max_len),
    'create_batch': fa_convert
}
text_block = TransformBlock(dl_type=SortedDL, dls_kwargs=dls_kwargs, batch_tfms=Undict(),)

dblock = DataBlock(blocks = [text_block, CategoryBlock()],
                   get_x=TextGetter(*dunder_textfields),
                   get_y=ItemGetter('speaker'),
                   splitter=IndexSplitter(valid_idx))

TextGetter extracting from field: line_text


<fastai.data.block.DataBlock at 0x7f79723dfa00>

In [18]:
%%time
dl_kwargs=[{'res':train_lens}, {'val_res':valid_lens}]
dls = dblock.dataloaders(ds_train)

Encoding. Sample: And I'd blow your mind.
Encoding. Sample: And I'd blow your mind.
Encoding. Sample: Yes, a madhouse.
Encoding. Sample: You were on a boat.
Encoding. Sample: Yeah, I bet you would. Just try not to be too gay on the court. And by gay I mean, um, you know, not in a homosexual way at all. I mean the uh, you know, like the bad-at-sports way. I think that goes without saying.
Encoding. Sample: See world.  Oceans. Fish.  Jump. China.
Encoding. Sample: But maybe I'm here for a reason, because I might have some good ideas, too. I've been sitting out there, and I've been learning a lot, and maybe I can just bring something different to the table.
Encoding. Sample: Is your wife still your contact?
Encoding. Sample: Listen. [puts ear to wall] Can you hear that? Oh man. These babies are thin.
Encoding. Sample: Newsflash, the whole thing needs to go in the car.
Encoding. Sample: Ay, Kay. Come on, you know, that's not. Cool it.
Encoding. Sample: Okay, I will. I don't know who that i

Encoding. Sample: Me too.  ...I think we're just drunk.
Encoding. Sample: I'm not-
Encoding. Sample: I went to the store and I pressed the buzzer, and they looked right at me, and then they looked away. And then I pressed the buzzer again, and they started taking pictures of me on their mobile phones. I guess I'm not the kind of guy that's good enough for precious heirlooms.
Encoding. Sample: Hey-OH!
Encoding. Sample: In every good hostage movie, during the part where it gets really tense, and you don't know whether the bad guys are going to let the hostages go free, the cops order pizza.
Encoding. Sample: Take a bowl and pass it down.
Encoding. Sample: Okay, well, obviously you don't know anything about leadership.
Encoding. Sample: Can you do this, Kevin?
Encoding. Sample: Believe me, she has enough toys... she doesn't need your watch.
Encoding. Sample: OK, well, I didn't say to write your name down, did I? Fill it out, leave it anonymous. Or, don't write any disease down at all and 

Encoding. Sample: Maybe.
Encoding. Sample: Okay, that is called a compromise. And it is style 3. And it is not ideal. To sum up, win/win - make the poster into a t-shirt, win/lose - take the poster down, compromise - Tuesdays and Thursdays. And the answer is... make the poster into a t-shirt! Win/win.
Encoding. Sample: Yes they were.
Encoding. Sample: No, it's for me, bimbo.  Kids.
Encoding. Sample: On it!
Encoding. Sample: Yeah, I'm sitting on twenty-five hundred in sales I can make at any time but those are my wait till the separation is legal sales.
Encoding. Sample: And how many kitchens?
Encoding. Sample: You broke my heart more recently and more often. And I think at some point, in my head, it just sort of clicked that we're not meant to be.
Encoding. Sample: Uhm.
Encoding. Sample: Okay, greatest strength.
Encoding. Sample: Yep. I'll talk to you tomorrow.
Encoding. Sample: Happy birthday Michael.
Encoding. Sample: So happy. Yeah.
Encoding. Sample: Surprise!
Encoding. Sample: [hig

TypeError: unsupported operand type(s) for +: 'BatchEncoding' and 'list'

In [14]:
dir(dls)

NameError: name 'dls' is not defined

In [None]:
dls.show_batch(max_n=4)