In [None]:
%pip install datasets
%pip install transformers[torch]
%pip install pytorch_lightning
%pip install sentencepiece

In [None]:
# For Google colab
# from google.colab import drive
# drive.mount('/content/drive')

train_path = "/content/drive/MyDrive/Colab Notebooks/nlp/t5_trained_model"
tokenizer_path = "/content/drive/MyDrive/Colab Notebooks/nlp/t5_tokenizer"

In [None]:
from datasets import load_dataset
train_dataset = load_dataset('wikisql', split='train') # 56,355 samples
valid_dataset = load_dataset('wikisql', split='validation') # 8,421 samples

In [None]:
from torch.utils.data import Dataset
import pandas as pd
import copy

class WikiSQLDataset(Dataset):
    def __init__(self, tokenizer, data, max_len_inp=512,max_len_out=96):
        self.answer = "answer"
        self.question = "question"
        self.data = data #demo purposes

        self.max_len_input = max_len_inp
        self.max_len_output = max_len_out
        self.tokenizer = tokenizer
        self.inputs = []
        self.targets = []
        self._build()

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, index):
        source_ids = self.inputs[index]["input_ids"].squeeze()
        target_ids = self.targets[index]["input_ids"].squeeze()
        #squeeze to get rid of the batch dimension
        src_mask = self.inputs[index]["attention_mask"].squeeze()
        target_mask = self.targets[index]["attention_mask"].squeeze()  # convert [batch,dim] to [dim]

        labels = copy.deepcopy(target_ids)
        labels [labels==0] = -100

        return {"source_ids": source_ids, "source_mask": src_mask,
                "target_ids": target_ids, "target_mask": target_mask,
                "labels":labels}

    def _build(self):
        for row in self.data:
            question = row['question']
            question = f"translate English to SQL: {question}"
            target = row['sql']['human_readable']

            tokenized_inputs = self.tokenizer.batch_encode_plus(
                [question], max_length=self.max_len_input,
                truncation = True,
                padding='max_length', return_tensors="pt"
            )
            # tokenize targets
            tokenized_targets = self.tokenizer.batch_encode_plus(
                [target], max_length=self.max_len_output,
                truncation = True,
                padding='max_length',return_tensors="pt"
            )

            self.inputs.append(tokenized_inputs)
            self.targets.append(tokenized_targets)

In [None]:
from transformers import (
      T5ForConditionalGeneration,
      T5Tokenizer
  )

t5_tokenizer = T5Tokenizer.from_pretrained('t5-base',model_max_length=512)
t5_model = T5ForConditionalGeneration.from_pretrained('t5-base')

train_data = WikiSQLDataset(t5_tokenizer, train_dataset)
validation_data = WikiSQLDataset(t5_tokenizer, valid_dataset)

In [None]:
import pytorch_lightning as pl
from torch.optim import AdamW
from torch.utils.data import DataLoader

class T5Tuner(pl.LightningModule):
    def __init__(self,batchsize, t5model, t5tokenizer):
        super(T5Tuner, self).__init__()
        self.batch_size = batchsize
        self.model = t5model
        self.tokenizer = t5tokenizer


    def forward( self, input_ids, attention_mask=None,
                decoder_attention_mask=None,
                lm_labels=None):

         outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_attention_mask=decoder_attention_mask,
            labels=lm_labels,
        )

         return outputs


    def training_step(self, batch, batch_idx):
        outputs = self.forward(
            input_ids=batch["source_ids"],
            attention_mask=batch["source_mask"],
            decoder_attention_mask=batch['target_mask'],
            lm_labels=batch['labels']
        )

        loss = outputs[0]
        self.log('train_loss',loss)
        return loss

    def validation_step(self, batch, batch_idx):
        outputs = self.forward(
            input_ids=batch["source_ids"],
            attention_mask=batch["source_mask"],
            decoder_attention_mask=batch['target_mask'],
            lm_labels=batch['labels']
        )

        loss = outputs[0]
        self.log("val_loss",loss)
        return loss

    def train_dataloader(self):
        return DataLoader(train_data, batch_size=self.batch_size,
                          num_workers=2)

    def val_dataloader(self):
        return DataLoader(validation_data,
                          batch_size=self.batch_size,
                          num_workers=2)

    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), lr=3e-4, eps=1e-8)
        return optimizer

In [None]:
import torch

device = 'cuda' if torch.cuda.is_available() else "cpu"
print(f"Using device {device}")
# Model Fine-Tuning
bs = 8
model = T5Tuner(bs, t5_model, t5_tokenizer)
ckpt_path_dir = "/content/drive/MyDrive/Colab Notebooks/nlp/t5-checkpoint"
trainer = pl.Trainer(max_epochs = 3, accelerator=device, default_root_dir=ckpt_path_dir, enable_checkpointing=True)
trainer.fit(model)


#save artifacts for deployment and inference
model.model.save_pretrained(train_path)
t5_tokenizer.save_pretrained(tokenizer_path)