# Preprocess text

In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [2]:
#export
from exp.nb_11a import *

## Loading and splitting the dataset

In [3]:
path = datasets.untar_data(datasets.URLs.IMDB)

In [4]:
path.ls()

[PosixPath('/home/fabiograetz/.fastai/data/imdb/README'),
 PosixPath('/home/fabiograetz/.fastai/data/imdb/train'),
 PosixPath('/home/fabiograetz/.fastai/data/imdb/test'),
 PosixPath('/home/fabiograetz/.fastai/data/imdb/ll_clas.pkl'),
 PosixPath('/home/fabiograetz/.fastai/data/imdb/unsup'),
 PosixPath('/home/fabiograetz/.fastai/data/imdb/tmp_clas'),
 PosixPath('/home/fabiograetz/.fastai/data/imdb/imdb.vocab'),
 PosixPath('/home/fabiograetz/.fastai/data/imdb/tmp_lm')]

In [5]:
#export
def read_file(fn): 
    with open(fn, 'r', encoding = 'utf8') as f: return f.read()

In [6]:
#export
class TextList(ItemList):
    @classmethod
    def from_files(cls, path, extensions='.txt', recurse=True, include=None, **kwargs):
        return cls(get_files(path, extensions, recurse=recurse, include=include), path, **kwargs)
    
    def get(self, i):
        if isinstance(i, Path): return read_file(i)
        return i

In [7]:
il = TextList.from_files(path, include=['train', 'test', 'unsup'])

In [8]:
len(il)

100000

In [9]:
il

TextList (100000 items)
 [PosixPath('/home/fabiograetz/.fastai/data/imdb/train/neg/9809_2.txt'), PosixPath('/home/fabiograetz/.fastai/data/imdb/train/neg/7291_2.txt'), PosixPath('/home/fabiograetz/.fastai/data/imdb/train/neg/1279_3.txt'), PosixPath('/home/fabiograetz/.fastai/data/imdb/train/neg/7323_1.txt'), PosixPath('/home/fabiograetz/.fastai/data/imdb/train/neg/9921_3.txt'), PosixPath('/home/fabiograetz/.fastai/data/imdb/train/neg/1825_2.txt'), PosixPath('/home/fabiograetz/.fastai/data/imdb/train/neg/233_1.txt'), PosixPath('/home/fabiograetz/.fastai/data/imdb/train/neg/3324_3.txt'), PosixPath('/home/fabiograetz/.fastai/data/imdb/train/neg/9439_3.txt'), PosixPath('/home/fabiograetz/.fastai/data/imdb/train/neg/10967_4.txt')...]
 Path: /home/fabiograetz/.fastai/data/imdb

In [10]:
text = il[0]

In [11]:
text

"Some wonder why there weren't anymore Mrs. Murphy movies after this one. Will it's because this movie totally blew snot. Disney was not the right studio to run this film. MAYBE Touchstone (well, they're owned by Disney, but it'd be more adult). The film is too kid-ish, as the book series is not. The casting is all wrong for the characters. The characters don't even act the way they do in the books. And why was Tucker changed to a guy? He's a girl in the frigging books! Was this done to make the film appeal to boys? Sheesh. And where was Pewter, the gray cat? One of the funniest characters from the book is absent from this filth. Rita Mae Brown is a good writer, but letting Disney blow her work was wrong. An animated feature film, perhaps in the vane of Don Bluth's artwork would suit a better Mrs. Murphy film. Overall, I give this a 2, because at least Disney made a film from an under-appreciated book series. But, I wish they did better. Either way, I still have my books to entertain m

In [12]:
sd = SplitData.split_by_func(il, partial(random_splitter, p_valid=0.1))

In [13]:
sd

SplitData
Train: TextList (89874 items)
 [PosixPath('/home/fabiograetz/.fastai/data/imdb/train/neg/9809_2.txt'), PosixPath('/home/fabiograetz/.fastai/data/imdb/train/neg/7291_2.txt'), PosixPath('/home/fabiograetz/.fastai/data/imdb/train/neg/1279_3.txt'), PosixPath('/home/fabiograetz/.fastai/data/imdb/train/neg/7323_1.txt'), PosixPath('/home/fabiograetz/.fastai/data/imdb/train/neg/9921_3.txt'), PosixPath('/home/fabiograetz/.fastai/data/imdb/train/neg/1825_2.txt'), PosixPath('/home/fabiograetz/.fastai/data/imdb/train/neg/233_1.txt'), PosixPath('/home/fabiograetz/.fastai/data/imdb/train/neg/3324_3.txt'), PosixPath('/home/fabiograetz/.fastai/data/imdb/train/neg/9439_3.txt'), PosixPath('/home/fabiograetz/.fastai/data/imdb/train/neg/10967_4.txt')...]
 Path: /home/fabiograetz/.fastai/data/imdb
