In [1]:
import os
import numpy as np

import dataset

import torch

import torch.nn as nn
assert torch.cuda.is_available()

In [2]:
MAXLEN = 64
def load_data(source, maxlen=MAXLEN, validation=0.1):
    filenames = [os.path.join('texts', f) for f in source]
    train, valid = dataset.load_data(filenames, validation, maxlen=maxlen)
    return train, valid

data_mix = load_data(['poetry', 'rabanit', 'pre_modern'])
data_modern = load_data(validation=0.2, source=['modern'])

In [46]:
UNITS = 300

LETTERS_SIZE = len(dataset.letters_table)
NIQQUD_SIZE = len(dataset.niqqud_table)
DAGESH_SIZE = len(dataset.dagesh_table)
SIN_SIZE = len(dataset.sin_table)

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        DROPOUT = 0.0
        self.embed = nn.Embedding(num_embeddings=LETTERS_SIZE, embedding_dim=UNITS)
        self.lstm1 = nn.LSTM(input_size=UNITS, hidden_size=UNITS, num_layers=1, batch_first=True, bidirectional=True, dropout=DROPOUT)
        self.lstm2 = nn.LSTM(input_size=UNITS, hidden_size=UNITS, num_layers=1, batch_first=True, bidirectional=True, dropout=DROPOUT)

        self.dense = nn.Linear(in_features=UNITS, out_features=UNITS)
        self.act = nn.ReLU()

        self.niqqud = nn.Linear(in_features=UNITS, out_features=NIQQUD_SIZE)
        self.dagesh = nn.Linear(in_features=UNITS, out_features=DAGESH_SIZE)
        self.sin = nn.Linear(in_features=UNITS, out_features=SIN_SIZE)

    def forward(self, x):
        embeds = self.embed(x)

        lstm_out, _ = self.lstm1(embeds)
        left, right = torch.chunk(lstm_out, 2, dim=-1)
        merge = left + right

        lstm_out, _ = self.lstm2(merge)
        left, right = torch.chunk(lstm_out, 2, dim=-1)
        merge = left + right + merge

        mid = embeds + self.act(self.dense(merge))

        niqqud = self.niqqud(mid).permute([0, 2, 1])
        dagesh = self.dagesh(mid).permute([0, 2, 1])
        sin = self.sin(mid).permute([0, 2, 1])
        # niqqud_scores = F.log_softmax(niqqud, dim=1)
        return {'N': niqqud, 'D': dagesh, 'S': sin}

model = Model()
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = torch.optim.Adam(model.parameters(), lr=2e-3)

device = torch.device("cuda")
model.to(device)

Model(
  (embed): Embedding(44, 300)
  (lstm1): LSTM(300, 300, batch_first=True, bidirectional=True)
  (lstm2): LSTM(300, 300, batch_first=True, bidirectional=True)
  (dense): Linear(in_features=300, out_features=300, bias=True)
  (act): ReLU()
  (niqqud): Linear(in_features=300, out_features=16, bias=True)
  (dagesh): Linear(in_features=300, out_features=3, bias=True)
  (sin): Linear(in_features=300, out_features=4, bias=True)
)

In [59]:
def sanity():
    train, valid = data_modern
    with torch.no_grad():
        x = train.normalized[:1]
        inputs = torch.from_numpy(x).to(torch.int64)
        tag_scores = model(inputs)
        print(tag_scores.shape)
        print(valid.niqqud[:1].shape)
        print(tag_scores)

torch.Size([1, 64, 16])
(1, 64)
tensor([[[ 0.0443,  0.1131,  0.1651,  ...,  0.1137,  0.0319,  0.1902],
         [ 0.1396, -0.1169,  0.2076,  ...,  0.0352,  0.2090,  0.2759],
         [ 0.1007, -0.1327,  0.1329,  ..., -0.0743,  0.1096, -0.0350],
         ...,
         [ 0.3901,  0.0842,  0.2003,  ...,  0.3017,  0.1579, -0.1293],
         [ 0.3616,  0.0776,  0.2187,  ...,  0.2765,  0.1573, -0.1133],
         [ 0.2803,  0.0605,  0.2528,  ...,  0.2360,  0.1469, -0.0808]]])


