In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import wandb
import torch
import torch.nn as nn

import utils
from naive_model import NaiveModel
import encoding

In [32]:
NUM_EMBEDDING = 2000

assert torch.cuda.is_available()
torch.backends.cudnn.benchmark = True
torch.set_default_tensor_type('torch.cuda.FloatTensor')

def to_device(d):
    if hasattr(d, 'cuda'):
        return d.cuda()
    return {k: v.cuda() for k, v in d.items()}

class Model(nn.Module):
    def __init__(self, units, combinations=encoding.NAMES):
        super().__init__()
        self.units = units

        self.embed = nn.Embedding(num_embeddings=NUM_EMBEDDING, embedding_dim=units)
        
        self.lstm_root = nn.LSTM(input_size=units, hidden_size=units, num_layers=1, batch_first=True, bidirectional=True)
        self.lstm_nonroot = nn.LSTM(input_size=units, hidden_size=units, num_layers=1, batch_first=True, bidirectional=True)

        self.tasks = {}
        for combination in combinations:
            out = nn.Linear(in_features=units, out_features=encoding.class_size(combination))
            self.tasks[combination] = out
            setattr(self, encoding.class_name(combination), out)

    def combine(self, lstm_out):
        lstm_out, (h_n, c_n) = lstm_out
        left, right = torch.chunk(h_n, 2, dim=0)
        return torch.squeeze(left + right)
            
    def isroot(self, combination):
        return any(r in combination for r in ['R1', 'R2', 'R3', 'R4'])
        
    def forward(self, x):
        embeds = self.embed(x)

        root_merge = self.combine(self.lstm_root(embeds))
        noroot_merge = self.combine(self.lstm_nonroot(embeds))

        return {combination: f((root_merge - noroot_merge) if self.isroot(combination)
                          else (root_merge + noroot_merge))
                for combination, f in self.tasks.items()}


In [28]:
def sanity():
    model = to_device(Model(100, combinations=[('B', 'T')]))
    print(model)
    with torch.no_grad():
        verbs = encoding.wordlist2numpy(["אתאקלם", "יכפיל"])
        labels = {'B': torch.Tensor([3, 5]), 'T': torch.Tensor([2, 4])}
        verbs = to_device(torch.from_numpy(verbs).to(torch.int64))
        tag_scores = model(verbs)
        for combination in tag_scores:
            print(combination)
            v = tag_scores[combination]
            print(f'{v=}')
            print(f'{labels=}')
            print()


In [35]:

# TEMP_PATH = 'model.pt'

#                 best_lr = 8e-4
#                 best_loss = 10
                
#                 torch.save({
#                     'state_dict': model.state_dict(),
#                     'optimizer': optimizer.state_dict(),
#                 }, TEMP_PATH)
                
#                 for i in range(1):
#                     checkpoint = torch.load(TEMP_PATH)
#                     model.load_state_dict(checkpoint['state_dict'])
#                     optimizer.load_state_dict(checkpoint['optimizer'])

def fit(model, train, test, *, epochs,  runsize, criterion, optimizer, phases, teacher, batch_size):
    x_train, y_train = utils.batch_xy(train, batch_size)
    x_test, y_test = utils.batch_xy(test, batch_size)
    x = {'train': x_train, 'test': x_test}
    y = {'train': y_train, 'test': y_test}

    stats = utils.Stats(model.tasks.keys(), runsize)
    
    for epoch in range(epochs):
        stats.epoch_start()
        
        for phase in phases:
            model.train(mode=phase == 'train')

            stats.phase_start(phase, batches_in_phase=len(x[phase]))

            for inputs, labels in zip(x[phase], y[phase]):
                stats.batch_start()

                inputs = to_device(inputs)
                labels = to_device(labels)

                with utils.conditional_grad(phase=='train'):
                    outputs = model(inputs)

                if teacher is not None:
                    pseudo_labels = teacher(inputs)
                    losses = {k: criterion(outputs[k].double(), pseudo_labels[k]) for k in outputs}
                else:
                    losses = {combination: criterion(output.double(), labels[combination])
                              for combination, output in outputs.items()}

                if phase == 'train':
                    stats.assert_reasonable_initial(losses, nn.CrossEntropyLoss)

                loss = sum(losses.values())

                if phase == 'train':
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                stats.update(loss=loss.item(),
                             batch_size=inputs.size(0),
                             d={k: (outputs[k], labels[k]) for k in outputs})
                
                stats.batch_end()
            stats.phase_end()
        stats.epoch_end()
    return stats

