In [1]:
%load_ext autoreload
%autoreload 2

In [92]:
import numpy as np
import wandb
import torch
import torch.nn as nn
from torch.nn import functional as F

import pytorch_lightning as pl

import utils
import encoding



In [7]:
test_size = 300
SENTENCE_MAXLEN = 30
WORD_MAXLEN = 11

def load_dataset(kind):
    torch.manual_seed(0)
    np.random.seed(0)

    filename = f'../Hebrew_UD/he_htb-ud-{kind}.conllu'

#     artifact = wandb.Artifact(artifact_name, type='dataset')
#     artifact.add_file(filename)

    (train_x, train_y), (test_x, test_y) = encoding.load_sentences_split(filename, test_size, SENTENCE_MAXLEN, WORD_MAXLEN)

    # utils.shuffle_in_unison([train_x, *pre_train_y.values()])
    return (train_x, train_y), (test_x, test_y)


kind = 'dev'
(train_x, train_y), (test_x, test_y) = load_dataset(kind)

In [114]:
NUM_EMBEDDING = 2000

class SimpleLSTM(nn.Module):
    def __init__(self, units):
        super().__init__()
        self.units = units
        self.lstm = nn.LSTM(input_size=units, hidden_size=units, num_layers=1, batch_first=False, bidirectional=True)

    def forward(self, x):
        # x: (..., UNITS)
        
        lstm_out, (hidden, cell) = self.lstm(x)
        # lstm_out: (..., UNITS * 2)
        # hidden: (2, ..., UNITS)
        # cell: (2, ..., UNITS)
        
        hidden = hidden[0] + hidden[1]
        # hidden: (..., UNITS)
        
        left, right = torch.chunk(lstm_out, 2, dim=-1)
        # left: (..., UNITS)
        # right: (..., UNITS)
        
        lstm_out = torch.squeeze(left + right)
        # lstm_out: (..., UNITS)
        
        return lstm_out, hidden
    
    
class IndependentModel(pl.LightningModule):

    def __init__(self, units):
        super().__init__()
        self.units = units

        self.embed = nn.Embedding(num_embeddings=NUM_EMBEDDING, embedding_dim=units)
        
        self.char_lstm1 = SimpleLSTM(units)
        self.char_lstm2 = SimpleLSTM(units)
        
        self.word_lstm1 = SimpleLSTM(units)
        self.word_lstm2 = SimpleLSTM(units)

        self.pos = nn.Linear(in_features=units, out_features=len(encoding.Classes.xpos))
        self.binyan = nn.Linear(in_features=units, out_features=len(encoding.Classes.HebBinyan))
        self.r1 = nn.Linear(in_features=units, out_features=len(encoding.RADICALS))
        self.r2 = nn.Linear(in_features=units, out_features=len(encoding.RADICALS))
        self.r3 = nn.Linear(in_features=units, out_features=len(encoding.RADICALS))
        self.r4 = nn.Linear(in_features=units, out_features=len(encoding.RADICALS))

    def forward(self, x):
        x = x.to(torch.int64)
        # x: (BATCH_SIZE, SENT_MAXLEN, WORD_MAXLEN)
        
        # Step 0: character embedding 
        
        x = x.reshape(-1, x.shape[-1])
        # x: (SENT_MAXLEN * BATCH_SIZE, WORD_MAXLEN)

        embeds = self.embed(x)
        # embeds: (WORD_MAXLEN, SENT_MAXLEN * BATCH_SIZE, UNITS)
        
        embeds = embeds.permute([1, 0, 2])
        # x: (SENT_MAXLEN * BATCH_SIZE, WORD_MAXLEN, UNITS)

        
        # STEP 1: character-level lstm -> word embedding
        
        char_lstm_out, _ = self.char_lstm1(embeds)
        # char_lstm_out: (WORD_MAXLEN, SENT_MAXLEN * BATCH_SIZE, UNITS)

        _, char_hidden = self.char_lstm2(char_lstm_out)
        # char_hidden: (SENT_MAXLEN * BATCH_SIZE, UNITS)
        
        char_hidden = char_hidden.reshape(SENTENCE_MAXLEN, -1, self.units)
        # char_hidden: (SENT_MAXLEN, BATCH_SIZE, UNITS)
 

        # STEP 2: sequence tagging using word-level lstm
    
        word_lstm_out, _ = self.word_lstm1(char_hidden)
        # word_lstm_out: (SENT_MAXLEN, BATCH_SIZE, UNITS)

        word_lstm_out += self.char_lstm2(word_lstm_out)[0]
        # word_lstm_out: (SENT_MAXLEN, BATCH_SIZE, UNITS)
        
        word_lstm_out = word_lstm_out.permute([1, 0, 2])
        # word_lstm_out: (BATCH_SIZE, SENT_MAXLEN, UNITS)
        
        return [linear(word_lstm_out) for linear in [self.pos, self.binyan, self.r1, self.r2, self.r3, self.r4]]

    def training_step(self, batch, batch_nb):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        return { 'loss': loss }

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.003)
    
    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        return {'test_loss': F.cross_entropy(y_hat, y)}

    def test_epoch_end(self, outputs):
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
        return {'avg_test_loss': avg_loss }


In [115]:
model = IndependentModel(200)

In [116]:
y = model(torch.tensor(train_x[:64, :, :]))

In [117]:
[a.shape for a in y]

[torch.Size([64, 30, 16]),
 torch.Size([64, 30, 8]),
 torch.Size([64, 30, 27]),
 torch.Size([64, 30, 27]),
 torch.Size([64, 30, 27]),
 torch.Size([64, 30, 27])]

In [50]:
train_y[0, :64].shape

(64, 30)

In [None]:
model = IndependentModel()
trainer = pl.Trainer()
trainer.fit(model, x, y)