In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import numpy as np
import wandb
import torch
import torch.nn as nn

import utils
from naive_model import NaiveModel
import encoding

In [4]:
NUM_EMBEDDING = 2000

assert torch.cuda.is_available()
torch.backends.cudnn.deterministic = True
torch.set_default_tensor_type('torch.cuda.FloatTensor')

def to_device(d):
    if hasattr(d, 'cuda'):
        return d.cuda()
    return {k: v.cuda() for k, v in d.items()}

class Model(nn.Module):
    arch = 'lstm'
    def __init__(self, units, combinations=encoding.NAMES):
        super().__init__()
        self.units = units

        self.embed = nn.Embedding(num_embeddings=NUM_EMBEDDING, embedding_dim=units)
        
        self.pre_lstm = nn.LSTM(input_size=units, hidden_size=units, num_layers=1, batch_first=False, bidirectional=True)
        
        self.post_lstm = nn.LSTM(input_size=units, hidden_size=units, num_layers=1, batch_first=False, bidirectional=True)

        self.tasks = {}
        for combination in combinations:
            out = nn.Linear(in_features=units, out_features=encoding.class_size(combination))
            self.tasks[combination] = out
            setattr(self, encoding.class_name(combination), out)

    def isroot(self, combination):
        return any(r in combination for r in ['R1', 'R2', 'R3', 'R4'])
        
    def forward(self, x):
        # x: (BATCH_SIZE, WORD_MAXLEN)
        
        x = x.permute([1, 0])
        # x: (WORD_MAXLEN, BATCH_SIZE)
        
        embeds = self.embed(x)
        # embeds: (WORD_MAXLEN, BATCH_SIZE, UNITS)
        
        lstm_out, (h_n0, c_n) = self.pre_lstm(embeds)
        # lstm_out: (WORD_MAXLEN, BATCH_SIZE, UNITS * 2)
        # h_n0: (2, BATCH_SIZE, UNITS)
        # c_n: (2, BATCH_SIZE, UNITS)
        
        left, right = torch.chunk(lstm_out, 2, dim=-1)
        # left: (WORD_MAXLEN, BATCH_SIZE, UNITS)
        # right: (WORD_MAXLEN, BATCH_SIZE, UNITS)
        
        lstm_out = torch.squeeze(left + right)
        # lstm_out: (WORD_MAXLEN, BATCH_SIZE, UNITS)

        lstm_out, (h_n1, c_n) = self.post_lstm(lstm_out)
        # lstm_out: (WORD_MAXLEN, BATCH_SIZE, UNITS * 2)
        # h_n1: (2, BATCH_SIZE, UNITS)
        # c_n: (2, BATCH_SIZE, UNITS)
        
        h_n = h_n0 + h_n1
        # h_n: (2, BATCH_SIZE, UNITS)

        return {combination: f(h_n[0] + h_n[1])
                for combination, f in self.tasks.items()}


In [5]:
def sanity():
    model = to_device(Model(100, combinations=[('R1', 'R2', 'R3', 'R4')]))
    print(model)
    with torch.no_grad():
        verbs = encoding.wordlist2numpy(["אתאקלם", "יכפיל", "בואס"])
        labels = {'R1': torch.Tensor([3, 5]), 'R2': torch.Tensor([2, 4]), 'R3': torch.Tensor([2, 4]), 'R4': torch.Tensor([2, 4])}
        verbs = to_device(torch.from_numpy(verbs).to(torch.int64))
        tag_scores = model(verbs)
        print(f"{tag_scores[('R1', 'R2', 'R3', 'R4')].shape=}")
        for combination in tag_scores:
            print(combination)
            v = tag_scores[combination]
            # print(f'{v=}')
            print(f'{labels=}')
            print()
sanity()