Valid: TextList (10126 items)
 [PosixPath('/home/fabiograetz/.fastai/data/imdb/train/neg/3059_1.txt'), PosixPath('/home/fabiograetz/.fastai/data/imdb/train/neg/845_1.txt'), PosixPath('/home/fabiograetz/

## Tokenization

In [14]:
#export
import spacy, html

In [15]:
#export
UNK, PAD, BOS, EOS, TK_REP, TK_WREP, TK_UP, TK_MAJ = "xxunk xxpad xxbos xxeos xxrep xxwrep xxup xxmaj".split()

def sub_br(t):
    "Replaces the <br /> by \n"
    re_br = re.compile(r'<\s*br\s*/?>', re.IGNORECASE)
    return re_br.sub("\n", t)

def spec_add_spaces(t):
    "Add spaces around / and #"
    return re.sub(r'([/#])', r' \1 ', t)

def rm_useless_spaces(t):
    "Remove multiple spaces"
    return re.sub(' {2,}', ' ', t)

def replace_rep(t):
    "Replace repetitions at the character level: cccc -> TK_REP 4 c"
    def _replace_rep(m:Collection[str]) -> str:
        c,cc = m.groups()
        return f' {TK_REP} {len(cc)+1} {c} '
    re_rep = re.compile(r'(\S)(\1{3,})')
    return re_rep.sub(_replace_rep, t)
    
def replace_wrep(t):
    "Replace word repetitions: word word word -> TK_WREP 3 word"
    def _replace_wrep(m:Collection[str]) -> str:
        c,cc = m.groups()
        return f' {TK_WREP} {len(cc.split())+1} {c} '
    re_wrep = re.compile(r'(\b\w+\W+)(\1{3,})')
    return re_wrep.sub(_replace_wrep, t)

def fixup_text(x):
    "Various messy things we've seen in documents"
    re1 = re.compile(r'  +')
    x = x.replace('#39;', "'").replace('amp;', '&').replace('#146;', "'").replace(
        'nbsp;', ' ').replace('#36;', '$').replace('\\n', "\n").replace('quot;', "'").replace(
        '<br />', "\n").replace('\\"', '"').replace('<unk>',UNK).replace(' @.@ ','.').replace(
        ' @-@ ','-').replace('\\', ' \\ ')
    return re1.sub(' ', html.unescape(x))

In [16]:
#export
default_pre_rules = [fixup_text, replace_rep, replace_wrep, spec_add_spaces, rm_useless_spaces, sub_br]

default_spec_tok = [UNK, PAD, BOS, EOS, TK_REP, TK_WREP, TK_UP, TK_MAJ]

In [17]:
replace_rep('aaaa')

' xxrep 4 a '

In [18]:
replace_wrep("test test test test ")

' xxwrep 4 test  '

*Rules that are applied after tokenization:*

In [19]:
#export
def replace_all_caps(x):
    "Replace tokens in ALL CAPS by their lower version and add `TK_UP` before."
    res = []
    for t in x:
        if t.isupper() and len(t) > 1: res.append(TK_UP); res.append(t.lower())
        else: res.append(t)
    return res

In [20]:
replace_all_caps(["AAA", "bbb", "Fabio" ,"FABIO"])

['xxup', 'aaa', 'bbb', 'Fabio', 'xxup', 'fabio']

In [21]:
#export
def deal_caps(x):
    "Replace all Capitalized tokens by their lower version and add `TK_MAJ` before."
    res = []
    for t in x:
        if t == '': continue
        if t[0].isupper() and len(t) > 1 and t[1:].islower(): 
            res.append(TK_MAJ)
        res.append(t.lower())
    return res

In [22]:
deal_caps(["AAA", "bbb", "Fabio" ,"FABIO"])

['aaa', 'bbb', 'xxmaj', 'fabio', 'fabio']

In [23]:
#export
def add_eos_bos(x): return [BOS] + x + [EOS]

In [24]:
#export
default_post_rules = [replace_all_caps, deal_caps, add_eos_bos]  # changed order with respect to fastai because otherwise all_caps are not handled correctly

In [25]:
x = ["AAA", "bbb", "Fabio" ,"FABIO"]

for f in default_post_rules:
    x = f(x)

In [26]:
x

['xxbos', 'xxup', 'aaa', 'bbb', 'xxmaj', 'fabio', 'xxup', 'fabio', 'xxeos']

In [27]:
#export
from spacy.symbols import ORTH
from concurrent.futures import ProcessPoolExecutor

In [28]:
#export
def parallel(func, arr, max_workers=4):
    if max_workers < 2:
        results = list(progress_bar(map(func, enumerate(arr)), total=len(arr)))
    else:
        with ProcessPoolExecutor(max_workers=max_workers) as ex:
            return list(progress_bar(ex.map(func, enumerate(arr)), total=len(arr)))
    if any([o is not None for o in results]): return results

In [29]:
#export
class TokenizeProcessor(Processor):
    def __init__(self, lang="en", chunksize=2000, pre_rules=None, post_rules=None, max_workers=4): 
        self.chunksize, self.max_workers = chunksize, max_workers
        self.tokenizer = spacy.blank(lang).tokenizer
        
        for w in default_spec_tok:
            self.tokenizer.add_special_case(w, [{ORTH: w}])
        
        self.pre_rules  = default_pre_rules  if pre_rules  is None else pre_rules
        self.post_rules = default_post_rules if post_rules is None else post_rules

    def proc_chunk(self, args):
        # chunk is a list of strings
        i, chunk = args
        chunk = [compose(t, self.pre_rules) for t in chunk]  # list of strings
        docs = [[d.text for d in doc] for doc in self.tokenizer.pipe(chunk)]  # docs is a list of lists of tokens
        docs = [compose(t, self.post_rules) for t in docs]  # Formerly capitalized tokens are all lowercase now with special tokens before
        return docs  # List of lists of tokens
 
    def __call__(self, items): 
        toks = []
        if isinstance(items[0], Path): items = [read_file(i) for i in items]
        # items is a list of strings
        chunks = [items[i: i+self.chunksize] for i in (range(0, len(items), self.chunksize))]
        # chunks is a list of lists of strings

        toks = parallel(self.proc_chunk, chunks, max_workers=self.max_workers)
        return sum(toks, [])
    
    def proc1(self, item): return self.proc_chunk([item])[0]
    
    def deprocess(self, toks): return [self.deproc1(tok) for tok in toks]
    def deproc1(self, tok):    return " ".join(tok)

In [30]:
tp = TokenizeProcessor()

In [31]:
text[:200]

"Some wonder why there weren't anymore Mrs. Murphy movies after this one. Will it's because this movie totally blew snot. Disney was not the right studio to run this film. MAYBE Touchstone (well, they'"

In [32]:
' • '.join(tp(il[:10])[0])[:400]

"xxbos • xxmaj • some • wonder • why • there • were • n't • anymore • xxmaj • mrs. • xxmaj • murphy • movies • after • this • one • . • xxmaj • will • it • 's • because • this • movie • totally • blew • snot • . • xxmaj • disney • was • not • the • right • studio • to • run • this • film • . • xxup • maybe • xxmaj • touchstone • ( • well • , • they • 're • owned • by • xxmaj • disney • , • but • it"

## Numericalization

In [33]:
#export
import collections

class NumericalizeProcessor(Processor):
    def __init__(self, vocab=None, max_vocab=60000, min_freq=2):
        self.vocab, self.max_vocab, self.min_freq = vocab, max_vocab, min_freq
        
    def __call__(self, items):
        # items is a list of lists of tokens
        # Define vocab on first use
        if self.vocab is None:
            freq = Counter(p for o in items for p in o)
            self.vocab = [o for o, c in freq.most_common(self.max_vocab) if c >= self.min_freq]
            
            for o in reversed(default_spec_tok):
                if o in self.vocab: self.vocab.remove(o)
                self.vocab.insert(0, o)
                
        if getattr(self, 'otoi', None) is None:
            self.otoi = collections.defaultdict(int, {v:k for k,v in enumerate(self.vocab)})
        
        return [self.proc1(o) for o in items]
    
    def proc1(self, item):
        # item is list of tokens
        return [self.otoi[o] for o in item]  # returns list of ints
    
    def deprocess(self, idxs):
        #idxs is a list of lists of ints
        assert self.vocab is not None
        return [self.deproc1(idx) for idx in idxs]
    
    def deproc1(self, idx):
        # idx is a list of ints
        return [self.vocab[i] for i in idx]

In [34]:
proc_tok, proc_num = TokenizeProcessor(max_workers=8), NumericalizeProcessor()

In [35]:
%time ll = label_by_func(sd, lambda x: 0, proc_x= [proc_tok, proc_num])

CPU times: user 17.5 s, sys: 1.92 s, total: 19.4 s
Wall time: 42 s


In [36]:
idxs = proc_num.proc1(["xxbos", "xxmaj", "some", "wonder", "why", "there", "were", "n't", "anymore"])

In [37]:
idxs

[2, 7, 65, 606, 154, 54, 86, 35, 1561]

In [38]:
proc_num.deproc1(idxs)

['xxbos', 'xxmaj', 'some', 'wonder', 'why', 'there', 'were', "n't", 'anymore']

In [39]:
print(ll.train.x_obj(0))

xxbos xxmaj some wonder why there were n't anymore xxmaj mrs. xxmaj murphy movies after this one . xxmaj will it 's because this movie totally blew snot . xxmaj disney was not the right studio to run this film . xxup maybe xxmaj touchstone ( well , they 're owned by xxmaj disney , but it 'd be more adult ) . xxmaj the film is too kid - ish , as the book series is not . xxmaj the casting is all wrong for the characters . xxmaj the characters do n't even act the way they do in the books . xxmaj and why was xxmaj tucker changed to a guy ? xxmaj he 's a girl in the frigging books ! xxmaj was this done to make the film appeal to boys ? xxmaj sheesh . xxmaj and where was xxmaj pewter , the gray cat ? xxmaj one of the funniest characters from the book is absent from this filth . xxmaj rita xxmaj mae xxmaj brown is a good writer , but letting xxmaj disney blow her work was wrong . xxmaj an animated feature film , perhaps in the vane of xxmaj don xxmaj bluth 's artwork would suit a better xxmaj

In [40]:
proc_num.deproc1(ll.train[0][0])[:10]

['xxbos',
 'xxmaj',
 'some',
 'wonder',
 'why',
 'there',
 'were',
 "n't",
 'anymore',
 'xxmaj']

## Batching

In [41]:
from IPython.display import display, HTML
import pandas as pd

Example:

In [42]:
text

"Some wonder why there weren't anymore Mrs. Murphy movies after this one. Will it's because this movie totally blew snot. Disney was not the right studio to run this film. MAYBE Touchstone (well, they're owned by Disney, but it'd be more adult). The film is too kid-ish, as the book series is not. The casting is all wrong for the characters. The characters don't even act the way they do in the books. And why was Tucker changed to a guy? He's a girl in the frigging books! Was this done to make the film appeal to boys? Sheesh. And where was Pewter, the gray cat? One of the funniest characters from the book is absent from this filth. Rita Mae Brown is a good writer, but letting Disney blow her work was wrong. An animated feature film, perhaps in the vane of Don Bluth's artwork would suit a better Mrs. Murphy film. Overall, I give this a 2, because at least Disney made a film from an under-appreciated book series. But, I wish they did better. Either way, I still have my books to entertain m

In [43]:
tokens = np.array(tp([text])[0])

In [44]:
tokens

array(['xxbos', 'xxmaj', 'some', 'wonder', ..., 'entertain', 'me', '.', 'xxeos'], dtype='<U11')

Let's say we split this into 6 batches of sequence length 15:

In [45]:
bs, seq_len = 6, 15

In [46]:
d_tokens = np.array([tokens[i * seq_len:(i+1)*seq_len] for i in range(bs)])

In [47]:
df = pd.DataFrame(d_tokens)

In [48]:
display(HTML(df.to_html(index=False, header=None)))

0,1,2,3,4,5,6,7,8,9,10,11,12,13,14
xxbos,xxmaj,some,wonder,why,there,were,n't,anymore,xxmaj,mrs.,xxmaj,murphy,movies,after
this,one,.,xxmaj,will,it,'s,because,this,movie,totally,blew,snot,.,xxmaj
disney,was,not,the,right,studio,to,run,this,film,.,xxup,maybe,xxmaj,touchstone
(,well,",",they,'re,owned,by,xxmaj,disney,",",but,it,'d,be,more
adult,),.,xxmaj,the,film,is,too,kid,-,ish,",",as,the,book
series,is,not,.,xxmaj,the,casting,is,all,wrong,for,the,characters,.,xxmaj


For a `bptt` of 5, we would go over the following three batches:

In [49]:
bs, bptt = 6, 5

In [50]:
for k in range(3):
    d_tokens = np.array([tokens[i*seq_len + k*bptt:i*seq_len + (k+1)*bptt] for i in range(bs)])
    df = pd.DataFrame(d_tokens)
    display(HTML(df.to_html(index=False, header=None)))

0,1,2,3,4
xxbos,xxmaj,some,wonder,why
this,one,.,xxmaj,will
disney,was,not,the,right
(,well,",",they,'re
adult,),.,xxmaj,the
series,is,not,.,xxmaj


0,1,2,3,4
there,were,n't,anymore,xxmaj
it,'s,because,this,movie
studio,to,run,this,film
owned,by,xxmaj,disney,","
film,is,too,kid,-
the,casting,is,all,wrong


0,1,2,3,4
mrs.,xxmaj,murphy,movies,after
totally,blew,snot,.,xxmaj
.,xxup,maybe,xxmaj,touchstone
but,it,'d,be,more
ish,",",as,the,book
for,the,characters,.,xxmaj


In [51]:
#export
class LM_PreLoader():
    def __init__(self, data, bs=64, bptt=70, shuffle=False):
        self.data, self.bs, self.bptt, self.shuffle = data, bs, bptt, shuffle
        total_len = sum([len(t) for t in data.x])
        self.n_batch = total_len // bs  # len of each of the bs concatenated streams
        self.batchify()
        
    def __len__(self):
        return ((self.n_batch - 1) // self.bptt) * self.bs  # How many token sequences of len bptt and "bs=1"
    
    def __getitem__(self, idx):
        source = self.batched_data[idx % self.bs]  # which row in the (bs, n_batch) tensor
        seq_idx = (idx // self.bs) * self.bptt     # idx of the substring of len bptt in source/one row
        return source[seq_idx:seq_idx+self.bptt],source[seq_idx+1:seq_idx+self.bptt+1]
    
    def batchify(self):
        texts = self.data.x
        if self.shuffle: texts = texts[torch.randperm(len(texts))]
        
        stream = torch.cat([tensor(t) for t in texts])
        
        self.batched_data = stream[:self.n_batch * self.bs].view(self.bs, self.n_batch)

*Helpful experiments to understand:*

In [52]:
def visualize_tensor(t):
    display(HTML(pd.DataFrame(np.array(t)).to_html(index=False, header=None)))

In [53]:
stream = torch.cat([tensor(t) for t in ll.train.x])

In [54]:
bs = 3
bptt = 5
n_batch = 15

In [55]:
batched_data = stream[:bs * n_batch].view(bs,n_batch); batched_data.shape

torch.Size([3, 15])

In [56]:
visualize_tensor(batched_data)

0,1,2,3,4,5,6,7,8,9,10,11,12,13,14
2,7,65,606,154,54,86,35,1561,7,2293,7,2128,118,119
19,43,9,7,105,16,22,107,19,29,480,4472,15067,9,7
998,25,37,8,227,1154,14,502,19,31,9,6,298,7,32086


In [57]:
idx = 4

In [58]:
source = batched_data[idx % bs]

In [59]:
visualize_tensor(source.unsqueeze(0))

0,1,2,3,4,5,6,7,8,9,10,11,12,13,14
19,43,9,7,105,16,22,107,19,29,480,4472,15067,9,7


In [60]:
seq_idx = (idx // bs) * bptt; seq_idx

5

In [61]:
visualize_tensor(source[seq_idx:seq_idx + bptt].unsqueeze(0))
visualize_tensor(source[seq_idx+1:seq_idx+bptt+1].unsqueeze(0))

0,1,2,3,4
16,22,107,19,29


0,1,2,3,4
22,107,19,29,480


In [62]:
lm_pre = LM_PreLoader(ll.valid, shuffle=True)

In [63]:
len(lm_pre)

43392

In [64]:
sum([len(t) for t in ll.valid.x])//70

43443

In [65]:
x, y = lm_pre[10]

In [66]:
visualize_tensor(x.unsqueeze(0))
visualize_tensor(y.unsqueeze(0))

0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69
1139,113,16,28,112,9,7,1222,12,833,10,225,126,10,1139,9,7,17,8,148,7,6838,3418,5581,8,440,14,1287,9,24,7,1139,753,95,7,1567,7,1062,7,1351,7,4450,36,2938,33,3111,7,10567,28,8,6,11465,125,2958,7,6193,23,7,100,302,1632,9,7,83,43,13,7,33396,22,139


0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69
113,16,28,112,9,7,1222,12,833,10,225,126,10,1139,9,7,17,8,148,7,6838,3418,5581,8,440,14,1287,9,24,7,1139,753,95,7,1567,7,1062,7,1351,7,4450,36,2938,33,3111,7,10567,28,8,6,11465,125,2958,7,6193,23,7,100,302,1632,9,7,83,43,13,7,33396,22,139,4804


In [67]:
dl = DataLoader(lm_pre, batch_size=64)

In [68]:
iter_dl = iter(dl)

In [69]:
x1, y1 = next(iter_dl)
x2, y2 = next(iter_dl)

In [70]:
x1.size(), y1.size()

(torch.Size([64, 70]), torch.Size([64, 70]))

In [71]:
vocab = proc_num.vocab

In [72]:
' '.join([vocab[o] for o in x1[0]])

'xxbos xxmaj great surprise from xxmaj peter xxmaj jackson on the xxmaj fellowship of the xxmaj ring xxmaj extended xxup dvd . xxmaj jack xxmaj black hands over the xxmaj one xxmaj ring at the xxmaj council of xxmaj elrond . xxmaj problem ? xxmaj he got drunk with some mates the night before and got is ... put somewhere personal . \n\n xxmaj spliced together with the scene from'

In [73]:
' '.join([vocab[o] for o in y1[0]])

'xxmaj great surprise from xxmaj peter xxmaj jackson on the xxmaj fellowship of the xxmaj ring xxmaj extended xxup dvd . xxmaj jack xxmaj black hands over the xxmaj one xxmaj ring at the xxmaj council of xxmaj elrond . xxmaj problem ? xxmaj he got drunk with some mates the night before and got is ... put somewhere personal . \n\n xxmaj spliced together with the scene from the'

In [74]:
' '.join([vocab[o] for o in x2[0]])

'the film , this is a great little treat for people who hate -- or love -- xxmaj the xxmaj lord of the xxmaj rings . xxeos xxbos xxmaj this is one of the funniest and most warm - hearted films ever ! xxmaj john xxmaj gordon xxmaj sinclair and xxmaj dee xxmaj hepburn were absolutely wonderful in this story of teenage love and the sudden twists & turns that'

In [75]:
' '.join([vocab[o] for o in y2[0]])

'film , this is a great little treat for people who hate -- or love -- xxmaj the xxmaj lord of the xxmaj rings . xxeos xxbos xxmaj this is one of the funniest and most warm - hearted films ever ! xxmaj john xxmaj gordon xxmaj sinclair and xxmaj dee xxmaj hepburn were absolutely wonderful in this story of teenage love and the sudden twists & turns that occur'

*Notice, that the target y has an offset of one to the right with respect to x!*

In [76]:
#export
def get_lm_dls(train_ds, valid_ds, bs, bptt, **kwargs):
    return (
        DataLoader(LM_PreLoader(train_ds, bs, bptt, shuffle=True), batch_size=bs, **kwargs),
        DataLoader(LM_PreLoader(valid_ds, bs, bptt, shuffle=False), batch_size=2*bs, **kwargs)
    )

In [77]:
#export
def lm_databunchify(sd, bs, bptt, **kwargs):
    return DataBunch(*get_lm_dls(sd.train, sd.valid, bs, bptt, **kwargs))

In [78]:
bs, bptt = 64, 70

In [79]:
data = lm_databunchify(ll, bs, bptt)

In [80]:
data.train_ds

<__main__.LM_PreLoader at 0x7ff000218c50>

In [81]:
len(data.train_ds)

387328

In [82]:
len(data.train_dl)

6052

In [83]:
6054 * bs

387456

In [84]:
next(iter(data.train_dl))[0].shape

torch.Size([64, 70])

## Batching for classification

In [85]:
proc_cat = CategoryProcessor()

In [86]:
il = TextList.from_files(path, include=['train', 'test'])

In [87]:
sd = SplitData.split_by_func(il, partial(grandparent_splitter, valid_name="test"))

In [88]:
ll = label_by_func(sd, parent_labeler, proc_x = [proc_tok, proc_num], proc_y=proc_cat)

In [89]:
pickle.dump(ll, open(path/"ll_clas.pkl", "wb"))

In [90]:
ll = pickle.load(open(path/"ll_clas.pkl", "rb"))

In [91]:
[(ll.train.x_obj(i), ll.train.y_obj(i)) for i in [1, 2355]];

For memory reasons it's good to have the biggest tensors first. Therefore we sort them this way for our `SortSampler`:

In [92]:
#export
from torch.utils.data import Sampler

class SortSampler(Sampler):
    def __init__(self, data_source, key):
        self.data_source, self.key = data_source, key
    def __len__(self):
        return len(self.data_source)
    def __iter__(self):
        return iter(sorted(list(range(len(self.data_source))), key=self.key, reverse=True))

For `SortishSampler`:

* Shuffle the text
* Build megabatches of `50 * bs`
* Sort these megabatches by length
* Split them into 50 minibatches

Gives randomized batches of roughly the same length.

In [93]:
#export
class SortishSampler(Sampler):
    def __init__(self, data_source, key, bs):
        self.data_source, self.key, self.bs = data_source, key, bs
        
    def __len__(self) -> int: return len(self.data_source)
    
    def __iter__(self):
        idxs = torch.randperm(len(self.data_source))
        megabatches = [idxs[i:i+self.bs*50] for i in range(0, len(idxs), self.bs*50)]  # list of tensors of ints
        sorted_idx = torch.cat([tensor(sorted(s, key=self.key, reverse=True)) for s in megabatches])
        batches = [sorted_idx[i:i+self.bs] for i in range(0, len(sorted_idx), self.bs)]  # list of tensors of ints
        
        max_idx = torch.argmax(tensor([self.key(ck[0]) for ck in batches]))  # find the chunk with the largest key, 0th element is largest
        
        # then make sure the largest sequence is first
        batches[0], batches[max_idx] = batches[max_idx], batches[0]
        
        batch_idxs = torch.randperm(len(batches) - 2)
        sorted_idx = torch.cat([batches[i+1] for i in batch_idxs]) if len(batches) > 1 else LongTensor([])
        sorted_idx = torch.cat([batches[0], sorted_idx, batches[-1]])
        return iter(sorted_idx)

In [94]:
#export
def pad_collate(samples, pad_idx=1, pad_first=False):
    max_len = max([len(s[0]) for s in samples])
    res = torch.zeros(len(samples), max_len).long() + pad_idx
    for i,s in enumerate(samples):
        if pad_first: res[i, -len(s[0]):] = LongTensor(s[0])
        else:         res[i, :len(s[0]) ] = LongTensor(s[0])
    return res, tensor([s[1] for s in samples])

In [95]:
bs = 64

In [96]:
train_sampler = SortishSampler(ll.train.x, key=lambda t: len(ll.train[int(t)][0]), bs=bs)

In [97]:
train_dl = DataLoader(ll.train, batch_size=bs, sampler=train_sampler, collate_fn=pad_collate)

In [98]:
iter_dl = iter(train_dl)

In [99]:
x, y = next(iter_dl)

In [100]:
lengths = []
for i in range(x.size(0)): lengths.append(x.size(1) - (x[i] == 1).sum().item())

In [101]:
lengths[:5], lengths[-1]

([3352, 1788, 1632, 1588, 1423], 990)

This was the first batch that contains the longest sequence. In the next batches the lengths are more similar:

In [102]:
x, y = next(iter_dl)
lengths = []
for i in range(x.size(0)): lengths.append(x.size(1) - (x[i] == 1).sum().item())
lengths[:5], lengths[-1]

([326, 326, 326, 326, 326], 315)

In [103]:
visualize_tensor(x)

0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127,128,129,130,131,132,133,134,135,136,137,138,139,140,141,142,143,144,145,146,147,148,149,150,151,152,153,154,155,156,157,158,159,160,161,162,163,164,165,166,167,168,169,170,171,172,173,174,175,176,177,178,179,180,181,182,183,184,185,186,187,188,189,190,191,192,193,194,195,196,197,198,199,200,201,202,203,204,205,206,207,208,209,210,211,212,213,214,215,216,217,218,219,220,221,222,223,224,225,226,227,228,229,230,231,232,233,234,235,236,237,238,239,240,241,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,261,262,263,264,265,266,267,268,269,270,271,272,273,274,275,276,277,278,279,280,281,282,283,284,285,286,287,288,289,290,291,292,293,294,295,296,297,298,299,300,301,302,303,304,305,306,307,308,309,310,311,312,313,314,315,316,317,318,319,320,321,322,323,324,325
2,7,301,714,12,75,65,477,101,113,12,29,59,12,230,20,208,338,104,759,10,1508,14,159,14,7,16344,10,30,83,81,14,65,1843,914,13,10954,17685,10,11,115,39,240,12,3466,8546,10,11,177,32,140,16,10,39,22,8308,17,12,1513,31379,430,14,2545,10,11,115,48,7377,2855,27,48,1901,23,3626,6443,292,17,11,240,4617,47,8,230,10,654,461,14,42,7166,44,143,8,287,30,132,209,14,1636,120,3079,10,11,115,6,9875,10,55,45,242,43,230,45,6,9875,10,240,3224,59,16,11,1832,43,102,230,17,14,2386,108,198,10,30,8,230,46,1458,14,2386,8,7166,230,60,14,159,363,11,41,6936,106,10,11,40,3071,23,6464,334,2668,8,18167,10,52,39,197,35,41,18167,10,11,115,39,60,14,159,61,11,185,28,8,7166,230,27,12,31525,4744,10,11,20,91,35,83,179,10,52,39,83,81,1119,8,4107,13,331,23,4617,4921,10,11,115,54,22,158,59,12,1999,10,11,127,1517,176,101,17,12,14010,23025,10,11,12,366,27,12,3361,15859,10,11,65,265,13,6522,4071,27,7488,23,667,2520,4910,10,11,252,18,197,35,404,32,109,8,118,657,10,44,18,78,157,15,7,3321,7,12526,15,17,16,1262,27,65,230,27,8,395,411,13,7,11720,45946,10,11,63,54,22,120,1473,17,8,195,308,3073,27,19,29,1094,12,4558,10,1073,343,11,25,2083,14,114,69,118,10,11,8,195,1402,6,3127,6,144,6,119,23,6,8,6,148,50,3
2,7,5260,44,18,258,12,6029,11840,252,8,603,13,59,892,9,18,461,18,258,205,2470,47,6029,10,30,2235,37,26,94,26,8,1373,122,17,19,31,9,24,7,17,19,31,10,8,1373,122,5643,12,164,6029,1784,462,7,16071,3170,7,0,49,2410,1681,17,20766,9,7,39,15,836,28,12,475,39,275,466,7,0,10,135,12,337,2162,25,275,47,7,0,9,7,318,13,43682,10,39,346,2384,61,13,8,9347,2324,10,132,14,1018,9,24,7,8,1099,15770,13,8,1373,122,15,492,12,1785,13,65,13,8,69,627,7,0,32,559,98,17,6029,36188,9,7,6029,18,1365,15,12,70,219,3596,10,11,32,78,559,148,72,49612,17,102,3826,13,136,9,12,688,512,13,19,15,8,6029,1850,7,3428,7,16867,9,7,284,12,1194,10,39,25,103,70,4135,17,129,757,9,24,7,17,8,31,10,12,7,195,6193,1139,15,2407,10,26,223,48,7,1046,7,28184,11,7,24371,9,7,8,2847,15,12,12135,10,79,178,8,195,7,32928,677,13,7,18201,60,559,5776,26,12,2847,413,23,182,51,38689,1788,14,56,263,17,8,486,9,7,30,19,25,12,541,12135,23,79,103,4367,12,828,2847,10,135,8,8660,13,273,541,319,466,273,102,9,24,7,8,2757,177,0,181,28,12,1456,6029,4737,17,413,9,7,17,212,10,18,41,244,12,6881,397,59,16,10,28,32,14,3838,8,1808,2162,17,1541,23,388,50,5639,95,125,125,0,125,38648,3249,0,7,139,3523,7,0,3
2,7,106,13,44,10,18,258,37,12,686,355,13,2457,7,4529,458,10,79,15,107,10,8,675,6277,13,1874,12,881,1034,1009,15,47,3358,1827,8,54269,9,7,19,29,15,180,13,48,1767,959,20,2087,27,15460,10,26,8,170,11,4531,7155,9467,9,7,17,77,2788,682,10,8,3903,1281,13,19,29,15,114,306,47,1371,12,2541,0,212,16,15,670,1271,7456,17,8,2074,659,33,7,697,38727,253,1631,27,12,1368,5213,13,64,18,1017,9,24,7,214,10,8,137,15,749,428,9,7,17,212,10,110,13,8,196,60,2299,946,11,783,267,14,7840,28,64,7,49688,3615,134,644,50,7,44,1078,13,8,196,86,926,1134,17,80,589,11,88,10865,28,8,320,10,76,8,360,2339,703,9,36,7,53,7,0,33,7,45,242,10,7,49688,1029,12,418,678,13,4818,10,56,107,39,696,88,109,14,2196,8,196,9,7,102,93,20,10,39,15,56,12,1926,3294,10,49,25702,332,12,170,27,48,1767,644,9,7,233,34,10,553,15,37,158,20,4240,3193,13,1980,27,1260,125,331,1260,366,9,7,11,56,107,1926,6103,142,10,74,170,493,12,8257,17,8,497,13,7,4529,458,9,7,56,516,147,4221,11,1422,158,164,1767,10,18,258,1064,7961,14,8377,147,385,179,9,7,443,10,19,75,32,1954,14,2715,48,1767,1573,9,24,7,17,364,134,7,7312,4024,7,67,137,10,901,366,36,56,3971,50,33,3215,7,6924,4024,7,273,11,191,694,169,10,102,93,8,196,3
2,7,166,10,18,41,35,371,8,232,364,82,14,140,44,8,5182,798,20,436,383,150,10,52,18,167,37,182,14,159,198,20,2641,150,9,24,7,30,18,41,65,75,624,12869,20,7,1892,7,748,118,346,0,32427,201,124,10,27,401,12,381,13,5277,9,7,19,25,37,43,13,112,9,7,16,679,88,217,10,11,28,301,18,167,37,182,14,3905,59,8,137,10,284,7,2115,7,17061,25,26,675,421,400,7,103,18,105,3316,8,938,621,13,15594,1915,10,26,46,86,1708,14,114,8,82,150,831,944,92,24,7,30,54,25,43,4591,243,20,18,96,35,367,30,98,3685,27,9,7,11,20,410,206,516,14,8,148,13,8,29,10,11,45,242,263,121,48748,17,8,6241,13,1005,2969,9,7,8,403,18,5613,14,38,1337,95,7,54,15,12,2290,1055,642,207,8,348,9,7,32,1731,16,15,469,8,1416,9,7,58,32,24,36,12,33,98,26,516,14,8,1416,26,32,78,177,11218,5036,8,0,1015,72,10,9477,206,94,147,241,677,14,857,4185,160,12,13658,5928,3223,253,14988,34,32,10,55,24,36,626,33,831,12,138,260,51,8,1416,52,32,78,13330,469,8,356,868,148,51,12,2757,13,1076,10807,4403,10,55,45,242,84,8,2302,2096,566,45,32,10,734,32,12,138,2169,75,9,24,18,140,10,303,22,159,27,36,12,33,9,18,250,53,3549,561,72,28,8,5101,510,9,7,39737,24,7,266,217,31,171,92,7,56,37,70,666,9,3
2,7,6108,7,5353,15,1041,8,2041,304,17,8,195,30,1115,14,40,610,10,7,1264,92,88,10,2013,61,165,2453,166,50,7,15767,4943,10,6685,7,1264,209,14,123,20,191,192,39,312,10,1832,28,108,422,8,231,26,48,7,300,125,7,1046,7,1594,23,230,9973,51,8,62,7,3848,9,62,7,613,10,1804,114,108,48,137,1500,62,39,78,5324,50,62,7,19,31,56,4337,8,287,61,10,76,47,8,409,4560,17,19,10,7,1264,152,2273,8,137,0,9,7,44,8,101,7840,28,80,621,13,644,47,5373,80,431,11,2703,80,2375,59,10,40162,4825,17,12,416,20,436,61,13,1624,162,27,1184,124,9,24,7,8,130,10,64,54,10,15,10,181,74,291,26,12,7948,4712,11,523,44,8,366,104,369,23,4856,0,30,26,19,15,1170,15370,31,10,20,56,811,46,3369,14,80,7376,11,114,13026,966,53,2071,34,3219,9,7,64,574,17,8,148,10,18,167,37,205,272,26,18,25,2268,8,105,14,435,215,177,8,31,1829,9,24,7,793,19,53,8,3386,11,126,62,7,855,7,1589,62,28,12,3331,31,59,12,7948,59,14,616,8,733,9,24,6,49679,7,243,13,648,95,68,43,13,8,645,813,12371,198,14,57,7376,10,71,60,57,22456,125,12296,469,57,40813,125,54401,6032,10,16,22,22456,143,8,40813,14,1779,366,14,159,14,8,3599,27,360,9316,9,12,6935,243,10,401,10,30,294,109,1005,19,31,15,68,46,197,35,76,98,19,227,50,3
2,6,7279,36,7,11040,7,0,33,11,6,55290,36,7,11494,7,34628,33,38,5796,17,12,1080,311,27,6,0,36,7,746,7,0,33,10,30,8,106,703,4298,12,1682,488,10,6,46556,10,14,590,6,0,36,8,130,10,45,242,45,8,481,10,91,35,1267,70,88,154,33,9,7,8,230,91,16,11,119,6365,40,306,39,525,17501,8,127,5796,10,3201,17,80,1080,11,662,39,511,12,320,54,36,26,16092,55,158,1751,33,9,7,45,8,187,75,39,975,6,11909,36,7,38163,7,0,33,10,551,13,6,0,10,11,525,4597,27,57,50,24,7,17,12,82,53,19,10,135,778,10,4605,10,4821,11,8874,159,517,17,517,10,74,43,15,1335,55,78,42,15030,10,1474,113,14,6,11909,10,79,15,8,81,420,49,91,35,140,64,22,182,34,11,87,35,12756,270,6719,7,19,31,2215,27,8276,8,631,11,2493,630,79,3099,17,219,11322,53,7,28746,7,17742,10,135,778,15,191,271,22,1266,9,7,89,78,250,3122,30,103,24167,17,121,53,6,46556,36,2148,275,47,7,17742,7,20893,33,10,79,91,8874,14,8,464,49,1521,108,229,120,23879,10,11,76,16045,27,8,551,13,8,230,39,529,50,7,39,83,225,42,12,774,14,58,158,53,20,10,30,13,286,18,140,54,203,101,53,108,61,54,10,17,7,3393,55,120,102,287,6719,7,16,22,12,70,67,29,10,2493,30,9102,10,59,12,631,630,1672,7,8,137,15,100,11,8,787,117,9,3
2,18,234,324,20,8,1313,7,1721,118,86,94,146,93,8,435,211,124,9,24,18,161,131,44,13,8,1313,124,10,30,61,13,8,762,10,19,15,8,12056,10,11,16,22,269,1388,9,24,21,7,2319,13,8,7,5155,21,73,4148,26,8,139,10,115,21,7,195,22,7,2041,21,10,115,21,7,1401,7,1747,21,10,115,21,7,1018,13,8,7,8554,21,10,11,448,19,3553,248,9,24,7,17,19,10479,1313,29,10,54,22,12,1302,183,27919,17,7,11665,11,7,1721,15,3734,34,4829,761,71,22,442,55,10893,26,71,717,61,34,12,2910,28,3864,9,24,7,30,26,7,1230,240,588,27,297,210,366,10,39,792,4812,112,11,2065,49,15,8,183,27919,9,24,7,8,1273,17,19,31,15,7750,652,1083,110,13,8,124,38,3118,11,3705,10,79,25,269,1388,9,7,1230,1393,2924,61,13,122,110,13,8,75,10,8,1789,38,859,23,344,51,8,248,124,10,11,153,8,211,156,38,1123,10,46,203,83,180,183,9,7,16,103,1476,8,953,1589,8,102,124,41,10,281,20,13,21,7,1401,23,7,1747,21,79,25,1257,17,456,11,122,974,9,24,7,296,17,8,31,806,206,5967,11,8,758,121,38,1377,30,74,43,15,178,295,6266,28,9,24,7,44,8,153,10,18,83,526,19,10,18,85,12,4759,10,11,8,2001,13,8,183,27919,15,1635,10,30,19,25,35,26,1123,26,18,321,3559,9,24,36,194,194,331,61,13,4,224,194,33,3,1
2,7,1177,155,42,387,14,8,7,2665,7,2126,6,21981,125,6,288,10756,20,18,81,1521,6,22716,37971,11,18,96,41,160,12,1240,2283,3347,9,24,7,9281,66,7,109,1112,17,277,701,499,9,7,4906,7817,57,127,438,11,3107,7,1918,14,1219,119,7,7218,11,7,4201,14,7,3096,9,7,71,511,57,1145,244,30,45,8,148,10,71,103,6577,8,187,778,13,7,7218,95,71,85,369,27,7,1918,9,24,7,17,19,418,2744,392,10,7,9841,7,0,3612,27,8,55308,707,17,701,1385,27,12,652,506,9,7,109,78,43,3275,1785,1385,646,17,19,712,23,1529,75,135,101,38,166,13,2003,26594,17,1730,10,1891,11,2271,3826,66,7,38227,6133,1766,28,117,322,12,1753,20,110,101,73,2196,14,16408,14,2201,9,7,1918,301,9260,7,4201,30,39,1627,28,7,4906,119,20,6221,1325,79,875,108,14,577,94,6799,9,24,7,110,794,226,250,667,59,8,293,107,7,4906,11,7,1918,233,311,10,88,10,19,155,42,8,860,12687,14,8,10430,509,13,80,11882,9,7,214,10,18,204,12,704,10,18,204,14,84,109,8,664,73,2292,119,19,4404,9,7,401,46,226,4404,162,10,109,78,43,42,52,272,59,8,5469,133,632,9,24,7,67,137,11,67,1374,9,7,8,127,138,520,23,172,155,37,42,6766,281,8,417,68,39,6786,45,8,9344,13,8,1469,17,7,26818,9,7,11,13,286,10,109,78,89,846,8,21,15384,10,15384,21,1374,13,8,513,9,3,1
2,12,7,2255,7,1172,7,163,7,16,15,450,34,8,315,82,13,127,1895,5616,994,10,7,3681,11,7,850,10,36,7,3089,7,3272,11,7,3665,7,14827,33,654,7,9622,353,36,7,826,7,24180,33,15,12,7078,145,654,127,10453,38,40,2134,11,1895,5616,10,23,11,10,28,108,11,40,3513,10,54,15,12,500,365,223,8,127,9,7,19,82,4409,8,578,7803,13,8,994,62,476,11,109,80,476,2974,34,127,285,7613,9,7,16,15,12,1371,29,7333,47,8,170,10,7,611,7,7558,10,419,8,3653,7,3681,11,9742,34,235,215,781,11,101,215,350,9,24,7,755,1434,31851,17,8,29,10,160,26,1689,10,343,10,6423,10,11,1531,9,7,110,13,151,1434,10264,207,8,239,1601,570,13,7,850,9,7,39,15,12,2112,10,1377,10,11,2902,145,10,30,60,40,4188,1556,9,24,7,8,2742,431,2835,72,8,21,243,21,13,8,29,95,21,7,115,17,8,7,12299,331,23,652,13,8,8569,10,44,2179,9788,14,12,128,27,77,1312,11,1860,11,8,966,13,8,7,219,7,0,7,2255,11,12,664,23,1583,6208,11,8,443,20,12,2275,105,2233,9,7,870,10,44,202,16438,104,43,10,11,12,2255,1172,163,16,9,7,8,2255,25,609,47,8,195,22,100,7960,11,1172,143,3835,51,8,3456,13,75,9,7,34,65,13,165,3835,38,3999,38328,9,7,469,8,3835,38,8,692,10,11,65,13,8,692,38,8448,9,18,258,2330,47,3206,9,21,3,1
2,7,19,15,8,139,454,59,2882,252,7,7558,11732,7,8,7,1317,277,116,9,7,1792,7,3627,11,7,2919,7,20028,652,72,8,280,27,12,100,82,11,12,196,20,484,164,217,14,1499,32,104,80,476,9,24,7,8079,27,1531,23,9303,17,630,10,8,7,300,7,1002,0,119,0,27,12,315,82,59,8,7,1805,7,1462,22,2251,23,136,9721,10,7,2326,7,4375,9,7,2488,1364,554,10,7,2919,7,20028,10,312,12,2074,7,1280,7,16313,146,93,12,189,13,7,31238,18,140,134,11,7,1792,7,3627,25,3235,1672,196,26,8,3833,10,0,11,4932,2314,27,26,94,1991,12071,26,8,839,524,17,8,327,9,7,16,22,164,9,7,8,34,23,280,1232,535,9,7,63,32,53,2882,506,23,50002,10,32,203,182,14,133,19,31,9,7,8,4769,28,7,14219,11,6575,3912,86,44,17,54,9,7,77,3977,906,8,54683,1079,4977,62,9,24,7,11,1336,95,809,8,6,4907,50,7,8,238,3835,11,2772,8,82,9882,50,7,25060,692,11,238,5829,8,82,944,628,8,116,16,155,10,48,1618,20,10421,90,69,424,93,37,9,24,7,9534,95,18,321,41,387,8,2882,14,1804,345,9,7,30,7,3627,60,158,14,3430,199,44,59,0,62,11,506,9,7,4136,7,4875,11,8,8223,86,12,358,14364,10,30,46,132,83,276,80,287,17,8,82,14,868,11,516,207,112,9,12,138,215,9,7,295,191,867,17,8,248,5525,9,753,125,184,9,24,3,1,1


Add convenience functions:

In [104]:
#export
def get_clas_dls(train_ds, valid_ds, bs, **kwargs):
    train_sampler = SortishSampler(train_ds.x, key=lambda t: len(train_ds.x[t]), bs=bs)
    
    valid_sampler = SortSampler(valid_ds.x, key=lambda t: len(valid_ds.x[t]))
    
    return (
        DataLoader(train_ds, batch_size=bs, sampler=train_sampler, collate_fn=pad_collate, **kwargs),
        DataLoader(valid_ds, batch_size=2*bs, sampler=valid_sampler, collate_fn=pad_collate, **kwargs)
    )

In [105]:
#export
def clas_databunchify(sd, bs, **kwargs):
    return DataBunch(*get_clas_dls(sd.train, sd.valid, bs, **kwargs))

In [106]:
bs, bptt = 64, 70

In [107]:
data = clas_databunchify(ll, bs)

In [108]:
x, y = next(iter(data.train_dl))

In [109]:
x.shape, y.shape

(torch.Size([64, 3352]), torch.Size([64]))

In [110]:
x, y = next(iter(data.valid_dl))

In [111]:
x.shape, y.shape

(torch.Size([128, 2901]), torch.Size([128]))

## Export


In [113]:
!python notebook2script.py 12_text.ipynb

Converted 12_text.ipynb to exp/nb_12.py
