In [1]:
%load_ext autoreload
%autoreload 2

In [4]:
import numpy as np
import torch
import torch.nn as nn

import pytorch_lightning as pl

import verbs
import neural

In [5]:
class DataModule(pl.LightningDataModule):

    def __init__(self, train_filename, features, batch_size=64, validation_size=10000):
        super().__init__()
        WORD_MAXLEN = 11
        self.validation_size = validation_size
        self.batch_size = batch_size
        self.features = features
        
        torch.manual_seed(0)
        np.random.seed(0)

        train_data_x, train_data_y = verbs.load_dataset(train_filename, word_maxlen=WORD_MAXLEN)

        self.data_train_full = torch.utils.data.TensorDataset(torch.Tensor(train_data_x), *[torch.Tensor(y).to(torch.int64) for y in train_data_y])

    def prepare_data(self):
        # No state assignment here
        pass

    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            self.data_train, self.data_val = torch.utils.data.random_split(self.data_train_full, [self.validation_size, len(self.data_train_full) - self.validation_size])
            self.dims = tuple(self.data_train[0][0].shape)

        if stage == 'test' or stage is None:
            self.dims = tuple(self.data_test[0][0].shape)

    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.data_train, batch_size=self.batch_size)

    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.data_val, batch_size=self.batch_size)

    def test_dataloader(self):
        return torch.utils.data.DataLoader(self.data_test, batch_size=self.batch_size)

In [8]:
NUM_EMBEDDING = 2000

class IndependentModel(neural.UdModel):

    def __init__(self, label_names, units=400, learning_rate=2e-3):
        super().__init__()
        self.save_hyperparameters()
        
        self.units = units

        self.embed = nn.Embedding(num_embeddings=NUM_EMBEDDING, embedding_dim=units)
        
        self.lstm = neural.SumBiLSTM(units)
        
        self.tasks = nn.ModuleDict({label_name: nn.Linear(in_features=units, out_features=verbs.Verb.class_size(label_name))
                                    for label_name in label_names})

    def forward(self, x):
        # x: (BATCH_SIZE, WORD_MAXLEN)
        
        x = x.permute([1, 0])
        # x: (WORD_MAXLEN, BATCH_SIZE)
        
        embeds = self.embed(x)
        # embeds: (WORD_MAXLEN, BATCH_SIZE, UNITS)
        
        _, char_hidden = self.lstm(embeds)
        # char_hidden: (BATCH_SIZE, UNITS)
        
        return {name: linear(char_hidden).permute([0, 2, 1])
                for name, linear in self.tasks.items()}

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
        return optimizer
    

In [None]:
val_size = 300

def load_dataset(corpus_name, artifact_name):
    torch.manual_seed(0)
    np.random.seed(0)

    filename = f'{corpus_name}/{artifact_name}.tsv'  # all_verbs_shuffled

    artifact = wandb.Artifact(artifact_name, type='dataset')
    artifact.add_file(filename)

    (train_x, pre_train_y), (val_x, pre_val_y) = encoding.load_dataset_split(filename, split=val_size)

    utils.shuffle_in_unison([train_x, *pre_train_y.values()])
    return (train_x, pre_train_y), (val_x, pre_val_y), artifact


corpus_name = 'ud'
arity = 'combined'
gen = 'train'
artifact_name = f'nocontext-{gen}'
ud_corpus = load_dataset(corpus_name, artifact_name)

corpus_name = 'synthetic'
arity = 'combined'
gen = 'all_pref'
artifact_name = f'{gen}_{arity}_shufroot'
synthetic_corpus = load_dataset(corpus_name, artifact_name)

In [None]:

# TEMP_PATH = 'model.pt'

#                 best_lr = 8e-4
#                 best_loss = 10
                
#                 torch.save({
#                     'state_dict': model.state_dict(),
#                     'optimizer': optimizer.state_dict(),
#                 }, TEMP_PATH)
                
#                 for i in range(1):
#                     checkpoint = torch.load(TEMP_PATH)
#                     model.load_state_dict(checkpoint['state_dict'])
#                     optimizer.load_state_dict(checkpoint['optimizer'])

def fit(model, train, test, *, epochs,  runsize, criterion, optimizer, batch_size, **_):
    train_x, train_y = train
    valx, valy = test
    
    assert_reasonable_initial = utils.Once(utils.assert_reasonable_initial)
    
    for epoch in range(epochs):
        train_stats = utils.Stats(model.tasks.keys())
        
        nbatches = len(train_x)
        for batch, (inputs, labels) in enumerate(zip(train_x, train_y), 1):
            model.train()

            inputs = to_device(inputs)
            labels = to_device(labels)

            outputs = model(inputs)

            losses = {combination: criterion(output.double(), labels[combination])
                      for combination, output in outputs.items()}

            loss = sum(losses.values())
            
            assert_reasonable_initial(losses, nn.CrossEntropyLoss)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
