# Thai2fit Language Model Pre-training

The goal of this notebook is to train a language model using the [fast.ai](http://www.fast.ai/) version of [AWD LSTM Language Model](https://arxiv.org/abs/1708.02182), with data from [Thai Wikipedia Dump](https://dumps.wikimedia.org/thwiki/latest/thwiki-latest-pages-articles.xml.bz2) last updated February 17, 2019. Using 40M/200k/200k tokens of train-validation-test split, we achieved validation perplexity of **46.80959 with 60,002 embeddings at 400 dimensions**, compared to state-of-the-art as of October 27, 2018 at **42.41 for English WikiText-2 by [Yang et al (2018)](https://arxiv.org/abs/1711.03953)**. To the best of our knowledge, there is no comparable research in Thai language at the point of writing (February 17, 2019).

Our workflow is as follows:

* Retrieve and process [Thai Wikipedia Dump](https://dumps.wikimedia.org/thwiki/latest/thwiki-latest-pages-articles.xml.bz2) according to [n-waves/ulmfit-multilingual](https://github.com/n-waves/ulmfit-multilingual)
* Perform 40M/200k/200k tokens of train-validation-test split split
* Minimal text cleaning and tokenization using `newmm` with frozen dictionary (`engine='ulmfit'`) of [pyThaiNLP](https://github.com/pyThaiNLP/pythainlp/)
* Train language model
* Evaluate model based on perplexity and eyeballing
* Extract embeddings to use as "word2vec"

## Imports

In [21]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

from fastai import *    
from fastai.text import * 
from utils import *

import dill as pickle

data_path = ''

## Text Cleaning

We follow the dataset creation, pre- and post-processing of [n-waves/ulmfit-multilingual](https://github.com/n-waves/ulmfit-multilingual):

* `ulmfit/create_wikitext.py` - Download thwiki in json format and separate them into 40M/200k/200k tokens of train-validation-test split. Also perform tokenization with whitespaces as separators.
* `ulmfit/postprocess_wikitext.py` - Replace numbers and replace out-of-vocabulary tokens with `xxunk` (frequency of less than 3).

We replaced the Moses Tokenizer with the following code to use [pyThaiNLP](https://github.com/pyThaiNLP/pythainlp/)'s `newmm` dictionary-based tokenizer with a frozen dictionary instead.

In [8]:
from pythainlp.tokenize import word_tokenize
class ThaiNLPTokenizer:
    def __init__(self,engine='ulmfit'):
        self.engine='ulmfit'
    def tokenize(self, t, return_str=True):
        res = word_tokenize(t,self.engine)
        return ' '.join(res) if return_str else res

Here is the statistics of the dataset:

```
before postprocessing
# documents: 121,027. # tokens: 39,378,410

after postprocessing
OOV ratio: 0.0042
data/wiki/th-all vocab size: 111,224
th.wiki.train.tokens. # of tokens: 41,482,435
th.wiki.valid.tokens. # of tokens: 200,563
th.wiki.test.tokens. # of tokens: 200,827
```

## Data Preparation

We used the `newmm` engine of `pyThaiNLP` to perform tokenization. Out of randomnum tokens from all of training set, we chose 60,000 embeddings (plus two for unknown and padding) of tokens which appeared more than twice (not typos) in the training set.


We perform the following text processing:

* Fix html tags to plain texts
* Lowercase all English words and if a word is written in all caps, we put it in a lower case and add `xxup` before
* Repetitive characters: Thai usually emphasizes adjectives by repeating the last character such as `อร่อยมากกกกกกก` to `อร่อยมาก xxrep 7 ` so that the word still retains its original form. 
* Normalize character order: for instance `นำ้` to `น้ำ`
* Add spaces around / and #
* Remove multiple spaces and newlines
* Remove empty brackets of all types (`([{`) which might result from cleaning up
* `pyThaiNLP`'s `newmm` word tokenizer with frozen dictionary (`engine ='ulmfit'`)  is used to tokenize the texts.

### Thai Tokenizer

We use the `newmm` tokenizer with a dictionary frozen as of 2018-10-23.

In [25]:
text='วิทยาศาสตร์ดาวเคราะห์เป็นสาขาวิชาที่ศึกษาเกี่ยวกับองค์ประกอบของดาวเคราะห์'
a = word_tokenize(text,engine='ulmfit')
a

['วิทยาศาสตร์',
 'ดาวเคราะห์',
 'เป็น',
 'สาขาวิชา',
 'ที่',
 'ศึกษา',
 'เกี่ยวกับ',
 'องค์ประกอบ',
 'ของ',
 'ดาวเคราะห์']

In [26]:
#integrated into pythainlp.ulmfit.utils
from fastai.text.transform import *
from pythainlp.tokenize import word_tokenize
from pythainlp.util import normalize as normalize_char_order

class ThaiTokenizer(BaseTokenizer):
    "Wrapper around a newmm tokenizer to make it a `BaseTokenizer`."
    def __init__(self, lang:str = 'th'):
        self.lang = lang
    def tokenizer(self, t:str) -> List[str]:
        return(word_tokenize(t,engine='ulmfit'))
    def add_special_cases(self, toks:Collection[str]):
        pass
    
def replace_rep_after(t:str) -> str:
    "Replace repetitions at the character level in `t` after the repetition"
    def _replace_rep(m:Collection[str]) -> str:
        c,cc = m.groups()
        return f' {c} {TK_REP} {len(cc)+1} '
    re_rep = re.compile(r'(\S)(\1{3,})')
    return re_rep.sub(_replace_rep, t)

def rm_useless_newlines(t:str) -> str:
    "Remove multiple newlines in `t`."
    return re.sub('[\n]{2,}', ' ', t)

def rm_brackets(t:str) -> str:
    "Remove all empty brackets from `t`."
    new_line = re.sub('\(\)','',t)
    new_line = re.sub('\{\}','',new_line)
    new_line = re.sub('\[\]','',new_line)
    return(new_line)

#in case we want to add more specific rules for thai
thai_rules = [fix_html, deal_caps, replace_rep_after, normalize_char_order, 
              spec_add_spaces, rm_useless_spaces, rm_useless_newlines, rm_brackets]

### Data Bunch

We trained the language model based on 80/20 train-validation split from Thai Wikipedia. Tokens are generated and numericalized filtering all words with frequency more than 2 and at maximum vocab size of 60,000 (plus unknown and padding tokens).

In [5]:
# tt = Tokenizer(tok_func = ThaiTokenizer, lang = 'th', rules = thai_rules)
# data_lm = TextLMDataBunch.from_csv(path = Path(DATA_PATH),csv_name='train.csv',valid_pct=0.01,
#                                   tokenizer=tt, min_freq=2, max_vocab=60000, bs=32)
# data.save('databunch')

In [17]:
data = TextLMDataBunch.load(DATA_PATH,'qrnn_db',bs=32)

In [18]:
len(data.valid_ds), len(data.train_ds)

(20706, 2049894)

In [19]:
next(iter(data.valid_dl))[0]

tensor([[    5,  5941,  7611,  ...,     2,   732,    11],
        [    2,    42,  4357,  ...,    41,     8, 38718],
        [    3, 11114,  3112,  ...,    19,  1125,    11],
        ...,
        [ 1725,  6629, 48984,  ...,  8315,     2, 10855],
        [   38, 21976,    52,  ...,   647,    60,  2388],
        [  348,     2,     2,  ...,   143,    72,     0]])

### Vocab

In [20]:
vocab_lm = data.vocab
vocab_lm.numericalize(word_tokenize('สวัสดีครับพี่น้อง', engine='ulmfit'))

[10986, 7391, 1405]

In [21]:
vocab_lm.textify([10986, 7391, 1405])

'สวัสดี ครับ พี่น้อง'

## Language Modeling

We train the language model according to the [ULMFit paper](https://arxiv.org/abs/1801.06146) but replacing LSTM modules with [QRNN modules](https://arxiv.org/abs/1611.01576).

In [18]:
# # select weight decays and lrs
# wds = [1e-8,1e-7,1e-6]
# plts = []
# for wd in wds:
#     learn = language_model_learner(data, bptt = 70, emb_sz = 300, nh = 1550, nl = 3,
#                                   drop_mult = 0.1, bias = True, qrnn = True, tie_weights=True,
#                                   pretrained_fnames = None)
#     learn.wd = wd
#     learn.opt_func = partial(optim.Adam, betas=(0.8, 0.99)) #heuristic reference from imdb_scripts
#     learn.lr_find(start_lr = 1e-3, end_lr = 1e1)
#     plts.append(learn.recorder.plot())

### Training with 20% Validation

In [19]:
#heuristic reference from imdb_scripts
learn = language_model_learner(data, bptt = 70, emb_sz = 300, nh = 1550, nl = 3,
                                  drop_mult = 0.1, bias = True, qrnn = True, 
                                  alpha=2, beta = 1, clip = 0.12, tie_weights=True,
                                  pretrained_fnames = None)
learn.metrics = [accuracy]
learn.opt_func = partial(optim.Adam, betas=(0.8, 0.99))

In [20]:
lr = 0.005
wd = 1e-7
learn.wd=wd

In [21]:
# learn.fit_one_cycle(cyc_len = 20, 
#                     max_lr= lr, #learning rate
#                     div_factor=20, #factor to discount from max
#                     moms = (0.8, 0.7), #momentums
#                     pct_start = 0.1, #where the peak is at 
#                     wd = wd #weight decay
#                    ) 

###  Training with 20% Validation and Less Regularization

In [22]:
#we tried to train it a little more with lower learnng rates, dropouts and weight decays but not really helpful
learn = language_model_learner(data, bptt = 70, emb_sz = 400, nh = 1550, nl = 3,
                                  drop_mult = 0., bias = True, qrnn = True, 
                                  alpha=2, beta = 1, clip = 0.12, tie_weights=True,
                                  pretrained_fnames = None)
learn.metrics = [accuracy]
learn.opt_func = partial(optim.Adam, betas=(0.8, 0.99))

In [23]:
lr = 0.001
wd = 1e-8
learn.wd=wd

In [24]:
# learn.fit_one_cycle(cyc_len = 10, 
#                     max_lr= lr, #learning rate
#                     div_factor=20, #factor to discount from max
#                     moms = (0.8, 0.7), #momentums
#                     pct_start = 0.1, #where the peak is at 
#                     wd = wd #weight decay
#                    ) 

### Training with 1% Validation

### Eyeballing Test
We perform eyeballing test by having the model "fill in the blanks".

In [22]:
learn = language_model_learner(data, bptt = 70, emb_sz = 400, nh = 1550, nl = 3,
                                  drop_mult = 0., bias = True, qrnn = True, 
                                  alpha=2, beta = 1, clip = 0.12,
                                  pretrained_fnames = None)
learn.load('thwiki_model_qrnn')
m = learn.model

In [23]:
import torch
from torch.autograd import Variable
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def gen_next(ss,topk):
    s = word_tokenize(ss,engine='ulmfit')
    t = torch.LongTensor(data.train_ds.vocab.numericalize(s)).view(-1,1).to(device)
    t.requires_grad = False
    m.reset()
    pred,*_ = m(t)
    pred_i = pred[-1].topk(topk)[1]
    return(data.train_ds.vocab.textify(pred_i))

def gen_sentences(ss,nb_words):
    result = []
    s = word_tokenize(ss,engine='ulmfit')
    t = torch.LongTensor(data.train_ds.vocab.numericalize(s)).view(-1,1).to(device)
    t.requires_grad = False
    m.reset()
    pred,*_ = m(t)
    for i in range(nb_words):
        pred_i = pred[-1].topk(2)[1]
        #get first one if not unknowns, pads, or spaces
        pred_i = pred_i[1] if pred_i.data[0] == 0 else pred_i[0]
        pred_i = pred_i.view(-1,1)
        result.append(data.train_ds.vocab.textify(pred_i))
        t = torch.cat((t,pred_i))
        pred,*_ = m(t)
    return(result)

In [24]:
ss = 'ฉันเดินไป'
gen_next(ss,10)

'ข้างนอก ข้างหลัง ทุกที่ เบื้องล่าง เบื้องหน้า ถึงกัน สะดุด ตามรอย ที่ใด ข้างบน'

In [25]:
ss = 'ฉันเดินไป'
''.join(gen_sentences(ss,100))

'ข้างนอกข้างนอก ฉันจะเดินต่อไป และเดินต่อไป เราจะนั่งพักผ่อน เราจะเดินถอยหลัง เราจะเราเราจะนั่งพักผ่อนฉัน เราจะนั่งบนเก้าอี้ เราเราจะนั่งเดินบนเก้าอี้ เราจะนั่งตรงหน้าฉันนั่งบนเก้าอี้ เราจะนั่งเดินบนรันเวย์ เราจะนั่งเดินบนเก้าอี้ เราจะนั่งอาสน์ เราจะนั่งเดินบนเก้าอี้ เราจะเดินต่อไป เราจะนั่งบนเก้าอี้ เราจะนั่งเดินบนรันเวย์ เราจะ'

## Embeddings

We extract the embedding layer of the encoder to be used in the same manner as `word2vec`. We can also create sentence vector by summing or averaging the vectors. For more details about `word2vec` use cases, see`word2vec_examples.ipynb`.

### Extract

In [97]:
emb_weights = list(learn.model.named_parameters())[0][1]
emb_np = to_np(emb_weights.data)

In [99]:
thai2vec = pd.DataFrame(emb_np)
new_itos = data.vocab.itos
#replace space with xxspace
new_itos[2] = 'xxspace'
#replace space for named entities with _
new_itos = [re.sub(' ','_',i) for i in new_itos]
#replace \n with xxeol
new_itos[4] = 'xxeol'
thai2vec.index = new_itos
thai2vec.head(10)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,390,391,392,393,394,395,396,397,398,399
xxunk,9.412697,-0.488594,-3.401881,-42.730244,35.064205,-26.100941,-3.747117,15.602914,37.966507,1.014693,...,-6.495045,-25.819569,38.371155,6.613634,1.742496,-9.713803,-12.573798,1.576579,-20.847286,-1.370693
xxpad,3.876736,-0.049269,-0.021756,-33.481445,31.61129,-20.91984,-0.508237,14.074152,32.848042,-1.403501,...,1.038427,-19.635546,29.998386,7.486456,0.269155,-11.942196,-4.076655,-0.319159,-13.353608,-0.664269
xxspace,10.631902,-3.973876,-1.743042,-42.599854,34.616451,-26.990564,10.954692,14.398766,37.990967,-3.763945,...,-1.278642,-25.836325,40.264915,7.057155,1.878293,-9.766123,-11.934086,0.281595,-21.392384,-2.502031
1,9.687781,-3.374289,-5.566193,-41.561626,36.881172,-28.293665,-1.500996,15.841971,39.105003,-0.694619,...,0.800407,-26.646166,42.336922,7.244391,1.356798,-10.226693,-12.59201,-1.624527,-21.064072,-9.900548
xxeol,10.076826,-3.298112,-3.062871,-42.180008,34.165661,-27.856457,6.696773,15.293684,38.029015,-4.269083,...,-0.678904,-25.555805,40.705482,7.259528,1.35303,-14.982669,-13.479355,-0.881967,-20.467594,-5.054183
xxfld,14.89963,-1.930843,2.001746,-42.514164,38.589172,-29.482353,-3.962751,3.097306,39.768669,-0.693229,...,3.201032,-22.658134,44.450848,6.989926,-1.499702,-7.79194,-4.544464,0.243512,-23.96633,-6.069082
ใน,9.683018,-1.806045,-3.538491,-42.151825,34.464622,-28.008196,-2.961989,14.223104,37.152637,1.24337,...,0.13816,-26.05076,38.420036,7.372793,2.484624,-11.34588,-17.19705,-0.72984,-20.466942,-0.287305
ที่,9.770124,-4.167158,-0.65167,-42.355728,36.28714,-26.348225,-2.798183,15.309365,37.787487,-1.505493,...,-0.002396,-26.294212,39.143227,7.409772,1.725253,-10.572651,-18.086336,-0.444782,-21.733566,-4.995567
และ,9.779366,-4.125625,-5.410614,-42.443588,35.46442,-26.667896,2.241813,13.938269,37.312401,-4.813928,...,-0.503761,-25.952402,38.727116,7.179166,2.370239,-8.047087,-15.263973,2.641281,-21.697216,-2.367746
เป็น,8.904553,-3.972028,0.898944,-41.665886,35.541092,-28.023504,-3.800491,16.07267,38.58075,-2.286343,...,2.570196,-25.606272,39.234039,7.785551,1.330086,-12.832536,-14.681183,-1.536232,-21.056726,-2.545995


In [115]:
thai2vec.tail(10)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,390,391,392,393,394,395,396,397,398,399
ː],4.526951,-0.684044,0.658317,-34.973621,32.692394,-21.356636,-0.297866,13.243143,33.953781,-2.175802,...,0.913986,-20.240484,31.480227,6.308282,0.222549,-12.421679,-5.094245,-0.876305,-16.382822,-1.94636
femina,5.643933,-0.454897,-0.781048,-34.636936,32.502296,-23.502565,0.150455,13.833745,35.022408,-0.543715,...,1.087782,-22.103447,33.913628,7.328459,0.590396,-11.619066,-6.42272,-0.582779,-15.718392,-1.16842
แมชอัป,4.505654,0.813675,0.5412,-36.005527,33.578327,-22.218058,-0.812224,14.030306,33.559189,-1.831509,...,2.220871,-20.00276,31.140167,7.040421,0.533256,-12.791709,-2.348727,-1.542387,-12.926419,-0.873332
ko-d,4.165557,0.011141,-0.987677,-35.324249,32.02523,-20.510134,-0.564647,14.144468,33.327763,-0.892376,...,1.502303,-20.634552,31.940294,7.187971,-0.247172,-11.427867,-5.017011,-0.749183,-15.432819,-0.678519
­่,5.919821,-0.623172,-1.006667,-37.552708,34.542671,-22.609804,-0.612308,13.414227,36.015945,-0.484734,...,0.972129,-21.235813,32.55191,7.112097,1.700822,-14.420567,-7.045519,-0.921345,-16.223467,-0.626158
มูโญซ,3.681906,-0.218666,0.948459,-33.641006,31.695154,-22.561327,0.889483,13.91015,34.946846,-1.855071,...,1.924642,-21.629454,29.379236,7.132369,0.257729,-12.290989,-5.269632,-1.54043,-15.146802,-0.754419
สัมภว,5.752477,0.315618,0.293014,-36.698799,34.379864,-23.329159,-0.070712,13.357928,35.762405,-1.04763,...,0.906964,-21.077778,33.069584,7.098242,1.803606,-11.642494,-5.033253,0.524389,-16.31374,-1.029239
เซ็ตเบ,4.389275,0.097917,-1.411201,-36.771015,34.442936,-23.039188,1.612276,14.783897,35.015724,-1.216834,...,0.590751,-19.663397,30.306244,6.418957,-0.314835,-10.774155,-4.693771,-0.30036,-12.498419,-1.852149
dacc,4.742739,-0.908443,0.547335,-35.256859,31.603384,-24.493853,-0.259395,13.682415,35.604534,-1.805019,...,1.418943,-22.847486,32.852303,7.006009,-0.03993,-10.81945,-6.981693,-1.151021,-14.195145,-1.810433
monsta,5.33329,0.036935,0.209927,-38.188976,33.112858,-21.396271,-0.456774,13.161014,34.834602,-1.270794,...,0.637157,-20.428938,33.175938,7.09113,1.161157,-11.597325,-5.889217,-0.936701,-16.72142,-1.459496


In [117]:
thai2save = thai2vec 
thai2save.to_csv(f'{MODEL_PATH}thai2vec.vec',sep=' ',header=False, line_terminator='\n')
#add NB_ROWS NB_COLS as header
thai2save.shape

(60002, 400)

In [119]:
from gensim.models import KeyedVectors
model = KeyedVectors.load_word2vec_format(f'{MODEL_PATH}thai2vec.vec',binary=False,
                                         unicode_errors = 'ignore')

In [133]:
# model.save_word2vec_format(f'{MODEL_PATH}thai2vec.vec',f'{MODEL_PATH}thai2vec.vocab',False)
# model.save_word2vec_format(f'{MODEL_PATH}thai2vec.bin',None,True)

## Document Vectors

We can also get document vector from the language model by applying the encoder to a sentence.

In [121]:
tt = ThaiTokenizer()
def document_vector(ss, learn, data):
    s = tt.tokenizer(ss)
    t = torch.tensor(data.vocab.numericalize(s), requires_grad=False)[:,None].to(device)
    m = learn.model[0]
    m.reset()
    pred,_ = m(t)
    res = pred[-1][-1,:,:].squeeze().detach().numpy()
    return(res)

In [122]:
ss = 'วันนี้วันดีปีใหม่'
document_vector(ss,learn,data).shape

(400,)

In [125]:
from pythainlp.ulmfit import *

In [127]:
document_vector(ss,learn,data)

array([ 0.083069,  0.014881,  0.010677,  0.012863, ..., -0.246675, -0.015147,  0.153276, -0.004983], dtype=float32)