In [218]:
from data_rnn import load_ndfa, load_brackets
# from data_prep import pad_and_convert
import pandas as pd
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

In [219]:
x_train_ndfa, (i2w_ndfa, w2i_ndfa) = load_ndfa(n=150_000)
x_train_brackets, (i2w_brackets, w2i_brackets) = load_brackets(n=150_000)


In [220]:
print(''.join([i2w_ndfa[i] for i in x_train_ndfa[50]]))


ss


In [221]:
w2i_ndfa

{'.pad': 0,
 '.start': 1,
 '.end': 2,
 '.unk': 3,
 'l': 4,
 '!': 5,
 'u': 6,
 'b': 7,
 'c': 8,
 's': 9,
 'm': 10,
 'k': 11,
 'a': 12,
 'v': 13,
 'w': 14}

In [222]:
# print(''.join([i2w_brackets[i] for i in x_train_brackets[10_000]]))


In [223]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [224]:
class LSTM(nn.Module):
    def __init__(self, vocab_size, emb_size, h, num_char, n_layers=1):
        super(LSTM, self).__init__()

        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.lstm = nn.LSTM(input_size=emb_size, hidden_size=h, num_layers=n_layers, batch_first=True)
        self.fc = nn.Linear(h, num_char)
    
    def forward(self, input_seq, h=None):
        embedded = self.embedding(input_seq)
        lstm_out, hidden = self.lstm(embedded, h)
        # lstm_out = lstm_out[:, -1, :]
        output = self.fc(lstm_out)
        # raise Exception('stop')
        return output, hidden

In [225]:
def pad_and_convert3(batch, w2i, batch_size):
    start_token = w2i['.start']
    end_token = w2i['.end']
    pad_token = w2i['.pad']

    num_examples = len(batch)
    num_batches = (num_examples + batch_size - 1) // batch_size
    
    # Create batches
    batches = [batch[i * batch_size: (i + 1) * batch_size] for i in range(num_batches)]
    
    padded_batches = []
    
    for b in batches:
        # Pad each sequence in the batch to the maximum length within the batch
        b = [[start_token] + x + [end_token] for x in b]
        max_len = max(len(x) for x in b)
        padded_batch = [x + [pad_token] * (max_len - len(x)) for x in b]
        padded_batches.append(padded_batch)

    # Convert the padded batches to PyTorch tensors
    padded_batches = [torch.tensor(pb, dtype=torch.long) for pb in padded_batches]

    return padded_batches

In [226]:
x_train_ndfa_padded3 = pad_and_convert3(x_train_ndfa, w2i_ndfa, batch_size=64)
x_train_brackets_padded3 = pad_and_convert3(x_train_brackets, w2i_brackets, batch_size=64)

In [227]:
vocab_size = len(w2i_ndfa)
emb_size = 32
h = 16
num_char = vocab_size
n_layers = 1

In [228]:
model = LSTM(vocab_size=vocab_size, emb_size=emb_size, h=h, num_char=num_char, n_layers=1)

In [243]:
num_epochs = 3
learning_rate = 0.001

In [244]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [245]:
def create_target2(padded_batches):
    target_batches = []

    for batch in padded_batches:
        target_batch = torch.cat((batch[:, 1:], torch.zeros(batch.size(0), 1).int()), dim=1)
        target_batches.append(target_batch)
    return target_batches

In [246]:
target_ndfa2 = create_target2(x_train_ndfa_padded3)

In [247]:
import torch.distributions as dist
def sample(lnprobs, temperature=1.0): 
    """
    Sample an element from a categorical distribution
    :param lnprobs: Outcome logits
    :param temperature: Sampling temperature. 1.0 follows the given
        distribution, 0.0 returns the maximum probability element. :return: The index of the sampled element.
    """
    if temperature == 0.0:
        return lnprobs.argmax()
    p = F.softmax(lnprobs / temperature, dim=0)
    cd = dist.Categorical(p)
    return cd.sample()

In [248]:
trainloader_ndfa = list(zip(x_train_ndfa_padded3, target_ndfa2))

In [249]:

