In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import itertools

import numpy as np
import wandb
import more_itertools as mi

NUM_EMBEDDING = 2000
wandb.join()

In [3]:
import utils
from naive_model import NaiveModel
import encoding

In [4]:
import torch

import torch.nn as nn
assert torch.cuda.is_available()
torch.backends.cudnn.benchmark = True
torch.set_default_tensor_type('torch.cuda.FloatTensor')

def to_device(d):
    if hasattr(d, 'cuda'):
        return d.cuda()
    return {k: v.cuda() for k, v in d.items()}

class Model(nn.Module):
    def __init__(self, units, combinations=encoding.NAMES):
        super().__init__()
        self.units = units

        self.embed = nn.Embedding(num_embeddings=NUM_EMBEDDING, embedding_dim=units)
        self.lstm = nn.LSTM(input_size=units, hidden_size=units, num_layers=1, batch_first=True, bidirectional=True)

        self.tasks = {}
        for combination in combinations:
            out = nn.Linear(in_features=units, out_features=encoding.class_size(combination))
            self.tasks[combination] = out
            setattr(self, encoding.class_name(combination), out)

    def forward(self, x):
        embeds = self.embed(x)

        lstm_out, (h_n, c_n) = self.lstm(embeds)
        left, right = torch.chunk(h_n, 2, dim=0)
        merge = torch.squeeze(left + right)

        outputs = { combination: f(merge) for combination, f in self.tasks.items() }
        return outputs


In [None]:
import functools
import operator as op

def sanity():
    model = to_device(Model(100, combinations=[('B', 'T')]))
    print(model)
    with torch.no_grad():
        verbs = encoding.wordlist2numpy(["אתאקלם", "יכפיל"])
        labels = {'B': torch.Tensor([3, 5]), 'T': torch.Tensor([2, 4])}
        verbs = to_device(torch.from_numpy(verbs).to(torch.int64))
        tag_scores = model(verbs)
        for combination in tag_scores:
            print(combination)
            v = tag_scores[combination]
            c_labels = functools.reduce(op.mul, [labels[k] for k in combination])
            print(f'{v=}')
            print(f'{c_labels=}')
            print(f'{labels=}')
            print()

sanity()

In [5]:
import concrete

def load_dataset(file_pat):
    *features_train, verbs_train = concrete.load_dataset(f'{file_pat}_train.tsv')
    *features_test, verbs_test = concrete.load_dataset(f'{file_pat}_test.tsv')
    return ((encoding.wordlist2numpy(verbs_train), encoding.list_of_lists_to_category(features_train)),
            (encoding.wordlist2numpy(verbs_test) , encoding.list_of_lists_to_category(features_test )))

def load_dataset_split(filename, split):
    *features_train, verbs_train = concrete.load_dataset(filename)
    features_test = [t[-split:] for t in features_train]
    verbs_test = verbs_train[-split:]
    del verbs_train[-split:]
    for t in features_train:
        del t[-split:]
    return ((encoding.wordlist2numpy(verbs_train), encoding.list_of_lists_to_category(features_train)),
            (encoding.wordlist2numpy(verbs_test ), encoding.list_of_lists_to_category(features_test )))


In [6]:
BATCH_SIZE = 128

def batch(a):
    ub = a.shape[0] // BATCH_SIZE * BATCH_SIZE
    return to_device(torch.from_numpy(a[:ub]).to(torch.int64)).split(BATCH_SIZE)

def ravel_multi_index(ys, combinations):
    nsamples = len(next(iter(ys.values())))
    return {combination: ys[combination] if not isinstance(combination, (tuple, list))
                   else (ys[combination[0]] if len(combination) == 0
                   else np.ravel_multi_index([ys[k] for k in combination], encoding.combined_shape(combination)))
            for combination in combinations}

def batch_all_ys(ys):
    res = []
    m = {k: batch(ys[k]) for k in ys}
    nbatches = len(next(iter(m.values())))
    for i in range(nbatches):
        res.append({k: m[k][i] for k in ys})
    return res