#             scheduler.step()
            
            train_stats.update(loss=loss.item(),
                               batch_size=inputs.size(0),
                               outputs=outputs,
                               labels=labels)

            if batch % runsize == 0 or batch == nbatches:
                model.eval()

                valstats = utils.Stats(model.tasks.keys())
                for inputs, labels in zip(valx, valy):
                    inputs = to_device(inputs)
                    labels = to_device(labels)

                    with torch.no_grad():
                        outputs = model(inputs)

                    losses = {combination: criterion(output.double(), labels[combination])
                              for combination, output in outputs.items()}

                    loss = sum(losses.values())

                    valstats.update(loss=loss.item(),
                                      batch_size=inputs.size(0),
                                      outputs=outputs,
                                      labels=labels)
                    
                utils.log(train_stats, valstats, batch, nbatches, epoch)


In [None]:
%env WANDB_SILENT true

def experiment(corpus, config, combinations=encoding.NAMES, names_str=''):
    print(config)
    
    torch.manual_seed(1)
    np.random.seed(1)
    
    (train_x, pre_train_y), (valx, pre_valy), artifact = corpus
    
    train_y = utils.ravel_multi_index(pre_train_y, combinations)
    valy = utils.ravel_multi_index(pre_valy, combinations)
    
    train = utils.batch_xy((train_x, train_y), config['batch_size'])
    test = utils.batch_xy((valx, valy), config['batch_size'])
    
    if corpus is synthetic_corpus:
        model = to_device(Model(units=config['units'], combinations=combinations))  # NaiveModel.learn_from_file(filename)
    else:
        model = torch.load(f"models/pretrain.pt")
    
    optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], weight_decay=config['weight_decay'])
    config.update({
        'runsize': 2 * 8192 // config['batch_size'],
        'optimizer': optimizer,
        'criterion': nn.CrossEntropyLoss(),
        'model': model,
    })
    
#     names_str = '+'.join(encoding.class_name(combination) for combination in combinations if combination not in encoding.NONROOTS)
#     if len(combinations) <= 3:
#         names_str += '_only'
    run = wandb.init(project="rootem",
                     group=f'ud',  # f'lr_units_grid_search-{arity}-{wandb.util.generate_id()}',
                     name=f"pretrained-batch_{config['batch_size']}",  # {model.arch}-{config['units']}-{config['lr']:.0e}-{config['batch_size']} f'{gen}-{arity}-{lr:.0e}',# f'{arity}-batch_{BATCH_SIZE}', # f'all-{arity}-lr_{lr:.0e}-units_{units}',
                     tags=[gen, arity, "ud", 'shuffle-root', 'shuffle', 'batchval', 'full-root'],
                     config=config)
    with run:
        run.use_artifact(artifact)

        wandb.config.update(config, allow_val_change=True)

#         if isinstance(model, nn.Module):
#             wandb.watch(model)

        fit(train=train,
            test=test,
            **config
        )
        wandb.save(f"{model.arch}.h5")
        
        if corpus is synthetic_corpus:
            torch.save(model, f"models/pretrain.pt")
        else:
            torch.save(model, f"models/postrain.pt")

    return model

%env WANDB_MODE dryrun

config = {
    'epochs': 1,
    'valsize': valsize,
    'batch_size': 128,
    'units': 350,
    'weight_decay': 7e-4,
    'dropout': 0.2,
    'num_layers': 1,
    'lr': 1e-3,
}
model = experiment(synthetic_corpus, config)
model = experiment(ud_corpus, config)

In [None]:

@torch.no_grad()
def predict(model, *verbs):
    model.eval()
    verbs = encoding.wordlist2numpy(verbs * 128)
    verbs = to_device(torch.from_numpy(verbs).to(torch.int64))
    outputs = {k: v[0] for k, v in model(verbs).items()}
    res = {}
    # FIX: assumes no overlaps
    for combination, v in outputs.items():
        if isinstance(combination, str):
            combination = tuple([combination])
        shape = encoding.combined_shape(combination)
        combined_index = v.argmax().cpu().data.numpy()
        indices = np.unravel_index(combined_index, shape)
        for k, i in zip(combination, indices):
            # assert k not in res, "Overlapping classes are not handled"
            s = k
            if k in res:
                s += "'"
            res[s] = encoding.from_category(k, i)
    if all(r in res for r in ['R1', 'R2', 'R3', 'R4']):
        res['R'] = ''.join(res[k] for k in ['R1', 'R2', 'R3', 'R4']).replace('.', '')
    return '\t'.join(f'{v:>6}' for k, v in res.items() if k not in ['R1', 'R2', 'R3', 'R4'])

In [None]:
s = 'השתזף שמרתי ירעדו נאכל הרבינו כשהתעצבנתם השגנו תרגלתי עופו פיהקתם צפינו הצפינו שרנו להתווכח תוכיחי קומו'

model = torch.load(f"models/pretrain.pt")
for k in s.split():
    print(k, predict(model, k))
print("חבל", predict(model, "חבל"))