In [None]:
# default_exp pretrained

# Pretrained

> fast.ai ULMFiT helpers to easily use pretrained models

In [None]:
#hide
from nbdev.showdoc import *

In [None]:
#export
import json
from fastai.text.all import SentencePieceTokenizer, language_model_learner, \
                            text_classifier_learner, untar_data, Path, patch, \
                            LMLearner, os, pickle, shutil, AWD_LSTM, accuracy, \
                            Perplexity, delegates

In [None]:
#export
def _get_config(path):
    with open(path/'model.json', 'r') as f:
        config = json.load(f)
    return config

In [None]:
#export
def _get_pretrained_model(url):
    fname = f"{url.split('/')[-1]}.tgz"
    path = untar_data(url, fname=fname, c_key='model')
    return path

In [None]:
#export
def _get_direction(backwards):
    return 'bwd' if backwards else 'fwd'

In [None]:
#hide
assert(_get_direction(backwards=False) == 'fwd')
assert(_get_direction(backwards=True) == 'bwd')

Get `model` and `vocab` files from path. 

In [None]:
#export
def _get_model_files(path, backwards=False):
    direction = _get_direction(backwards)
    config = _get_config(path/direction)
    try: 
        model_path = path/direction
        model_file = list(model_path.glob(f'*model.pth'))[0]
        vocab_file = list(model_path.glob(f'*vocab.pkl'))[0]
        fnames = [model_file.absolute(),vocab_file.absolute()]
    except IndexError: print(f'The model in {model_path} is incomplete, download again'); raise
    fnames = [str(f.parent/f.stem) for f in fnames]
    return fnames

## Tokenizer

Get `tokenizer` from model-config. Tokenizer parameters in `model.json` will be passed to the Tokenizer. As of now SentencePieceTokenizer is hard-coded.

In [None]:
#export
def tokenizer_from_pretrained(url, backwards=False, **kwargs):
    path = _get_pretrained_model(url)
    direction = _get_direction(backwards)
    config = _get_config(path/direction)
    tok = None
    if config['tokenizer']['class'] == 'SentencePieceTokenizer':
        tok = SentencePieceTokenizer(**config['tokenizer']['params'], **kwargs)
    return tok

## Language Model Learner

Create `langauge_model_learner` from pretrained model-URL. 

In [None]:
#export
@delegates(language_model_learner)
def language_model_from_pretrained(dls, url=None, backwards=False, metrics=None, **kwargs):
    arch = AWD_LSTM # TODO: Read from config
    path = _get_pretrained_model(url)
    fnames = _get_model_files(path)
    metrics = [accuracy, Perplexity()] if metrics == None else metrics
    return language_model_learner(dls, 
                                  arch, 
                                  pretrained=True, 
                                  pretrained_fnames=fnames, 
                                  metrics=metrics,
                                  **kwargs)

Saves a trained or fine-tuned language model. 

In [None]:
#export
def _get_model_path(learn=None, path=None):
    path = (learn.path/learn.model_dir) if not path else Path(path)
    if not path.exists(): os.makedirs(path, exist_ok=True)
    return path

Saves the following model files to `path`:
- Model (`lm_model.pth`)
- Encoder (`lm_encoder.pth`)
- Vocab from dataloaders (`lm_vocab.pkl`)
- SentencePieceModel (`spm/`)

In [None]:
#export
@patch
def save_lm(x:LMLearner, path=None, with_encoder=True):
    path = _get_model_path(x, path)
    x.to_fp32()
    # save model
    x.save((path/'lm_model').absolute(), with_opt=False)
    
    # save encoder
    if with_encoder:
        x.save_encoder((path/'lm_encoder').absolute())

    # save vocab
    with open((path/'lm_vocab.pkl').absolute(), 'wb') as f:
        pickle.dump(x.dls.vocab, f)
        
    # copy SPM if path not spm path
    spm_path = Path(x.dls.tok.cache_dir)
    if path.absolute() != spm_path.absolute():
        target_path = path/'spm'
        if not target_path.exists(): os.makedirs(target_path, exist_ok=True)
        shutil.copyfile(spm_path/'spm.model', target_path/'spm.model')
        shutil.copyfile(spm_path/'spm.vocab', target_path/'spm.vocab')
    
    return path