def fit(model, train, test, *, epochs,  runsize, criterion, optimizer, phases, teacher):
    x_train, y_train = train
    x_test, y_test = test
    data = {
        'train': (batch(x_train), batch_all_ys(y_train)),
        'test':  (batch(x_test ), batch_all_ys(y_test ))
    }

    stats = utils.Stats(model.tasks.keys(), runsize)
    
    for epoch in range(epochs):
        stats.epoch_start()
        
        for phase in phases:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            stats.phase_start(phase, batches_in_phase=len(data[phase][0]))

            for inputs, labels in zip(*data[phase]):
                inputs = to_device(inputs)
                labels = to_device(labels)
                stats.batch_start()
                
                if phase == 'train':
                    outputs = model(inputs)
                else:
                    with torch.no_grad():
                        outputs = model(inputs)

                if teacher is not None:
                    pseudo_labels = teacher(inputs)
                    losses = {k: criterion(outputs[k].double(), pseudo_labels[k]) for k in outputs}
                else:
                    losses = {combination: criterion(output.double(), labels[combination])
                              for combination, output in outputs.items()}
                
                # if phase == 'train' and isinstance(criterion, nn.CrossEntropyLoss):
                #     stats.assert_reasonable_initial(losses)
                
                loss = sum(losses.values())

                if phase == 'train':
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                
                stats.update(loss=loss.item(),
                             batch_size=inputs.size(0),
                             d={k: (outputs[k], labels[k]) for k in outputs})
                
                stats.batch_end()
            stats.phase_end()
        stats.epoch_end()
    return stats

@torch.no_grad()
def predict(model, *verbs):
    model.eval()
    verbs = encoding.wordlist2numpy(verbs)
    verbs = to_device(torch.from_numpy(verbs).to(torch.int64))
    outputs = model(verbs)
    res = {}
    # FIX: assumes no overlaps
    for combination, v in outputs.items():
        combined_index = torch.argmax(v).cpu().data.numpy()
        indices = np.unravel_index(combined_index, encoding.combined_shape(combination))
        for k, i in zip(combination, indices):
            # assert k not in res, "Overlapping classes are not handled"
            s = k
            if k in res:
                s += "'"
            res[s] = encoding.from_category(k, i)
    if all(r in res for r in ['R1', 'R2', 'R3', 'R4']):
        res['R'] = ''.join(res[k] for k in ['R1', 'R2', 'R3', 'R4']).replace('.', '')
    return res
    

In [7]:
torch.manual_seed(0)
np.random.seed(0)

arity = '3'
gen = 'all'
artifact_name = f'{gen}_{arity}_shufroot'
filename = f'synthetic/{artifact_name}.tsv'  # all_verbs_shuffled
test_size = 5000

artifact = wandb.Artifact(artifact_name, type='dataset')
artifact.add_file(filename)

(train_x, pre_train_y), (test_x, pre_test_y) = load_dataset_split(filename, split=test_size)

def shuffle_in_unison_scary(arrs):
    rng_state = np.random.get_state()
    for arr in arrs:
        np.random.set_state(rng_state)
        np.random.shuffle(arr)
shuffle_in_unison_scary([train_x, *pre_train_y.values()])
# TODO: shuffle train_x and pre_train_y

In [8]:
def naive_config(filename):
    return {
        'model': NaiveModel.learn_from_file(filename),
        'phases': ['test'],
        'criterion': nn.CrossEntropyLoss(),
        'optimizer': None
    }

def teacher_config(train):
    res = standard_config()
    res['teacher'] = NaiveModel.learn_from_data(train)
    res['criterion'] = nn.BCEWithLogitsLoss()  # BCELoss: works, but total loss is nan
    return res


def standard_config(model, lr):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    return {
        'model': model,
        'criterion': nn.CrossEntropyLoss(),
        'optimizer': optimizer,
        'phases': ['train', 'test'],
        'teacher': None
    }

In [9]:
%env WANDB_SILENT true
%env WANDB_MODE run

config = {
    'optimizer': 'adam',
    'batch_size': BATCH_SIZE,
    'epochs': 1,
    'runsize': 2 * 8192 // BATCH_SIZE,
    'test_size': test_size,
}

