# Download Dataset

In [1]:
! gdown https://drive.google.com/uc?id=1dPCpNIUxqhe2GccLF0tmAapnwgA5Olr2

Downloading...
From: https://drive.google.com/uc?id=1dPCpNIUxqhe2GccLF0tmAapnwgA5Olr2
To: /content/quotes_dataset.csv
165MB [00:00, 180MB/s]


# Import Dependencies

In [2]:
import pandas as pd
import string
import re
import numpy as np

import torch
from torch import nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# Initialize Constants

In [167]:
START_TOKEN = "<str>"
END_TOKEN = "<end>"
PAD_TOKEN = "<pad>"
SPECIAL_TOKEN = "<spc>"
FREQUENCY_THRESOLD = 1

MAX_QUOTE_LEN = 8
MIN_QUOTE_LEN = 6

BATCH_SIZE = 64
EMB_SIZE = 256
NUM_LAYERS = 1
LSTM_SIZE = 256
LEARNING_RATE = 0.001

CUDA = torch.cuda.is_available()
DEVICE = torch.device("cuda") if CUDA else torch.device("cpu")

# Preprocess Text

## Read Data

In [168]:
quote_df = pd.read_csv('quotes_dataset.csv')
quotes = list(quote_df.iloc[:, 0])

  interactivity=interactivity, compiler=compiler, result=result)


## Clean & tokenize text

In [169]:
def clean(text):
  text = str(text)
  text = text.lower().strip()

  text = re.sub(r"won\'t", "will not", text)
  text = re.sub(r"can\'t", "can not", text)

  # general
  text = re.sub(r"n\'t", " not", text)
  text = re.sub(r"\'re", " are", text)
  text = re.sub(r"\'s", " is", text)
  text = re.sub(r"\'d", " would", text)
  text = re.sub(r"\'ll", " will", text)
  text = re.sub(r"\'t", " not", text)
  text = re.sub(r"\'ve", " have", text)
  text = re.sub(r"\'m", " am", text)
  
  for p in string.punctuation:
    text = text.replace(p, " ")
  
  text = re.sub(r"\s{2,}", " ", text)
  return START_TOKEN + " " + text.strip() + " " + END_TOKEN

In [170]:
quotes_cleaned = [clean(i).split() for i in quotes if str(i) != 'nan']

In [171]:
processed_quotes = []

for q in quotes_cleaned:
  if len(q) >= MIN_QUOTE_LEN and len(q) <= MAX_QUOTE_LEN:
    processed_quotes.append(q)

In [172]:
print("Total " + str(len(processed_quotes)) + " quotes selected having word count between " + str(MIN_QUOTE_LEN-2) + " and " + str(MAX_QUOTE_LEN-2))

Total 16393 quotes selected having word count between 4 and 6


## Create word frequencies

In [173]:
word_frequency_dict = {}

for q in processed_quotes:
  for word in q:
    if word not in word_frequency_dict:
      word_frequency_dict[word] = 1
    else:
      word_frequency_dict[word] += 1

## Create word-integer mappings

In [174]:
word_to_int = {}
int_to_word = {}

word_to_int[PAD_TOKEN] = 0
int_to_word[0] = PAD_TOKEN
word_to_int[SPECIAL_TOKEN] = 1
int_to_word[1] = SPECIAL_TOKEN

index = 2
for word, freq in word_frequency_dict.items():
  if freq > FREQUENCY_THRESOLD:
    word_to_int[word] = index
    int_to_word[index] = word

    index += 1

In [175]:
vocab = pd.DataFrame()

vocab['Words'] = list(word_to_int.keys())
vocab['ID'] = list(word_to_int.values())

vocab.to_csv('vocab.csv')

In [176]:
print("The size of vocabulary is " + str(len(word_to_int)))

The size of vocabulary is 3138


## Filter & Pad text


In [177]:
def filter_pad(text_tokens):
  for q in range(len(text_tokens)):
    if text_tokens[q] not in word_to_int:
      text_tokens[q] = SPECIAL_TOKEN

  text_tokens = text_tokens + [PAD_TOKEN] * (MAX_QUOTE_LEN - len(text_tokens))
  
  return text_tokens

In [178]:
processed_quotes = [filter_pad(q) for q in processed_quotes]

## Map tokens to integers

In [179]:
def map_word_to_int(text_tokens):
  return [word_to_int[i] for i in text_tokens]

In [180]:
def map_int_to_word(text_tokens):
  return [int_to_word[i] for i in text_tokens]

In [181]:
processed_quotes = [map_word_to_int(q) for q in processed_quotes]

# Dataset & Dataloader

In [182]:
class QuoteDataset(Dataset):
  def __init__(self, quotes=processed_quotes):
    self.quotes = np.array(quotes)

  def __len__(self):
    return len(self.quotes)

  def __getitem__(self, index):
    quote = self.quotes[index]
    data = {
        'x': torch.from_numpy(quote[:-1]), 
        'y': torch.from_numpy(quote[1:])
        }
    return data

