In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
from fastai.text import *
import numpy as np
import pickle
import sentencepiece as spm
from tqdm import tqdm

In [3]:
import fastai, torch
fastai.__version__ , torch.__version__

('1.0.57', '1.0.0')

In [4]:
!nvidia-smi

Sat Aug  8 13:16:51 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 430.64       Driver Version: 430.64       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  Tesla V100-PCIE...  Off  | 00000001:00:00.0 Off |                    0 |
| N/A   23C    P0    25W / 250W |     11MiB / 16160MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage    

In [5]:
torch.cuda.set_device(0)

In [6]:
!pwd

/data/home/ubuntu/gaurav/in/fire/code-mixed-enta/language_model


In [7]:
path = Path('./')

In [11]:
def handle_all_caps(t: str) -> str:
    tokens = t.split()
    tokens = replace_all_caps(tokens)
    return ' '.join(tokens)

def handle_upper_case_first_letter(t: str) -> str:
    tokens = t.split()
    tokens = deal_caps(tokens)
    return ' '.join(tokens)

def lower_case_everything(t: str) -> str:
    return t.lower()

In [12]:
class CodeMixedTamilTokenizer(BaseTokenizer):
    def __init__(self, lang:str):
        self.lang = lang
        self.sp = spm.SentencePieceProcessor()
        self.sp.Load(str(path/"../tokenizer/taen_spm.model"))
        
    def tokenizer(self, t:str) -> List[str]:
        return self.sp.EncodeAsPieces(t)

In [13]:
sp = spm.SentencePieceProcessor()
sp.Load(str(path/"../tokenizer/taen_spm.model"))
itos = [sp.IdToPiece(int(i)) for i in range(8000)]

In [14]:
len(itos)

8000

In [15]:
itos[:20]

['xxunk',
 'xxbos',
 'xxeos',
 'xxpad',
 'xxfld',
 'xxmaj',
 'xxup',
 'xxrep',
 'xxwrep',
 '.',
 ',',
 '▁',
 's',
 'a',
 '="',
 'in',
 'doc',
 't',
 'il',
 'i']

In [16]:
# 8,000 is the vocab size that we chose in sentencepiece
taen_vocab = Vocab(itos)

In [17]:
tokenizer = Tokenizer(lang='taen', tok_func=CodeMixedTamilTokenizer)

In [18]:
tokenizer.pre_rules.append(lower_case_everything)
tokenizer.pre_rules.append(handle_all_caps)
tokenizer.pre_rules.append(handle_upper_case_first_letter)

In [19]:
tokenizer.special_cases, tokenizer.pre_rules, tokenizer.post_rules

(['xxunk',
  'xxpad',
  'xxbos',
  'xxeos',
  'xxfld',
  'xxmaj',
  'xxup',
  'xxrep',
  'xxwrep'],
 [<function fastai.text.transform.fix_html>,
  <function fastai.text.transform.replace_rep>,
  <function fastai.text.transform.replace_wrep>,
  <function fastai.text.transform.spec_add_spaces>,
  <function fastai.text.transform.rm_useless_spaces>,
  <function __main__.lower_case_everything>,
  <function __main__.handle_all_caps>,
  <function __main__.handle_upper_case_first_letter>],
 [<function fastai.text.transform.replace_all_caps>,
  <function fastai.text.transform.deal_caps>])

In [20]:
tokens = tokenizer.process_all(['Tell me about TOUR self, mujhe jaanna hai'])

In [21]:
''.join(tokens[0])

'▁tell▁me▁about▁tour▁self,▁mujhe▁jaanna▁hai'

In [22]:
path

PosixPath('.')

In [23]:
data_lm = TextLMDataBunch.from_folder(path=path/'../dataset_preparation', train='train_uncased' , valid='valid_uncased', vocab=taen_vocab, tokenizer=tokenizer)

In [24]:
data_lm.batch_size

64

In [25]:
# data_lm.save()

In [26]:
data_lm.show_batch()