In [39]:
BATCH_SIZE = 32

def batch(a):
    ub = a.shape[0] // BATCH_SIZE * BATCH_SIZE
    return torch.from_numpy(a[:ub]).to(torch.int64).split(BATCH_SIZE)

def accuracy(output, ybatch):
    n = (ybatch != 0).sum()
    c = np.argmax(output, axis=1)
    return ((c == ybatch) & (c != 0)).sum() / n

def fit(data, epochs=1):
    train, _ = data
    x_all = batch(train.normalized)
    y_all = [{'N': n, 'D': d, 'S': s}
             for n, d, s in zip(batch(train.niqqud), batch(train.dagesh), batch(train.sin))]

    for epoch in range(epochs):
        total = len(x_all)
        accs = {'N': [], 'D': [], 'S': []}
        for i, (x, y) in enumerate(zip(x_all, y_all)):
            x = x.to(device)
            optimizer.zero_grad()
            outputs = model(x)
            outputs = {k: v.to(device) for k, v in outputs.items()}

            y = {k: v.to(device) for k, v in y.items()}

            sub_losses = {k: criterion(outputs[k], y[k]) for k in outputs}
            loss = sum(sub_losses.values())
            loss.backward()

            optimizer.step()

            outputs = {k: v.cpu().data.numpy() for k, v in outputs.items()}
            y = {k: v.cpu().data.numpy() for k, v in y.items()}

            for k in outputs:
                accs[k].append(accuracy(outputs[k], y[k]))

            if i % 20 == 0:
                print("{:4}/{}".format(i, total), end=' ')
                for k in accs:
                    print("{}_acc: {:.4f}".format(k, np.mean(accs[k])), end=' ')
                print("Loss: {:.4f}".format(loss.item()), end='\r')
                accs = {'N': [], 'D': [], 'S': []}
        print()

def validate(data):
    _, valid = data
    x_all = batch(valid.normalized)
    y_all = [{'N': n, 'D': d, 'S': s}
             for n, d, s in zip(batch(valid.niqqud), batch(valid.dagesh), batch(valid.sin))]

    with torch.no_grad():
        accs = {'N': [], 'D': [], 'S': []}
        losses = []
        for i, (x, y) in enumerate(zip(x_all, y_all)):
            x = x.to(device)

            outputs = model(x)
            outputs = {k: v.to(device) for k, v in outputs.items()}

            y = {k: v.to(device) for k, v in y.items()}

            sub_losses = {k: criterion(outputs[k], y[k]) for k in outputs}
            loss = sum(sub_losses.values())

            outputs = {k: v.cpu().data.numpy() for k, v in outputs.items()}
            y = {k: v.cpu().data.numpy() for k, v in y.items()}

            for k in outputs:
                accs[k].append(accuracy(outputs[k], y[k]))
            losses.append(loss.item())

        for k in accs:
            print("{}_acc: {:.4f}".format(k, np.mean(accs[k])), end=' ')
        print("Loss: {:.4f}".format(np.mean(losses)))

In [45]:
fit(data_mix)
validate(data_mix)
fit(data_modern)
validate(data_modern)

4360/4372 N_acc: 0.9495 D_acc: 0.9845 S_acc: 0.9990 Loss: 0.1880
N_acc: 0.9232 D_acc: 0.9737 S_acc: 0.9984 Loss: 0.3225
 460/467 N_acc: 0.9686 D_acc: 0.9880 S_acc: 0.9995 Loss: 0.1225
N_acc: 0.9559 D_acc: 0.9853 S_acc: 0.9991 Loss: 0.1864