def experiment(combinations):
    train_y = ravel_multi_index(pre_train_y, combinations)
    test_y = ravel_multi_index(pre_test_y, combinations)
    train = (train_x, train_y)
    test = (test_x, test_y)
    
    units = 400
    lr = 8e-4

    config.update({
        'units': units,
        'lr': lr,
    })
    names_str = '+'.join(encoding.class_name(combination) for combination in combinations)
    run = wandb.init(project="rootem",
                     group=f'joint-{gen}-{arity}-noroot',  # f'lr_units_grid_search-{arity}-{wandb.util.generate_id()}',
                     name=f'joint-{gen}-{arity}-{names_str}-noroot',  # f'{gen}-{arity}-{lr:.0e}',# f'{arity}-batch_{BATCH_SIZE}', # f'all-{arity}-lr_{lr:.0e}-units_{units}',
                     tags=[gen, arity, 'synthetic', 'shuffle-root', 'no_prefix', 'shuffle', 'partitions'],
                     config=config)
    with run:
        run.use_artifact(artifact)

        wandb.config.update(config, allow_val_change=True)

        print(config)

        model = to_device(Model(units=units, combinations=combinations))
        if isinstance(model, nn.Module):
            wandb.watch(model)

        stats = fit(train=train,
            test=test,
            epochs=config['epochs'],
            runsize=config['runsize'],
            **standard_config(model, lr)
        )
        wandb.save(f"simple_{arity}.h5")
    # wandb.join()
    return model, stats

def experiment_partitions():
    import more_itertools as mi
    partitions = [[tuple(x) for x in part] for part in mi.set_partitions(set(encoding.CLASSES) - {'R1', 'R2', 'R3', 'R4'})]
    for i, part in enumerate(partitions, 1):
        print(f'{i}/{len(partitions)}: {part}')
        model, stats = experiment(part)

experiment_partitions()
# TODO: statistics for each k in each combination

env: WANDB_SILENT=true
env: WANDB_MODE=run
28/52: [('P',), ('B', 'V', 'G'), ('T',)]


{'optimizer': 'adam', 'batch_size': 128, 'epochs': 1, 'runsize': 128, 'test_size': 5000, 'units': 400, 'lr': 0.0008}
 1  5376/ 5470 P_acc: 0.980 BxVxG_acc: 0.440 T_acc: 0.931 B_acc: 0.820 V_acc: 0.685 G_acc: 0.788 Loss: 1.3253
 1    39/   39 P_acc: 0.979 BxVxG_acc: 0.447 T_acc: 0.919 B_acc: 0.811 V_acc: 0.689 G_acc: 0.793 Loss: 1.3804
29/52: [('P',), ('V', 'G'), ('B', 'T')]


{'optimizer': 'adam', 'batch_size': 128, 'epochs': 1, 'runsize': 128, 'test_size': 5000, 'units': 400, 'lr': 0.0008}
 1  5376/ 5470 P_acc: 0.980 VxG_acc: 0.533 BxT_acc: 0.791 V_acc: 0.690 G_acc: 0.787 B_acc: 0.826 T_acc: 0.929 Loss: 1.2992
 1    39/   39 P_acc: 0.979 VxG_acc: 0.535 BxT_acc: 0.768 V_acc: 0.690 G_acc: 0.790 B_acc: 0.803 T_acc: 0.913 Loss: 1.3684
30/52: [('B', 'P'), ('G',), ('V', 'T')]


{'optimizer': 'adam', 'batch_size': 128, 'epochs': 1, 'runsize': 128, 'test_size': 5000, 'units': 400, 'lr': 0.0008}
 1  5376/ 5470 BxP_acc: 0.820 G_acc: 0.795 VxT_acc: 0.679 B_acc: 0.827 P_acc: 0.979 V_acc: 0.694 T_acc: 0.925 Loss: 1.2936
 1    39/   39 BxP_acc: 0.806 G_acc: 0.801 VxT_acc: 0.672 B_acc: 0.814 P_acc: 0.977 V_acc: 0.692 T_acc: 0.912 Loss: 1.3573
31/52: [('P',), ('B', 'G'), ('V', 'T')]