idx,text
0,"▁k . ▁la kkum an an "", ▁bhi . ▁1939 ) ▁woru ▁tamilk ▁arasielvadi . ▁thirunelveli ▁ma wa ▁ xxrep ▁4 ▁t ▁ch ▁serneaver . ▁tanatu ▁19 ▁vatu ▁vayatil ▁dravid ▁iyakkatil ▁tannai ▁itu badutti ky and ▁iver ▁binner ▁nell ai ▁dootukudi ▁on ou patt ▁nell ai ▁mawatt ▁porul aurag ▁8 ▁andus , ▁bri coppatt ▁nell ai ▁mawatt ▁seyal a parag ▁8 ▁andus ▁(19 87 ▁muthal ▁1994 ▁varai ), ▁tim uk"
1,"i ddu ▁azi ki m , ▁anal ▁don mayana ▁nagar i ganglum ▁ala bbar iya ▁valam um ▁niraint ▁afbrika ▁khanda tap ▁parchiya ▁arimugham ▁ikk ddurai . ▁tamadhu ▁ay rat ▁u zi ppal ▁ulak ukkuk ▁ganakk a tar es ▁khandupitip pu gans ▁valang y ▁china wap ▁parchiya ▁kddurai ▁itu . ▁airopavin ▁vadamengu ▁or attil ▁aru kon ▁amaiple ▁amaindull ▁france ▁nott ap e ▁parchiya ▁kddurai . ▁ansaiya ▁mes pa dom iya ▁yennum"
2,"▁id ="" 11 70 07"" ▁url ="" https : ▁ / ▁ / ▁ ta . wikipedia . org ▁ / ▁ wiki ? curid = 11 70 07"" ▁title ="" per iya vach an b illa ""> ▁periya vach an b illa ▁periya vach an b illa ▁id ikk al ▁tamil ▁ur y asrier . ▁vainav ▁urya si ar kalul ▁mudanmayana var . ▁iver ▁nal ay ir ▁di v"
3,"▁ / ▁ / ▁ ta . wikipedia . org ▁ / ▁ wiki ? curid = 27 2 68"" ▁title ="" chi th thi ▁( tirippadam ) ""> ▁chith thi ▁( tirippadam ) ▁chith thi ▁1966 ▁aam ▁antu ▁velivanth ▁tamilt ▁tirippatamakum . ▁ke . ▁es . ▁go bal krishna ▁iyakkatil ▁velivanth ▁ittrappattil ▁em . ▁aar . ▁radha , ▁muddu raman , ▁pad mini ▁manllum ▁balarum ▁naditthrundaner . ▁< ▁"
4,▁kaikal ack ▁gan u u du ▁enpatu ▁payanull tak ▁irukk in ta ▁ena ▁ari a patvilly . ▁tho mar uku ▁aal an avar luku ▁arugamiil ▁irukki il ▁mukh ad r ikhal ▁ani vatu ▁balan ▁mik atak ▁irukalam . ▁adhikapatiana ▁udal ▁alltu ▁samuk ▁do dudal il ▁irundhu ▁vilaki ▁iruppatu ▁balan ull tak ▁irundatu ▁ena ▁dirmani padaskana ▁podhumana ▁sandues ▁ille ▁alltu . ▁tani oru ▁al uku d ▁tati man ▁bi di kum


In [27]:
len(data_lm.train_dl)

7614

In [28]:
len(data_lm.valid_dl)

1888

In [29]:
len(data_lm.vocab.itos)

8000

In [30]:
learn = language_model_learner(data_lm, AWD_LSTM, pretrained=False)

In [31]:
gc.collect()

3305

In [32]:
learn.model

SequentialRNN(
  (0): AWD_LSTM(
    (encoder): Embedding(8000, 400, padding_idx=1)
    (encoder_dp): EmbeddingDropout(
      (emb): Embedding(8000, 400, padding_idx=1)
    )
    (rnns): ModuleList(
      (0): WeightDropout(
        (module): LSTM(400, 1152, batch_first=True)
      )
      (1): WeightDropout(
        (module): LSTM(1152, 1152, batch_first=True)
      )
      (2): WeightDropout(
        (module): LSTM(1152, 400, batch_first=True)
      )
    )
    (input_dp): RNNDropout()
    (hidden_dps): ModuleList(
      (0): RNNDropout()
      (1): RNNDropout()
      (2): RNNDropout()
    )
  )
  (1): LinearDecoder(
    (decoder): Linear(in_features=400, out_features=8000, bias=True)
    (output_dp): RNNDropout()
  )
)

