In [None]:
# default_exp pretrained

# Pretrained

> fast.ai ULMFiT helpers to easily use pretrained models

In [None]:
#hide
from nbdev.showdoc import *

In [None]:
#export
import json
from fastai.text.all import SentencePieceTokenizer, SpacyTokenizer, language_model_learner, \
                            text_classifier_learner, untar_data, Path, patch, \
                            LMLearner, os, pickle, shutil, AWD_LSTM, accuracy, \
                            Perplexity, delegates

In [None]:
#export
def _get_config(path):
    with open(path/'model.json', 'r') as f:
        config = json.load(f)
    return config

In [None]:
#export
def _get_pretrained_model(url):
    fname = f"{url.split('/')[-1]}.tgz"
    path = untar_data(url, fname=fname, c_key='model')
    return path

In [None]:
#export
def _get_direction(backwards):
    return 'bwd' if backwards else 'fwd'

In [None]:
#hide
assert(_get_direction(backwards=False) == 'fwd')
assert(_get_direction(backwards=True) == 'bwd')

Get `model` and `vocab` files from path. 

In [None]:
#export
def _get_model_files(path, backwards=False):
    direction = _get_direction(backwards)
    config = _get_config(path/direction)
    try: 
        model_path = path/direction
        model_file = list(model_path.glob(f'*model.pth'))[0]
        vocab_file = list(model_path.glob(f'*vocab.pkl'))[0]
        fnames = [model_file.absolute(),vocab_file.absolute()]
    except IndexError: print(f'The model in {model_path} is incomplete, download again'); raise
    fnames = [str(f.parent/f.stem) for f in fnames]
    return fnames

## Tokenizer

Get `tokenizer` from model-config. Tokenizer parameters in `model.json` will be passed to the Tokenizer. As of now SentencePiece and Spacy are supported.

In [None]:
#export
def tokenizer_from_pretrained(url, pretrained=False, backwards=False, **kwargs):
    path = _get_pretrained_model(url)
    direction = _get_direction(backwards)
    config = _get_config(path/direction)
    sp_model=path/'spm'/'spm.model' if pretrained else None
    if config['tokenizer']['class'] == 'SentencePieceTokenizer':
        tok = SentencePieceTokenizer(**config['tokenizer']['params'], sp_model=sp_model, **kwargs)
    elif config['tokenizer']['class'] == 'SpacyTokenizer':
        tok = SpacyTokenizer(**config['tokenizer']['params'], **kwargs)
    else:
        raise ValueError('Tokenizer not supported')
    return tok

## Language Model Learner

Create `langauge_model_learner` from pretrained model-URL. All parameters will be passed to `language_model_learner`. The following parameters are set automatically: `arch`, `pretrained` and `pretrained_fnames`. By default `accuracy` and `perplexity` are passed as `metrics`. 

In [None]:
#export
@delegates(language_model_learner)
def language_model_from_pretrained(dls, url=None, backwards=False, metrics=None, **kwargs):
    arch = AWD_LSTM # TODO: Read from config
    path = _get_pretrained_model(url)
    fnames = _get_model_files(path)
    metrics = [accuracy, Perplexity()] if metrics == None else metrics
    return language_model_learner(dls, 
                                  arch, 
                                  pretrained=True, 
                                  pretrained_fnames=fnames, 
                                  metrics=metrics,
                                  **kwargs)

In [None]:
#export
def _get_model_path(learn=None, path=None):
    path = (learn.path/learn.model_dir) if not path else Path(path)
    if not path.exists(): os.makedirs(path, exist_ok=True)
    return path

Saves the following model files to `path`:
- Model (`lm_model.pth`)
- Encoder (`lm_encoder.pth`)
- Vocab from dataloaders (`lm_vocab.pkl`)
- SentencePieceModel (`spm/`)

In [None]:
#export
@patch
def save_lm(x:LMLearner, path=None, with_encoder=True):
    path = _get_model_path(x, path)
    x.to_fp32()
    # save model
    x.save((path/'lm_model').absolute(), with_opt=False)
    
    # save encoder
    if with_encoder:
        x.save_encoder((path/'lm_encoder').absolute())

    # save vocab
    with open((path/'lm_vocab.pkl').absolute(), 'wb') as f:
        pickle.dump(x.dls.vocab, f)
       
    # save tokenizer if SentencePiece is used
    if isinstance(x.dls.tok, SentencePieceTokenizer):
        # copy SPM if path not spm path
        spm_path = Path(x.dls.tok.cache_dir)
        if path.absolute() != spm_path.absolute():
            target_path = path/'spm'
            if not target_path.exists(): os.makedirs(target_path, exist_ok=True)
            shutil.copyfile(spm_path/'spm.model', target_path/'spm.model')
            shutil.copyfile(spm_path/'spm.vocab', target_path/'spm.vocab')
    
    return path