{'optimizer': 'adam', 'batch_size': 128, 'epochs': 1, 'runsize': 128, 'test_size': 5000, 'units': 400, 'lr': 0.0008}
 1  5376/ 5470 P_acc: 0.981 BxG_acc: 0.665 VxT_acc: 0.680 B_acc: 0.827 G_acc: 0.793 V_acc: 0.694 T_acc: 0.925 Loss: 1.3001
 1    39/   39 P_acc: 0.978 BxG_acc: 0.652 VxT_acc: 0.668 B_acc: 0.810 G_acc: 0.794 V_acc: 0.689 T_acc: 0.910 Loss: 1.3850
32/52: [('P',), ('G',), ('B', 'V', 'T')]


{'optimizer': 'adam', 'batch_size': 128, 'epochs': 1, 'runsize': 128, 'test_size': 5000, 'units': 400, 'lr': 0.0008}
 1  5376/ 5470 P_acc: 0.980 G_acc: 0.797 BxVxT_acc: 0.568 B_acc: 0.819 V_acc: 0.693 T_acc: 0.925 Loss: 1.2466
 1    39/   39 P_acc: 0.979 G_acc: 0.799 BxVxT_acc: 0.562 B_acc: 0.797 V_acc: 0.694 T_acc: 0.916 Loss: 1.3121
33/52: [('B', 'P', 'V'), ('G',), ('T',)]


{'optimizer': 'adam', 'batch_size': 128, 'epochs': 1, 'runsize': 128, 'test_size': 5000, 'units': 400, 'lr': 0.0008}
 1  5376/ 5470 BxPxV_acc: 0.572 G_acc: 0.794 T_acc: 0.931 B_acc: 0.821 P_acc: 0.979 V_acc: 0.693 Loss: 1.3610
 1    39/   39 BxPxV_acc: 0.575 G_acc: 0.801 T_acc: 0.917 B_acc: 0.809 P_acc: 0.978 V_acc: 0.689 Loss: 1.4103
34/52: [('P', 'V'), ('B', 'G'), ('T',)]


{'optimizer': 'adam', 'batch_size': 128, 'epochs': 1, 'runsize': 128, 'test_size': 5000, 'units': 400, 'lr': 0.0008}
 1  5376/ 5470 PxV_acc: 0.695 BxG_acc: 0.662 T_acc: 0.932 P_acc: 0.981 V_acc: 0.699 B_acc: 0.825 G_acc: 0.792 Loss: 1.3893
 1    39/   39 PxV_acc: 0.689 BxG_acc: 0.657 T_acc: 0.915 P_acc: 0.979 V_acc: 0.694 B_acc: 0.814 G_acc: 0.796 Loss: 1.4477
35/52: [('P', 'V'), ('G',), ('B', 'T')]


{'optimizer': 'adam', 'batch_size': 128, 'epochs': 1, 'runsize': 128, 'test_size': 5000, 'units': 400, 'lr': 0.0008}
 1  5376/ 5470 PxV_acc: 0.693 G_acc: 0.793 BxT_acc: 0.791 P_acc: 0.981 V_acc: 0.698 B_acc: 0.825 T_acc: 0.927 Loss: 1.3295
 1    39/   39 PxV_acc: 0.688 G_acc: 0.798 BxT_acc: 0.774 P_acc: 0.979 V_acc: 0.692 B_acc: 0.809 T_acc: 0.913 Loss: 1.3898
36/52: [('B', 'V'), ('P', 'G'), ('T',)]


{'optimizer': 'adam', 'batch_size': 128, 'epochs': 1, 'runsize': 128, 'test_size': 5000, 'units': 400, 'lr': 0.0008}
 1  5376/ 5470 BxV_acc: 0.575 PxG_acc: 0.781 T_acc: 0.931 B_acc: 0.823 V_acc: 0.692 P_acc: 0.979 G_acc: 0.791 Loss: 1.3707
 1    39/   39 BxV_acc: 0.577 PxG_acc: 0.786 T_acc: 0.914 B_acc: 0.810 V_acc: 0.693 P_acc: 0.979 G_acc: 0.795 Loss: 1.4418
37/52: [('V',), ('B', 'P', 'G'), ('T',)]


