In [None]:
import sys
import torch
from tqdm import tqdm

sys.path.append('..')
from utils import *
from datasets import *
from rnn import LanguageModelSRNN

In [None]:
BATCH_SIZE = 4
SEQ_LEN = 10
NUM_EPOCHS = 1

### Load TFDS TinyShakespeare Train and Test Sets

In [None]:
train_set = load_tf_dataset('tiny_shakespeare', split_name='train')
train_set = '\n'.join(train_set) # join all the strings together
# grab 50% of the dataset
train_set = train_set[:len(train_set)//8]
train_dataloader, vocab_size = create_char_level_dataloader(train_set, batch_size=BATCH_SIZE, seq_len=SEQ_LEN)
dataset = CharLevelDataset(train_set)
dataset.vocab['\n'], vocab_size

### Messing around with dataloader and character embeddings

In [None]:
data, labels = next(iter(train_dataloader))
contexts, embedded_labels = expand_sequence_for_rnn_training(data, labels, vocab_size, SEQ_LEN)
contexts[0].size(), embedded_labels.transpose(0, 1)[0].size() 
# transposing embedded labels to swap batch and seq_len as context is batch x one timestep

In [None]:
model = LanguageModelSRNN(vocab_size=vocab_size,seq_len=SEQ_LEN, hidden_dim=100)

In [None]:
gen_seq = model.generate(torch.tensor([dataset.vocab['\n']]), SEQ_LEN, 20)
"".join([list(dataset.vocab.keys())[i - 1] for i in gen_seq])

In [None]:
# number of parameters
print(f"LM variant of SimpleRNN has: {num_learnable_params(model):,} parameters")


In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-1)
loss = torch.nn.CrossEntropyLoss()
for epoch in range(NUM_EPOCHS):
    progress = tqdm(train_dataloader)
    for idx, batch in enumerate(progress):
        data, labels = batch
        contexts, embedded_labels = expand_sequence_for_rnn_training(data, labels, vocab_size, SEQ_LEN)
        avg_loss = 0
        optimizer.zero_grad()
        for i in range(SEQ_LEN):
            context = contexts[i]
            label = embedded_labels[:, i]
            pred = model(context)[:,-1,:].squeeze(1)
            loss_val = loss(pred, label)
            avg_loss += loss_val.item()
            loss_val.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
            optimizer.step()

            
        avg_loss /= SEQ_LEN
        progress.set_postfix(loss=avg_loss)  
    

In [None]:
gen_seq = model.generate(torch.tensor([dataset.vocab['\n']]), SEQ_LEN, 20)

In [None]:
"".join([list(dataset.vocab.keys())[i - 1] for i in gen_seq])