In [1]:
%load_ext autoreload
%autoreload 2

In [19]:
import numpy as np
import wandb
import torch
import torch.nn as nn
from torch.nn import functional as F

import pytorch_lightning as pl

import utils
import encoding
import ud
import neural

In [3]:
class ConlluDataModule(pl.LightningDataModule):

    def __init__(self, train_filename, test_filename, features, batch_size=64, validation_size=300):
        # TODO: use dev set instead of splitting validation
        super().__init__()
        SENTENCE_MAXLEN = 30
        WORD_MAXLEN = 11
        self.validation_size = validation_size
        self.batch_size = batch_size
        self.features = features
        
        train_data_x, train_data_y = encoding.load_sentences(train_filename, features, SENTENCE_MAXLEN, WORD_MAXLEN)
        self.data_train_full = torch.utils.data.TensorDataset(torch.Tensor(train_data_x), *[torch.Tensor(y).to(torch.int64) for y in train_data_y])
        
        test_data_x, test_data_y = encoding.load_sentences(test_filename, features, SENTENCE_MAXLEN, WORD_MAXLEN)
        self.data_test = torch.utils.data.TensorDataset(torch.Tensor(test_data_x), *[torch.Tensor(y).to(torch.int64) for y in test_data_y])

    def prepare_data(self):
        # No state assignment here
        pass

    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            self.data_train, self.data_val = torch.utils.data.random_split(self.data_train_full, [self.validation_size, len(self.data_train_full) - self.validation_size])
            self.dims = tuple(self.data_train[0][0].shape)

        if stage == 'test' or stage is None:
            self.dims = tuple(self.data_test[0][0].shape)

    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.data_train, batch_size=self.batch_size)

    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.data_val, batch_size=self.batch_size)

    def test_dataloader(self):
        return torch.utils.data.DataLoader(self.data_test, batch_size=self.batch_size)

In [114]:
NUM_EMBEDDING = 2000


class SentenceModel(neural.UdModel):

    def __init__(self, label_map, units=400, learning_rate=2e-3):
        super().__init__()
        self.save_hyperparameters()
        
        self.units = units

        self.embed = nn.Embedding(num_embeddings=NUM_EMBEDDING, embedding_dim=units)
        
        self.char_lstm = neural.SumSharedBiLSTM(units)
        
        self.word_lstm = neural.SumBiLSTM(units)
        
        self.tasks = nn.ModuleDict({label_name: nn.Linear(in_features=units, out_features=class_size)
                                    for label_name, class_size in label_map.items()})
        
        self.norm = nn.BatchNorm1d(units, affine=False)

    def forward(self, x):
        BATCH_SIZE, SENT_MAXLEN, WORD_MAXLEN = x.shape
        
        x = x.to(torch.int64)
        # x: (BATCH_SIZE, SENT_MAXLEN, WORD_MAXLEN)
        
        # Step 0: character embedding 
        
        x = x.reshape(BATCH_SIZE * SENT_MAXLEN, WORD_MAXLEN)
        # x: (BATCH_SIZE * SENT_MAXLEN, WORD_MAXLEN)

        embeds = self.embed(x)
        # embeds: (BATCH_SIZE * SENT_MAXLEN, WORD_MAXLEN, UNITS)
        
        embeds = embeds.permute([1, 0, 2])
        # x: (WORD_MAXLEN, BATCH_SIZE * SENT_MAXLEN, UNITS)

        
        # STEP 1: character-level lstm -> word embedding
        
        _, char_hidden = self.char_lstm(embeds)
        # char_hidden: (BATCH_SIZE * SENT_MAXLEN, UNITS)
        char_hidden = self.norm(char_hidden)
        
        char_hidden = char_hidden.reshape(BATCH_SIZE, SENT_MAXLEN, -1)
        # char_hidden: (BATCH_SIZE, SENT_MAXLEN, UNITS)
        
        char_hidden = char_hidden.permute([1, 0, 2])
        # char_hidden: (SENT_MAXLEN, BATCH_SIZE, UNITS)

        # STEP 2: sequence tagging using word-level lstm
    
        word_lstm_out, _ = self.word_lstm(char_hidden)
        # word_lstm_out: (SENT_MAXLEN, BATCH_SIZE, UNITS)
        
        word_lstm_out = word_lstm_out.permute([1, 0, 2])
        # word_lstm_out: (BATCH_SIZE, SENT_MAXLEN, UNITS)
        
        char_hidden = char_hidden.permute([1, 0, 2])
        # char_hidden: (BATCH_SIZE, SENT_MAXLEN, UNITS)
        
        summed = word_lstm_out + char_hidden
        # summed: (BATCH_SIZE, SENT_MAXLEN, UNITS)
        
        return {name: linear(summed).permute([0, 2, 1])
                for name, linear in self.tasks.items()}
  
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [4], gamma=0.5)
        return [optimizer], [scheduler]


In [23]:
label_map = {name: ud.Token.class_size(name) for name in [
#     'Abbr',
    'Case',
    'Cconj',
    'Det',
#     'Definite',
    'Gender',
#     'HebExistential',
#     'HebSource',
    'HebBinyan',
#     'Mood',
    'Number',
    'Person',
#     'Polarity',
    'Pos',
#     'Prefix',
    'PronGender',
    'PronNumber',
    'PronPerson',
    'PronType',
#     'Reflex',
    'R1',
    'R2',
    'R3',
    'R4',
    'Tense',
#     'VerbForm',
    'VerbType',
    'Voice',
]}

In [91]:
dataset = ConlluDataModule(
    f'../Hebrew_UD/he_htb-ud-train.conllu',
    f'../Hebrew_UD/he_htb-ud-dev.conllu',
    batch_size=32,
    label_map=label_map)
dataset.setup()

In [115]:
model = SentenceModel(dataset.features, units=400, learning_rate=4e-3)
wandb_logger = pl.loggers.WandbLogger(project='rootem', group='ud-fix1', name=f'schedule_multistep_4')
trainer = pl.Trainer(gpus=1, max_epochs=20, logger=wandb_logger)
trainer.fit(model, dataset)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]



  | Name      | Type            | Params
----------------------------------------------
0 | embed     | Embedding       | 800 K 
1 | char_lstm | SumSharedBiLSTM | 2 M   
2 | word_lstm | SumBiLSTM       | 2 M   
3 | tasks     | ModuleDict      | 76 K  
4 | norm      | BatchNorm1d     | 0     


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




1

In [None]:
sent = 'מילים כדורבנות התואמות בדיוק נמרץ את מה שהחתום מטה טען במשך שנים רבות'.split()
text = torch.Tensor([encoding.wordlist2numpy(sent, word_maxlen=11)] * 2)  # .cuda()

In [None]:
res = [[ud.Token.decode_label(label, idx) for idx in value.argmax(1)[0]]
        for label, value in model(text).items()]
print('Token', *[f[:7] for f in features], sep='\t')
for item in zip(sent, *res):
    print(*item, sep='\t')


In [None]:
for s, p, b, *rs in zip(sent, pos, binyan, r1, r2, r4):
    print(s, p, b, ''.join(rs), sep='\t')

In [None]:
model = IndependentModel(300)
trainer = pl.Trainer(gpus=1, auto_lr_find=True, max_epochs=10)
lr_finder = trainer.lr_find(model, train_dataloader=dataset, min_lr=1e-5, early_stop_threshold=None)

In [None]:
fig = lr_finder.plot(suggest=True)

In [None]:
trainer.test(model, test_dataloaders=dataset.test_dataloader())

In [None]:
help(trainer.test)