In [1]:
from collections import Counter


import numpy as np
import json
import re
import string

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader


VOCAB_SIZE = 10000
MAX_LEN = 200
EMBEDDING_DIM = 100
N_UNITS = 128
VALIDATION_SPLIT = 0.2
SEED = 42
LOAD_MODEL = False
BATCH_SIZE = 32
EPOCHS = 25

In [2]:
# Load the full dataset
with open("./full_format_recipes.json") as json_data:
    recipe_data = json.load(json_data)

# Filter the dataset
filtered_data = [
    "Recipe for " + x["title"] + " | " + " ".join(x["directions"])
    for x in recipe_data
    if "title" in x
    and x["title"] is not None
    and "directions" in x
    and x["directions"] is not None
]


In [9]:
# Pad the punctuation, to treat them as separate 'words'
def pad_punctuation(s):
    s = re.sub(f"([{string.punctuation}])", r" \1 ", s)
    s = re.sub(" +", " ", s)
    return s


def build_vocab(text_data, vocab_size=VOCAB_SIZE):
    word_counts = Counter((word for line in text_data for word in line.split()))
    vocab = ["", "[UNK]"] + [x[0] for x in word_counts.most_common(vocab_size - 2)]
    return vocab


class TextDataset(Dataset):
    def __init__(self, text_data, max_len=MAX_LEN, vocab_size=VOCAB_SIZE):
        self.text_data = [pad_punctuation(x).lower() for x in text_data]
        self.max_len = max_len
        self.vocab_size = vocab_size
        self.vocab = build_vocab(self.text_data, vocab_size=vocab_size)

    def __len__(self):
        return len(self.text_data)

    def __getitem__(self, idx):
        text = self.text_data[idx]
        token = torch.tensor(
            [
                self.vocab.index(x) if x in self.vocab else 1  # [UNK]
                for x in text.split(" ")[: self.max_len]
            ]
        )
        token = F.pad(token, (0, self.max_len - token.shape[0]), value=0)  # [PAD]

        x = token[:-1]  # Input sequence
        y = token[1:]  # Target sequence

        return x, y

In [10]:
class LSTM(nn.Module):
    def __init__(self, vocab_size, embedding_dim, n_units):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, n_units, batch_first=True)
        self.dense = nn.Linear(n_units, vocab_size)

    def forward(self, x):
        x = self.embedding(x)
        x, _ = self.lstm(x)
        x = self.dense(x)
        return x  # Output shape: (batch_size, seq_len, vocab_size)


In [None]:
lstm = LSTM(VOCAB_SIZE, EMBEDDING_DIM, N_UNITS)

text_dataset = TextDataset(filtered_data, max_len=MAX_LEN, vocab_size=VOCAB_SIZE)
text_dataloader = DataLoader(text_dataset, batch_size=BATCH_SIZE, shuffle=True)


loss_fn = nn.CrossEntropyLoss()  # Use CrossEntropyLoss for multi-class classification
optimizer = torch.optim.Adam(lstm.parameters())

# Training loop
for epoch in range(EPOCHS):
    for x, y in text_dataloader:
        optimizer.zero_grad()
        out = lstm(x)  # Forward pass
        loss = loss_fn(out.view(-1, VOCAB_SIZE), y.view(-1))  # Reshape for loss function
        loss.backward()  # Backward pass
        optimizer.step()  # Optimize

        print(f'Epoch: {epoch}, Loss: {loss.item()}')  # Log the loss

In [107]:
y.shape

torch.Size([32, 199])