### Imports and loads

In [1]:
from typing import List

import pickle as pkl
import numpy as np
import html
from pathlib import Path

from fastai.text import *
from sklearn.model_selection import train_test_split

In [2]:
DATA_PATH = Path('DATA/')

### Dataset properties, inspection, tokenization

In [3]:
DATASET_NAME = 'x_and_y_cleaned.pkl'
with open(DATA_PATH/DATASET_NAME, 'rb') as f:
    articles, categories = pkl.load(f)

In [4]:
# Label None as 'none'
categories = ['none' if not x else x for x in categories]

In [5]:
CLASSES = sorted(list(set(categories)))
ARTICLE_COUNT = len(articles)
CLASS_COUNT = len(CLASSES)
BOS = 'xbos'  # beginning-of-sentence tag
FLD = 'xfld'  # data field tag
MAX_SIZE = 250

print(ARTICLE_COUNT)
print(CLASS_COUNT)

48514
138


In [144]:
print(CLASSES)

['aiaeri', 'arvamus', 'arvamus/intervjuu', 'arvamus/juhtkiri', 'arvamus/karikatuur', 'arvamus/kommentaar', 'arvamus/koomiks', 'arvamus/lugejakiri', 'arvamus/nadalatipud', 'arvamus/repliik', 'arvamus/seisukoht', 'blogid/avastaeestimaad', 'blogid/aveameerikas', 'blogid/filmiblogi', 'blogid/hollandiblogi', 'blogid/indoneesiablogi', 'blogid/jumestusblogi', 'blogid/korvpallimm', 'blogid/lehesaba', 'blogid/londonilustiblogi', 'blogid/malluka', 'blogid/meistriteblogi', 'blogid/moeajakiri', 'blogid/moekeeris', 'blogid/motteid', 'blogid/muusikablogi', 'blogid/opetajablogi', 'blogid/psyhholoogiablogi', 'blogid/pulmablogi', 'blogid/raamatublogi', 'blogid/raha', 'blogid/seljakotigablogi', 'blogid/spordiblogi', 'blogid/teleblogi', 'blogid/trenniblogi', 'blogid/valdojahilo', 'blogid/yksikvanem', 'eestinaine/elud-inimesed', 'eriline/horoskoop', 'eriline/mystika', 'joulud', 'kroonika/eesti', 'lemmikloom', 'linnaleht/arvamus', 'linnaleht/dilaila', 'linnaleht/karikatuur', 'linnaleht/kodusedlood', 'linna

In [56]:
# Dataset examples:
index = 0
print('ARTICLE: ', articles[index][0:110], '...')
print('CATEGORY: ', categories[index])

ARTICLE:  Kas parima aastavahetuse programmi pani eetrisse ETV, Kanal 2 või hoopis TV3? ETVst näegid vaatajad saateid "V ...
CATEGORY:  televeeb/tvuudised


In [54]:
# Get median/average word count
print(np.median([len(x.split(' ')) for x in articles]))
print(np.mean([len(x.split(' ')) for x in articles]))

261.0
387.0457805994146


In [6]:
# One hot encoding
# labels = []
# for x in categories:
#     y = [0 for x in range(CLASS_COUNT)]
#     y[CLASSES.index(x)] = 1
#     labels.append(y)

# Class index encoding
labels = []
for x in categories:
    y = CLASSES.index(x)
    labels.append(y)

In [7]:
np.random.seed(42)
train_texts, val_texts, train_labels, val_labels = train_test_split(articles, labels, test_size=0.1, random_state=42)
pickle.dump([train_texts, val_texts, train_labels, val_labels], open(DATA_PATH/'tokens'/'trnx_valx_trny_valy_ind_split.pkl', 'wb'))

### Tokenize

In [8]:
max_vocab = 60000
min_freq = 5

In [59]:
tok_train = Tokenizer(lang='xx').proc_all_mp(partition_by_cores(train_texts))
tok_val = Tokenizer(lang='xx').proc_all_mp(partition_by_cores(val_texts))

In [60]:
freq = Counter(p for o in tok_train for p in o)
print(len(tok_train))
freq.most_common(25)

43662


[(',', 657926),
 ('.', 559252),
 ('"', 217175),
 ('ja', 210514),
 ('on', 197759),
 ('et', 150766),
 ('ei', 106727),
 ('kui', 74991),
 ('ta', 66639),
 ('ka', 58212),
 ('oli', 51101),
 ('oma', 46727),
 ('-', 46020),
 ('ning', 45314),
 ('see', 45285),
 ('xbos', 43662),
 ('xfld', 43662),
 ('0', 42597),
 ('aga', 38936),
 ('t_up', 31812),
 ('mis', 31436),
 ('ma', 30478),
 ('siis', 29830),
 ('kes', 29218),
 ('tema', 28739)]

In [61]:
print(tok_val[5])

['xbos', 'vehklemisliidu', 'president', ',', 'riigikogu', 'liige', 'margus', 'hanson', 'tõdes', ',', 'et', 'naiskond', 'vehkles', 'kaunilt', 'kuni', 'finaalini', '.', '"', 'naised', 'olid', 'väga', 'tublid', '.', 'meil', 'on', 'noor', ',', 'perspektiivikas', 'ja', 'arenev', 'võistkond', ':', 'teise', 'kohaga', 'tuleb', 'igati', 'rahul', 'olla', ',', 'sest', 'ega', 'jõu', 'ja', 'võimu', 'vastu', 'ei', 'saa', '!', '"', 'hanson', 'lisas', ',', 'et', 'teda', 'rõõmustab', 'sten', 'priinitsa', 'individuaalturniiril', 'saadud', 'kaheksas', 'koht', ',', 'millega', 'mees', 'suurendab', 'ka', 't_up', 'eok', 'toetusraha', '.', '"', 'meie', 'vehklejad', 'on', 'tõestanud', ',', 'et', 'neid', 'saab', 'usaldada', '.', 'sportlased', 'seavad', 'kõrged', 'sihid', 'ja', 'on', 'võimelised', 'neid', 'täitma', ';', '"', 'kinnitas', 'ta', '.', 'ühtlasi', 'märkis', 'hanson', ',', 'et', 'suur', 'on', 'treener', 'igor', 'tšikinjovi', 'panus', '.', '"', 'ta', 'on', 'toonud', 'värsket', 'verd', 'ja', 'hingamist',

In [62]:
freq_val = Counter(p for o in tok_val for p in o)
print(len(tok_val))
freq_val.most_common(25)

4852


[(',', 72534),
 ('.', 62293),
 ('"', 23741),
 ('ja', 23599),
 ('on', 21847),
 ('et', 16600),
 ('ei', 11625),
 ('kui', 8402),
 ('ta', 7294),
 ('ka', 6594),
 ('oli', 5541),
 ('ning', 5219),
 ('oma', 5142),
 ('-', 5101),
 ('see', 4893),
 ('xbos', 4852),
 ('xfld', 4852),
 ('0', 4743),
 ('aga', 4345),
 ('mis', 3589),
 ('t_up', 3475),
 ('ma', 3323),
 ('tema', 3264),
 ('eesti', 3248),
 ('siis', 3235)]

In [63]:
np.save(DATA_PATH/'tokens/tok_train_pad.npy', tok_train)
np.save(DATA_PATH/'tokens/tok_val_pad.npy', tok_val)

In [64]:
itos = [o for o,c in freq.most_common(max_vocab) if c>min_freq]
itos.insert(0, '_pad_')
itos.insert(0, '_unk_')

In [65]:
stoi = collections.defaultdict(lambda:0, {v:k for k,v in enumerate(itos)})

In [97]:
train_lm = np.array([[stoi[o] for o in p] for p in tok_train])
val_lm = np.array([[stoi[o] for o in p] for p in tok_val])

In [117]:
# Pad and crop values
train_lm_pad = [x[:MAX_SIZE] if len(x) > MAX_SIZE else x + [0 for i in range(MAX_SIZE - len(x))] for x in train_lm]
val_lm_pad = [x[:MAX_SIZE] if len(x) > MAX_SIZE else x + [0 for i in range(MAX_SIZE - len(x))] for x in val_lm]

In [118]:
np.save(DATA_PATH/'tokens'/'trn_ids.npy', train_lm_pad) # Oversaved all as padded
np.save(DATA_PATH/'tokens'/'val_ids.npy', val_lm_pad)
pickle.dump(itos, open(DATA_PATH/'tokens'/'itos.pkl', 'wb'))

### Load tokenized data

In [9]:
train_texts, val_texts, train_labels, val_labels = pickle.load(open(DATA_PATH/'tokens'/'trnx_valx_trny_valy_ind_split.pkl', 'rb'))
train_lm = np.load(DATA_PATH/'tokens'/'trn_ids.npy')
val_lm = np.load(DATA_PATH/'tokens'/'val_ids.npy')
itos = pickle.load(open(DATA_PATH/'tokens'/'itos.pkl', 'rb'))

In [91]:
print(train_texts[0])

xbos Peaminister Taavi Rõivas jätab võimutüli tõttu ära visiidid Leedusse ja Rootsi, teda asendab väliskaubandus- ja ettevõtlusminister Anne Sulling.  Valitsuse pressiesindaja kinnitas pühapäeva pärastlõunal, et Rõivas ei sõida esmaspäeval visiidile Leetu ja Rootsi. Pressiesindaja teatel jäävad visiidid ära "seoses ametikohustustega Eestis". Reformierakonna esimees, peaminister Taavi Rõivas pidi esmaspäeval koos teiste Balti riikide valitsusjuhtidega osalema Leedus Klaipedas aset leidval LNG ujuvterminali saabumistseremoonial. Enne tseremooniat pidi aset leidma peaministrite ning Ameerika Ühendriikide esindajate ühine töölõuna. Pärastlõunal pidi Rõivas suunduma edasi Stockholmi, kus toimub Balti- ja Põhjamaade tippkohtumine. Rootsi, Soome, Norra, Islandi, Taani, Eesti, Läti ja Leedu peaministrite kohtumisel räägitakse majanduse olukorrast Euroopas, transatlantilistest suhetest ning Ukrainaga seotud arengutest. Pühapäeval kohtuvad Reformierakonna ja Sotsiaaldemokraatliku Erakonna esimeh

In [105]:
print(train_lm[0])

[17, 425, 524, 658, 2109, 0, 254, 63, 48013, 0, 5, 563, 2, 84, 28902, 0, 5, 53171, 1428, 26646, 3, 64, 755, 588, 438, 2029, 1368, 2, 7, 658, 8, 7061, 661, 4845, 17602, 5, 563, 3, 588, 704, 1070, 48013, 63, 4, 552, 0, 136, 4, 3, 813, 829, 2, 425, 524, 658, 388, 661, 79, 383, 555, 1197, 0, 3403, 5258, 0, 1815, 0, 21, 11591, 0, 0, 3, 105, 50278, 388, 1815, 3905, 31416, 15, 542, 1406, 7017, 3196, 0, 3, 1368, 388, 658, 0, 180, 3805, 2, 45, 638, 50279, 5, 9650, 11347, 3, 563, 2, 322, 2, 954, 2, 5409, 2, 2147, 2, 27, 2, 662, 5, 1109, 31416, 2357, 2275, 3893, 3806, 938, 2, 0, 5920, 15, 12824, 433, 43618, 3, 679, 10646, 813, 5, 7930, 871, 48014, 2, 7, 3819, 1726, 1064, 194, 0, 7214, 1899, 3, 18, 37, 7215, 19, 19, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0

In [10]:
print(train_labels[0])
print(CLASSES[train_labels[0]])
# print(CLASSES[train_labels[0].index(1)]) # for one hot

130
uudised/eesti


In [11]:
bs=32

class TokDataset(Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y
        self.len = len(self.x)
        self.x_data = torch.from_numpy(self.x); self.x_data = self.x_data.long()
        self.y_data = torch.from_numpy(self.y); self.y_data = self.y_data.long()
        print('x shape', self.x_data.shape)
        print('y shape', self.y_data.shape)
        
    def __getitem__(self, index):
        return self.x_data[index], self.y_data[index]
    
    def __len__(self):
        return self.len
    
ds = TokDataset(train_lm, np.asarray(train_labels))
dl = torch.utils.data.DataLoader(dataset=ds, batch_size=bs, shuffle=True, num_workers=0)

x shape torch.Size([43662, 250])
y shape torch.Size([43662])


In [12]:
test_values = iter(dl)

In [13]:
xs, ys = next(test_values)

In [14]:
print(xs)

tensor([[   17,  2631, 13166,  ...,   170,     0,     0],
        [   17,    35,     0,  ...,     3,     0,     0],
        [   17,  3274,  2432,  ...,     0,     0,     0],
        ...,
        [   17,    27, 52997,  ...,   216,     0,     0],
        [   17,   353,     6,  ...,     0,     0,     0],
        [   17,   929,     0,  ...,    71,     0,     0]])


### Feed-forward NN

In [15]:
torch.__version__

'0.4.1'

In [40]:
class SimpleFNN(nn.Module):
    def __init__(self, input_size, vocab_size, num_outputs, num_l, neurons: List[int], e_size=200):
        super(SimpleFNN, self).__init__()
        self.e = nn.Embedding(vocab_size, e_size)
        self.input_l = nn.Linear(e_size * input_size, neurons[0])
        self.middle_l = nn.ModuleList()
        for i in range(num_l):
            self.middle_l.append(nn.Linear(neurons[i], neurons[i+1]))
        self.output_l = nn.Linear(neurons[-1], num_outputs)
        
    def forward(self, x):
        i_sz = x.shape[-1]
        x = F.relu(self.e(x))
        x = x.view(-1,  i_sz * x.shape[-1])
        x = F.relu(self.input_l(x))
        for l in self.middle_l:
            x = F.relu(l(x))
        return self.output_l(x) # No softmax for crossentropy
        

In [41]:
print(train_labels[0])
print(train_lm.shape)

130
(43662, 250)


In [42]:
fnn = SimpleFNN(MAX_SIZE, max_vocab, CLASS_COUNT, 4, [200, 300, 100, 50, 20]).cuda()
print(fnn)

SimpleFNN(
  (e): Embedding(60000, 200)
  (input_l): Linear(in_features=50000, out_features=200, bias=True)
  (middle_l): ModuleList(
    (0): Linear(in_features=200, out_features=300, bias=True)
    (1): Linear(in_features=300, out_features=100, bias=True)
    (2): Linear(in_features=100, out_features=50, bias=True)
    (3): Linear(in_features=50, out_features=20, bias=True)
  )
  (output_l): Linear(in_features=20, out_features=138, bias=True)
)


In [43]:
# Test pass through
fnn(xs.cuda()).shape

torch.Size([32, 138])

In [44]:
crit = torch.nn.CrossEntropyLoss()
opt = optim.Adam(fnn.parameters(), 1e-3)

In [45]:
for ep in range(5):
    for i, data in tqdm(enumerate(dl, 0)):
        x, y = data
        x = x.cuda()
        y = y.cuda()
        

        y_h = fnn(x)
        loss = crit(y_h, y)
        #import pdb;pdb.set_trace()
        if i % 250 == 0:
            print(f'Epoch: {ep} | loss {loss.item()}')

        opt.zero_grad()
        loss.backward()
        opt.step()

0it [00:00, ?it/s]Epoch: 0 | loss 4.919979095458984
245it [00:03, 67.29it/s]Epoch: 0 | loss 2.7303550243377686
497it [00:07, 67.25it/s]Epoch: 0 | loss 1.531820297241211
750it [00:11, 67.33it/s]Epoch: 0 | loss 1.9544901847839355
995it [00:14, 67.32it/s]Epoch: 0 | loss 1.102027416229248
1247it [00:18, 67.29it/s]Epoch: 0 | loss 1.9706389904022217
1365it [00:20, 67.27it/s]
0it [00:00, ?it/s]Epoch: 1 | loss 1.1481518745422363
245it [00:03, 67.12it/s]Epoch: 1 | loss 1.3593789339065552
497it [00:07, 67.16it/s]Epoch: 1 | loss 1.0671448707580566
749it [00:11, 67.22it/s]Epoch: 1 | loss 0.9334531426429749
994it [00:14, 67.16it/s]Epoch: 1 | loss 0.3123042583465576
1247it [00:18, 67.15it/s]Epoch: 1 | loss 1.134935975074768
1365it [00:20, 67.13it/s]
0it [00:00, ?it/s]Epoch: 2 | loss 1.0404680967330933
245it [00:03, 67.03it/s]Epoch: 2 | loss 0.6578649282455444
497it [00:07, 67.09it/s]Epoch: 2 | loss 1.1025166511535645
749it [00:11, 67.13it/s]Epoch: 2 | loss 0.4597235918045044
994it [00:14, 67.11it/s]

In [None]:
# Can also check how balanced the classes are