for epoch in range(num_epochs):
    total_loss = 0.0

    for batch_idx, (inputs, targets) in enumerate(trainloader_ndfa):

        model.train()        
        optimizer.zero_grad()

        h = None
        
        output, _ = model(inputs, h)
        output = output.reshape(-1, vocab_size)
        targets = targets.reshape(-1)

        loss = criterion(output, targets)  

        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    average_loss = total_loss / len(trainloader_ndfa)
    print(f'Epoch [{epoch+1}/{num_epochs}], Average Loss: {average_loss:.4f}')

torch.save(model.state_dict(), 'lstm_model_ndfa.pth')

Epoch [1/3], Average Loss: 0.1706
Epoch [2/3], Average Loss: 0.1191
Epoch [3/3], Average Loss: 0.1125


In [184]:
model = LSTM(vocab_size=len(set(w2i_brackets)), emb_size=300, h=300, num_char=len(set(w2i_brackets)), n_layers=1)
max_length = 50
for epoch in range(num_epochs):
    total_loss = 0.0

    for batch_idx, (inputs, targets) in enumerate(dataloader_brackets2):
        print(f'Batch Index: {batch_idx}, Batch Size: {inputs.size(0)}')

        model.train()        
        optimizer.zero_grad()

        h = None
        
        output, _ = model(inputs, h)

        print('output shape', output.shape)
        output = output.reshape(-1, vocab_size)
        print('output shape', output.shape)
        targets = targets.reshape(-1)


        loss = criterion(output, targets)  

        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        print(f'Epoch [{epoch+1}/{num_epochs}], Iteration [{batch_idx+1}/{len(dataloader_brackets2)}], Loss: {total_loss:.4f}')

        model.eval()
        seed_seq = [w2i_brackets['.start'], w2i_brackets['('], w2i_brackets['('], w2i_brackets[')']]
        seed_input = torch.tensor([seed_seq], dtype=torch.long)
        with torch.no_grad():
            for t in range(max_length - 1):
                output, _ = model(seed_input, h)
                print('output', output[0, -1, :])
                next_token = sample(output[0, -1, :])
                print('next token', next_token)
                seed_seq.append(next_token.item())

                if next_token == w2i_brackets['.end']:
                    break

                seed_input = torch.tensor([[next_token]], dtype=torch.long)
                print('seed seq', seed_seq)
            
            generated_sequence = [''.join(i2w_brackets[i] for i in seed_seq)]
            print(f'Generated Sequence after epoch {epoch+1}: {generated_sequence}')

        average_loss = total_loss / len(dataloader_brackets2.dataset)
        print(f'Epoch [{epoch+1}/{num_epochs}], Average Loss: {average_loss:.4f}')

torch.save(model.state_dict(), 'lstm_model.pth')

Batch Index: 0, Batch Size: 10
output shape torch.Size([10, 1024, 6])
output shape torch.Size([10240, 6])
Epoch [1/1], Iteration [1/15000], Loss: 1.8493
output tensor([-0.0378,  0.0149, -0.0281,  0.0786,  0.1398, -0.0557])
next token tensor(0)
seed seq [1, 5, 5, 4, 0]
output tensor([-0.0309,  0.0400, -0.0891,  0.0500,  0.0425,  0.0024])
next token tensor(1)
seed seq [1, 5, 5, 4, 0, 1]
output tensor([-0.0243,  0.1024,  0.0615,  0.0875,  0.1341, -0.1349])
next token tensor(3)
seed seq [1, 5, 5, 4, 0, 1, 3]
output tensor([ 0.1759,  0.1224, -0.0430,  0.0645, -0.0967,  0.0110])
next token tensor(1)
seed seq [1, 5, 5, 4, 0, 1, 3, 1]
output tensor([-0.0243,  0.1024,  0.0615,  0.0875,  0.1341, -0.1349])
next token tensor(1)
seed seq [1, 5, 5, 4, 0, 1, 3, 1, 1]
output tensor([-0.0243,  0.1024,  0.0615,  0.0875,  0.1341, -0.1349])
next token tensor(0)
seed seq [1, 5, 5, 4, 0, 1, 3, 1, 1, 0]
output tensor([-0.0309,  0.0400, -0.0891,  0.0500,  0.0425,  0.0024])
next token tensor(5)
seed seq [1, 5,

KeyboardInterrupt: 