# 🥙 LSTM on Recipe Data

This notebook is an **unofficial PyTorch implementation** of the excellent [Keras example](https://github.com/davidADSP/Generative_Deep_Learning_2nd_Edition/tree/main/notebooks/05_autoregressive/01_lstm) autoregressive LSTM model, originally created by David Foster as part of the companion code for the excellent book [Generative Deep Learning, 2nd Edition](https://www.oreilly.com/library/view/generative-deep-learning/9781098134174/).

_The original code is available [here](https://github.com/davidADSP/Generative_Deep_Learning_2nd_Edition) and is licensed under the Apache License 2.0._
_This implementation is distributed under the Apache License 2.0. See the LICENSE file for details._

In this notebook, we'll walk through the steps required to train your own LSTM on the recipes dataset using PyTorch

In [None]:
%load_ext autoreload
%autoreload 2

import os

# Get the working directory and the current notebook directory
working_dir = os.getcwd()
exp_dir = os.path.join(working_dir, "notebooks/05_autoregressive/01_lstm/")

In [None]:
import json
import re
import string

import torch
from tokenizers import Tokenizer, models, pre_tokenizers, trainers
from torch.utils.data import Dataset, DataLoader
from torch import nn
from torch.utils.tensorboard import SummaryWriter
from torchinfo import summary
from torch import optim

import numpy as np

## 0. Parameters <a name="parameters"></a>

In [None]:
VOCAB_SIZE = 10000
MAX_LEN = 200
EMBEDDING_DIM = 100
N_UNITS = 128
VALIDATION_SPLIT = 0.2
SEED = 42
LOAD_MODEL = False
BATCH_SIZE = 32
EPOCHS = 50
LEARNING_RATE = 0.001

## 1. Prepare the data <a name="prepare"></a>

In [None]:
data_dir = working_dir + "/data"
dataset_dir = data_dir + "/epirecipes"
data_file = dataset_dir + "/full_format_recipes.json"

In [None]:
with open(data_file) as data_json:
    data_raw = json.load(data_json)

print(data_raw[0].keys())
print(data_raw[0])

In [None]:
filtered_data = [ f"Recipe for {x['title']} | " + " ".join(x['directions'])
                 for x in data_raw
                 if "title" in x and
                 x["title"] is not None and
                 "directions" in x and
                 x["directions"] is not None ]

In [None]:
print(len(filtered_data))
print(filtered_data[0])

## 2. Tokenise the data

In [None]:
def pad_punctuation(str):
    # add space before and after every punctuation
    str = re.sub(f"([{string.punctuation}])", r" \1 ", str)
    # replace multiple spaces with one space
    str = re.sub(" +", " ", str)

    return str

In [None]:
test_text = "Hello   there!"
test_text = pad_punctuation(test_text)
print(test_text) 


In [None]:
train_data_list = list(map(pad_punctuation, filtered_data))

In [None]:
# we will set the value for the token paralization to avoid getting warning
os.environ["TOKENIZERS_PARALLELISM"] = "true"
# we will use the hugging face Tokenizers package to Tokenize the dataset and create the vocab
# We will use a simple word tokenizer
# the tokenizer itself will handel assigning a numerical value to each word
tokenizer = Tokenizer(models.WordLevel(unk_token="<unk>"))
# the pre tokenizer will pre process the test and split it into words (based on whitespace)
tokenizer.pre_tokenizer = pre_tokenizers.Whitespace()

pre_tokenized_text = tokenizer.pre_tokenizer.pre_tokenize_str(test_text)
print(pre_tokenized_text)

In [None]:
# to form the vocabilary using the tokenizer we use trainer
trainer = trainers.WordLevelTrainer(special_tokens=["<pad>", "<unk>"], vocab_size=VOCAB_SIZE)
tokenizer.train_from_iterator(train_data_list, trainer)

vocab = tokenizer.get_vocab()
pad_idx = vocab["<pad>"]

# enable trancation and padding for the dataste so that all entries would have the same length
tokenizer.enable_padding(length=MAX_LEN + 1, pad_id=pad_idx, pad_token="<pad>")
tokenizer.enable_truncation(max_length=MAX_LEN + 1)

In [None]:
# Check the resulting vocabulary
print("Vocabulary size:", tokenizer.get_vocab_size())
print("Vocabulary:", vocab)
print("padiing index = ", pad_idx)
vocab_idnx_to_word = {vocab[key]: key for key in vocab.keys()}
print(vocab_idnx_to_word)
test_vector = tokenizer.encode(test_text)
print(test_vector.ids)
print(test_vector.tokens)

In [None]:
# tokenize the data
vectorized_data = [tokenizer.encode(sentence).ids for sentence in train_data_list]
print(len(vectorized_data))
print(len(vectorized_data[0]))

In [None]:
class TextSeqDataset(Dataset):
    def __init__(self, vectorized_data_list):
        super().__init__()
        self.vectorized_data_list = vectorized_data_list
    
    def __len__(self):
        return(len(self.vectorized_data_list))
    
    def get_data_pair(self, idx):
        text = self.vectorized_data_list[idx]
        x = torch.tensor(text[:-1])
        y = torch.tensor(text[1:])

        return x, y
    
    def __getitem__(self, idx):
        return self.get_data_pair(idx)

## 3. Create the Training Set

In [None]:
train_dataset = TextSeqDataset(vectorized_data_list=vectorized_data)
x, y = train_dataset.get_data_pair(0)
print(x.shape)
print(y.shape)
print(x[0:5])
print(y[0:5])

In [None]:
train_data_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

In [None]:
## 4. Build the LSTM <a name="build"></a>

In [None]:
class Lstm(nn.Module):
    def __init__(self, vocab_size, embedded_dim=100, lstm_units=128, 
                 pad_idx=0, is_pidirectional=False, log_dir="./log"):
        super().__init__()
        self.embedded_dim = embedded_dim
        self.lstm_units = lstm_units
        self.vocab_size = vocab_size
        self.is_pidirectional = is_pidirectional
        if self.is_pidirectional:
            self.lstm_unit_multipler = 2
        else:
             self.lstm_unit_multipler = 1
        self.pad_idx = pad_idx
        self.writer = SummaryWriter(log_dir=log_dir)

        self.embedded = nn.Embedding(num_embeddings=self.vocab_size, embedding_dim=self.embedded_dim,
                                     padding_idx=self.pad_idx)
        self.lstm = nn.LSTM(input_size=self.embedded_dim, hidden_size=self.lstm_units, 
                            batch_first=True, bidirectional=self.is_pidirectional)
         
        self.fc = nn.Linear(in_features=self.lstm_units * self.lstm_unit_multipler, out_features=self.vocab_size)
    
    def forward(self, x):

        x = self.embedded(x)
        # the lstm output shape is Batch_size, seq_length, lstm_units it contains the hidden state of all timestamps
        # hn is the final hidden state of shape lstm_layer_num, batch_size, lstm_uints
        # cn is the final cell state of shape lstm_layer_num, batch_size, lstm_uints
        output, (hn, cn) = self.lstm(x)
        x = self.fc(output)
        # we will use cross entropy loss that will internally apply softmax
        # x = torch.softmax(x, dim=-1)

        return x
    
    def fit(self, train_dataloader, loss_fn, optimizer, epochs, device, callbacks=None):

        self.loss_fn = loss_fn
        self.optimizer = optimizer
        self.device = device

        for epoch in range(epochs):

            acc_loss = 0

            for train_data, train_gt in train_dataloader:

                train_data = train_data.to(device)
                train_gt = train_gt.to(device)

                # training step
                self.train()
                optimizer.zero_grad()

                pred = self.forward(train_data)

                pred = pred.permute(0, 2, 1)
                
                loss = loss_fn(pred, train_gt)

                loss.backward()
                optimizer.step()
            
                acc_loss += loss.item()
            
            acc_loss /= len(train_dataset)

            print(f"epoch {epoch + 1} / {epochs}: loss = {acc_loss}")

            self.writer.add_scalar("training_loss", acc_loss, global_step=epoch)

            # run call back functions
            if callbacks is not None:
                logs = {"model":self,
                        "device":self.device,
                        "model_state_dict": self.state_dict(),
                        "loss": acc_loss
                }

                for callback in callbacks:
                    callback.on_epoch_end(epoch, logs=logs)

                
    

In [None]:
log_dir =  exp_dir + "/log"
os.makedirs(log_dir, exist_ok=True)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

lstm_model = Lstm(vocab_size=tokenizer.get_vocab_size(),
                  embedded_dim=EMBEDDING_DIM,
                  lstm_units=N_UNITS, pad_idx=pad_idx,
                  is_pidirectional=False, log_dir=log_dir).to(device)

print(lstm_model.state_dict)

In [None]:
loader_itr = iter(train_data_loader)
sample_input, sample_output = next(loader_itr)

In [None]:
summary(lstm_model, input_size=(1, 4), dtypes=[sample_input.dtype])

In [None]:
class Callback:
    def on_epoch_end(self, epoch, logs=None):
        pass

In [None]:
class SaveCheckpoint(Callback):
    def __init__(self, save_dir, save_every=10):
        super().__init__()
        self.save_dir = save_dir
        self.save_every = save_every
    def on_epoch_end(self, epoch, logs=None):
        
        if (epoch % self.save_every) == 0:
            checkpoint = {"epoch":epoch,
                        "model_state_dict":logs["model_state_dict"],
                        "loss":logs["loss"]
                        }
            checkpoint_file = self.save_dir + f"/checkpoint_{epoch}.pth"

            torch.save(checkpoint, checkpoint_file)

In [None]:
class TextGenerator(Callback):
    def __init__(self, index_to_word, top_k=10):
        self.index_to_word = index_to_word
        self.word_to_index = {
            word: index for index, word in enumerate(index_to_word)
        }  

    def sample_from(self, probs, temperature):  
        probs = probs ** (1 / temperature)
        probs = probs / torch.sum(probs)
        sample_token = torch.multinomial(probs, 1).item()
        return sample_token, probs

    def generate(self, model, start_prompt, max_tokens, temperature, device):
        start_tokens = [
            self.word_to_index.get(x, 1) for x in start_prompt.split()
        ] 
        sample_token = None
        info = []
        while len(start_tokens) < max_tokens and sample_token != 0: 

            with torch.no_grad():
                x = torch.tensor([start_tokens]).to(device)
                y = model(x).detach().to("cpu") 
                # since we output logits with no softmax we will 
                # apply softmax here to get the probabilities
                y_prob = torch.softmax(y, dim=-1)
                sample_token, probs = self.sample_from(y_prob[0][-1], temperature)  
                info.append({"prompt": start_prompt, "word_probs": probs})
                start_tokens.append(sample_token)  
                start_prompt = start_prompt + " " + self.index_to_word[sample_token]
        print(f"\ngenerated text:\n{start_prompt}\n")
        return info

    def on_epoch_end(self, epoch, logs=None):

        if logs:
            model = logs["model"]
            device = logs["device"]
            self.generate(model, "recipe for", max_tokens=100, temperature=1.0, device=device)

## 5. Train the LSTM <a name="train"></a>

In [None]:
sample_dir =  exp_dir + "/sample_gen"
os.makedirs(sample_dir, exist_ok=True)

checkpoint_dir =  exp_dir + "/checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

In [None]:
callbacks = [SaveCheckpoint(save_dir=checkpoint_dir, save_every=2),
             TextGenerator(index_to_word=vocab_idnx_to_word)]

In [None]:
# check if we have checkpoint to load
if LOAD_MODEL:
    checkpoint_file = checkpoint_dir + "/checkpoint_10.pth"
    checkpoint = torch.load(checkpoint_file)
    lstm_model.load_state_dict(checkpoint["model_state_dict"])

In [None]:
optimizer = optim.Adam(params=lstm_model.parameters(), lr=LEARNING_RATE)
loss_fn = nn.CrossEntropyLoss()

In [None]:
lstm_model.fit(train_data_loader, loss_fn=loss_fn, optimizer=optimizer, 
               epochs=EPOCHS, device=device, callbacks=callbacks)

## 6. Generate text using the LSTM

In [None]:
text_generator = TextGenerator(index_to_word=vocab_idnx_to_word)

In [None]:
def print_probs(info, vocab, top_k=5):
    for i in info:
        print(f"\nPROMPT: {i['prompt']}")
        word_probs = i["word_probs"]
        p_sorted, i_sorted = torch.sort(word_probs, descending=True)
        p_sorted = p_sorted[:top_k].numpy()
        i_sorted = i_sorted[:top_k].numpy()
        for p, i in zip(p_sorted, i_sorted):
            round_prob = np.round(100*p,2)
            print(f"{vocab[i]}:   \t{round_prob}%")
        print("--------\n")

In [None]:
info = text_generator.generate(lstm_model,
    "recipe for roasted vegetables | chop 1 /", max_tokens=10, temperature=1.0, device=device
)

In [None]:
word_probs = info[0]["word_probs"]
top_k = 5
p_sorted, i_sorted = torch.sort(word_probs, descending=True)

In [None]:
print(p_sorted.shape)
print(p_sorted[:top_k].shape)

In [None]:
print_probs(info, vocab_idnx_to_word)

In [None]:
info = text_generator.generate(lstm_model,
    "recipe for roasted vegetables | chop 1 /", max_tokens=10, temperature=0.2, device=device
)
print_probs(info, vocab_idnx_to_word)

In [None]:
info = text_generator.generate(lstm_model,
    "recipe for chocolate ice cream |", max_tokens=7, temperature=1.0, device=device
)
print_probs(info, vocab_idnx_to_word)

In [None]:
info = text_generator.generate(lstm_model,
    "recipe for chocolate ice cream |", max_tokens=7, temperature=0.2, device=device
)
print_probs(info, vocab_idnx_to_word)