In [1]:
import os

import numpy as np

import wandb

import utils

os.environ['WANDB_SILENT'] = 'true'

NUM_EMBEDDING = 2000
def word2numpy(txt):
    return np.array([ord(c) for c in txt])

def wordlist2numpy(lines):
    return utils.pad_sequences([word2numpy(line) for line in lines],
                               maxlen=12, dtype=int, value=0)

RADICALS = ['.'] + list('אבגדהוזחטיכלמנסעפצקרשת') + ["ג'", "ז'", "צ'", 'שׂ']

BINYAN = 'פעל נפעל פיעל פועל הפעיל הופעל התפעל'.split()
TENSE = 'עבר הווה עתיד ציווי'.split()
VOICE = 'ראשון שני שלישי'.split()
GENDER = 'זכר נקבה'.split()
PLURAL = 'יחיד רבים'.split()

NAMES = ['B', 'T', 'V', 'G', 'P', 'R1', 'R2', 'R3', 'R4']
FEATURES = {
    'B': BINYAN,
    'T': TENSE,
    'V': VOICE,
    'G': GENDER,
    'P': PLURAL,
    'R1': RADICALS,
    'R2': RADICALS,
    'R3': RADICALS,
    'R4': RADICALS,
}

def to_category(name, b):
    return FEATURES[name].index(b)

def from_category(name, index):
    return FEATURES[name][index]

def list_to_category(name, bs):
    return np.array([to_category(name, b) for b in bs])

def list_from_category(name, indexes):
    return [from_category(name, index) for index in indexes]

def list_of_lists_to_category(items):
    return { name: list_to_category(name, item)
             for name, item in zip(NAMES, items) }

In [2]:
import torch

import torch.nn as nn
assert torch.cuda.is_available()

def to_device(d):
    if hasattr(d, 'cuda'):
        return d.cuda()
    return {k: v.cuda() for k, v in d.items()}

class Model(nn.Module):
    def __init__(self, units):
        super().__init__()
        self.units = units

        self.embed = nn.Embedding(num_embeddings=NUM_EMBEDDING, embedding_dim=units)
        self.lstm1 = nn.LSTM(input_size=units, hidden_size=units, num_layers=1, batch_first=True, bidirectional=True)

        self.binyan = nn.Linear(in_features=units, out_features=len(BINYAN))
        self.tense = nn.Linear(in_features=units, out_features=len(TENSE))
        self.voice = nn.Linear(in_features=units, out_features=len(VOICE))
        self.gender = nn.Linear(in_features=units, out_features=len(GENDER))
        self.plural = nn.Linear(in_features=units, out_features=len(PLURAL))

        self.r1 = nn.Linear(in_features=units, out_features=len(RADICALS))
        self.r2 = nn.Linear(in_features=units, out_features=len(RADICALS))
        self.r3 = nn.Linear(in_features=units, out_features=len(RADICALS))
        self.r4 = nn.Linear(in_features=units, out_features=len(RADICALS))

        self.features = {
            'B': self.binyan,
            'T': self.tense,
            'V': self.voice,
            'G': self.gender,
            'P': self.plural,

            'R1': self.r1,
            'R2': self.r2,
            'R3': self.r3,
            'R4': self.r4,
        }

    def forward(self, x):
        embeds = self.embed(x)

        lstm_out, (h_n, c_n) = self.lstm1(embeds)
        left, right = torch.chunk(h_n, 2, dim=0)
        merge = torch.squeeze(left + right)

        outputs = { k: f(merge) for k, f in self.features.items() }
        return outputs


In [3]:

def sanity():
    model = create_model(100)
    with torch.no_grad():
        verbs = wordlist2numpy(["כשאתאקלם"])
        verbs = to_device(torch.from_numpy(verbs).to(torch.int64))
        tag_scores = model(verbs)
        for k in NAMES:
            print(k)
            v = nn.Softmax()(tag_scores[k]).cpu().detach().numpy()
            print(v)
            print(f'{np.mean(v)=}')
            print(f'{-np.log(1/len(v))=}')
            print()

# sanity()

In [4]:
import concrete

def load_dataset(file_pat):
    *features_train, verbs_train = concrete.load_dataset(f'{file_pat}_train.tsv')
    *features_test, verbs_test = concrete.load_dataset(f'{file_pat}_test.tsv')
    return ((wordlist2numpy(verbs_train), list_of_lists_to_category(features_train)),
            (wordlist2numpy(verbs_test), list_of_lists_to_category(features_test)))

def load_dataset_split(filename, split):
    *features_train, verbs_train = concrete.load_dataset(filename)
    features_test = [t[-split:] for t in features_train]
    verbs_test = verbs_train[-split:]
    del verbs_train[-split:]
    for t in features_train:
        del t[-split:]
    return ((wordlist2numpy(verbs_train), list_of_lists_to_category(features_train)),
            (wordlist2numpy(verbs_test), list_of_lists_to_category(features_test)))


In [5]:
BATCH_SIZE = 64

def batch(a):
    ub = a.shape[0] // BATCH_SIZE * BATCH_SIZE
    return to_device(torch.from_numpy(a[:ub]).to(torch.int64)).split(BATCH_SIZE)

def batch_all_ys(ys):
    res = []
    m = {k: batch(ys[k]) for k in NAMES}
    nbatches = len(m['B'])
    for i in range(nbatches):
        res.append({k: m[k][i] for k in NAMES})
    return res