In [183]:
dataset = QuoteDataset(quotes=processed_quotes)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

In [184]:
for batch in dataloader:
  print(batch['x'].shape, batch['y'].shape)
  break

torch.Size([64, 7]) torch.Size([64, 7])


# Model

In [185]:
class Model(nn.Module):
  def __init__(self):
    super(Model, self).__init__()

    self.vocab_size = len(word_to_int)
    self.lstm_size = LSTM_SIZE
    self.embedding_dim = EMB_SIZE
    self.num_layers = NUM_LAYERS

    self.emb = nn.Embedding(
        num_embeddings=self.vocab_size, 
        embedding_dim=self.embedding_dim
        )
    self.lstm = nn.LSTM(
        input_size=self.embedding_dim,
        hidden_size=self.lstm_size,
        num_layers=self.num_layers
        )
    self.fc = nn.Linear(self.lstm_size, self.vocab_size)

  def forward(self, x, prev_state):
    emb = self.emb(x)
    output, state = self.lstm(emb, prev_state)
    y = self.fc(output)
    return y, state

  def init_state(self, seq_length):
    return (
        torch.zeros(self.num_layers, seq_length, self.lstm_size).to(DEVICE),
        torch.zeros(self.num_layers, seq_length, self.lstm_size).to(DEVICE)
    )

# Training & Quote Prediction

In [186]:
def generate(model, max_words=MAX_QUOTE_LEN-2):
  model.eval()
  h, c = model.init_state(seq_length=1)

  x = torch.from_numpy(np.array([word_to_int[START_TOKEN]])).to(DEVICE)
  x = x.unsqueeze(0)

  words = []
  for w in range(max_words):
    y, (h, c) = model(x, (h, c))        
    y = y[0][-1]

    # topk, indices = torch.topk(y, 10)
    # y[[i for i in range(len(y)) if i not in indices]] = 0
    
    p = nn.functional.softmax(y, dim=0).cpu().detach().numpy()
    word_index = np.random.choice(len(y), p=p)

    while word_index == word_to_int[SPECIAL_TOKEN]:
      word_index = np.random.choice(len(y), p=p)        

    if int_to_word[word_index] == END_TOKEN or int_to_word[word_index] == PAD_TOKEN:
      break

    x = torch.from_numpy(np.array([word_index])).to(DEVICE)
    x = x.unsqueeze(0)

    words.append(int_to_word[word_index])

  return words

In [187]:
def train(model, dataloader, epochs):  
  criterion = nn.CrossEntropyLoss()
  optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

  for e in range(epochs):
    model.train()
    loss = 0
    for index, batch in enumerate(dataloader):
      x = batch['x'].to(DEVICE)
      y = batch['y'].to(DEVICE)

      optimizer.zero_grad()

      h, c = model.init_state(seq_length=MAX_QUOTE_LEN-1)
      y_pred, (h, c) = model(x, (h, c))

      batch_loss = criterion(y_pred.transpose(1, 2), y)
      batch_loss.backward()
      optimizer.step()

      loss += batch_loss.item() * x.size(0)
      # print({ 'Epoch': e, 'Batch': index, 'Loss': batch_loss.item() })

    avg_loss = loss / len(dataloader.sampler)
    print('Average loss after ' + str(e+1) + ' epoch = ' + str(avg_loss))

    quote = " ".join(generate(model))
    print('Sample quote --> ' + quote)  

    print('')
  
  torch.save(model.state_dict(), "model.pt")

In [188]:
model = Model().to(DEVICE)
train(model, dataloader, epochs=50)

Average loss after 1 epoch = 4.685294387339527
Sample quote --> love all of business moment is

Average loss after 2 epoch = 4.126872152739096
Sample quote --> words write your situations made me

Average loss after 3 epoch = 3.8973622079489836
Sample quote --> music god is sun free without

Average loss after 4 epoch = 3.73212630170046
Sample quote --> sacred life is a dream sees

Average loss after 5 epoch = 3.6050949697343833
Sample quote --> every stone towards this

Average loss after 6 epoch = 3.4997692533564133
Sample quote --> love your sacred people a pretty

Average loss after 7 epoch = 3.419645624178801
Sample quote --> fear is my life to create

Average loss after 8 epoch = 3.3464082544851874
Sample quote --> custom is power of understanding leads

Average loss after 9 epoch = 3.289383109039249
Sample quote --> follow the best revenge is god

Average loss after 10 epoch = 3.240859399333152
Sample quote --> dwell in hell sunlight is your

Average loss after 11 epoch = 3.2027

In [189]:
model = Model().to(DEVICE)
model.load_state_dict(torch.load("model.pt"))

for _ in range(10):
  generated_quote = " ".join(generate(model)).capitalize()
  print(generated_quote)

Greater accomplishment
The soul is contentment is no
Reach out of spirit of destiny
Praying is the soul of the
Get mad here is a seductive
Increased
Praying is contagious remember about achieving
Stay present stay determined soul has
A poet is god direct your
Change plus time to dream