{'optimizer': 'adam', 'batch_size': 128, 'epochs': 1, 'runsize': 128, 'test_size': 5000, 'units': 400, 'lr': 0.0008}
 1  5376/ 5470 V_acc: 0.697 BxPxG_acc: 0.660 T_acc: 0.932 B_acc: 0.826 P_acc: 0.977 G_acc: 0.791 Loss: 1.3880
 1    39/   39 V_acc: 0.694 BxPxG_acc: 0.657 T_acc: 0.918 B_acc: 0.819 P_acc: 0.977 G_acc: 0.796 Loss: 1.4251
38/52: [('V',), ('P', 'G'), ('B', 'T')]


{'optimizer': 'adam', 'batch_size': 128, 'epochs': 1, 'runsize': 128, 'test_size': 5000, 'units': 400, 'lr': 0.0008}
 1  5376/ 5470 V_acc: 0.696 PxG_acc: 0.782 BxT_acc: 0.790 P_acc: 0.980 G_acc: 0.792 B_acc: 0.825 T_acc: 0.928 Loss: 1.3439
 1    39/   39 V_acc: 0.694 PxG_acc: 0.785 BxT_acc: 0.771 P_acc: 0.975 G_acc: 0.797 B_acc: 0.806 T_acc: 0.912 Loss: 1.3976
39/52: [('B', 'V'), ('G',), ('P', 'T')]


{'optimizer': 'adam', 'batch_size': 128, 'epochs': 1, 'runsize': 128, 'test_size': 5000, 'units': 400, 'lr': 0.0008}
 1  5376/ 5470 BxV_acc: 0.574 G_acc: 0.797 PxT_acc: 0.928 B_acc: 0.820 V_acc: 0.694 P_acc: 0.981 T_acc: 0.930 Loss: 1.3462
 1    39/   39 BxV_acc: 0.572 G_acc: 0.799 PxT_acc: 0.917 B_acc: 0.803 V_acc: 0.688 P_acc: 0.980 T_acc: 0.917 Loss: 1.4165
40/52: [('V',), ('B', 'G'), ('P', 'T')]


{'optimizer': 'adam', 'batch_size': 128, 'epochs': 1, 'runsize': 128, 'test_size': 5000, 'units': 400, 'lr': 0.0008}
 1  5376/ 5470 V_acc: 0.697 BxG_acc: 0.665 PxT_acc: 0.930 B_acc: 0.825 G_acc: 0.795 P_acc: 0.981 T_acc: 0.931 Loss: 1.3738
 1    39/   39 V_acc: 0.692 BxG_acc: 0.655 PxT_acc: 0.916 B_acc: 0.813 G_acc: 0.795 P_acc: 0.980 T_acc: 0.916 Loss: 1.4373
41/52: [('V',), ('G',), ('B', 'P', 'T')]


{'optimizer': 'adam', 'batch_size': 128, 'epochs': 1, 'runsize': 128, 'test_size': 5000, 'units': 400, 'lr': 0.0008}
 1  5376/ 5470 V_acc: 0.696 G_acc: 0.795 BxPxT_acc: 0.787 B_acc: 0.822 P_acc: 0.980 T_acc: 0.928 Loss: 1.3253
 1    39/   39 V_acc: 0.695 G_acc: 0.800 BxPxT_acc: 0.780 B_acc: 0.814 P_acc: 0.977 T_acc: 0.915 Loss: 1.3644
42/52: [('B',), ('P',), ('V',), ('G', 'T')]


{'optimizer': 'adam', 'batch_size': 128, 'epochs': 1, 'runsize': 128, 'test_size': 5000, 'units': 400, 'lr': 0.0008}
 1  5376/ 5470 B_acc: 0.827 P_acc: 0.980 V_acc: 0.696 GxT_acc: 0.739 G_acc: 0.785 T_acc: 0.928 Loss: 1.3842
 1    39/   39 B_acc: 0.812 P_acc: 0.981 V_acc: 0.694 GxT_acc: 0.740 G_acc: 0.794 T_acc: 0.914 Loss: 1.4569
43/52: [('B',), ('P',), ('V', 'G'), ('T',)]


