# Steps
Generator:
    * Randomly cut the sentence
    * Generate text for the rest of the sentence
    
Discriminator:
    * Predict where the sentence was cut. Regression?

* Start with a pre trained network?
* Embedding weight tying
* Encoder weight tying?

Use MLM model to generate, predict like electra

In [1]:
from fastai2.text.all import *

In [2]:
%load_ext line_profiler

In [3]:
def decode(dls, idxs, decoder=decode_spec_tokens):
    vocab = dls.train_ds.numericalize.vocab
    tokens = [[vocab[i] for i in row if vocab[i] not in [BOS, PAD]] for row in idxs]
    sep = dls.train_ds.tokenizer[-1].sep
    return [sep.join(decoder(o)) for o in tokens]

In [4]:
#export
# TODO: Currently the hidden state is considering the generated words. How to instead predict on generated but still keep hidden of only original?
# it's important to keep the idxs that where cut
@patch
def generate(self:LMLearner, idxs, n_words=1, no_unk=False, temperature=1., min_p=None, no_bar=True, only_last_word=True, reset=False):
    "Return `idxs` and the `n_words` that come after"
    idxs_all = idxs
    model = learn.model.eval()
    if no_unk: unk_idx = self.dls.vocab.index(UNK)
    if reset: self.model.reset()
    for _ in (range(n_words) if no_bar else progress_bar(range(n_words), leave=False)):
        logits = model(idxs)[0][:, -1, :]
        probs = F.softmax(logits, dim=-1)
        if no_unk: probs[:, unk_idx] = 0.
        if min_p is not None:
            if not all((res >= min_p).float().sum(-1)):
                warn(f"There is no item with probability >= {min_p}, try a lower value.")
            else: probs[res < min_p] = 0.
        if temperature != 1.: probs.pow_(1 / temperature)
        samples = torch.multinomial(probs, 1).to(idxs.device)
        idxs = idxs_all = retain_type(torch.cat([idxs_all, samples], dim=-1), idxs)
        if only_last_word: idxs = idxs[:, -1:]
    return idxs_all

In [5]:
class GeneratorCB(Callback):
    def __init__(self, gen): self.gen = gen
        
    def begin_batch(self):
        inp,cut,cut_pct = self._random_cut(self.xb[0])
        with torch.no_grad(): pred = self.gen.generate(inp, n_words=inp.shape[-1]-cut)
        self.learn.xb = (pred,)
        self.learn.yb = (tensor(cut_pct).expand(inp.shape[0]).to(self.dls.device),)
        
    def _random_cut(self, idxs):
        seq_len = idxs.shape[1]
        cut = random.randint(int(seq_len*.3), int(seq_len*.9))
        return idxs.new(idxs[:, :cut]), cut, cut/seq_len

In [6]:
source = untar_data(URLs.IMDB)
fns = get_text_files(source)

In [7]:
tfms = [Tokenizer.from_folder(source), Numericalize()]
splits = RandomSplitter(.1)(fns)
# TODO: Don't need to be LMDataLoader
dset_lm = Datasets(fns, tfms=[tfms], splits=splits, dl_type=LMDataLoader)

In [8]:
dls_lm = dset_lm.dataloaders()

In [9]:
gen = language_model_learner(dls_lm, AWD_LSTM)

In [11]:
learn = text_classifier_learner(dls_lm, AWD_LSTM, n_out=1, y_range=(0,1), metrics=L1LossFlat(), loss_func=MSELossFlat(), cbs=GeneratorCB(gen))

In [None]:
moms=(0.8,0.7,0.8)

In [14]:
learn.fit_one_cycle(1, 2e-2, moms=moms)

epoch,train_loss,valid_loss,None,time
0,0.130774,0.005478,0.057658,13:53


In [None]:
learn.save('pre_freeze')

In [None]:
learn.unfreeze()
learn.fit_one_cycle(10, 2e-3, moms=(0.8,0.7,0.8))

In [None]:
learn.save('pre_unfreeze')

In [16]:
learn.save_encoder('enc')

## Classifier

In [23]:
dblock_clas = DataBlock(blocks=(TextBlock.from_folder(source, vocab=dset_lm.vocab),CategoryBlock),
                      get_x=read_tokenized_file,
                      get_y = parent_label,
                      get_items=partial(get_text_files, folders=['train', 'test']),
                      splitter=GrandparentSplitter(valid_name='test'))
dls_clas = dblock_clas.dataloaders(source, bs=128, seq_len=80)

In [None]:
learn = text_classifier_learner(dbunch_clas, AWD_LSTM, drop_mult=0.5, metrics=accuracy).to_fp16()
learn.load_encoder('fine_tuned_enc')

In [None]:
learn.lr_find()

In [None]:
def get_lrs(lr): return slice(lr/(2.6**4),lr)

In [None]:
learn.fit_one_cycle(1, 2e-2, moms=moms)

In [None]:
learn.freeze_to(-2)
learn.fit_one_cycle(1, get_lrs(1e-2), moms=moms)

In [None]:
learn.freeze_to(-3)
learn.fit_one_cycle(1, get_lrs(5e-3), moms=moms)

In [None]:
learn.unfreeze()
learn.fit_one_cycle(2, get_lrs(1e-3), moms=moms)