## Text Classifier

In [None]:
#def vocab_from_lm(learn=None, path=None):
#    path = _get_model_path(learn, path)
#    with open((path/'lm_vocab.pkl').absolute(), 'rb') as f:
#        return pickle.load(f)

In [None]:
#def spm_from_lm(learn=None, path=None):
#    path = _get_model_path(learn, path)

Create `text_classifier_learner` from fine-tuned model path (saved with `learn.save_lm()`).

In [None]:
#export
@delegates(text_classifier_learner)
def text_classifier_from_lm(dls, path=None, backwards=False, **kwargs):
    arch = AWD_LSTM # TODO: Read from config
    path = _get_model_path(path=path)
    learn = text_classifier_learner(dls, arch, pretrained=False, **kwargs)
    learn.load_encoder((path/'lm_encoder').absolute())
    return learn

# Tests - Tokenizer, LM and Classifier

In [None]:
#hide
#slow
url = 'http://localhost:8080/ulmfit-dewiki'
tok = tokenizer_from_pretrained(url)
assert(tok.vocab_sz == 15000)

In [None]:
#hide
#slow
from fastai.text.all import AWD_LSTM, DataBlock, TextBlock, ColReader, RandomSplitter
import pandas as pd

backwards = False

df = pd.read_csv(Path('_test/data_lm_sample.csv'))

dblocks = DataBlock(blocks=(TextBlock.from_df('text', tok=tok, is_lm=True, backwards=backwards)),
                    get_x=ColReader('text'), 
                    splitter=RandomSplitter(valid_pct=0.1, seed=42))
dls = dblocks.dataloaders(df, bs=128)

learn = language_model_from_pretrained(dls, url=url, backwards=backwards)
learn.fit_one_cycle(1)

  return array(a, dtype, copy=False, order=order)


epoch,train_loss,valid_loss,accuracy,perplexity,time
0,6.504971,6.521173,0.169837,679.374512,00:01


In [None]:
#hide
#slow
path = learn.save_lm()
vocab = learn.dls.vocab

In [None]:
#hide
#slow
from fastai.text.all import AWD_LSTM, DataBlock, TextBlock, ColReader, RandomSplitter, CategoryBlock
import pandas as pd

backwards = False

df = pd.read_csv(Path('_test/data_class_sample.csv'))

dblocks = DataBlock(blocks=(TextBlock.from_df('text', tok=tok, vocab=vocab, backwards=backwards), CategoryBlock),
                    get_x=ColReader('text'), 
                    get_y=ColReader('label'))
dls = dblocks.dataloaders(df, bs=128)

learn = text_classifier_from_lm(dls, path=path, backwards=backwards)
learn.fit_one_cycle(1)
learn.get_preds()

  return array(a, dtype, copy=False, order=order)


epoch,train_loss,valid_loss,time
0,0.737829,0.69308,00:02


(tensor([[0.5207, 0.4793],
         [0.4851, 0.5149],
         [0.5548, 0.4452],
         [0.5162, 0.4838],
         [0.5283, 0.4717],
         [0.5159, 0.4841],
         [0.5616, 0.4384],
         [0.5148, 0.4852],
         [0.5429, 0.4571],
         [0.5390, 0.4610],
         [0.5448, 0.4552],
         [0.4894, 0.5106],
         [0.4974, 0.5026],
         [0.4913, 0.5087],
         [0.4934, 0.5066],
         [0.5321, 0.4679],
         [0.5187, 0.4813],
         [0.5144, 0.4856],
         [0.4850, 0.5150],
         [0.5156, 0.4844],
         [0.4975, 0.5025],
         [0.5126, 0.4874],
         [0.4875, 0.5125],
         [0.4996, 0.5004],
         [0.4993, 0.5007],
         [0.5608, 0.4392],
         [0.5059, 0.4941],
         [0.5372, 0.4628],
         [0.5267, 0.4733],
         [0.5041, 0.4959],
         [0.5775, 0.4225],
         [0.4946, 0.5054],
         [0.4977, 0.5023],
         [0.5091, 0.4909],
         [0.5064, 0.4936],
         [0.5321, 0.4679],
         [0.5233, 0.4767],
 