Model(
  (embed): Embedding(2000, 100)
  (pre_lstm): LSTM(100, 100, bidirectional=True)
  (post_lstm): LSTM(100, 100, bidirectional=True)
  (R1xR2xR3xR4): Linear(in_features=100, out_features=531441, bias=True)
)
tag_scores[('R1', 'R2', 'R3', 'R4')].shape=torch.Size([3, 531441])
('R1', 'R2', 'R3', 'R4')
labels={'R1': tensor([3., 5.]), 'R2': tensor([2., 4.]), 'R3': tensor([2., 4.]), 'R4': tensor([2., 4.])}



In [6]:

# TEMP_PATH = 'model.pt'

#                 best_lr = 8e-4
#                 best_loss = 10
                
#                 torch.save({
#                     'state_dict': model.state_dict(),
#                     'optimizer': optimizer.state_dict(),
#                 }, TEMP_PATH)
                
#                 for i in range(1):
#                     checkpoint = torch.load(TEMP_PATH)
#                     model.load_state_dict(checkpoint['state_dict'])
#                     optimizer.load_state_dict(checkpoint['optimizer'])

def fit(model, train, test, *, epochs,  runsize, criterion, optimizer, batch_size, **_):
    train_x, train_y = train
    test_x, test_y = test
    
    assert_reasonable_initial = utils.Once(utils.assert_reasonable_initial)
    
    for epoch in range(epochs):
        train_stats = utils.Stats(model.tasks.keys())
        
        nbatches = len(train_x)
        for batch, (inputs, labels) in enumerate(zip(train_x, train_y), 1):
            model.train()

            inputs = to_device(inputs)
            labels = to_device(labels)

            outputs = model(inputs)

            losses = {combination: criterion(output.double(), labels[combination])
                      for combination, output in outputs.items()}

            loss = sum(losses.values())
            
            assert_reasonable_initial(losses, nn.CrossEntropyLoss)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
#             scheduler.step()
            
            train_stats.update(loss=loss.item(),
                               batch_size=inputs.size(0),
                               outputs=outputs,
                               labels=labels)

            if batch % runsize == 0 or batch == nbatches:
                model.eval()

                test_stats = utils.Stats(model.tasks.keys())
                for inputs, labels in zip(test_x, test_y):
                    inputs = to_device(inputs)
                    labels = to_device(labels)

                    with torch.no_grad():
                        outputs = model(inputs)

                    losses = {combination: criterion(output.double(), labels[combination])
                              for combination, output in outputs.items()}

                    loss = sum(losses.values())

                    test_stats.update(loss=loss.item(),
                                      batch_size=inputs.size(0),
                                      outputs=outputs,
                                      labels=labels)
                    
                utils.log(train_stats, test_stats, batch, nbatches, epoch)


In [36]:
test_size = 300

def load_dataset(corpus_name, artifact_name):
    torch.manual_seed(0)
    np.random.seed(0)

    filename = f'{corpus_name}/{artifact_name}.tsv'  # all_verbs_shuffled

    artifact = wandb.Artifact(artifact_name, type='dataset')
    artifact.add_file(filename)

    (train_x, pre_train_y), (test_x, pre_test_y) = encoding.load_dataset_split(filename, split=test_size)

    utils.shuffle_in_unison([train_x, *pre_train_y.values()])
    return (train_x, pre_train_y), (test_x, pre_test_y), artifact


corpus_name = 'ud'
arity = 'combined'
gen = 'train'
artifact_name = f'nocontext-{gen}'
ud_corpus = load_dataset(corpus_name, artifact_name)

corpus_name = 'synthetic'
arity = 'combined'
gen = 'all_pref'
artifact_name = f'{gen}_{arity}_shufroot'
synthetic_corpus = load_dataset(corpus_name, artifact_name)

In [37]:
%env WANDB_SILENT true