{'optimizer': 'adam', 'batch_size': 128, 'epochs': 1, 'runsize': 128, 'test_size': 5000, 'units': 400, 'lr': 0.0008}
 1  5376/ 5470 B_acc: 0.827 P_acc: 0.980 VxG_acc: 0.536 T_acc: 0.930 V_acc: 0.692 G_acc: 0.791 Loss: 1.3618
 1    39/   39 B_acc: 0.810 P_acc: 0.976 VxG_acc: 0.537 T_acc: 0.914 V_acc: 0.690 G_acc: 0.791 Loss: 1.4644
44/52: [('B',), ('P',), ('G',), ('V', 'T')]


{'optimizer': 'adam', 'batch_size': 128, 'epochs': 1, 'runsize': 128, 'test_size': 5000, 'units': 400, 'lr': 0.0008}
 1  5376/ 5470 B_acc: 0.828 P_acc: 0.981 G_acc: 0.796 VxT_acc: 0.680 V_acc: 0.695 T_acc: 0.927 Loss: 1.3031
 1    39/   39 B_acc: 0.812 P_acc: 0.977 G_acc: 0.798 VxT_acc: 0.673 V_acc: 0.693 T_acc: 0.913 Loss: 1.3965
45/52: [('B',), ('P', 'V'), ('G',), ('T',)]


{'optimizer': 'adam', 'batch_size': 128, 'epochs': 1, 'runsize': 128, 'test_size': 5000, 'units': 400, 'lr': 0.0008}
 1  5376/ 5470 B_acc: 0.827 PxV_acc: 0.695 G_acc: 0.797 T_acc: 0.930 P_acc: 0.981 V_acc: 0.699 Loss: 1.3964
 1    39/   39 B_acc: 0.810 PxV_acc: 0.687 G_acc: 0.799 T_acc: 0.919 P_acc: 0.978 V_acc: 0.693 Loss: 1.4792
46/52: [('B',), ('V',), ('P', 'G'), ('T',)]


{'optimizer': 'adam', 'batch_size': 128, 'epochs': 1, 'runsize': 128, 'test_size': 5000, 'units': 400, 'lr': 0.0008}
 1  5376/ 5470 B_acc: 0.827 V_acc: 0.696 PxG_acc: 0.784 T_acc: 0.932 P_acc: 0.980 G_acc: 0.794 Loss: 1.4019
 1    39/   39 B_acc: 0.819 V_acc: 0.693 PxG_acc: 0.787 T_acc: 0.917 P_acc: 0.977 G_acc: 0.799 Loss: 1.4791
47/52: [('B',), ('V',), ('G',), ('P', 'T')]


{'optimizer': 'adam', 'batch_size': 128, 'epochs': 1, 'runsize': 128, 'test_size': 5000, 'units': 400, 'lr': 0.0008}
 1  5376/ 5470 B_acc: 0.826 V_acc: 0.696 G_acc: 0.794 PxT_acc: 0.930 P_acc: 0.981 T_acc: 0.932 Loss: 1.3771
 1    39/   39 B_acc: 0.808 V_acc: 0.697 G_acc: 0.799 PxT_acc: 0.914 P_acc: 0.979 T_acc: 0.914 Loss: 1.4610
48/52: [('B', 'P'), ('V',), ('G',), ('T',)]


{'optimizer': 'adam', 'batch_size': 128, 'epochs': 1, 'runsize': 128, 'test_size': 5000, 'units': 400, 'lr': 0.0008}
 1  5376/ 5470 BxP_acc: 0.819 V_acc: 0.696 G_acc: 0.797 T_acc: 0.931 B_acc: 0.826 P_acc: 0.979 Loss: 1.3955
 1    39/   39 BxP_acc: 0.806 V_acc: 0.696 G_acc: 0.799 T_acc: 0.914 B_acc: 0.814 P_acc: 0.976 Loss: 1.4720
49/52: [('P',), ('B', 'V'), ('G',), ('T',)]


