In [1]:
import os

import numpy as np
import wandb

import utils
from stats import Stats
from naive_model import NaiveModel
from encoding import *

os.environ['WANDB_SILENT'] = 'true'

NUM_EMBEDDING = 2000

In [2]:
import torch

import torch.nn as nn
assert torch.cuda.is_available()

def to_device(d):
    if hasattr(d, 'cuda'):
        return d.cuda()
    return {k: v.cuda() for k, v in d.items()}

class Model(nn.Module):
    def __init__(self, units):
        super().__init__()
        self.units = units

        self.embed = nn.Embedding(num_embeddings=NUM_EMBEDDING, embedding_dim=units)
        self.lstm1 = nn.LSTM(input_size=units, hidden_size=units, num_layers=1, batch_first=True, bidirectional=True)

        self.binyan = nn.Linear(in_features=units, out_features=len(BINYAN))
        self.tense = nn.Linear(in_features=units, out_features=len(TENSE))
        self.voice = nn.Linear(in_features=units, out_features=len(VOICE))
        self.gender = nn.Linear(in_features=units, out_features=len(GENDER))
        self.plural = nn.Linear(in_features=units, out_features=len(PLURAL))

        self.r1 = nn.Linear(in_features=units, out_features=len(RADICALS))
        self.r2 = nn.Linear(in_features=units, out_features=len(RADICALS))
        self.r3 = nn.Linear(in_features=units, out_features=len(RADICALS))
        self.r4 = nn.Linear(in_features=units, out_features=len(RADICALS))

        self.features = {
            'B': self.binyan,
            'T': self.tense,
            'V': self.voice,
            'G': self.gender,
            'P': self.plural,

            'R1': self.r1,
            'R2': self.r2,
            'R3': self.r3,
            'R4': self.r4,
        }
        wandb.watch(self)

    def forward(self, x):
        embeds = self.embed(x)

        lstm_out, (h_n, c_n) = self.lstm1(embeds)
        left, right = torch.chunk(h_n, 2, dim=0)
        merge = torch.squeeze(left + right)

        outputs = { k: f(merge) for k, f in self.features.items() }
        return outputs


In [3]:

def sanity():
    model = create_model(100)
    with torch.no_grad():
        verbs = wordlist2numpy(["כשאתאקלם"])
        verbs = to_device(torch.from_numpy(verbs).to(torch.int64))
        tag_scores = model(verbs)
        for k in NAMES:
            print(k)
            v = nn.Softmax()(tag_scores[k]).cpu().detach().numpy()
            print(v)
            print(f'{np.mean(v)=}')
            print(f'{-np.log(1/len(v))=}')
            print()

# sanity()

In [4]:
import concrete

def load_dataset(file_pat):
    *features_train, verbs_train = concrete.load_dataset(f'{file_pat}_train.tsv')
    *features_test, verbs_test = concrete.load_dataset(f'{file_pat}_test.tsv')
    return ((wordlist2numpy(verbs_train), list_of_lists_to_category(features_train)),
            (wordlist2numpy(verbs_test), list_of_lists_to_category(features_test)))

def load_dataset_split(filename, split):
    *features_train, verbs_train = concrete.load_dataset(filename)
    features_test = [t[-split:] for t in features_train]
    verbs_test = verbs_train[-split:]
    del verbs_train[-split:]
    for t in features_train:
        del t[-split:]
    return ((wordlist2numpy(verbs_train), list_of_lists_to_category(features_train)),
            (wordlist2numpy(verbs_test ), list_of_lists_to_category(features_test )))


In [9]:
BATCH_SIZE = 64

def batch(a):
    ub = a.shape[0] // BATCH_SIZE * BATCH_SIZE
    return to_device(torch.from_numpy(a[:ub]).to(torch.int64)).split(BATCH_SIZE)

def batch_all_ys(ys):
    res = []
    m = {k: batch(ys[k]) for k in NAMES}
    nbatches = len(m['B'])
    for i in range(nbatches):
        res.append({k: m[k][i] for k in NAMES})
    return res