In [33]:
learn.fit_one_cycle(10, 1e-2, callbacks=[callbacks.SaveModelCallback(learn, every='improvement', monitor='valid_loss', name='best_model')])

epoch,train_loss,valid_loss,accuracy,time
0,4.307283,4.490289,0.302827,10:34
1,4.319037,4.412583,0.309391,10:33
2,4.431201,4.36628,0.315155,10:32
3,4.472373,4.23228,0.328268,10:32
4,4.328704,4.110378,0.339276,10:34
5,4.103197,3.98382,0.352649,10:33
6,4.138057,3.848249,0.366615,10:46
7,4.007083,3.736214,0.379458,10:33
8,3.884695,3.649416,0.39012,10:34
9,3.968307,3.624365,0.393848,10:46


Better model found at epoch 0 with valid_loss value: 4.490289211273193.
Better model found at epoch 1 with valid_loss value: 4.412583351135254.
Better model found at epoch 2 with valid_loss value: 4.3662800788879395.
Better model found at epoch 3 with valid_loss value: 4.232280254364014.
Better model found at epoch 4 with valid_loss value: 4.110378265380859.
Better model found at epoch 5 with valid_loss value: 3.9838204383850098.
Better model found at epoch 6 with valid_loss value: 3.8482494354248047.
Better model found at epoch 7 with valid_loss value: 3.7362143993377686.
Better model found at epoch 8 with valid_loss value: 3.649416446685791.
Better model found at epoch 9 with valid_loss value: 3.6243646144866943.


In [34]:
learn.load('best_model')

