In [None]:
! pip install pytorch-lightning==1.5.3 --quiet

In [None]:
%load_ext tensorboard

In [None]:
import requests
import tarfile
from pathlib import Path
from typing import Union, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import torchtext
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

import pytorch_lightning as pl
import numpy as np

import pandas as pd

# The Dataset and DataModule

In [None]:
class WikiQA(Dataset):
    url = "http://www.cs.cmu.edu/~ark/QA-data/data/Question_Answer_Dataset_v1.2.tar.gz"
    tar = "Question_Answer_Dataset_v1.2.tar.gz"
    _mask = None

    def __init__(self, root: Union[str, Path], download=False, train=True, max_length=None):
        super().__init__()
        self.root = Path(root) if isinstance(root, str) else root
        self.train = train
        self.max_length = max_length

        # downloading dataset
        if download:
            WikiQA.download(self.root)

        # creating tokenizer
        self.tokenizer = get_tokenizer("basic_english")

        # reading the three files in the S08, S09 and S10 folders
        frames = []
        for s in ["S08", "S09", "S10"]:
            df = pd.read_csv(self.root/"Question_Answer_Dataset_v1.2"/s/"question_answer_pairs.txt", sep="\t", encoding="iso-8859-1")
            df = df[["Question", "Answer"]]
            df.dropna(inplace=True)
            frames.append(df)
        self.df = pd.concat(frames)
        if max_length:
            self.df = self.df[self.df.apply(
                # subtracting 2 from max length cuz of sos and eos tokens
                lambda row: len(self.tokenizer(row.Question)) <= max_length - 2 and len(self.tokenizer(row.Answer)) <= max_length - 2, 
                axis=1
            )]
        self.df.reset_index(inplace=True, drop=True)

        # using random splitting for now
        # setting the train-test split mask if its not set
        if WikiQA._mask is None:
            WikiQA._mask = np.random.rand(len(self.df)) < 0.8
        # splitting into train or test dataset
        if train:
            self.df = self.df[WikiQA._mask]
        else:
            self.df = self.df[~WikiQA._mask]

        # generating vocab
        def yield_tokens(dataframe: pd.DataFrame):
            for row in dataframe.itertuples():
                yield self.tokenizer(row.Question) + self.tokenizer(row.Answer)
        
        self.vocab = build_vocab_from_iterator(yield_tokens(self.df), specials=["<unk>", "<sos>", "<eos>", "<pad>"])
        self.unk_idx = self.vocab["<unk>"]
        self.eos_idx = self.vocab["<eos>"]
        self.sos_idx = self.vocab["<sos>"]
        self.pad_idx = self.vocab["<pad>"]
        self.vocab.set_default_index(self.unk_idx)

        # defining pipelines
        self.text_pipeline = lambda x: self.vocab(self.tokenizer(x))
        self.label_pipeline = lambda x: self.vocab(self.tokenizer(x))
    
    def __getitem__(self, index):
        row = self.df.iloc[index]
        return self.text_pipeline(row["Question"]), self.label_pipeline(row["Answer"])

    def __len__(self):
        return len(self.df)

    def collate_fn(self):
        def wrapper(batch):
            texts, labels = zip(*batch)
            lengths = torch.LongTensor([len(s) for s in texts])

            # adding the SOS and EOS tokens
            texts = [
                torch.cat([
                    torch.tensor([self.sos_idx]), 
                    torch.tensor(s), 
                    torch.tensor([self.eos_idx])
                ]) for s in texts
            ]
            labels = [
                torch.cat([
                    torch.tensor([self.sos_idx]), 
                    torch.tensor(l), 
                    torch.tensor([self.eos_idx])
                ]) for l in labels
            ]

            texts = torch.nn.utils.rnn.pad_sequence(texts, padding_value=self.pad_idx, batch_first=False)
            labels = torch.nn.utils.rnn.pad_sequence(labels, padding_value=self.pad_idx, batch_first=False)

            if self.max_length:
                # if max_length was set and the number of tokens in this batch is less
                # than max_length, pad the remaining space (this can happen if by chance
                # every sample in the batch has the number of tokens < max_length)
                if texts.size(0) < self.max_length:
                    texts = F.pad(texts, (0, 0, 0, self.max_length - texts.size(0)), value=self.pad_idx)
                if labels.size(0) < self.max_length:
                    labels = F.pad(labels, (0, 0, 0, self.max_length - labels.size(0)), value=self.pad_idx)

            return texts, labels, lengths

        return wrapper

    @staticmethod
    def download(root: Union[str, Path]):
        root = Path(root) if isinstance(root, str) else root
        if root.exists():
            return

        root.mkdir(parents=True)

        # downloading dataset
        res = requests.get(WikiQA.url, stream=True)
        if res.status_code != 200:
            raise Exception("Download failed.")
        with open(str(root/WikiQA.tar), "wb") as f:
            f.write(res.raw.read())

        # extracting dataset
        with tarfile.open(str(root/WikiQA.tar)) as f:
            f.extractall(str(root))