@torch.no_grad()
def predict(model, *verbs):
    model.eval()
    verbs = encoding.wordlist2numpy(verbs)
    verbs = to_device(torch.from_numpy(verbs).to(torch.int64))
    outputs = model(verbs)
    res = {}
    # FIX: assumes no overlaps
    for combination, v in outputs.items():
        combined_index = torch.argmax(v).cpu().data.numpy()
        indices = np.unravel_index(combined_index, encoding.combined_shape(combination))
        if isinstance(combination, str):
            combination = tuple([combination])
        for k, i in zip(combination, indices):
            # assert k not in res, "Overlapping classes are not handled"
            s = k
            if k in res:
                s += "'"
            res[s] = encoding.from_category(k, i)
    if all(r in res for r in ['R1', 'R2', 'R3', 'R4']):
        res['R'] = ''.join(res[k] for k in ['R1', 'R2', 'R3', 'R4']).replace('.', '')
    return res

In [10]:
torch.manual_seed(0)
np.random.seed(0)

arity = '3'
gen = 'all'
artifact_name = f'{gen}_{arity}_shufroot'
filename = f'synthetic/{artifact_name}.tsv'  # all_verbs_shuffled
test_size = 20000

artifact = wandb.Artifact(artifact_name, type='dataset')
artifact.add_file(filename)

(train_x, pre_train_y), (test_x, pre_test_y) = encoding.load_dataset_split(filename, split=test_size)

utils.shuffle_in_unison([train_x, *pre_train_y.values()])

In [11]:
def naive_config(filename):
    return {
        'model': NaiveModel.learn_from_file(filename),
        'phases': ['test'],
        'criterion': nn.CrossEntropyLoss(),
        'optimizer': None
    }

def standard_config(model, lr):
    return {
        'model': model,
        'criterion': nn.CrossEntropyLoss(),
        'phases': ['train', 'test'],
        'teacher': None
    }

def teacher_config(train, model, lr):
    res = standard_config(model, lr)
    res['teacher'] = NaiveModel.learn_from_data(train)
    res['criterion'] = nn.BCEWithLogitsLoss()  # BCELoss: works, but total loss is nan
    return res

In [39]:
%env WANDB_SILENT true
%env WANDB_MODE run

config = {
    'epochs': 1,
    'test_size': test_size,
}

def experiment(units, lr, batch_size):
    torch.manual_seed(1)
    np.random.seed(1)
    
    combinations = list(encoding.CLASSES)
    
    train_y = utils.ravel_multi_index(pre_train_y, combinations)
    test_y = utils.ravel_multi_index(pre_test_y, combinations)
    train = (train_x, train_y)
    test = (test_x, test_y)
    
    config.update({
        'runsize': 2 * 8192 // batch_size,
        'batch_size': batch_size,
        'units': units,
        'lr': lr,
        'optimizer': 'adam'
    })
    
    model = to_device(Model(units=units, combinations=combinations))
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
    names_str = '+'.join(encoding.class_name(combination) for combination in combinations)
    run = wandb.init(project="rootem",
                     group=f'separate-lstm',  # f'lr_units_grid_search-{arity}-{wandb.util.generate_id()}',
                     name=f'sub_add-{units}-{lr:.0e}-{batch_size}',  # f'{gen}-{arity}-{lr:.0e}',# f'{arity}-batch_{BATCH_SIZE}', # f'all-{arity}-lr_{lr:.0e}-units_{units}',
                     tags=[gen, arity, 'synthetic', 'shuffle-root', 'no_prefix', 'shuffle'],
                     config=config)
    with run:
        run.use_artifact(artifact)

        wandb.config.update(config, allow_val_change=True)

        print(config)

        if isinstance(model, nn.Module):
            wandb.watch(model)

        stats = fit(train=train,
            test=test,
            epochs=config['epochs'],
            runsize=config['runsize'],
            optimizer=optimizer,
            batch_size=batch_size,
            **standard_config(model, lr)
        )
        wandb.save(f"simple_{arity}.h5")

    return model, stats

