<a href="https://colab.research.google.com/github/iwatchkin/Language-modeling/blob/main/Generating_Text_from_Language_Models.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Generating text from language models.**

In [None]:
!pip install datasets

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import nltk
from nltk.tokenize import sent_tokenize
from nltk.tokenize import word_tokenize

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from datasets import load_dataset
from sklearn.model_selection import train_test_split

from tqdm.auto import tqdm
from collections import defaultdict
from typing import List, Set, Dict

In [None]:
nltk.download('punkt')

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

# Data preprocessing.

In [None]:
dataset = load_dataset('imdb')

In [None]:
sentences = []
words_threshold = 50

for sentence in tqdm(dataset['train']['text']):
  sentences.extend([s.lower() for s in sent_tokenize(sentence) if len(s) <
                    words_threshold])

In [None]:
len(sentences)

In [None]:
words = defaultdict(int)

for sentence in tqdm(sentences):
  for word in word_tokenize(sentence):
    words[word] += 1

In [None]:
vocab = set(['<bos>', '<eos>', '<unk>', '<pad>'])
freq_threshold = 250

for word in tqdm(words):
  if words[word] >= freq_threshold:
    vocab.add(word)

In [None]:
print(f'Vocab size: {len(vocab)}')

In [None]:
word2ind = {word: i for i, word in enumerate(vocab)}
ind2word = {i: word for word, i in word2ind.items()}

In [None]:
class WordDatset(Dataset):
  def __init__(self, sentences: List[str], word2ind: Dict[str, int]):
    self.sentences = sentences
    self.word2ind = word2ind
    self.bos_id = word2ind['<bos>']
    self.eos_id = word2ind['<eos>']
    self.unk_id = word2ind['<unk>']
    self.pad_id = word2ind['<pad>']

  def __len__(self) -> int:
    return len(self.sentences)

  def __getitem__(self, index: int) -> List[int]:
    tokenized_sentences = [self.bos_id]
    tokenized_sentences += [self.word2ind.get(word, self.unk_id) for word in
                            word_tokenize(self.sentences[index])]
    tokenized_sentences += [self.eos_id]

    return tokenized_sentences

In [None]:
def collate_fn(input_batch: List[List[int]],
               pad_id: int = word2ind['<pad>'],
               device: str = 'cuda') -> torch.Tensor:
  seq_lens = [len(seq) for seq in input_batch]
  max_seq_len = max(seq_lens)

  batch = []
  for seq in input_batch:
    batch.append(seq + [pad_id] * (max_seq_len - len(seq)))

  batch = torch.LongTensor(batch).to(device)
  new_batch = {
      'input_ids': batch[:,:-1],
      'target_ids': batch[:, 1:]}

  return new_batch

In [None]:
train_sentences, eval_sentences = train_test_split(sentences, train_size=0.8)

train_dataset = WordDatset(train_sentences, word2ind)
eval_dataset = WordDatset(eval_sentences, word2ind)

In [None]:
batch_size = 256

train_dataloader = DataLoader(train_dataset,
                              collate_fn=collate_fn,
                              batch_size=batch_size)

eval_dataloader = DataLoader(eval_dataset,
                              collate_fn=collate_fn,
                              batch_size=batch_size)

# The architecture of the language model.

In [None]:
class LanguageModel(nn.Module):
  def __init__(self, hidden_dim: int, vocab_size: int):
    super().__init__()
    self.embedding = nn.Embedding(vocab_size, hidden_dim)
    self.lstm = nn.LSTM(hidden_dim, hidden_dim, batch_first=True)
    self.linear = nn.Linear(hidden_dim, hidden_dim)
    self.linear_output = nn.Linear(hidden_dim, vocab_size)

    self.dropout = nn.Dropout(p=0.2)
    self.relu = nn.ReLU()

  def forward(self, input_batch) -> torch.Tensor:
    embeddings = self.embedding(input_batch)
    output, _ = self.lstm(embeddings)
    output = self.dropout(self.linear(self.relu(output)))
    output = self.linear_output(self.relu(output))

    return output

# The training loop of the model.

In [None]:
def calculate_perplexity(model,criterion, eval_dataloader) -> float:
  model.eval()
  perplexity = []

  with torch.no_grad():
    for batch in eval_dataloader:
      logits = model(batch['input_ids']).flatten(start_dim=0, end_dim=1)
      target_logits = batch['target_ids'].flatten()
      loss = criterion(logits, target_logits)
      perplexity.append(torch.exp(loss).item())

  perplexity = sum(perplexity) / len(perplexity)

  return perplexity

In [None]:
hidden_dim = 256
vocab_size = len(vocab)

In [None]:
lm = LanguageModel(hidden_dim, vocab_size).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=word2ind['<pad>'])
optimizer = torch.optim.Adam(params=lm.parameters(), lr=0.01)

In [None]:
num_epochs = 10
train_loss = []
perplexities = []

for epoch in range(num_epochs):
  epoch_loss = []
  lm.train()
  for batch in tqdm(train_dataloader):
    optimizer.zero_grad()

    logits = lm(batch['input_ids']).flatten(start_dim=0, end_dim=1)
    target_logits = batch['target_ids'].flatten()
    loss = criterion(logits, target_logits)
    loss.backward()
    optimizer.step()

    epoch_loss.append(loss.item())

  avg_loss = sum(epoch_loss) / len(epoch_loss)
  print(f'Epoch {epoch}: average error per epoch = {avg_loss:.3f}')
  train_loss.append(avg_loss)
  perplexities.append(calculate_perplexity(lm, criterion, eval_dataloader))

In [None]:
plt.plot(range(len(train_loss)), train_loss)
plt.xlabel('epoch')
plt.title('Cross Entropy Loss')

In [None]:
plt.plot(range(len(perplexities)), perplexities)
plt.xlabel('epoch')
plt.title('Perplexities')

# Text generation.

In [None]:
def generate_sequence(model,
                      source_sequence: str,
                      max_num_words: int = 20) -> str:
  device = 'cpu'
  model = model.to(device)
  input_ids = [word2ind['<bos>']] + [word2ind.get(word, word2ind['<unk>']) for
                                     word in word_tokenize(source_sequence)]
  input_ids = torch.LongTensor(input_ids).to(device)

  model.eval()
  with torch.no_grad():
    for _ in range(max_num_words):
      next_word_probabilities = model(input_ids)[-1]
      next_word = next_word_probabilities.squeeze().argmax()
      input_ids = torch.cat([input_ids, next_word.unsqueeze(0)])

      if next_word.item() == word2ind['<eos>']:
        break

  сontinued_sequence = ' '.join([ind2word[i.item()] for i in input_ids])

  return сontinued_sequence