In [1]:
from data_rnn import load_brackets
from torch import nn
import torch.optim as optim
import torch
import random
from tqdm import tqdm

from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter()

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
x_train, (i2w, w2i) = load_brackets(n=150_000, seed=0)

In [3]:
len(x_train)

150000

In [4]:
x_train[94000]

[5, 5, 4, 4]

In [5]:
print(''.join([i2w[i] for i in x_train[97000]]))

(())


In [6]:
i2w

['.pad', '.start', '.end', '.unk', ')', '(']

In [7]:
len(i2w)

6

In [8]:
max_tokens = 20000
batches = []

start_seq_len = len(x_train[-1]) + 2
batch_x = torch.zeros(size=(max_tokens//start_seq_len, start_seq_len), dtype=torch.long)
num_sequences = 0
for seq in reversed(x_train):
    if num_sequences == batch_x.shape[0]:
        batch_y = torch.zeros_like(batch_x, dtype=torch.long)
        batch_y[:, :-1] = batch_x[:, 1:]
        batches.append((batch_x, batch_y))
        seq_len = len(seq) + 2
        batch_x = torch.zeros(size=(max_tokens//seq_len, seq_len), dtype=torch.long)
        num_sequences = 0
    seq = [1] + seq + [2]
    batch_x[num_sequences, :len(seq)] = torch.tensor(seq)
    num_sequences += 1
batches.append((batch_x, batch_y))

print(batches[0][0].shape)
print(batches[0][1].shape)

torch.Size([19, 1024])
torch.Size([19, 1024])


In [9]:
random.shuffle(batches)
print(batches[0][0].shape)
batches[0][1].shape

torch.Size([5000, 4])


torch.Size([5000, 4])

## Model

In [10]:
class AutoregressModel(nn.Module):
    def __init__(self, embedding_dim, hidden_size, dropout, vocab_size, num_layers):
        super().__init__()
        self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim, padding_idx=0)
        self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, dropout=dropout)
        self.out = nn.Linear(in_features=hidden_size, out_features=vocab_size)

    def forward(self, input):
        embeddings = self.embedding(input)
        hidden, _ = self.lstm(embeddings)
        out = self.out(hidden)
        return out



In [11]:
import torch.distributions as dist
import torch.nn.functional as F
def sample(lnprobs, temperature=1.0):
    """
    Sample an element from a categorical distribution
    :param lnprobs: Outcome logits
    :param temperature: Sampling temperature. 1.0 follows the given
    distribution, 0.0 returns the maximum probability element.
    :return: The index of the sampled element.
    """
    if temperature == 0.0:
        return lnprobs.argmax()
    p = F.softmax(lnprobs / temperature, dim=0)
    cd = dist.Categorical(p)
    return cd.sample()

In [12]:
combos = [
    ['.start','(','(','(','(','(','('],
    ['.start','(','(','(','(','(',')'],
    ['.start','(','(','(','(',')',')'],
    ['.start','(','(','(',')',')','('],
    ['.start','(','(',')','(',')','('],
    ['.start','(','(','(',')',')','('],
    ['.start','(','(',')','(','(','('],
    ['.start','(','(',')','(',')','('],
    ['.start','(','(',')','(','(',')'],
    ['.start','(','(','(','(',')','('],
    ]

In [13]:
writer = SummaryWriter('runs/brackets_long')

n_epochs = 60
lr = 0.01
max_lr = 0.05
weight_decay = 0.0001

hidden_size = 100
embedding_dim = 50
dropout=0
vocab_size = len(i2w)
num_layers = 1
max_generation_length = 160

print(f"Batches per epoch: {len(batches)}")

autoregressmodel = AutoregressModel(
    embedding_dim=embedding_dim,
    hidden_size=hidden_size,
    dropout=dropout,
    vocab_size=vocab_size,
    num_layers=num_layers)

writer.add_graph(autoregressmodel, batches[0][0])

criterion = nn.CrossEntropyLoss(ignore_index=0) # TODO reduction="sum"
optimizer = optim.Adam(autoregressmodel.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=max_lr, steps_per_epoch=len(batches), epochs=n_epochs)

for epoch in range(n_epochs):
    print(f"Epoch {epoch+1}/{n_epochs}")
    for i, data in enumerate(batches):
        inputs, labels = data
        optimizer.zero_grad()
        outputs = autoregressmodel(inputs)
        outputs = torch.moveaxis(outputs, 1, 2)
        loss = criterion(outputs, labels)
        writer.add_scalar("Training loss", loss, (i+1)+len(batches)*epoch)
        loss.backward()
        optimizer.step()
        scheduler.step()
        
    for combo in combos:
        seq = torch.tensor([[w2i[char] for char in combo]], dtype=torch.long)
        with torch.no_grad():
            pred = autoregressmodel(seq)
            next_char = sample(lnprobs=pred[0][-1], temperature=0.5)
            while seq.shape[1] < max_generation_length-1:
                if int(next_char) == 2:
                    break
                seq = torch.cat((seq, torch.tensor([[next_char]], dtype=torch.long)), dim=1)
                pred = autoregressmodel(seq)
                next_char = sample(lnprobs=pred[0][-1], temperature=0.01)
        seq = torch.cat((seq, torch.tensor([[next_char]], dtype=torch.long)), dim=1)
        print(''.join([i2w[i] for i in seq[0]]), end=" ")
        # get if this is correct
        counts = torch.bincount(seq[0])
        if int(seq[0][-1]) != 4:
            print("Longer than max_generation_length --> correctness can not be evaluated")
        elif counts[4] == counts[5]:
            print("correct :)")
        else: print("incorrect :(")

            
        
print("\nDONE\n")

Batches per epoch: 86
Epoch 1/60
.start((((((())))))())()())(()))(())())(())(()))(())).end Longer than max_generation_length --> correctness can not be evaluated
.start((((()))))()).end Longer than max_generation_length --> correctness can not be evaluated
.start(((()))).end Longer than max_generation_length --> correctness can not be evaluated
.start((())(()))()).end Longer than max_generation_length --> correctness can not be evaluated
.start(()()(()))()).end Longer than max_generation_length --> correctness can not be evaluated
.start((())())()())(())).end Longer than max_generation_length --> correctness can not be evaluated
.start(()((())))()).end Longer than max_generation_length --> correctness can not be evaluated
.start(()()())()).end Longer than max_generation_length --> correctness can not be evaluated
.start(()(()()))()())(())).end Longer than max_generation_length --> correctness can not be evaluated
.start(((()(()))))()).end Longer than max_generation_length --> correctne