LanguageLearner(data=TextLMDataBunch;

Train: LabelList (357 items)
x: LMTextList
▁x x bo s ▁< doc ▁id =" 24 65 52" ▁url =" https : ▁ / ▁ / ▁ ta . wikipedia . org ▁ / ▁ wiki ? curid = 24 65 52" ▁title =" di . ▁e . ▁ke . ▁ilakk uman an "> ▁di . ▁e . ▁ke . ▁ilakk uman an ▁di . ▁e . ▁ke . ▁ilakk uman an ▁(" t . ▁a . ▁k . ▁la kkum an an ", ▁bhi . ▁1939 ) ▁woru ▁tamilk ▁arasielvadi . ▁thirunelveli ▁ma wa ▁ xxrep ▁4 ▁t ▁ch ▁serneaver . ▁tanatu ▁19 ▁vatu ▁vayatil ▁dravid ▁iyakkatil ▁tannai ▁itu badutti ky and ▁iver ▁binner ▁nell ai ▁dootukudi ▁on ou patt ▁nell ai ▁mawatt ▁porul aurag ▁8 ▁andus , ▁bri coppatt ▁nell ai ▁mawatt ▁seyal a parag ▁8 ▁andus ▁(19 87 ▁muthal ▁1994 ▁varai ), ▁tim uk avil ▁paniaxullar . ▁ti mugha vilrundhu ▁va co ▁velieri abodhu ▁avar uton ▁vant ▁8 ▁mawatt ▁seyal a or galul ▁ivar um ▁oruvar . ▁adan bin ▁nell ai ▁mawatt ▁madi muk ▁seyal abh r agaum ▁ , talam s ▁ adsi man uku lu ▁ushupinaragaum ▁9 ▁andus ▁paniaxullar . ▁pin ▁2003 ▁il ▁va co uton ▁ettptt ▁karutu ▁vedubatt a

In [35]:
TEXT = "my name is"
N_WORDS = 40
N_SENTENCES = 2

In [37]:
print("\n".join(learn.predict(TEXT, N_WORDS, temperature=0.9) for _ in range(N_SENTENCES)))

my name is a hu gas . ▁ethr ali kalin ▁pa thak gangs ▁finwarumaru : ▁tiru pati ▁ tra in gen , ▁de m my man in um ▁er gen a way ▁has i der s t ▁makas ▁ingu p ▁pes e pattu
my name is ▁var ta gon al s ▁the ▁fa ex s ▁in s or me dh y ▁natu re : ▁3. 2 ▁c m . ▁2013 ▁in ▁fi t ▁cl in t ▁co x ce ▁a ▁mor io c ▁enjh ▁ach il


In [38]:
np.exp(3.624365)

37.500902524246804

In [39]:
defaults.device = torch.device('cpu')
learn.model.eval()
learn.export()

In [40]:
path

PosixPath('.')

In [41]:
encoder = get_model(learn.model)[0]

In [42]:
encoder.state_dict()['encoder.weight'].shape

torch.Size([8000, 400])

In [43]:
embeddings = encoder.state_dict()['encoder.weight']

In [44]:
embeddings = np.array(embeddings)

In [45]:
embeddings[0].shape

(400,)

In [46]:
df = pd.DataFrame(embeddings)

In [47]:
df.shape

(8000, 400)

In [48]:
df.to_csv('ulmfit_embeddings.tsv', sep='\t', index=False, header=False)

In [49]:
df.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,390,391,392,393,394,395,396,397,398,399
0,-0.67643,0.365254,-0.109539,1.386747,0.127163,1.041088,-0.975489,0.355097,0.362999,0.622357,...,-0.186305,-0.14604,-0.088788,-0.252422,-0.618491,-0.827193,1.753406,0.155206,-1.233057,-0.852615
1,-0.104451,-0.01,-0.079616,0.559299,-0.011074,0.046993,0.045187,0.610791,-0.103101,0.213283,...,-0.015699,-0.063259,0.123854,0.073189,-0.184061,-0.299852,-0.138602,0.072578,-0.152961,-0.009376
2,-0.102247,-0.009684,-0.079079,0.559127,-0.010568,0.046423,0.044713,0.610951,-0.104661,0.213368,...,-0.01637,-0.061588,0.124559,0.074444,-0.182768,-0.298372,-0.138531,0.072295,-0.153083,-0.009217
3,-0.102839,-0.009622,-0.080255,0.55836,-0.011011,0.048089,0.045544,0.61095,-0.103975,0.213319,...,-0.016009,-0.062617,0.123953,0.072562,-0.183773,-0.299184,-0.141226,0.071492,-0.153465,-0.010092
4,-0.102592,-0.008316,-0.079421,0.558533,-0.010209,0.04809,0.045019,0.610198,-0.103783,0.211797,...,-0.016531,-0.062153,0.123998,0.073275,-0.183127,-0.300289,-0.139307,0.072279,-0.153554,-0.009857


In [50]:
df.shape

(8000, 400)

In [51]:
len(itos)

8000

In [52]:
df2 = pd.DataFrame(itos)

In [53]:
df2.head()

Unnamed: 0,0
0,xxunk
1,xxbos
2,xxeos
3,xxpad
4,xxfld


In [54]:
df2.shape

(8000, 1)

In [55]:
df2.to_csv('ulmfit_embeddings_metadata.tsv', sep='\t', index=False, header=False)

In [56]:
encoder.state_dict()['encoder.weight'][1]

tensor([-1.0445e-01, -1.0000e-02, -7.9616e-02,  5.5930e-01, -1.1074e-02,
         4.6993e-02,  4.5187e-02,  6.1079e-01, -1.0310e-01,  2.1328e-01,
        -1.1276e-02,  8.9489e-02, -8.4570e-03, -1.9432e-01,  1.7409e-02,
         1.4394e-01, -3.2438e-01,  8.1573e-02, -1.2693e-01,  1.9580e-01,
        -2.2160e-01,  1.2461e-01, -7.0240e-02, -1.1072e-01,  1.5984e-01,
         1.1763e+00,  2.0249e-01,  2.6018e-01,  1.5847e-01,  1.4048e-01,
         6.8172e-02, -2.3396e-01,  2.1471e-01, -3.5477e-01,  3.2458e-02,
        -1.7663e-01,  1.5906e-01, -1.2462e-01,  2.4571e-01, -4.1767e-01,
         1.6315e-02, -2.8302e-01, -7.6148e-02,  7.5319e-02,  6.7007e-02,
        -4.2344e-02,  2.6242e-02,  5.3554e-02,  2.4956e-01,  2.9927e-01,
        -6.3178e-02,  1.0965e-01,  2.9615e-01, -3.1136e-02,  3.4137e-01,
         1.4412e-01, -1.2900e-01, -2.9176e-02, -3.8198e-01,  1.5966e-01,
        -4.6067e-03,  3.4884e-02, -2.4460e-01, -1.2963e-02,  8.0436e-02,
        -5.2440e-01,  4.4380e-01,  1.7138e-01, -1.9