## Text Classifier

In [None]:
#def vocab_from_lm(learn=None, path=None):
#    path = _get_model_path(learn, path)
#    with open((path/'lm_vocab.pkl').absolute(), 'rb') as f:
#        return pickle.load(f)

In [None]:
#def spm_from_lm(learn=None, path=None):
#    path = _get_model_path(learn, path)

Create `text_classifier_learner` from fine-tuned model path (saved with `learn.save_lm()`).

In [None]:
#export
@delegates(text_classifier_learner)
def text_classifier_from_lm(dls, path=None, backwards=False, **kwargs):
    arch = AWD_LSTM # TODO: Read from config
    path = _get_model_path(path=path)
    learn = text_classifier_learner(dls, arch, pretrained=False, **kwargs)
    learn.load_encoder((path/'lm_encoder').absolute())
    return learn

# Tests - Tokenizer, LM and Classifier

In [None]:
#hide
#slow
url = 'http://localhost:8080/ulmfit-dewiki'
tok = tokenizer_from_pretrained(url, pretrained=True)
assert(tok.vocab_sz == 15000)
assert('ulmfit-dewiki/spm/spm.model' in str(tok.sp_model))

In [None]:
#hide
#slow
tok = tokenizer_from_pretrained(url, pretrained=False)
assert(tok.sp_model == None)
assert(tok.vocab_sz == 15000)

In [None]:
#hide
#slow
from fastai.text.all import AWD_LSTM, DataBlock, TextBlock, ColReader, RandomSplitter
import pandas as pd

backwards = False

df = pd.read_csv(Path('_test/data_lm_sample.csv'))

dblocks = DataBlock(blocks=(TextBlock.from_df('text', tok=tok, is_lm=True, backwards=backwards)),
                    get_x=ColReader('text'), 
                    splitter=RandomSplitter(valid_pct=0.1, seed=42))
dls = dblocks.dataloaders(df, bs=128)

learn = language_model_from_pretrained(dls, url=url, backwards=backwards)
learn.fit_one_cycle(1)

  return array(a, dtype, copy=False, order=order)


epoch,train_loss,valid_loss,accuracy,perplexity,time
0,6.440651,6.521347,0.169837,679.493103,00:01


In [None]:
#hide
#slow
path = learn.save_lm()
vocab = learn.dls.vocab

In [None]:
#hide
#slow
from fastai.text.all import AWD_LSTM, DataBlock, TextBlock, ColReader, RandomSplitter, CategoryBlock
import pandas as pd

backwards = False

df = pd.read_csv(Path('_test/data_class_sample.csv'))

dblocks = DataBlock(blocks=(TextBlock.from_df('text', tok=tok, vocab=vocab, backwards=backwards), CategoryBlock),
                    get_x=ColReader('text'), 
                    get_y=ColReader('label'))
dls = dblocks.dataloaders(df, bs=128)

learn = text_classifier_from_lm(dls, path=path, backwards=backwards)
learn.fit_one_cycle(1)
learn.get_preds()

  return array(a, dtype, copy=False, order=order)


epoch,train_loss,valid_loss,time
0,0.773969,0.685955,00:02


(tensor([[0.5084, 0.4916],
         [0.4719, 0.5281],
         [0.5388, 0.4612],
         [0.5013, 0.4987],
         [0.5037, 0.4963],
         [0.4966, 0.5034],
         [0.5330, 0.4670],
         [0.5015, 0.4985],
         [0.5461, 0.4539],
         [0.5114, 0.4886],
         [0.5471, 0.4529],
         [0.4729, 0.5271],
         [0.4876, 0.5124],
         [0.4873, 0.5127],
         [0.4879, 0.5121],
         [0.5057, 0.4943],
         [0.5063, 0.4937],
         [0.4945, 0.5055],
         [0.4684, 0.5316],
         [0.4992, 0.5008],
         [0.4818, 0.5182],
         [0.4998, 0.5002],
         [0.4930, 0.5070],
         [0.5001, 0.4999],
         [0.4888, 0.5112],
         [0.5351, 0.4649],
         [0.5052, 0.4948],
         [0.5320, 0.4680],
         [0.5020, 0.4980],
         [0.4859, 0.5141],
         [0.5468, 0.4532],
         [0.4762, 0.5238],
         [0.4925, 0.5075],
         [0.4900, 0.5100],
         [0.4985, 0.5015],
         [0.5184, 0.4816],
         [0.5138, 0.4862],
 