# Convert `fastai 0.7` Pretrained Weights to `fastai 1.0`

In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import numpy as np
import dill as pickle
from collections import Counter
from pathlib import Path
from sklearn.model_selection import train_test_split

from fastai.text import *
from pythainlp.ulmfit import *


DATA_PATH='../lm_data/'
MODEL_PATH = f'{DATA_PATH}models/'

In [2]:
data_lm = TextLMDataBunch.load(DATA_PATH,'qrnn_db',bs=64)

In [34]:
#heuristic reference from imdb_scripts
learn = language_model_learner(data_lm, bptt = 70, emb_sz = 300, nh = 1150, nl = 3,
                                  drop_mult = 0.3, bias = False,
                                  alpha=2, beta = 1)
learn.metrics = [accuracy]
learn.opt_func = partial(optim.Adam, betas=(0.8, 0.99))
learn.wd = 1e-7

In [35]:
new_wgts = learn.model.state_dict()
wgts = torch.load(f'{MODEL_PATH}thwiki_model_v02.pth', map_location=lambda storage, loc: storage)

In [36]:
#copy weights to new format
#encode
new_wgts['0.encoder.weight'] = wgts['0.encoder.weight']
new_wgts['0.encoder_dp.emb.weight'] = wgts['0.encoder_with_dropout.embed.weight']
#rnn0
new_wgts['0.rnns.0.weight_hh_l0_raw'] = wgts['0.rnns.0.module.weight_hh_l0_raw']
new_wgts['0.rnns.0.module.weight_ih_l0'] = wgts['0.rnns.0.module.weight_ih_l0']
new_wgts['0.rnns.0.module.weight_hh_l0'] = wgts['0.rnns.0.module.weight_hh_l0_raw']
new_wgts['0.rnns.0.module.bias_ih_l0'] = wgts['0.rnns.0.module.bias_ih_l0']
new_wgts['0.rnns.0.module.bias_hh_l0'] = wgts['0.rnns.0.module.bias_hh_l0']
#rnn1
new_wgts['0.rnns.1.weight_hh_l0_raw'] = wgts['0.rnns.1.module.weight_hh_l0_raw']
new_wgts['0.rnns.1.module.weight_ih_l0'] = wgts['0.rnns.1.module.weight_ih_l0']
new_wgts['0.rnns.1.module.weight_hh_l0'] = wgts['0.rnns.1.module.weight_hh_l0_raw']
new_wgts['0.rnns.1.module.bias_ih_l0'] = wgts['0.rnns.1.module.bias_ih_l0']
new_wgts['0.rnns.1.module.bias_hh_l0'] = wgts['0.rnns.1.module.bias_hh_l0']
#rnn2
new_wgts['0.rnns.2.weight_hh_l0_raw'] = wgts['0.rnns.2.module.weight_hh_l0_raw']
new_wgts['0.rnns.2.module.weight_ih_l0'] = wgts['0.rnns.2.module.weight_ih_l0']
new_wgts['0.rnns.2.module.weight_hh_l0'] = wgts['0.rnns.2.module.weight_hh_l0_raw']
new_wgts['0.rnns.2.module.bias_ih_l0'] = wgts['0.rnns.2.module.bias_ih_l0']
new_wgts['0.rnns.2.module.bias_hh_l0'] = wgts['0.rnns.2.module.bias_hh_l0']
#decode
new_wgts['1.decoder.weight'] = wgts['1.decoder.weight']

In [37]:
torch.save(new_wgts,f'{MODEL_PATH}thwiki_model_lstm.pth')

In [38]:
learn.load('thwiki_model_lstm')