In [16]:
max_generation_length = 300
for combo in combos:
    seq = torch.tensor([[w2i[char] for char in combo]], dtype=torch.long)
    with torch.no_grad():
        pred = autoregressmodel(seq)
        next_char = sample(lnprobs=pred[0][-1], temperature=0.5)
        while seq.shape[1] < max_generation_length:
            if int(next_char) == 2:
                break
            seq = torch.cat((seq, torch.tensor([[next_char]], dtype=torch.long)), dim=1)
            pred = autoregressmodel(seq)
            next_char = sample(lnprobs=pred[0][-1], temperature=0.01)
    seq = torch.cat((seq, torch.tensor([[next_char]], dtype=torch.long)), dim=1)

    # get if this is correct
    counts = torch.bincount(seq[0])
    print(seq.shape[1])
    print(''.join([i2w[i] for i in seq[0]]), end=" ")
    if int(seq[0][-1]) == 4:
        print("Longer than max_generation_length --> correctness can not be evaluated")
    elif counts[4] == counts[5]:
        print("correct :)")
    else: print("incorrect :(")

14
.start(((((()))))).end correct :)
12
.start((((())))).end correct :)
12
.start(((())())).end correct :)
12
.start((())(())).end correct :)
12
.start(()()(())).end correct :)
10
.start((())()).end correct :)
14
.start(()(((())))).end correct :)
10
.start(()()()).end correct :)
10
.start(()(())).end correct :)
14
.start(((()(())))).end correct :)
