### Text Generation Using LSTM


In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import io
import re
from tqdm.notebook import trange, tqdm

import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
import torch.nn.functional as F
from torch.distributions import Categorical

from torchtext.datasets import WikiText2, EnWik9, AG_NEWS
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
import torchtext.transforms as T
from torch.hub import load_state_dict_from_url
from torchtext.data.functional import sentencepiece_tokenizer, load_sp_model

In [8]:
# Define hyperparameters
# Step size for parameter updates
learning_rate = 1e-4

# Number of training epochs
nepochs = 20

# Number of samples processed together
batch_size = 32

# Maximum sequence length
max_len = 64

In [5]:
# AGNews dataset class definition
class AGNews(Dataset):
    def __init__(self, test_train="train"):
        # Read the AG News dataset CSV file based on the test_train parameter (train or test)
        self.df = pd.read_csv(os.path.join("../data/" + test_train + ".csv"),
                              names=["Class", "Title", "Content"])
        
        # Fill missing values with empty string
        self.df.fillna('', inplace=True)
        
        # Combine Title and Content columns into a single Article column
        self.df['Article'] = self.df['Title'] + " : " + self.df['Content']
        
        # Drop Title and Content columns as they are no longer needed
        self.df.drop(['Title', 'Content'], axis=1, inplace=True)
        
        # Replace special characters with whitespace in the Article column
        self.df['Article'] = self.df['Article'].str.replace(r'\\n|\\|\\r|\\r\\n|\n|"', ' ', regex=True)

    # Method to get a single item from the dataset
    def __getitem__(self, index):
        # Get the text of the article at the given index, converted to lowercase
        text = self.df.loc[index]["Article"].lower()

        return text

    # Method to get the length of the dataset
    def __len__(self):
        # Return the total number of articles in the dataset
        return len(self.df)

In [6]:
# Create AGNews dataset instances for training and testing
dataset_train = AGNews(test_train="train")
dataset_test = AGNews(test_train="test")

In [9]:
# Create data loaders for training and testing datasets
# DataLoader for training dataset
data_loader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=8, drop_last=True)
# DataLoader for testing dataset
data_loader_test = DataLoader(dataset_test, batch_size=batch_size, shuffle=True, num_workers=8)