In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import wandb
import torch
import torch.nn as nn
from torch.nn import functional as F

import pytorch_lightning as pl

import utils
import ud
import neural



In [10]:
class ConlluDataModule(pl.LightningDataModule):

    def __init__(self, train_filename, test_filename, batch_size=32, validation_size=500):
        # TODO: use dev set instead of splitting validation
        super().__init__()
        SENTENCE_MAXLEN = 30
        WORD_MAXLEN = 25
        self.validation_size = validation_size
        self.batch_size = batch_size
        
        train_data_x, train_data_y = ud.load_dataset(train_filename, ud.parse_opnlp, SENTENCE_MAXLEN, WORD_MAXLEN)
        self.data_train_full = torch.utils.data.TensorDataset(torch.Tensor(train_data_x), *[torch.Tensor(y).to(torch.int64) for y in train_data_y])
        
        test_data_x, test_data_y = ud.load_dataset(test_filename, ud.parse_opnlp, SENTENCE_MAXLEN, WORD_MAXLEN)
        self.data_test = torch.utils.data.TensorDataset(torch.Tensor(test_data_x), *[torch.Tensor(y).to(torch.int64) for y in test_data_y])

    def prepare_data(self):
        # No state assignment here
        pass

    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            self.data_val, self.data_train = torch.utils.data.random_split(self.data_train_full, [self.validation_size, len(self.data_train_full) - self.validation_size])
            self.dims = tuple(self.data_train[0][0].shape)

        if stage == 'test' or stage is None:
            self.dims = tuple(self.data_test[0][0].shape)

    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.data_train, batch_size=self.batch_size)

    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.data_val, batch_size=self.batch_size)

    def test_dataloader(self):
        return torch.utils.data.DataLoader(self.data_test, batch_size=self.batch_size)

In [9]:
NUM_EMBEDDING = 2000


class SentenceModel(neural.UdModel):

    def __init__(self, label_map, units=400, learning_rate=2e-3):
        super().__init__()
        self.save_hyperparameters()
        
        self.units = units

        self.embed = nn.Embedding(num_embeddings=NUM_EMBEDDING, embedding_dim=units)
        
        self.char_lstm = neural.SumSharedBiLSTM(units)
        
        self.word_lstm = neural.SumBiLSTM(units)
        
        self.norm = nn.BatchNorm1d(units, affine=False)
        
        self.tasks = nn.ModuleDict([(label_name, nn.Linear(in_features=units, out_features=class_size))
                                     for label_name, class_size in label_map.items()])

    def forward(self, x):
        BATCH_SIZE, SENT_MAXLEN, WORD_MAXLEN = x.shape
        
        x = x.to(torch.int64)
        # x: (BATCH_SIZE, SENT_MAXLEN, WORD_MAXLEN)
        
        # Step 0: character embedding 
        
        x = x.reshape(BATCH_SIZE * SENT_MAXLEN, WORD_MAXLEN)
        # x: (BATCH_SIZE * SENT_MAXLEN, WORD_MAXLEN)

        embeds = self.embed(x)
        # embeds: (BATCH_SIZE * SENT_MAXLEN, WORD_MAXLEN, UNITS)
        
        embeds = embeds.permute([1, 0, 2])
        # x: (WORD_MAXLEN, BATCH_SIZE * SENT_MAXLEN, UNITS)

        
        # STEP 1: character-level lstm -> word embedding
        
        _, char_hidden = self.char_lstm(embeds)
        # char_hidden: (BATCH_SIZE * SENT_MAXLEN, UNITS)
        char_hidden = self.norm(char_hidden)
        
        char_hidden = char_hidden.reshape(BATCH_SIZE, SENT_MAXLEN, -1)
        # char_hidden: (BATCH_SIZE, SENT_MAXLEN, UNITS)
        
        char_hidden = char_hidden.permute([1, 0, 2])
        # char_hidden: (SENT_MAXLEN, BATCH_SIZE, UNITS)

        # STEP 2: sequence tagging using word-level lstm
    
        word_lstm_out, _ = self.word_lstm(char_hidden)
        # word_lstm_out: (SENT_MAXLEN, BATCH_SIZE, UNITS)
        
        word_lstm_out = word_lstm_out.permute([1, 0, 2])
        # word_lstm_out: (BATCH_SIZE, SENT_MAXLEN, UNITS)
        
        char_hidden = char_hidden.permute([1, 0, 2])
        # char_hidden: (BATCH_SIZE, SENT_MAXLEN, UNITS)
        
        summed = word_lstm_out + char_hidden
        # summed: (BATCH_SIZE, SENT_MAXLEN, UNITS)
        
        return {name: linear(summed).permute([0, 2, 1])
                for name, linear in self.tasks.items()}
  
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
#         scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [1, 2], gamma=0.3)
        return [optimizer] #, [scheduler]


In [5]:
dataset = ConlluDataModule(
    f'../Hebrew_UD/he_htb-ud-train.conllu',
    f'../Hebrew_UD/he_htb-ud-dev.conllu',
    batch_size=32,
    validation_size=500)