def fit(model, train, test, *, epochs,  runsize, criterion, optimizer, phases, teacher):
    x_train, y_train = train
    x_test, y_test = train
    data = {
        'train': (batch(x_train), batch_all_ys(y_train)),
        'test':  (batch(x_test ), batch_all_ys(y_test ))
    }

    stats = Stats(runsize)
    
    for epoch in range(epochs):
        stats.epoch_start()
        
        for phase in phases:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            stats.phase_start(phase, batches_in_phase=len(data[phase][0]))

            for inputs, labels in zip(*data[phase]):
                stats.batch_start()
                
                if phase == 'train':
                    outputs = model(inputs)
                else:
                    with torch.no_grad():
                        outputs = model(inputs)

                if teacher is not None:
                    pseudo_labels = teacher(inputs)
                    losses = {k: criterion(outputs[k].double(), pseudo_labels[k]) for k in outputs}
                else:
                    losses = {k: criterion(outputs[k].double(), labels[k]) for k in outputs}
                
                if phase == 'train' and isinstance(criterion, nn.CrossEntropyLoss):
                    stats.assert_resonable_initial(losses)
                
                loss = sum(losses.values())

                if phase == 'train':
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                
                stats.update(loss=loss.item(),
                             batch_size=inputs.size(0),
                             d={k: (outputs[k], labels[k].detach()) for k in outputs})
                
                stats.batch_end()
            stats.phase_end()
        stats.epoch_end()

@torch.no_grad()
def predict(model, *verbs):
    model.eval()
    verbs = wordlist2numpy(verbs)
    verbs = to_device(torch.from_numpy(verbs).to(torch.int64))
    outputs = model(verbs)
    res = {k: from_category(k, torch.argmax(v))
           for k, v in outputs.items()}
    res['R'] = ''.join(res[k] for k in ['R1', 'R2', 'R3', 'R4']).replace('.', '')
    return res
    

In [24]:
arity = 'combined'
gen = 'unique'
artifact_name = f'{gen}_{arity}_shuffled'
filename = f'synthetic/{artifact_name}.tsv'  # all_verbs_shuffled
test_size = 5000

artifact = wandb.Artifact(artifact_name, type='dataset')
artifact.add_file(filename)

train, test = load_dataset_split(filename, split=test_size)

In [None]:
def naive_config(filename):
    return {
        'model': NaiveModel.learn_from_file(filename),
        'phases': ['test'],
        'criterion': nn.CrossEntropyLoss(),
        'optimizer': None
    }

def teacher_config(train):
    res = standard_config()
    res['teacher'] = NaiveModel.learn_from_data(train)
    res['criterion'] = nn.BCEWithLogitsLoss()  # BCELoss: works, but total loss is nan
    return res

In [25]:
os.environ['WANDB_MODE'] = 'run'  # 'dryrun'

config = {
    'optimizer': 'adam',
    'batch_size': BATCH_SIZE,
    'epochs': 1,
    'runsize': 8,
    'test_size': test_size,
}
# group = f'lr_units_grid_search-{arity}-{wandb.util.generate_id()}'

def experiment(lr):
    units = 400

    config.update({
        'units': units,
        'lr': lr,
    })

    run = wandb.init(project="rootem",
                     # group=group,
                     name=f'{gen}-{arity}-{lr:.0e}',# f'{arity}-batch_{BATCH_SIZE}', # f'all-{arity}-lr_{lr:.0e}-units_{units}',
                     tags=[gen, arity, 'synthetic', 'shuffle', 'no_prefix'],
                     config=config)

    run.use_artifact(artifact)

    wandb.config.update(config, allow_val_change=True)

    def standard_config():
        model = to_device(Model(units=config['units']))
        optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'])
        return {
            'model': model,
            'criterion': nn.CrossEntropyLoss(),
            'optimizer': optimizer,
            'phases': ['train', 'test'],
            'teacher': None
        }

    print(config)
    fit(train=train,
        test=test,
        epochs=config['epochs'],
        runsize=config['runsize'],
        **standard_config()
    )
    wandb.save(f"simple_{arity}.h5")

for lr in [8e-4, 10e-4, 20e-4, 30e-4, 40e-4, 50e-4, 60e-4]:
    experiment(lr)

{'optimizer': 'adam', 'batch_size': 64, 'epochs': 1, 'runsize': 8, 'test_size': 5000, 'units': 400, 'lr': 0.0008}
 1  1256/ 1258 B_acc: 0.992 T_acc: 0.996 V_acc: 0.994 G_acc: 1.000 P_acc: 0.998 R1_acc: 0.990 R2_acc: 0.988 R3_acc: 0.996 R4_acc: 0.998 Loss: 0.20089
 1  1258/ 1258 B_acc: 0.990 T_acc: 0.995 V_acc: 0.996 G_acc: 0.999 P_acc: 0.997 R1_acc: 0.992 R2_acc: 0.991 R3_acc: 0.996 R4_acc: 0.996 Loss: 0.2056


{'optimizer': 'adam', 'batch_size': 64, 'epochs': 1, 'runsize': 8, 'test_size': 5000, 'units': 400, 'lr': 0.001}
 1  1256/ 1258 B_acc: 0.992 T_acc: 0.996 V_acc: 0.992 G_acc: 1.000 P_acc: 0.998 R1_acc: 0.990 R2_acc: 0.990 R3_acc: 0.994 R4_acc: 0.998 Loss: 0.23418
 1  1258/ 1258 B_acc: 0.989 T_acc: 0.995 V_acc: 0.996 G_acc: 0.999 P_acc: 0.997 R1_acc: 0.993 R2_acc: 0.990 R3_acc: 0.995 R4_acc: 0.996 Loss: 0.2071