def callback(phase, epoch, batch, total, running_corrects, running_divisor, running_loss):
    mean_loss = np.mean(running_loss)
    accuracies = {k: running_corrects[k] / running_divisor for k in running_corrects}
    
    print("{:2} {:5}/{:5}".format(epoch, batch, total), end=' ')
    for k in accuracies:
        print("{}_acc: {:.3f}".format(k, accuracies[k]), end=' ')
    print("Loss: {:.4f}".format(mean_loss), end='\r')

    pref = "train/" if phase == 'train' else "val/"
    wandb.log({'phase': phase,
               'epoch': epoch,
               # 'batch': batch,
               f"{pref}Loss": mean_loss,
               **{f"{pref}Accuracy_{k}": accuracies[k] for k in accuracies}})

def fit(model, x_train, y_train, x_test, y_test, *, epochs, criterion, optimizer, runsize, train_only=False):
    
    data = {
        'train': (batch(x_train), batch_all_ys(y_train)),
        'test': (batch(x_test), batch_all_ys(y_test))
    }
    
    initial = True
    
    for epoch in range(epochs):
        for phase in ['train', 'test']:
            if train_only and phase != 'train':
                continue

            if phase == 'train':
                model.train()
            else:
                model.eval()

            total = len(data[phase][0])

            running_corrects = {k: 0.0 for k in NAMES}
            running_divisor = 0
            running_loss = []

            for i, (inputs, labels) in enumerate(zip(*data[phase]), 1):
                
                if phase == 'train':
                    outputs = model(inputs)
                else:
                    with torch.no_grad():
                        outputs = model(inputs)

                losses = {k: criterion(outputs[k], labels[k]) for k in outputs}
                if initial:
                    expected_losses = {k: -np.log(1/len(FEATURES[k])) for k in outputs}
                    # print(*(f'{k}: {v.item():.4f}' for k, v in losses.items()))
                    # print(*(f'{k}: {v:.4f}' for k, v in expected_losses.items()))
                    assert all(abs(1 - losses[k] / expected_losses[k]) < 0.1
                               for k in losses)
                    initial = False
                loss = sum(losses.values())

                if phase == 'train':
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    
                running_loss.append(loss.item())
                
                running_divisor += inputs.size(0)
                for k in outputs:
                    running_corrects[k] += torch.sum(torch.argmax(outputs[k], dim=1) == labels[k].detach())
                
                if phase == 'train' and i % runsize == 0:
                    callback(phase, epoch, i, total, running_corrects, running_divisor, running_loss)
                    running_corrects = {k: 0.0 for k in NAMES}
                    running_divisor = 0.0
                    running_loss = []
                    
            if phase != 'train':
                callback(phase, epoch, i, total, running_corrects, running_divisor, running_loss)
                
            print()

@torch.no_grad()
def predict(model, *verbs):
    model.eval()
    verbs = wordlist2numpy(verbs)
    verbs = to_device(torch.from_numpy(verbs).to(torch.int64))
    outputs = model(verbs)
    res = {k: from_category(k, torch.argmax(v))
              for k, v in outputs.items()}
    res['R'] = ''.join(res[k] for k in ['R1', 'R2', 'R3', 'R4']).replace('.', '')
    return res
    

In [9]:
arity = '3'
artifact_name = f'all_{arity}_shuffled'
filename = f'synthetic/{artifact_name}.tsv'  # all_verbs_shuffled
test_size = 5000

artifact = wandb.Artifact(artifact_name, type='dataset')
artifact.add_file(filename)

train, test = load_dataset_split(filename, split=test_size)

In [10]:
config = {
    'optimizer': 'adam',
    'batch_size': BATCH_SIZE,
    'epochs': 1,
    'runsize': 128,
    'test_size': test_size,
}
# group = f'lr_units_grid_search-{arity}-{wandb.util.generate_id()}'

units = 400
lr = 8e-4

config.update({
    'units': units,
    'lr': lr,
})

run = wandb.init(project="rootem",
                 # group=group,
                 name=f'all-{arity}-lr_{lr:.0e}-units_{units}',
                 tags=['all', arity, 'synthetic', 'shuffle', 'no_prefix'],
                 config=config)

run.use_artifact(artifact)

wandb.config.update(config, allow_val_change=True)

model = to_device(Model(units=config['units']))
wandb.watch(model)

print(config)
fit(model,
    *train,
    *test,
    epochs=config['epochs'],
    criterion=nn.CrossEntropyLoss(),
    optimizer=torch.optim.Adam(model.parameters(), lr=config['lr']),
    runsize=config['runsize'],
    train_only=False,
)

wandb.save(f"simple_{arity}.h5")

{'optimizer': 'adam', 'batch_size': 64, 'epochs': 1, 'runsize': 128, 'test_size': 5000, 'units': 400, 'lr': 0.0008}
 0 10496/10547 B_acc: 0.829 T_acc: 0.921 V_acc: 0.688 G_acc: 0.800 P_acc: 0.979 R1_acc: 0.975 R2_acc: 0.818 R3_acc: 1.000 R4_acc: 0.983 Loss: 1.9435
 0    78/   78 B_acc: 0.829 T_acc: 0.923 V_acc: 0.689 G_acc: 0.790 P_acc: 0.977 R1_acc: 0.973 R2_acc: 0.814 R3_acc: 1.000 R4_acc: 0.984 Loss: 1.9558


[]

In [None]:
print(predict(model, 'סבסו'))
print(predict(model, 'מקדו'))
print(predict(model, 'נמזר'))
print(predict(model, 'כרדו'))

In [None]:
print(predict(model, 'הבריל'))
print(predict(model, 'חגוו'))
print(predict(model, 'עגו'))
print(predict(model, 'צירלל'))

In [None]:
print(predict(model, "השטקרפתי"))

In [None]:
print(predict(model, "ישצו"))

In [None]:
print(f'{1e-4:.0e}')