units = 150
for batch_size in [32, 64, 128, 256, 512]:
    for lr in [3e-4, 1e-3, 3e-3, 5e-3]:
        model, stats = experiment(units, lr, batch_size)


env: WANDB_SILENT=true
env: WANDB_MODE=run


{'epochs': 1, 'test_size': 20000, 'runsize': 512, 'batch_size': 32, 'units': 150, 'lr': 0.0003, 'optimizer': 'adam'}
 1 20992/21412 B_acc: 0.822 T_acc: 0.928 V_acc: 0.693 G_acc: 0.794 P_acc: 0.980 R1_acc: 0.968 R2_acc: 0.821 R3_acc: 1.000 R4_acc: 0.981 Loss: 1.9948
 1   625/  625 B_acc: 0.810 T_acc: 0.922 V_acc: 0.697 G_acc: 0.796 P_acc: 0.978 R1_acc: 0.942 R2_acc: 0.798 R3_acc: 1.000 R4_acc: 0.946 Loss: 2.3877


{'epochs': 1, 'test_size': 20000, 'runsize': 512, 'batch_size': 32, 'units': 150, 'lr': 0.001, 'optimizer': 'adam'}
 1 20992/21412 B_acc: 0.821 T_acc: 0.929 V_acc: 0.694 G_acc: 0.793 P_acc: 0.980 R1_acc: 0.970 R2_acc: 0.817 R3_acc: 1.000 R4_acc: 0.982 Loss: 1.9607
 1   625/  625 B_acc: 0.813 T_acc: 0.919 V_acc: 0.698 G_acc: 0.794 P_acc: 0.978 R1_acc: 0.934 R2_acc: 0.774 R3_acc: 1.000 R4_acc: 0.944 Loss: 2.5145


{'epochs': 1, 'test_size': 20000, 'runsize': 512, 'batch_size': 32, 'units': 150, 'lr': 0.003, 'optimizer': 'adam'}
 1 20992/21412 B_acc: 0.821 T_acc: 0.926 V_acc: 0.691 G_acc: 0.789 P_acc: 0.979 R1_acc: 0.964 R2_acc: 0.812 R3_acc: 1.000 R4_acc: 0.977 Loss: 2.1055
 1   625/  625 B_acc: 0.816 T_acc: 0.920 V_acc: 0.699 G_acc: 0.794 P_acc: 0.977 R1_acc: 0.929 R2_acc: 0.770 R3_acc: 1.000 R4_acc: 0.948 Loss: 2.6879


{'epochs': 1, 'test_size': 20000, 'runsize': 512, 'batch_size': 32, 'units': 150, 'lr': 0.005, 'optimizer': 'adam'}
 1 20992/21412 B_acc: 0.812 T_acc: 0.923 V_acc: 0.690 G_acc: 0.788 P_acc: 0.977 R1_acc: 0.959 R2_acc: 0.799 R3_acc: 1.000 R4_acc: 0.974 Loss: 2.2782
 1   625/  625 B_acc: 0.808 T_acc: 0.915 V_acc: 0.689 G_acc: 0.788 P_acc: 0.977 R1_acc: 0.938 R2_acc: 0.796 R3_acc: 1.000 R4_acc: 0.946 Loss: 2.6837