def experiment(corpus, config, combinations=encoding.NAMES, names_str=''):
    print(config)
    
    torch.manual_seed(1)
    np.random.seed(1)
    
    (train_x, pre_train_y), (test_x, pre_test_y), artifact = corpus
    
    train_y = utils.ravel_multi_index(pre_train_y, combinations)
    test_y = utils.ravel_multi_index(pre_test_y, combinations)
    
    train = utils.batch_xy((train_x, train_y), config['batch_size'])
    test = utils.batch_xy((test_x, test_y), config['batch_size'])
    
    if corpus is synthetic_corpus:
        model = to_device(Model(units=config['units'], combinations=combinations))  # NaiveModel.learn_from_file(filename)
    else:
        model = torch.load(f"models/pretrain.pt")
    
    optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], weight_decay=config['weight_decay'])
    config.update({
        'runsize': 2 * 8192 // config['batch_size'],
        'optimizer': optimizer,
        'criterion': nn.CrossEntropyLoss(),
        'model': model,
    })
    
#     names_str = '+'.join(encoding.class_name(combination) for combination in combinations if combination not in encoding.NONROOTS)
#     if len(combinations) <= 3:
#         names_str += '_only'
    run = wandb.init(project="rootem",
                     group=f'ud',  # f'lr_units_grid_search-{arity}-{wandb.util.generate_id()}',
                     name=f"pretrained-batch_{config['batch_size']}",  # {model.arch}-{config['units']}-{config['lr']:.0e}-{config['batch_size']} f'{gen}-{arity}-{lr:.0e}',# f'{arity}-batch_{BATCH_SIZE}', # f'all-{arity}-lr_{lr:.0e}-units_{units}',
                     tags=[gen, arity, "ud", 'shuffle-root', 'shuffle', 'batchval', 'full-root'],
                     config=config)
    with run:
        run.use_artifact(artifact)

        wandb.config.update(config, allow_val_change=True)

#         if isinstance(model, nn.Module):
#             wandb.watch(model)

        fit(train=train,
            test=test,
            **config
        )
        wandb.save(f"{model.arch}.h5")
        
        if corpus is synthetic_corpus:
            torch.save(model, f"models/pretrain.pt")
        else:
            torch.save(model, f"models/postrain.pt")

    return model

%env WANDB_MODE dryrun

config = {
    'epochs': 1,
    'test_size': test_size,
    'batch_size': 128,
    'units': 350,
    'weight_decay': 7e-4,
    'dropout': 0.2,
    'num_layers': 1,
    'lr': 1e-3,
}
model = experiment(synthetic_corpus, config)
model = experiment(ud_corpus, config)

env: WANDB_SILENT=true
env: WANDB_MODE=dryrun
{'epochs': 1, 'test_size': 300, 'batch_size': 128, 'units': 350, 'weight_decay': 0.0007, 'dropout': 0.2, 'num_layers': 1, 'lr': 0.001}


