In [1]:
%load_ext autoreload
%autoreload 2

In [56]:
import numpy as np
import wandb
import torch
import torch.nn as nn
from torch.nn import functional as F

import pytorch_lightning as pl

import utils
import encoding

In [3]:
class ConlluDataModule(pl.LightningDataModule):

    def __init__(self, conllu_filename, batch_size=64):
        super().__init__()
        SENTENCE_MAXLEN = 30
        WORD_MAXLEN = 11
        self.batch_size = batch_size
        data_x, data_y = encoding.load_sentences(conllu_filename, SENTENCE_MAXLEN, WORD_MAXLEN)
        self.data = torch.utils.data.TensorDataset(torch.Tensor(data_x), *[torch.Tensor(y).to(torch.int64) for y in data_y])

    def prepare_data(self):
        # No state assignment here
        pass

    def setup(self, stage=None):
        val_size = 300
        if stage == 'fit' or stage is None:
            self.data_train, self.data_val = torch.utils.data.random_split(self.data, [val_size, len(self.data) - val_size])
            self.dims = tuple(self.data_train[0][0].shape)

        if stage == 'test': # or stage is None:
            assert False

    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.data_train, batch_size=self.batch_size)

    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.data_val, batch_size=self.batch_size)

    def test_dataloader(self):
        assert False
        return torch.utils.data.DataLoader(self.data_test, batch_size=self.batch_size)

In [70]:
NUM_EMBEDDING = 2000

class SumBiLSTM(nn.Module):
    def __init__(self, units):
        super().__init__()
        self.units = units
        self.lstm = nn.LSTM(input_size=units, hidden_size=units, num_layers=1, batch_first=False, bidirectional=True)

    def forward(self, x):
        # x: (..., UNITS)
        
        lstm_out, (hidden, cell) = self.lstm(x)
        # lstm_out: (..., UNITS * 2)
        # hidden: (2, ..., UNITS)
        # cell: (2, ..., UNITS)
        
        hidden = hidden[0] + hidden[1]
        # hidden: (..., UNITS)
        
        left, right = torch.chunk(lstm_out, 2, dim=-1)
        # left: (..., UNITS)
        # right: (..., UNITS)
        
        lstm_out = torch.squeeze(left + right)
        # lstm_out: (..., UNITS)
        
        return lstm_out, hidden
    
    
class IndependentModel(pl.LightningModule):

    def __init__(self, units):
        super().__init__()
        self.units = units

        self.embed = nn.Embedding(num_embeddings=NUM_EMBEDDING, embedding_dim=units)
        
        self.char_lstm1 = SumBiLSTM(units)
        self.char_lstm2 = SumBiLSTM(units)
        
        self.word_lstm1 = SumBiLSTM(units)
        self.word_lstm2 = SumBiLSTM(units)

        self.pos = nn.Linear(in_features=units, out_features=len(encoding.Classes.xpos))
        self.binyan = nn.Linear(in_features=units, out_features=len(encoding.Classes.HebBinyan))
        self.r1 = nn.Linear(in_features=units, out_features=len(encoding.RADICALS))
        self.r2 = nn.Linear(in_features=units, out_features=len(encoding.RADICALS))
        self.r3 = nn.Linear(in_features=units, out_features=len(encoding.RADICALS))
        self.r4 = nn.Linear(in_features=units, out_features=len(encoding.RADICALS))

    def forward(self, x):
        SENT_MAXLEN = x.shape[1]
        
        x = x.to(torch.int64)
        # x: (BATCH_SIZE, SENT_MAXLEN, WORD_MAXLEN)
        
        # Step 0: character embedding 
        
        x = x.reshape(-1, x.shape[-1])
        # x: (SENT_MAXLEN * BATCH_SIZE, WORD_MAXLEN)

        embeds = self.embed(x)
        # embeds: (WORD_MAXLEN, SENT_MAXLEN * BATCH_SIZE, UNITS)
        
        embeds = embeds.permute([1, 0, 2])
        # x: (SENT_MAXLEN * BATCH_SIZE, WORD_MAXLEN, UNITS)

        
        # STEP 1: character-level lstm -> word embedding
        
        char_lstm_out, _ = self.char_lstm1(embeds)
        # char_lstm_out: (SENT_MAXLEN * BATCH_SIZE, WORD_MAXLEN, UNITS)

        _, char_hidden = self.char_lstm2(char_lstm_out)
        # char_hidden: (SENT_MAXLEN * BATCH_SIZE, UNITS)
        
        char_hidden = char_hidden.reshape(SENT_MAXLEN, -1, self.units)
        # char_hidden: (SENT_MAXLEN, BATCH_SIZE, UNITS)
 

        # STEP 2: sequence tagging using word-level lstm
    
        word_lstm_out, _ = self.word_lstm1(char_hidden)
        # word_lstm_out: (SENT_MAXLEN, BATCH_SIZE, UNITS)

        word_lstm_out = word_lstm_out + self.char_lstm2(word_lstm_out)[0]
        # word_lstm_out: (SENT_MAXLEN, BATCH_SIZE, UNITS)
        
        word_lstm_out = word_lstm_out.permute([1, 0, 2])
        # word_lstm_out: (BATCH_SIZE, SENT_MAXLEN, UNITS)
        
        return [linear(word_lstm_out).permute([0, 2, 1]) for linear in [self.pos, self.binyan, self.r1, self.r2, self.r3, self.r4]]

    def compute_metrics(self, batch):
        x, *ys = batch
        ys_hat = self(x)
        loss = sum(F.cross_entropy(y_hat, y) for y_hat, y in zip(ys_hat, ys))
        accuracy = {name: pl.metrics.functional.accuracy(y_hat, y)
                    for y_hat, y, name in zip(ys_hat, ys, encoding.names())}
        return loss, accuracy
    
    def training_step(self, batch, batch_nb):
        loss, accuracy = self.compute_metrics(batch)
        result = pl.TrainResult(minimize=loss)
        result.log('train_loss', loss, prog_bar=True)
        for k, v in accuracy.items():
            result.log(f't_{k}_acc', v, prog_bar=True)
        return result
    
    def validation_step(self, batch, batch_idx):
        loss, accuracy = self.compute_metrics(batch)
        result = pl.EvalResult()
        result.log('val_loss', loss, prog_bar=True)
        for k, v in accuracy.items():
            result.log(f'val_{k}_acc', v, prog_bar=True)
        return result

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.003)
    
    def test_step(self, batch, batch_idx):
        x, *ys = batch
        ys_hat = self(x)
        return {'test_loss': F.cross_entropy(ys_hat, ys)}

    def test_epoch_end(self, outputs):
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
        return {'avg_test_loss': avg_loss }


In [46]:
dataset = ConlluDataModule(f'../Hebrew_UD/he_htb-ud-dev.conllu')
dataset.setup()

In [None]:
model = IndependentModel(200)
trainer = pl.Trainer(max_epochs=1)
trainer.fit(model, dataset)

GPU available: True, used: False
TPU available: False, using: 0 TPU cores

   | Name       | Type      | Params
------------------------------------------
0  | embed      | Embedding | 400 K 
1  | char_lstm1 | SumBiLSTM | 643 K 
2  | char_lstm2 | SumBiLSTM | 643 K 
3  | word_lstm1 | SumBiLSTM | 643 K 
4  | word_lstm2 | SumBiLSTM | 643 K 
5  | pos        | Linear    | 3 K   
6  | binyan     | Linear    | 1 K   
7  | r1         | Linear    | 5 K   
8  | r2         | Linear    | 5 K   
9  | r3         | Linear    | 5 K   
10 | r4         | Linear    | 5 K   


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…