In [1]:
%load_ext autoreload
%autoreload 2

In [14]:
import numpy as np
import torch
import torch.nn as nn

import pytorch_lightning as pl

import verbs
import neural

In [20]:
class VerbDataModule(pl.LightningDataModule):

    def __init__(self, train_filename, batch_size=128, validation_size=10000):
        super().__init__()
        WORD_MAXLEN = 11
        self.validation_size = validation_size
        self.batch_size = batch_size
        
        torch.manual_seed(0)
        np.random.seed(0)

        train_data_x, train_data_y = verbs.load_dataset(train_filename, word_maxlen=WORD_MAXLEN)

        self.data_train_full = torch.utils.data.TensorDataset(torch.tensor(train_data_x, dtype=torch.int64),
                                                              *[torch.tensor(y, dtype=torch.int64) for y in train_data_y])

    def prepare_data(self):
        # No state assignment here
        pass

    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            self.data_train, self.data_val = torch.utils.data.random_split(self.data_train_full, [self.validation_size, len(self.data_train_full) - self.validation_size])
            self.dims = tuple(self.data_train[0][0].shape)

        if stage == 'test':
            #or stage is None:
            assert False  # self.dims = tuple(self.data_test[0][0].shape)

    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.data_train, batch_size=self.batch_size)

    def val_dataloader(self):
        return torch.utils.data.DataLoader(self.data_val, batch_size=self.batch_size)

    def test_dataloader(self):
        return torch.utils.data.DataLoader(self.data_test, batch_size=self.batch_size)

In [26]:
NUM_EMBEDDING = 2000

class IndependentModel(neural.UdModel):

    def __init__(self, label_map, units=400, learning_rate=1e-3, weight_decay=None):
        super().__init__()
        self.save_hyperparameters()
        
        self.units = units

        self.embed = nn.Embedding(num_embeddings=NUM_EMBEDDING, embedding_dim=units)
        
        self.lstm = neural.SumBiLSTM(units)
        
        self.tasks = nn.ModuleDict([(label_name, nn.Linear(in_features=units, out_features=class_size))
                                     for label_name, class_size in label_map.items()])

    def forward(self, x):
        # x: (BATCH_SIZE, WORD_MAXLEN)

        x = x.permute([1, 0])
        # x: (WORD_MAXLEN, BATCH_SIZE)
        
        embeds = self.embed(x)
        # embeds: (WORD_MAXLEN, BATCH_SIZE, UNITS)
        
        _, char_hidden = self.lstm(embeds)
        # char_hidden: (BATCH_SIZE, UNITS)
        
        return {name: linear(char_hidden)
                for name, linear in self.tasks.items()}

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay)
        return optimizer
    

In [7]:
label_map = {name: verbs.Verb.class_size(name) for name in [
    'Binyan',
    'Tense',
    'Voice',
    'Gender',
    'Plural',
    'R1',
    'R2',
    'R3',
    'R4',
]}

In [28]:
dataset = VerbDataModule(
    f'synthetic/all_pref_combined_shufroot.tsv',
    batch_size=32
)
dataset.setup()

In [29]:
model = IndependentModel(label_map, units=400, learning_rate=1e-3, weight_decay=7e-4)
wandb_logger = pl.loggers.WandbLogger(project='rootem', group='verbs-lightning', name=f'shared_32')
trainer = pl.Trainer(gpus=1, max_epochs=5, logger=wandb_logger)
trainer.fit(model, dataset)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]



  | Name  | Type            | Params
------------------------------------------
0 | embed | Embedding       | 800 K 
1 | lstm  | SumSharedBiLSTM | 2 M   
2 | tasks | ModuleDict      | 54 K  


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




1

In [None]:

@torch.no_grad()
def predict(model, *verbs):
    model.eval()
    verbs = encoding.wordlist2numpy(verbs * 128)
    verbs = to_device(torch.from_numpy(verbs).to(torch.int64))
    outputs = {k: v[0] for k, v in model(verbs).items()}
    res = {}
    # FIX: assumes no overlaps
    for combination, v in outputs.items():
        if isinstance(combination, str):
            combination = tuple([combination])
        shape = encoding.combined_shape(combination)
        combined_index = v.argmax().cpu().data.numpy()
        indices = np.unravel_index(combined_index, shape)
        for k, i in zip(combination, indices):
            # assert k not in res, "Overlapping classes are not handled"
            s = k
            if k in res:
                s += "'"
            res[s] = encoding.from_category(k, i)
    if all(r in res for r in ['R1', 'R2', 'R3', 'R4']):
        res['R'] = ''.join(res[k] for k in ['R1', 'R2', 'R3', 'R4']).replace('.', '')
    return '\t'.join(f'{v:>6}' for k, v in res.items() if k not in ['R1', 'R2', 'R3', 'R4'])

In [None]:
s = 'השתזף שמרתי ירעדו נאכל הרבינו כשהתעצבנתם השגנו תרגלתי עופו פיהקתם צפינו הצפינו שרנו להתווכח תוכיחי קומו'

model = torch.load(f"models/pretrain.pt")
for k in s.split():
    print(k, predict(model, k))
print("חבל", predict(model, "חבל"))