In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
#export
from nb_007a import *

# IMDB

## Fine-tuning the LM

Data has been prepared in csv files at the beginning 007a, we will use it know.

### Loading the data

In [None]:
PATH = Path('data/aclImdb/')
CLAS_PATH = PATH/'clas'
LM_PATH = PATH/'lm'
MODEL_PATH = PATH/'models'
os.makedirs(CLAS_PATH, exist_ok=True)
os.makedirs(LM_PATH, exist_ok=True)
os.makedirs(MODEL_PATH, exist_ok=True)

In [None]:
tokenizer = Tokenizer(rules=rules, special_cases=[BOS, FLD, UNK, PAD])
train_ds, valid_ds = TextDataset.from_csv(LM_PATH, tokenizer)

In [None]:
bs,bptt = 100,70
train_dl = LanguageModelLoader(np.concatenate(train_ds.ids), bs, bptt)
valid_dl = LanguageModelLoader(np.concatenate(valid_ds.ids), bs, bptt)

In [None]:
data = DataBunch(train_dl, valid_dl)

### Adapt the pre-trained weights to the new vocabulary

Download the pretrained model and the corresponding itos dictionary here and put them in the MODEL_PATH folder.

In [None]:
def replace(itos, tok1, tok2):
    itos[itos.index(tok1)] = tok2
    return itos

def apply_new_flags():
    "Temporary function to change the old special tokens by the new ones"
    itos_wt = pickle.load(open(MODEL_PATH/'itos.pkl', 'rb'))
    olds = ['_unk_', '_pad_', 'xbos', 'xfld', 'u_n', 't_up', 'tk_rep', 'tk_wrep']
    news = [UNK, PAD, BOS, FLD, UNK, TOK_UP, TK_REP, TK_WREP]
    for tok1,tok2 in zip(olds, news):
        itos_wt = replace(itos_wt, tok1, tok2)
    pickle.dump(itos_wt, open(MODEL_PATH/'itos.pkl', 'wb'))

In [None]:
#apply_new_flags()

In [None]:
itos_wt = pickle.load(open(MODEL_PATH/'itos.pkl', 'rb'))
stoi_wt = {v:k for k,v in enumerate(itos_wt)}

In [None]:
def convert_weights(wgts, stoi_wgts, itos_new):
    dec_bias, enc_wgts = wgts['1.decoder.bias'], wgts['0.encoder.weight']
    bias_m, wgts_m = dec_bias.mean(0), enc_wgts.mean(0)
    new_w = enc_wgts.new_zeros((len(itos_new),enc_wgts.size(1))).zero_()
    new_b = dec_bias.new_zeros((len(itos_new),)).zero_()
    for i,w in enumerate(itos_new):
        r = stoi_wgts[w] if w in stoi_wgts else -1
        new_w[i] = enc_wgts[r] if r>=0 else wgts_m
        new_b[i] = dec_bias[r] if r>=0 else bias_m
    wgts['0.encoder.weight'] = new_w
    wgts['0.encoder_dp.emb.weight'] = new_w.clone()
    wgts['1.decoder.weight'] = new_w.clone()
    wgts['1.decoder.bias'] = new_b
    return wgts

In [None]:
wgts = torch.load(MODEL_PATH/'lstm.pth', map_location=lambda storage, loc: storage)

In [None]:
wgts['1.decoder.bias'][:10]

In [None]:
itos_wt[:10]

In [None]:
wgts = convert_weights(wgts, stoi_wt, train_ds.vocab.itos)

In [None]:
wgts['1.decoder.bias'][:10]

In [None]:
train_ds.vocab.itos[:10]

## Define the model

In [None]:
vocab_size = len(text_data.itos)
emb_sz,nh,nl = 400,1150,3
dps = np.array([0.25, 0.1, 0.2, 0.02, 0.15])*0.7

In [None]:
model = get_language_model(vocab_size, emb_sz, nh, nl, 0, input_p=dps[0], output_p=dps[1], weight_p=dps[2], 
                           embed_p=dps[3], hidden_p=dps[4])
model.load_state_dict(wgts)

Separation in different groups for discriminitative lr and gradual unfreezing.

In [None]:
groups = [nn.Sequential(rnn, dp) for rnn, dp in zip(model[0].rnns, model[0].hidden_dps)] 
groups.append(nn.Sequential(model[0].encoder, model[0].encoder_dp, model[1]))

In [None]:
learn = Learner(data, model)
learn.layer_groups = groups
learn.callbacks.append(RNNTrainer(learn, bptt, alpha=2, beta=1))
learn.metrics = [accuracy]
learn.freeze()

In [None]:
lr_find(learn)

In [None]:
learn.recorder.plot()

In [None]:
learn.fit_one_cycle(1, 1e-2, moms=(0.8,0.7), wd=1e-7)

In [None]:
learn.unfreeze()
learn.save('fit_head')

In [None]:
learn.load('fit_head')
learn.fit_one_cycle(10, 1e-3, moms=(0.8,0.7), wd=1e-7)