LanguageLearner(data=TextLMDataBunch;
Train: LabelList
y: LMLabel (2049894 items)
[0 0 0 0 ... 0 0 0 0]
Path: .
x: TextList (2049894 items)
[list([5, 2, 3, 2, 479, 289, 4]) list([5, 2, 3, 2, 4]) list([5, 2, 3, 2, 4])
 list([5, 2, 3, 2, 183, 516, 596, 2214, 834, 21, 1877, 183, 516, 4524, 2, 742, 64, 9257, 7830, 2, 9038, 103, 132, 2, 235, 4])
 ...
 list([5, 2, 3, 2, 6, 26, 51, 2, 1302, 2, 13, 2750, 27, 499, 53, 646, 2, 37994, 2, 62, 905, 2, 28, 13, 2620, 1381, 27, 26, 51, 2, 2151, 2, 7, 13, 2750, 2, 26710, 2, 62, 4])
 list([5, 2, 3, 2, 4]) list([5, 2, 3, 2, 4]) list([5, 2, 3, 2, 20, 2, 12, 2, 22, 35, 4])]
Path: ../lm_data;
Valid: LabelList
y: LMLabel (20706 items)
[0 0 0 0 ... 0 0 0 0]
Path: .
x: TextList (20706 items)
[list([5, 2, 3, 2, 134, 71, 6, 26, 51, 2, 22917, 2, 33, 874, 7, 13, 1005, 55, 258, 91, 77, 252, 2221, 2, 12380, 14, 1012, 250, 166, 135, 280, 2, 28, 38, 55, 7745, 2, 6, 26, 2, 51, 2, 10210, 2, 208, 7594, 144, 1555, 9171, 441, 2, 954, 324, 6870, 2, 6127, 6460, 38, 404, 31, 

In [None]:
#new weight merging function
def merge_wgts(em_sz, wgts, itos_pre, itos_cls):
    vocab_size = len(itos_cls)
    enc_wgts = to_np(wgts['0.encoder.weight'])
    #average weight of encoding
    row_m = enc_wgts.mean(0)
    stoi_pre = collections.defaultdict(lambda:-1, {v:k for k,v in enumerate(itos_pre)})
    #new embedding based on classification dataset
    new_w = np.zeros((vocab_size, em_sz), dtype=np.float32)
    for i,w in enumerate(itos_cls):
        r = stoi_pre[w]
        #use pretrianed embedding if present; else use the average
        new_w[i] = enc_wgts[r] if r>=0 else row_m
    wgts['0.encoder.weight'] = T(new_w)
    wgts['0.encoder_dp.embed.weight'] = T(np.copy(new_w))
    wgts['1.decoder.weight'] = T(np.copy(new_w))
    return(wgts)

In [6]:
#replace tokens with new version
itos_lstm = pickle.load(open(f'{MODEL_PATH}itos_lstm.pkl','rb'))
itos_lstm[0] = 'xxunk'
itos_lstm[1] = 'xxpad'
itos_lstm[6] = 'xxfld'
itos_lstm[3661] = 'xxrep'
pickle.dump(itos_lstm,open(f'{MODEL_PATH}itos_lstm.pkl','wb'))

In [7]:
itos_lstm

['xxunk',
 'xxpad',
 ' ',
 'ใน',
 'ที่',
 'และ',
 'xxfld',
 'เป็น',
 'ของ',
 'มี',
 'ได้',
 '"',
 'การ',
 '(',
 ')',
 'โดย',
 'กับ',
 'ส',
 'จะ',
 'ปี',
 'ว่า',
 'จาก',
 ',',
 'ให้',
 'ซึ่ง',
 'พ.ศ.',
 'ไป',
 '.',
 'เมื่อ',
 '-',
 'มา',
 'อ',
 'น',
 'พระ',
 'ก็',
 'หรือ',
 'อยู่',
 '์',
 'นี้',
 'คือ',
 'ร์',
 'แต่',
 'ใช้',
 'ค.ศ.',
 '/',
 'ยัง',
 'ด้วย',
 'เขา',
 'คน',
 'วันที่',
 'ถูก',
 'ไม่',
 'แห่ง',
 'เพื่อ',
 'ๆ',
 'จึง',
 'ทาง',
 'ขึ้น',
 '2',
 ':',
 'ได้รับ',
 'นั้น',
 '1',
 'สามารถ',
 'ผู้',
 'แล้ว',
 'เมือง',
 'ล',
 'ทำให้',
 '3',
 'กัน',
 'ประเทศ',
 'ถึง',
 'ทรง',
 'เธอ',
 'ชื่อ',
 'หนึ่ง',
 'เรื่อง',
 'เพลง',
 'ริ',
 'สร้าง',
 'ออก',
 'ตาม',
 'ต่อ',
 'เกิด',
 'ท',
 'ต้อง',
 'แบบ',
 'ไทย',
 'อีก',
 'ทั้ง',
 'ระหว่าง',
 'ต่อมา',
 'จังหวัด',
 'นา',
 'พระองค์',
 'ส่วน',
 'มิ',
 'ความ',
 'ทำ',
 'ปัจจุบัน',
 '4',
 'แรก',
 'เช่น',
 'ลง',
 'มาก',
 'สำหรับ',
 'กลุ่ม',
 'ด้าน',
 'ทีม',
 'เรียก',
 'พบ',
 'รา',
 'เอ',
 'ตัว',
 'อย่าง',
 'ใหม่',
 '%',
 'ช',
 'วัน',
 '5',
 'ณ',
 'เริ่ม