{'epochs': 1, 'test_size': 20000, 'runsize': 256, 'batch_size': 64, 'units': 150, 'lr': 0.0003, 'optimizer': 'adam'}
 1 10496/10706 B_acc: 0.817 T_acc: 0.926 V_acc: 0.694 G_acc: 0.794 P_acc: 0.978 R1_acc: 0.965 R2_acc: 0.814 R3_acc: 1.000 R4_acc: 0.979 Loss: 2.05961
 1   312/  312 B_acc: 0.807 T_acc: 0.922 V_acc: 0.697 G_acc: 0.796 P_acc: 0.977 R1_acc: 0.948 R2_acc: 0.790 R3_acc: 1.000 R4_acc: 0.955 Loss: 2.3764


{'epochs': 1, 'test_size': 20000, 'runsize': 256, 'batch_size': 64, 'units': 150, 'lr': 0.001, 'optimizer': 'adam'}
 1 10496/10706 B_acc: 0.822 T_acc: 0.928 V_acc: 0.694 G_acc: 0.794 P_acc: 0.979 R1_acc: 0.970 R2_acc: 0.820 R3_acc: 1.000 R4_acc: 0.983 Loss: 1.9604
 1   312/  312 B_acc: 0.815 T_acc: 0.924 V_acc: 0.699 G_acc: 0.796 P_acc: 0.978 R1_acc: 0.942 R2_acc: 0.794 R3_acc: 1.000 R4_acc: 0.952 Loss: 2.4073


{'epochs': 1, 'test_size': 20000, 'runsize': 256, 'batch_size': 64, 'units': 150, 'lr': 0.003, 'optimizer': 'adam'}
 1 10496/10706 B_acc: 0.822 T_acc: 0.923 V_acc: 0.696 G_acc: 0.795 P_acc: 0.979 R1_acc: 0.967 R2_acc: 0.813 R3_acc: 1.000 R4_acc: 0.981 Loss: 2.0496
 1   312/  312 B_acc: 0.810 T_acc: 0.921 V_acc: 0.699 G_acc: 0.795 P_acc: 0.978 R1_acc: 0.937 R2_acc: 0.788 R3_acc: 1.000 R4_acc: 0.944 Loss: 2.5773


{'epochs': 1, 'test_size': 20000, 'runsize': 256, 'batch_size': 64, 'units': 150, 'lr': 0.005, 'optimizer': 'adam'}
 1 10496/10706 B_acc: 0.817 T_acc: 0.922 V_acc: 0.691 G_acc: 0.791 P_acc: 0.976 R1_acc: 0.963 R2_acc: 0.806 R3_acc: 1.000 R4_acc: 0.974 Loss: 2.1819
 1   312/  312 B_acc: 0.806 T_acc: 0.919 V_acc: 0.697 G_acc: 0.791 P_acc: 0.977 R1_acc: 0.936 R2_acc: 0.793 R3_acc: 1.000 R4_acc: 0.934 Loss: 2.6740


{'epochs': 1, 'test_size': 20000, 'runsize': 128, 'batch_size': 128, 'units': 150, 'lr': 0.0003, 'optimizer': 'adam'}
 1  5248/ 5353 B_acc: 0.812 T_acc: 0.924 V_acc: 0.693 G_acc: 0.794 P_acc: 0.978 R1_acc: 0.961 R2_acc: 0.806 R3_acc: 1.000 R4_acc: 0.975 Loss: 2.15666
 1   156/  156 B_acc: 0.803 T_acc: 0.918 V_acc: 0.697 G_acc: 0.792 P_acc: 0.975 R1_acc: 0.949 R2_acc: 0.775 R3_acc: 1.000 R4_acc: 0.953 Loss: 2.4489


{'epochs': 1, 'test_size': 20000, 'runsize': 128, 'batch_size': 128, 'units': 150, 'lr': 0.001, 'optimizer': 'adam'}
 1  5248/ 5353 B_acc: 0.820 T_acc: 0.925 V_acc: 0.697 G_acc: 0.795 P_acc: 0.980 R1_acc: 0.969 R2_acc: 0.817 R3_acc: 1.000 R4_acc: 0.982 Loss: 1.9765
 1   156/  156 B_acc: 0.808 T_acc: 0.922 V_acc: 0.699 G_acc: 0.798 P_acc: 0.975 R1_acc: 0.945 R2_acc: 0.782 R3_acc: 1.000 R4_acc: 0.957 Loss: 2.3830