{'optimizer': 'adam', 'batch_size': 64, 'epochs': 1, 'runsize': 8, 'test_size': 5000, 'units': 400, 'lr': 0.002}
 1  1256/ 1258 B_acc: 0.996 T_acc: 0.998 V_acc: 0.996 G_acc: 1.000 P_acc: 0.996 R1_acc: 0.996 R2_acc: 0.984 R3_acc: 0.988 R4_acc: 0.998 Loss: 0.18975
 1  1258/ 1258 B_acc: 0.989 T_acc: 0.995 V_acc: 0.995 G_acc: 0.999 P_acc: 0.996 R1_acc: 0.992 R2_acc: 0.989 R3_acc: 0.995 R4_acc: 0.995 Loss: 0.2067


{'optimizer': 'adam', 'batch_size': 64, 'epochs': 1, 'runsize': 8, 'test_size': 5000, 'units': 400, 'lr': 0.003}
 1  1256/ 1258 B_acc: 0.988 T_acc: 0.998 V_acc: 0.994 G_acc: 1.000 P_acc: 0.998 R1_acc: 0.990 R2_acc: 0.979 R3_acc: 0.996 R4_acc: 0.998 Loss: 0.24706
 1  1258/ 1258 B_acc: 0.986 T_acc: 0.994 V_acc: 0.995 G_acc: 0.998 P_acc: 0.996 R1_acc: 0.989 R2_acc: 0.985 R3_acc: 0.992 R4_acc: 0.996 Loss: 0.2443


{'optimizer': 'adam', 'batch_size': 64, 'epochs': 1, 'runsize': 8, 'test_size': 5000, 'units': 400, 'lr': 0.004}
 1  1256/ 1258 B_acc: 0.984 T_acc: 0.994 V_acc: 0.992 G_acc: 1.000 P_acc: 1.000 R1_acc: 0.992 R2_acc: 0.980 R3_acc: 0.990 R4_acc: 0.998 Loss: 0.25330
 1  1258/ 1258 B_acc: 0.984 T_acc: 0.993 V_acc: 0.995 G_acc: 0.999 P_acc: 0.996 R1_acc: 0.988 R2_acc: 0.979 R3_acc: 0.993 R4_acc: 0.995 Loss: 0.2757


{'optimizer': 'adam', 'batch_size': 64, 'epochs': 1, 'runsize': 8, 'test_size': 5000, 'units': 400, 'lr': 0.005}
 1  1256/ 1258 B_acc: 0.982 T_acc: 0.990 V_acc: 0.992 G_acc: 0.998 P_acc: 0.998 R1_acc: 0.990 R2_acc: 0.973 R3_acc: 0.988 R4_acc: 0.998 Loss: 0.36944
 1  1258/ 1258 B_acc: 0.978 T_acc: 0.992 V_acc: 0.993 G_acc: 0.998 P_acc: 0.995 R1_acc: 0.984 R2_acc: 0.975 R3_acc: 0.989 R4_acc: 0.993 Loss: 0.3422


{'optimizer': 'adam', 'batch_size': 64, 'epochs': 1, 'runsize': 8, 'test_size': 5000, 'units': 400, 'lr': 0.006}
 1  1256/ 1258 B_acc: 0.975 T_acc: 0.994 V_acc: 0.988 G_acc: 0.996 P_acc: 1.000 R1_acc: 0.992 R2_acc: 0.982 R3_acc: 0.988 R4_acc: 0.994 Loss: 0.36916
 1  1258/ 1258 B_acc: 0.977 T_acc: 0.991 V_acc: 0.990 G_acc: 0.998 P_acc: 0.994 R1_acc: 0.985 R2_acc: 0.968 R3_acc: 0.985 R4_acc: 0.991 Loss: 0.3978


In [None]:
print(predict(model, 'סבסו'))
print(predict(model, 'מקדו'))
print(predict(model, 'נמזר'))
print(predict(model, 'כרדו'))

In [None]:
print(predict(model, 'הבריל'))
print(predict(model, 'חגוו'))
print(predict(model, 'עגו'))
print(predict(model, 'צירלל'))

In [None]:
print(predict(model, "השטקרפתי"))

In [None]:
print(predict(model, "ישסו"))

In [None]:
import importlib
import encoding
import naive_model
encoding = importlib.reload(encoding)
naive_model = importlib.reload(naive_model)
NaiveModel = naive_model.NaiveModel