class WikiQADataModule(pl.LightningDataModule):
    def __init__(self, root: Union[str, Path], num_workers: int=2, batch_size: int=128, max_length: int = None):
        self.root = root
        self.num_workers = num_workers
        self.batch_size = batch_size
        self.max_length = max_length
        self.train_dataset = None
        self.test_dataset = None

    def prepare_data(self):
        WikiQA.download(self.root)

    def setup(self, stage: Optional[str]=None):
        self.train_dataset = WikiQA(self.root, train=True, max_length=self.max_length)
        self.test_dataset = WikiQA(self.root, train=False, max_length=self.max_length)

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True,
            collate_fn=self.train_dataset.collate_fn()
        )

    def val_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
            collate_fn=self.test_dataset.collate_fn()
        )
    
    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True,
            collate_fn=self.test_dataset.collate_fn()
        )

# The Models

In [None]:
class Encoder(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.hidden_size = hidden_size
        
        self.embedding = nn.Embedding(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size)

    def forward(self, input):
        # input shape: (seq_length, batch_size)
        embedded = self.embedding(input)

        # embedded shape: (seq_length, batch_size, hidden_size)
        output, hidden = self.gru(embedded)

        # output shape: (seq_length, batch_size, num_directions * hidden_size)
        # hidden shape: (num_directions * num_layers, batch_size, hidden_size)
        #
        # num_directions: 1 (1 for uni-directional rnn, 2 for bi-directional)
        # num_layers: in this model, 1
        #
        # output contains all the hidden states produced for each element of each
        # sequence in the batch, whereas hidden contains the hidden state of the 
        # last element of each sequence in the batch
        return output, hidden

    def init_hidden(self):
        return torch.zeros(1, 1, self.hidden_size)


class AttentionDecoder(nn.Module):
    def __init__(self, hidden_size, output_size, max_length, dropout_p=0.1):
        super().__init__()
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.max_length = max_length
        self.dropout_p = dropout_p

        self.embedding = nn.Embedding(output_size, hidden_size)
        self.attn = nn.Linear(hidden_size * 2, max_length)
        self.attn_combine = nn.Linear(hidden_size * 2, hidden_size)
        self.dropout = nn.Dropout(dropout_p)
        self.gru = nn.GRU(hidden_size, hidden_size)
        self.out = nn.Linear(hidden_size, output_size)

    def forward(self, input, hidden, encoder_outputs):
        # input shape: (batch_size)
        #
        # input contains one token only (seq_length = 1)
        embedded = self.embedding(input).unsqueeze(0)
        embedded = self.dropout(embedded)


        # embedded shape: (1, batch_size, hidden_size)
        # hidden shape: (num_directions * num_layers, batch_size, hidden_size)
        #
        # num_directions: 1 (1 for uni-directional rnn, 2 for bi-directional)
        # num_layers: from encoder, 1 
        attn_weights = F.softmax(
            self.attn(torch.cat((embedded, hidden), 2)),
            dim=2
        )

        # attn_weights shape: (1, batch_size, max_length)
        # encoder_outputs shape: (seq_length, batch_size, num_directions * hidden_size)
        #
        # seq_length must be equal to max_length
        # here, permuting the dimensions so that the batch dimension is first
        attn_applied = torch.bmm(attn_weights.permute(1, 0, 2), encoder_outputs.permute(1, 0, 2))
        attn_applied = attn_applied.permute(1, 0, 2)
    
        # attn_applied shape: (1, batch_size, hidden_size)
        gru_in = self.attn_combine(
            torch.cat((embedded, attn_applied), 2)
        )

        # gru_in shape: (1, batch_size, hidden_size)
        output, hidden = self.gru(gru_in, hidden)
        output = F.relu(output)

        # output shape: (1, batch_size, num_directions * hidden_size)
        output = F.log_softmax(self.out(output.squeeze(0)), dim=1)

        # output shape: (batch_size, output_size)
        # hidden shape: (num_directions * num_layers, batch_size, hidden_size)
        # attn_weights shape: (1, batch_size, max_length)
        return output, hidden, attn_weights

    def init_hidden(self):
        return torch.zeros(1, 1, self.hidden_size)


class Seq2Seq(pl.LightningModule):
    def __init__(self, input_size, hidden_size, output_size, max_length, dropout_p, lr):
        super().__init__()
        self.output_size = output_size
        self.encoder = Encoder(input_size, hidden_size)
        self.decoder = AttentionDecoder(hidden_size, output_size, max_length, dropout_p)
        self.lr = lr
        self.loss = nn.NLLLoss()

    def forward(self, texts, labels, lengths, teacher_forcing_ratio = 0.5):
        # texts shape: (seq_length, batch_size)
        # labels shape: (seq_length, batch_size)
        # lengths shape: (batch_size)

        seq_length = labels.size(0)
        batch_size = labels.size(1)

        # tensor to store decoder outputs
        outputs = torch.zeros(seq_length, batch_size, self.output_size).to(self.device)

        # forward prop encoder
        encoder_outputs, hidden = self.encoder(texts)
        
        # first input of the decoder has to be SOS tokens, so taking
        # the SOS tokens from the labels
        decoder_input = labels[0, :]

        # initial hidden of decoder = last hidden of encoder
        decoder_hidden = hidden

        for t in range(1, seq_length):
            # forward prop decoder
            decoder_output, decoder_hidden, _ = self.decoder(decoder_input, decoder_hidden, encoder_outputs)

            # store decoder output
            outputs[t] = decoder_output

            # use teacher forcing?
            use_teacher_forcing = np.random.rand() < teacher_forcing_ratio
            
            # get highest predicted token
            predicted_token = decoder_output.argmax(1)

            # set next input
            # if use_teacher_forcing: use actual next token as next input
            # else: use predicted token
            decoder_input = labels[t] if use_teacher_forcing else predicted_token

        return outputs

    def _common_step(self, batch, batch_idx):
        texts, labels, lengths = batch
        output = self(texts, labels, lengths)

        # labels shape: (seq_length, batch_size)
        # output shape: (seq_length, batch_size, output_size)
        loss = self.loss(
            output.view(-1, self.output_size),
            labels.view(-1)
        )

        return {
            "loss": loss
        }

    def training_step(self, batch, batch_idx):
        return self._common_step(batch, batch_idx)

    def training_epoch_end(self, outputs):
        log_metrics = {
            "train_loss": torch.stack([x["loss"] for x in outputs]).mean()
        }

        self.log_dict(log_metrics, prog_bar=True)

    def validation_step(self, batch, batch_idx):
        return self._common_step(batch, batch_idx)

    def validation_epoch_end(self, outputs):
        log_metrics = {
            "val_loss": torch.stack([x["loss"] for x in outputs]).mean()
        }

        self.log_dict(log_metrics, prog_bar=True)
        return log_metrics

    def test_step(self, batch, batch_idx):
        return self.validation_step(batch, batch_idx)
    
    def test_epoch_end(self, outputs):
        log_metrics = {
            "test_loss": torch.stack([x["loss"] for x in outputs]).mean()
        }

        self.log_dict(log_metrics, prog_bar=True)
        return log_metrics

    def configure_optimizers(self):
        return torch.optim.SGD(self.parameters(), lr=self.lr)


In [None]:
max_length = 12
data_module = WikiQADataModule(root="data", batch_size=256, num_workers=2, max_length=max_length)
data_module.prepare_data()
data_module.setup()

In [None]:
config = {
    "input_size": len(data_module.train_dataset.vocab),
    "hidden_size": 256,
    "output_size": len(data_module.train_dataset.vocab),
    "max_length": max_length,
    "dropout_p": 0.1,
    "lr": 0.01
}

In [None]:
model = Seq2Seq(**config)
trainer = pl.Trainer(gpus=1, max_epochs=100, reload_dataloaders_every_epoch=False, log_every_n_steps=10)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [None]:
%tensorboard --logdir lightning_logs

In [None]:
trainer.fit(model, data_module)

  f"DataModule.{name} has already been called, so it will not be called again. "
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type             | Params
---------------------------------------------
0 | encoder | Encoder          | 1.1 M 
1 | decoder | AttentionDecoder | 2.0 M 
2 | loss    | NLLLoss          | 0     
---------------------------------------------
3.1 M     Trainable params
0         Non-trainable params
3.1 M     Total params
12.555    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

  f"The number of training samples ({self.num_training_batches}) is smaller than the logging interval"


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

In [None]:
# ! tensorboard dev upload --logdir logs --name "END3 Session 6 WikiQA Model"