{'epochs': 1, 'test_size': 300, 'batch_size': 128, 'units': 350, 'weight_decay': 0.0007, 'dropout': 0.2, 'num_layers': 1, 'lr': 0.001, 'runsize': 128, 'optimizer': Adam (1: 0.938 train/Accuracy_R2: 0.761 train/Accuracy_R3: 0.985 train/Accuracy_R4: 0.961 train/Accuracy_R1xR2xR3xR4: 0.707 val/Loss: 2.4984 val/Accuracy_B: 0.629 val/Accuracy_T: 0.887 val/Accuracy_V: 0.691 val/Accuracy_G: 0.699 val/Accuracy_P: 0.957 val/Accuracy_R1: 0.996 val/Accuracy_R2: 0.996 val/Accuracy_R3: 1.000 val/Accuracy_R4: 0.938 val/Accuracy_R1xR2xR3xR4: 0.930 
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    lr: 0.001
    weight_decay: 0.0007
), 'criterion': CrossEntropyLoss(), 'model': Model(
  (embed): Embedding(2000, 350)
  (pre_lstm): LSTM(350, 350, bidirectional=True)
  (post_lstm): LSTM(350, 350, bidirectional=True)
  (B): Linear(in_features=350, out_features=7, bias=True)
  (T): Linear(in_features=350, out_features=4, bias=True)
  (V): Linear(in_features=350, out_features=5,

 0    90/   90 train/Loss: 1.8064 train/Accuracy_B: 0.924 train/Accuracy_T: 0.941 train/Accuracy_V: 0.899 train/Accuracy_G: 0.924 train/Accuracy_P: 0.973 train/Accuracy_R1: 0.945 train/Accuracy_R2: 0.935 train/Accuracy_R3: 0.991 train/Accuracy_R4: 0.966 train/Accuracy_R1xR2xR3xR4: 0.879 val/Loss: 0.7256 val/Accuracy_B: 0.973 val/Accuracy_T: 0.980 val/Accuracy_V: 0.969 val/Accuracy_G: 0.969 val/Accuracy_P: 0.984 val/Accuracy_R1: 0.973 val/Accuracy_R2: 0.977 val/Accuracy_R3: 1.000 val/Accuracy_R4: 0.984 val/Accuracy_R1xR2xR3xR4: 0.941 

In [11]:

@torch.no_grad()
def predict(model, *verbs):
    model.eval()
    verbs = encoding.wordlist2numpy(verbs * 128)
    verbs = to_device(torch.from_numpy(verbs).to(torch.int64))
    outputs = {k: v[0] for k, v in model(verbs).items()}
    res = {}
    # FIX: assumes no overlaps
    for combination, v in outputs.items():
        if isinstance(combination, str):
            combination = tuple([combination])
        shape = encoding.combined_shape(combination)
        combined_index = v.argmax().cpu().data.numpy()
        indices = np.unravel_index(combined_index, shape)
        for k, i in zip(combination, indices):
            # assert k not in res, "Overlapping classes are not handled"
            s = k
            if k in res:
                s += "'"
            res[s] = encoding.from_category(k, i)
    if all(r in res for r in ['R1', 'R2', 'R3', 'R4']):
        res['R'] = ''.join(res[k] for k in ['R1', 'R2', 'R3', 'R4']).replace('.', '')
    return '\t'.join(f'{v:>6}' for k, v in res.items() if k not in ['R1', 'R2', 'R3', 'R4'])

In [38]:
s = 'השתזף שמרתי ירעדו נאכל הרבינו כשהתעצבנתם השגנו תרגלתי עופו פיהקתם צפינו הצפינו שרנו להתווכח תוכיחי קומו'

model = torch.load(f"models/pretrain.pt")
for k in s.split():
    print(k, predict(model, k))
print("חבל", predict(model, "חבל"))

השתזף  התפעל	 ציווי	 שלישי	   זכר	  יחיד	   שזפ
שמרתי    פעל	   עבר	 ראשון	   זכר	  יחיד	   שמר
ירעדו    פעל	  עתיד	 שלישי	   זכר	  רבים	   רעד
נאכל   נפעל	  הווה	 ראשון	   זכר	  יחיד	   אכל
הרבינו  הפעיל	   עבר	 ראשון	   זכר	  רבים	   רבי
כשהתעצבנתם  התפעל	   עבר	   שני	   זכר	  רבים	  עצבנ
השגנו  הפעיל	   עבר	 ראשון	   זכר	  רבים	   שגג
תרגלתי   פיעל	   עבר	 ראשון	  נקבה	  יחיד	  תגגל
עופו   פועל	   עבר	 שלישי	   זכר	  רבים	   עפפ
פיהקתם   פיעל	   עבר	   שני	   זכר	  רבים	   פהק
צפינו    פעל	   עבר	 ראשון	  נקבה	  רבים	   צפי
הצפינו  הפעיל	   עבר	 ראשון	   זכר	  רבים	   צפי
שרנו    פעל	   עבר	 ראשון	   זכר	  רבים	   שיר
להתווכח  התפעל	  הווה	     _	     _	     _	   וכח
תוכיחי  הפעיל	  עתיד	   שני	  נקבה	  יחיד	   יכח
קומו    פעל	   עבר	 שלישי	   זכר	  רבים	   קומ
חבל    פעל	   עבר	 שלישי	   זכר	  יחיד	   חבל
