# Text Generation Demo on Pytoch Lightning: Date Generation (One-to-Many)

In this demo, we will show you how to create a text generator using Pytoch Lightning. This demo is inspired by Andrew Ng's deeplearning.ai course on sequence models. In this demo, we create a one-to-many RNN model for generating date in the following format: e.g. "2002-03-11".  

In [None]:
import csv
import numpy as np
import random
import math
import sys

import torchtext
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
!pip install pytorch_lightning
import pytorch_lightning as pl
from pytorch_lightning import Trainer

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


# Generate Dataset
We generate a toy dataset using datetime library.  The target output only comes in one format (iso format). 

In [None]:
#Generating a toy dataset
import datetime
base = datetime.datetime.today()
base = datetime.date(base.year, base.month, base.day)
date_list = [base - datetime.timedelta(days=x) for x in range(0, 1500)]
data = [date.isoformat() for date in date_list] 
print(data[:5])
maxlen=10 #all the seqeunces have 10 characters

['2023-01-03', '2023-01-02', '2023-01-01', '2022-12-31', '2022-12-30']


In [None]:
chars = list(set(''.join(data)))
data_size, vocab_size = len(data), len(chars)
print('There are %d lines and %d unique characters in your data.' % (data_size, vocab_size))
print("max length =",maxlen)
sorted_chars= sorted(chars)
print(sorted_chars)

There are 1500 lines and 11 unique characters in your data.
max length = 10
['-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9']


In [None]:
# In this demo, we will use "<S>" as a seed character to initiate the sequence
sorted_chars.insert(0,"<S>") 
print(sorted_chars)
vocab_size = len(sorted_chars)
print(vocab_size)

['<S>', '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
12


In [None]:
vocab = torchtext.vocab.vocab({})
for char in sorted_chars: vocab.append_token(char) 

In [None]:
print(vocab.get_itos()) 

['<S>', '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9']


In [None]:
print(vocab.get_stoi())

{'7': 9, '6': 8, '8': 10, '5': 7, '9': 11, '4': 6, '-': 1, '<S>': 0, '0': 2, '1': 3, '2': 4, '3': 5}


# Preprocessing data

In [None]:
#Encoding data
encoded = []
for line in data:
    line = [l for l in line] #change from string to list
    indices = vocab(line)
    encoded.append(indices)
  

In [None]:
class DateDataset(Dataset):
  def __init__(self, data):
    data = [[0] + d for d in data] # add <s> at the start of every data point
    self.encoded = torch.LongTensor(data)
    
  def __getitem__(self, idx):
    return self.encoded[idx]

  def __len__(self):
    return len(self.encoded)

In [None]:
class DateDataModule(pl.LightningDataModule):

  def __init__(self, train_data, batch_size, num_workers=0):
      super().__init__()
      self.train_data = train_data
      self.batch_size = batch_size
      self.num_workers = num_workers
 

  def setup(self, stage: str):
    pass

  def collate_fn(self, batch):
      one_hot_x = torch.stack([F.one_hot(b, num_classes=len(vocab)) for b in batch])
      return {"x": one_hot_x.float(), "y": torch.stack(batch)}

  def train_dataloader(self):
      train_dataset = DateDataset(self.train_data)
      train_loader = DataLoader(train_dataset, 
                                batch_size = self.batch_size, 
                                shuffle = True, 
                                collate_fn = self.collate_fn,
                                num_workers = self.num_workers)
      
      return train_loader
    
 

In [None]:
batch_size = 16
data_module = DateDataModule(encoded, batch_size=batch_size,num_workers=0)

# Create & train model


In [None]:
class SimpleRNN(pl.LightningModule):
    def __init__(self, vocab_size, learning_rate, criterion):
                
        super().__init__()
        self.hidden_dim = 16
        self.vocab_size = vocab_size
        self.rnn = nn.RNNCell(self.vocab_size, self.hidden_dim)
       
        self.fc = nn.Linear(self.hidden_dim, self.vocab_size)
        self.learning_rate = learning_rate
        self.criterion = criterion


    def forward(self, src, hx):
        hx = self.rnn(src, hx)
        prediction_logit = self.fc(hx)
        return prediction_logit, hx

    def training_step(self, batch, batch_idx):
        src = batch['x'][:, :-1]
        target = batch['y'][:, 1:]
        temp = []
        hx = torch.randn(src.shape[0], self.hidden_dim).to(self.rnn.weight_ih.device)
        prediction = torch.zeros((src.shape[0], src.shape[1], self.vocab_size) ,device=hx.device)
        
        for i in range(src.shape[1]):
          prediction_logit, hx = self(src[:,i], hx)
          prediction[:, i, :] = prediction_logit

        prediction = prediction.reshape(-1, vocab_size)
        target = target.reshape(-1)
        loss = self.criterion(prediction, target)
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.learning_rate)

In [None]:
criterion = nn.CrossEntropyLoss()
vocab_size = len(vocab)
lr = 0.005
model = SimpleRNN(vocab_size, lr, criterion)

In [None]:
def generate(model):
  model.eval()
  with torch.no_grad():
    output_list = []
    input = F.one_hot(torch.zeros([1], dtype=torch.long), num_classes=len(vocab))
    input = input.float()
    input = input.to(model.device)
    hx = torch.randn(input.shape[0], 16).to(model.device)
    for i in range(10):
      logit, hx = model(input, hx)
      prob = F.softmax(logit, dim=-1)
      pred = torch.multinomial(prob, 1)
      output = pred.item()
      output_list.append(output)

      input = F.one_hot(torch.tensor([output], dtype=torch.long), num_classes=len(vocab))
      input = input.float()
      input = input.to(model.device)
  return "".join(vocab.lookup_tokens(output_list))

In [None]:
class PrintCallback(pl.callbacks.Callback):
  def __init__(self, what="epochs", verbose=True):
        self.what = what
        self.verbose = verbose
        self.state = {"epochs": 0, "batches": 0}

  def on_train_epoch_end(self, *args, **kwargs):
        if self.what == "epochs":
            self.state["epochs"] += 1
        if self.state["epochs"] % 2 == 0:
            print('----- Generating text after Epoch: %d' % self.state["epochs"])
            for i in range(3):
              print(generate(model))


In [None]:
trainer = Trainer(
    max_epochs=10,
    gpus=1,
    callbacks=[PrintCallback()]
)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


# Let's train the model and generate some text

In [None]:
for i in range(3): #before training
  print(generate(model))

3<S>39908261
6308535277
2<S>92761-63


In [None]:
trainer.fit(model, data_module)

INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name      | Type             | Params
-----------------------------------------------
0 | rnn       | RNNCell          | 480   
1 | fc        | Linear           | 204   
2 | criterion | CrossEntropyLoss | 0     
-----------------------------------------------
684       Trainable params
0         Non-trainable params
684       Total params
0.003     Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

----- Generating text after Epoch: 2
2029-11-21
2021-02-11
2019-06-14
----- Generating text after Epoch: 4
2022-10-04
2011-08-04
2021-12-17
----- Generating text after Epoch: 6
2022-02-15
2020-11-27
2022-06-05
----- Generating text after Epoch: 8
2062-07-27
2019-07-14
2019-01-15


INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=10` reached.


----- Generating text after Epoch: 10
2022-07-25
2021-02-03
2021-11-21


In [None]:
for i in range(10):
  print(generate(model))

2021-09-32
2022-01-04
2019-10-12
2019-08-27
2020-02-25
2020-03-10
2019-12-02
2019-03-21
2022-10-24
2019-10-22