{'optimizer': 'adam', 'batch_size': 128, 'epochs': 1, 'runsize': 128, 'test_size': 5000, 'units': 400, 'lr': 0.0008}
 1  5376/ 5470 P_acc: 0.981 BxV_acc: 0.572 G_acc: 0.793 T_acc: 0.932 B_acc: 0.820 V_acc: 0.693 Loss: 1.3747
 1    39/   39 P_acc: 0.978 BxV_acc: 0.579 G_acc: 0.801 T_acc: 0.914 B_acc: 0.810 V_acc: 0.695 Loss: 1.4440
50/52: [('P',), ('V',), ('B', 'G'), ('T',)]


{'optimizer': 'adam', 'batch_size': 128, 'epochs': 1, 'runsize': 128, 'test_size': 5000, 'units': 400, 'lr': 0.0008}
 1  5376/ 5470 P_acc: 0.981 V_acc: 0.697 BxG_acc: 0.663 T_acc: 0.932 B_acc: 0.826 G_acc: 0.791 Loss: 1.4000
 1    39/   39 P_acc: 0.979 V_acc: 0.694 BxG_acc: 0.657 T_acc: 0.916 B_acc: 0.814 G_acc: 0.798 Loss: 1.4743
51/52: [('P',), ('V',), ('G',), ('B', 'T')]


{'optimizer': 'adam', 'batch_size': 128, 'epochs': 1, 'runsize': 128, 'test_size': 5000, 'units': 400, 'lr': 0.0008}
 1  5376/ 5470 P_acc: 0.980 V_acc: 0.699 G_acc: 0.796 BxT_acc: 0.788 B_acc: 0.822 T_acc: 0.928 Loss: 1.3486
 1    39/   39 P_acc: 0.978 V_acc: 0.692 G_acc: 0.800 BxT_acc: 0.779 B_acc: 0.814 T_acc: 0.913 Loss: 1.3927
52/52: [('B',), ('P',), ('V',), ('G',), ('T',)]


{'optimizer': 'adam', 'batch_size': 128, 'epochs': 1, 'runsize': 128, 'test_size': 5000, 'units': 400, 'lr': 0.0008}
 1  5376/ 5470 B_acc: 0.827 P_acc: 0.979 V_acc: 0.696 G_acc: 0.793 T_acc: 0.930 Loss: 1.4110
 1    39/   39 B_acc: 0.808 P_acc: 0.978 V_acc: 0.694 G_acc: 0.800 T_acc: 0.914 Loss: 1.4965


TypeError: cannot unpack non-iterable NoneType object

In [None]:
k = 'R1'
labels = [x[::-1] for x in CLASSES[k]]
ax = sn.heatmap(stats.confusion[k], xticklabels=labels, yticklabels=labels, square=True, robust=True, cmap="cividis")

In [None]:
ax = sn.heatmap(stats.confusion_logprobs[k], xticklabels=labels, yticklabels=labels, square=True, robust=True, cmap="cividis")

In [None]:
print(predict(model, 'סבסו'))
print(predict(model, 'מקדו'))
print(predict(model, 'נמזר'))
print(predict(model, 'כרדו'))

In [None]:
print(predict(model, 'הבריל'))
print(predict(model, 'חגוו'))
print(predict(model, 'עגו'))
print(predict(model, 'צירלל'))

In [None]:
print(predict(model, "השטקרפתי"))

In [None]:
print(predict(model, "ישסו"))

In [None]:
import importlib
import encoding
import naive_model
import utils
encoding = importlib.reload(encoding)
naive_model = importlib.reload(naive_model)
utils = importlib.reload(utils)
wandb = importlib.reload(wandb)
Stats = utils.Stats
NaiveModel = naive_model.NaiveModel

In [None]:
batch_all_ys(test_y)

In [None]:
test_y[('B', 'T')].shape

In [None]:
ravel_multi_index(pre_test_y, [('B', 'T')])

In [None]:
combination = ('B', 'T')
np.ravel_multi_index([pre_test_y[k] for k in combination], encoding.combined_shape(combination))

In [None]:
torch.cuda.is_available()

In [10]:
wandb.join()

In [None]:
def nonempty_powerset(seq):
    return itertools.chain.from_iterable(itertools.combinations(seq, r) for r in range(1, len(seq)+1))

def tensor_outer_product(a, b, c):
    return torch.einsum('bi,bj,bk->bijk', a, b, c)