{'epochs': 1, 'test_size': 20000, 'runsize': 128, 'batch_size': 128, 'units': 150, 'lr': 0.003, 'optimizer': 'adam'}
 1  5248/ 5353 B_acc: 0.825 T_acc: 0.927 V_acc: 0.693 G_acc: 0.794 P_acc: 0.979 R1_acc: 0.968 R2_acc: 0.814 R3_acc: 1.000 R4_acc: 0.980 Loss: 2.0143
 1   156/  156 B_acc: 0.811 T_acc: 0.923 V_acc: 0.699 G_acc: 0.794 P_acc: 0.979 R1_acc: 0.945 R2_acc: 0.774 R3_acc: 1.000 R4_acc: 0.946 Loss: 2.5422


{'epochs': 1, 'test_size': 20000, 'runsize': 128, 'batch_size': 128, 'units': 150, 'lr': 0.005, 'optimizer': 'adam'}
 1  5248/ 5353 B_acc: 0.821 T_acc: 0.925 V_acc: 0.692 G_acc: 0.794 P_acc: 0.979 R1_acc: 0.966 R2_acc: 0.810 R3_acc: 1.000 R4_acc: 0.978 Loss: 2.0900
 1   156/  156 B_acc: 0.808 T_acc: 0.921 V_acc: 0.697 G_acc: 0.793 P_acc: 0.976 R1_acc: 0.948 R2_acc: 0.790 R3_acc: 1.000 R4_acc: 0.949 Loss: 2.5587


{'epochs': 1, 'test_size': 20000, 'runsize': 64, 'batch_size': 256, 'units': 150, 'lr': 0.0003, 'optimizer': 'adam'}
 1  2624/ 2676 B_acc: 0.811 T_acc: 0.922 V_acc: 0.692 G_acc: 0.793 P_acc: 0.976 R1_acc: 0.954 R2_acc: 0.797 R3_acc: 1.000 R4_acc: 0.973 Loss: 2.32630
 1    78/   78 B_acc: 0.793 T_acc: 0.917 V_acc: 0.695 G_acc: 0.792 P_acc: 0.976 R1_acc: 0.941 R2_acc: 0.782 R3_acc: 1.000 R4_acc: 0.954 Loss: 2.5738


{'epochs': 1, 'test_size': 20000, 'runsize': 64, 'batch_size': 256, 'units': 150, 'lr': 0.001, 'optimizer': 'adam'}
 1  2624/ 2676 B_acc: 0.819 T_acc: 0.925 V_acc: 0.695 G_acc: 0.794 P_acc: 0.980 R1_acc: 0.967 R2_acc: 0.813 R3_acc: 1.000 R4_acc: 0.979 Loss: 2.03110
 1    78/   78 B_acc: 0.811 T_acc: 0.922 V_acc: 0.698 G_acc: 0.796 P_acc: 0.975 R1_acc: 0.945 R2_acc: 0.795 R3_acc: 1.000 R4_acc: 0.960 Loss: 2.3441


{'epochs': 1, 'test_size': 20000, 'runsize': 64, 'batch_size': 256, 'units': 150, 'lr': 0.003, 'optimizer': 'adam'}
 1  2624/ 2676 B_acc: 0.820 T_acc: 0.926 V_acc: 0.695 G_acc: 0.793 P_acc: 0.979 R1_acc: 0.968 R2_acc: 0.814 R3_acc: 1.000 R4_acc: 0.980 Loss: 2.0058
 1    78/   78 B_acc: 0.812 T_acc: 0.918 V_acc: 0.696 G_acc: 0.794 P_acc: 0.977 R1_acc: 0.935 R2_acc: 0.772 R3_acc: 1.000 R4_acc: 0.953 Loss: 2.5196