In [6]:
dataset.setup()

In [11]:
# for i in range(30):

model = SentenceModel(ud.Token.label_map(), units=400, learning_rate=0.8e-3)
# model.char_lstm.load_state_dict(torch.load('models/verbs_sumsharedlstm.pth'))
wandb_logger = pl.loggers.WandbLogger(project='rootem', group='ud-real', name=f'dotted-alternating')
trainer = pl.Trainer(gpus=1, max_epochs=50, logger=wandb_logger)
trainer.fit(model, dataset)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]



  | Name      | Type            | Params
----------------------------------------------
0 | embed     | Embedding       | 800 K 
1 | char_lstm | SumSharedBiLSTM | 2 M   
2 | word_lstm | SumBiLSTM       | 2 M   
3 | norm      | BatchNorm1d     | 0     
4 | tasks     | ModuleDict      | 87 K  


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

Saving latest checkpoint..





1

In [None]:
sent = 'רציתי לעשות דברים אז הלכתי לאכול פיצה. לא שבעתי. אני אמשיך לכתוב אם תורידי את המים'.split()
text = torch.Tensor([[utils.encode_word(word, 11) for word in sent]] * 2)  # .cuda()

In [None]:
res = [[ud.Token.decode_label(label, idx) for idx in value.argmax(1)[0]]
        for label, value in model(text).items()]
indices = [9, 14, -6, -7, -8, -9]
print('Token', *[ud.Token.labels()[3:][i][:7] for i in indices], sep='\t')
for item in zip(sent, *res):
#     if item[14] == 'VERB':
        print(item[0], *[item[i] for i in indices], sep='\t')
#     else:
#         print(item[0])


In [12]:
for b in dataset.test_dataloader():
    text = b[0]
    preds = model.predict(text)
    for sent_id in range(len(text)):
        for word_id in range(len(text[sent_id])):
            if text[sent_id, word_id].sum() == 0:
                break
            line = [ud.Token.decode_label(label, pred[sent_id, word_id]) for label, pred in preds.items()]
            word = ''.join(chr(int(x)) for x in text[sent_id, word_id].data.numpy())
            root = ''.join([line[19], line[20], line[21], line[22]]).replace('.','')
            print(word, line[12], line[7], root, sep='\t')
        print()
    break

עַשְׂרוֹת           	NUM	PAAL	עשר
אֲנָשִׁים           	NOUN	NIFAL	אוש
מַגִּיעִים          	VERB	HIFIL	נגע
מִתְאִילַנְד        	ADV	HITPAEL	אעד
לְיִשְׂרָאֵל        	PROPN	NIFAL	שאל
כְּשֶׁהֵם           	PRON	PAAL	כהמ
נִרְשָׁמִים         	VERB	NIFAL	רשמ
כְּמִתְנַדְּבִים,   	NOUN	HITPAEL	נבב
אַךְ                	CCONJ	PAAL	אככ
לְמַעֲשֶׂה          	ADV	PIEL	עשי
מְשֻׁמְּשִׁים       	VERB	PIEL	שמש
עוֹבְדִים           	NOUN	PAAL	עבד
שְׂכִירִים          	NOUN	PAAL	שכר
זוֹלִים.            	ADJ	PAAL	זלי

תּוֹפָעָה           	NOUN	PUAL	תפע
זוֹ                 	PRON	PAAL	זוז
הִתְבָּרְרָה        	VERB	HITPAEL	ברר
אֶתְמוֹל            	ADV	PAAL	תמל
בְּוַועֲדַת         	NOUN	PIEL	בעת
הָעֲבוֹדָה          	NOUN	PAAL	עבד
וְהָרְוָוחָה        	NOUN	PAAL	רוח
שֶׁל                	ADP	PAAL	שלל
הַכְּנֶסֶת,         	PROPN	HIFIL	כנס
שֶׁדָּנָה           	VERB	PAAL	דונ
בְּנוֹשֵׂא          	NOUN	PAAL	נשא
הָעֲסָקַת           	NOUN	PAAL	עסק
עוֹבְדִים           	NOUN	PAAL	עבד
זָרִים.             	ADJ	PAAL	זרר

יוֹ"ר   

In [None]:
trainer.test(model, test_dataloaders=dataset.test_dataloader())

In [None]:
dataset.data_train[:][2]

In [None]:
label_map = {name: ud.Token.class_size(name) for name in [
#     'Abbr',
    'Case',
    'Cconj',
    'Det',
#     'Definite',
    'Gender',
#     'HebExistential',
#     'HebSource',
    'HebBinyan',
#     'Mood',
    'Number',
    'Person',
#     'Polarity',
    'Pos',
#     'Prefix',
    'PronGender',
    'PronNumber',
    'PronPerson',
    'PronType',
#     'Reflex',
    'R1',
    'R2',
    'R3',
    'R4',
    'Tense',
#     'VerbForm',
    'VerbType',
    'Voice',
]}