{'epochs': 1, 'test_size': 20000, 'runsize': 64, 'batch_size': 256, 'units': 150, 'lr': 0.005, 'optimizer': 'adam'}
 1  2624/ 2676 B_acc: 0.823 T_acc: 0.926 V_acc: 0.691 G_acc: 0.794 P_acc: 0.981 R1_acc: 0.969 R2_acc: 0.815 R3_acc: 1.000 R4_acc: 0.979 Loss: 2.0278
 1    78/   78 B_acc: 0.809 T_acc: 0.918 V_acc: 0.698 G_acc: 0.794 P_acc: 0.978 R1_acc: 0.941 R2_acc: 0.769 R3_acc: 1.000 R4_acc: 0.954 Loss: 2.5581


{'epochs': 1, 'test_size': 20000, 'runsize': 32, 'batch_size': 512, 'units': 150, 'lr': 0.0003, 'optimizer': 'adam'}
 1  1312/ 1338 B_acc: 0.796 T_acc: 0.919 V_acc: 0.685 G_acc: 0.790 P_acc: 0.975 R1_acc: 0.940 R2_acc: 0.772 R3_acc: 1.000 R4_acc: 0.959 Loss: 2.64987
 1    39/   39 B_acc: 0.785 T_acc: 0.912 V_acc: 0.691 G_acc: 0.792 P_acc: 0.973 R1_acc: 0.927 R2_acc: 0.744 R3_acc: 1.000 R4_acc: 0.930 Loss: 2.8993


{'epochs': 1, 'test_size': 20000, 'runsize': 32, 'batch_size': 512, 'units': 150, 'lr': 0.001, 'optimizer': 'adam'}
 1  1312/ 1338 B_acc: 0.816 T_acc: 0.924 V_acc: 0.695 G_acc: 0.795 P_acc: 0.979 R1_acc: 0.965 R2_acc: 0.810 R3_acc: 1.000 R4_acc: 0.977 Loss: 2.11007
 1    39/   39 B_acc: 0.812 T_acc: 0.918 V_acc: 0.697 G_acc: 0.794 P_acc: 0.976 R1_acc: 0.947 R2_acc: 0.779 R3_acc: 1.000 R4_acc: 0.957 Loss: 2.4012


{'epochs': 1, 'test_size': 20000, 'runsize': 32, 'batch_size': 512, 'units': 150, 'lr': 0.003, 'optimizer': 'adam'}
 1  1312/ 1338 B_acc: 0.823 T_acc: 0.926 V_acc: 0.692 G_acc: 0.795 P_acc: 0.979 R1_acc: 0.968 R2_acc: 0.812 R3_acc: 1.000 R4_acc: 0.981 Loss: 2.02732
 1    39/   39 B_acc: 0.809 T_acc: 0.921 V_acc: 0.698 G_acc: 0.795 P_acc: 0.978 R1_acc: 0.938 R2_acc: 0.794 R3_acc: 1.000 R4_acc: 0.953 Loss: 2.4805


{'epochs': 1, 'test_size': 20000, 'runsize': 32, 'batch_size': 512, 'units': 150, 'lr': 0.005, 'optimizer': 'adam'}
 1  1312/ 1338 B_acc: 0.820 T_acc: 0.925 V_acc: 0.691 G_acc: 0.795 P_acc: 0.980 R1_acc: 0.968 R2_acc: 0.814 R3_acc: 1.000 R4_acc: 0.981 Loss: 2.0325
 1    39/   39 B_acc: 0.807 T_acc: 0.921 V_acc: 0.698 G_acc: 0.790 P_acc: 0.978 R1_acc: 0.946 R2_acc: 0.773 R3_acc: 1.000 R4_acc: 0.955 Loss: 2.4832


In [None]:
def print_words(model):
    print(predict(model, 'הבריל'))
    print(predict(model, 'חגוו'))
    print(predict(model, 'עגו'))
    print(predict(model, 'צירלל'))
    print(predict(model, "השטקרפתי"))
    print(predict(model, "ישסו"))
print_words(model)

In [None]:
input("$ ")
input("$ ")

In [None]:
wandb.join()

NameError: name 'BATCH_SIZE' is not defined