<a href="https://colab.research.google.com/github/mushrafi88/asr_bangla/blob/main/gpt_2_asr_bn_corrector.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


# **Install libraries**

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [2]:
!pip install huggingface
!pip install datasets tqdm pandas
!pip install sentencepiece
!pip install transformers
!pip install wandb
!pip install deepspeed
!pip install pytorch_lightning

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting huggingface
  Downloading huggingface-0.0.1-py3-none-any.whl (2.5 kB)
Installing collected packages: huggingface
Successfully installed huggingface-0.0.1
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting datasets
  Downloading datasets-2.8.0-py3-none-any.whl (452 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m452.9/452.9 KB[0m [31m9.3 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0.0,>=0.2.0
  Downloading huggingface_hub-0.11.1-py3-none-any.whl (182 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m182.4/182.4 KB[0m [31m24.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting xxhash
  Downloading xxhash-3.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (213 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m213.0/213.0 KB[0m [31m23.1 MB/s[0m

In [3]:
# hide
import datetime
import os
from pathlib import Path
import random
from typing import Any, Dict, List, Optional, Tuple

import datasets
from deepspeed.ops import adam
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pydantic
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import random_split, DataLoader
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torchvision import models
from tqdm.auto import tqdm
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
import wandb

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
os.environ["TOKENIZERS_PARALLELISM"] = "true"

print(pl.__version__, torch.__version__, transformers.__version__)

1.8.6 1.13.0+cu116 4.25.1


In [4]:
pd.set_option('display.max_colwidth', None)

In [5]:
df = pd.read_csv('/content/drive/MyDrive/bert_asr/bert_corrector/data/train/raw.csv')
df = df[['wav2vec2','sentence']]
df = df.dropna(how='any')

In [6]:
#hide
LEARNING_RATE = 1e-4
EPOCHS = 1
BATCH_SIZE = 12
MAX_LEN = 256
LANGUAGE_MODEL = "flax-community/gpt2-bengali"
LOG_PATH = "/content/kaggle/working/logs/"
FREEZE_LAYERS = 2
UNFREEZE_LAYERS = False
UNFREEZE_BATCH_IDX = 1000
LABEL_MASK = -100
NUM_BATCHES = 10_000

In [7]:
#collapse-show
class Tokenizer:
    def __init__(self, tokenizer, max_len: int):
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.bos = tokenizer.bos_token
        self.eos = tokenizer.eos_token
        self.sep = tokenizer.sep_token
        self.num_special_tokens = len(self.tokenizer.all_special_tokens)
        
    def __getattr__(self, attribute: str):
        if hasattr(self.tokenizer, attribute):
            return getattr(self.tokenizer, attribute)
        else:
            raise AttributeError(f"{attribute} not found")

    def __call__(self, input_sentences: List[str], output_sentences: Optional[List[str]]=None, device:torch.device=None) -> AutoTokenizer:
        if output_sentences is None:
            sentences = [self.bos + x + self.sep for x in input_sentences]
        else:
            sentences = [self.bos + x + self.sep + y + self.eos for x, y in zip(input_sentences, output_sentences)]
        
        tokenized = self.tokenizer(
            sentences, 
            truncation=True,
            padding=True,
            return_tensors="pt",
            max_length=self.max_len,
        )
        if device is not None:
            return {key: tensor.to(device) for key, tensor in tokenized.items()}
        return tokenized

    def decode(self, x: Dict[str, torch.LongTensor]):
        return [self.tokenizer.decode(sentence[:sentence_len]) for sentence, sentence_len in 
                zip(x["input_ids"], target["attention_mask"].sum(axis=-1))]
    
    def batch_decode(self, encoded_outputs: torch.LongTensor) -> List[str]:
        return self.tokenizer.batch_decode(encoded_outputs.cpu(), skip_special_tokens=True)
    
    def __len__(self):
        return len(self.tokenizer)


# get text base and transform
language_model = AutoModelForCausalLM.from_pretrained(LANGUAGE_MODEL)
tokenizer = Tokenizer(
    AutoTokenizer.from_pretrained(
        LANGUAGE_MODEL, 
        bos_token="<|startoftext|>",
        eos_token="<|endoftext|>", 
        pad_token="<|pad|>", 
        sep_token="<|sep|>"
    ),
    MAX_LEN,
)
language_model.resize_token_embeddings(len(tokenizer))

Downloading:   0%|          | 0.00/864 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/510M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.76M [00:00<?, ?B/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Embedding(50259, 768)

In [8]:
# hide
class GeneratorConfig(pydantic.BaseModel):
    bos_token_id: int
    pad_token_id: int
    eos_token_id: int
    repetition_penalty: float = 1.2
    beam_search: bool = True
    num_beam: int = 5
    early_stopping: bool = True
    max_generated_len: int = MAX_LEN
    no_repeat_ngram_size: int = 2
    top_k: int = 2000
    top_p: float = 0.95

    def build_generator_kwargs(self) -> Dict[str, Any]:
        common_params = {
            "bos_token_id": self.bos_token_id,
            "pad_token_id": self.pad_token_id,
            "eos_token_id": self.eos_token_id,
        }   
        if self.beam_search:
            return {
                **common_params,
                **{
                    "max_length": self.max_generated_len,
                    "num_beams": self.num_beam,
                    "no_repeat_ngram_size": self.no_repeat_ngram_size,
                    "early_stopping": self.early_stopping,
                    "repetition_penalty": self.repetition_penalty,
                }
            }
        else:
            return {
                **common_params,  
                **{
                    "max_length": self.max_generated_len,
                    "do_sample": True,
                    "top_k": self.top_k,
                    "top_p": self.top_p,
                    "early_stopping": self.early_stopping,
                    "repetition_penalty": self.repetition_penalty,
                }
            }

In [9]:
df = df.rename(columns={"wav2vec2": "input", "sentence": "output"})

In [10]:
df=df.reset_index(drop=True)

In [11]:
from sklearn.model_selection import train_test_split
train_df, test_df = train_test_split(df, test_size=0.10, shuffle=True)
train_df.shape, test_df.shape

((155459, 2), (17274, 2))

In [12]:
test_df.head()

Unnamed: 0,input,output
3170,পেয়েছেন ছাববি সাজার দুইষ উনুসার ভোট,পেয়েছেন হাজার ভোট
121840,মেট্রোরেল প্রকল্পের কারণে দেশের অন্যতম অন্যতম বিরল ই মুখল কাঠাম ভুমকির মধ্যে রয়েছে,মেট্রোরেল প্রকল্পের কারণে দেশের অন্যতম অন্যতম বিরল এই মুঘল কাঠামো হুমকির মধ্যে রয়েছে
24671,তান জান পুটি,তানজাং পুটি
23526,ডিভিলিয়ার্স মাত্র বতরিশ বলে একচার,ডিভিলিয়ার্স মাত্র বলে চার
17677,আছে মধ্যপ্রদেশের,আছে মধ্য প্রদেশের


### We will use a token length of 64 since it will cover the vast majority of examples

In [13]:
from datasets import Dataset
train_dataset = Dataset.from_pandas(train_df)
test_dataset = Dataset.from_pandas(test_df)

In [14]:
train_dataset.save_to_disk('/content/train')
test_dataset.save_to_disk('/content/test')

Saving the dataset (0/1 shards):   0%|          | 0/155459 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/17274 [00:00<?, ? examples/s]

In [15]:
train_dataset[1]

{'input': 'এরপর এই চলচ্চিত্রের হিন্দি পুনর নির্মাণ দিয়ে তার বলিউডে অভিষেক ঘটে',
 'output': 'এরপর এই চলচ্চিত্রের হিন্দি পুনর্নির্মাণ দিয়ে তার বলিউডে অভিষেক ঘটে',
 '__index_level_0__': 154134}

### Load the Dataset

In [16]:
max_len = 64 

In [17]:
train_dataset

Dataset({
    features: ['input', 'output', '__index_level_0__'],
    num_rows: 155459
})

In [18]:
train_dataset = train_dataset.remove_columns("__index_level_0__")
test_dataset = test_dataset.remove_columns("__index_level_0__")

In [19]:
def group_batch(batch):
    return {k: [v] for k, v in batch.items()}
train_dl = train_dataset.map(group_batch, batched=True, batch_size=16)
valid_dl  = test_dataset.map(group_batch, batched=True, batch_size=16)

  0%|          | 0/9717 [00:00<?, ?ba/s]

  0%|          | 0/1080 [00:00<?, ?ba/s]

In [25]:
#collapse-show
class LightningModule(pl.LightningModule):
    def __init__(
        self,
        model: nn.Module,
        tokenizer: Tokenizer,
        generation_kwargs: Dict[str, Any],
        lr: float = 1e-3,
    ) -> None:
        super().__init__()
        self.model = model
        self.tokenizer = tokenizer
        self.lr = lr
        self.generation_kwargs = generation_kwargs
        self.original_embed_weights = self.model.transformer.wte.weight[:-self.tokenizer.num_special_tokens].clone()
        
        for layer in self.model.transformer.h[:FREEZE_LAYERS]:
            layer.eval()
            for p in layer.parameters():
                p.requires_grad = False
        
        self.table_logging = 0
        
    def common_step(self, batch: Dict[str, torch.LongTensor]) -> torch.Tensor:
        good_grammar_batch = self.tokenizer(batch["output"], batch["output"], self.device)
        good_grammar_labels = good_grammar_batch["input_ids"].clone()
        good_grammar_labels[good_grammar_batch["attention_mask"] == 0] = LABEL_MASK
        mask = (good_grammar_labels == self.tokenizer.sep_token_id).roll(shifts=1, dims=-1).cumsum(dim=-1) == 0
        good_grammar_labels[mask] = LABEL_MASK
        
        bad_grammar_batch = self.tokenizer(batch["input"], batch["output"], self.device)
        bad_grammar_labels = bad_grammar_batch["input_ids"].clone()
        bad_grammar_labels[bad_grammar_batch["attention_mask"] == 0] = LABEL_MASK
        mask = (bad_grammar_labels == self.tokenizer.sep_token_id).roll(shifts=1, dims=-1).cumsum(dim=-1) == 0
        bad_grammar_labels[mask] = LABEL_MASK

        good_grammar_out = self.model(
            **good_grammar_batch,
            labels=good_grammar_labels,
        )
        bad_grammar_out = self.model(
            **bad_grammar_batch,
            labels=bad_grammar_labels,
        )
        return good_grammar_out.loss + bad_grammar_out.loss
        
    def training_step(
        self, batch: Dict[str, torch.LongTensor], batch_idx: int,
    ) -> torch.Tensor:
        if (batch_idx + 1) % 100 == 0:
            self.model.transformer.wte.weight[:-self.tokenizer.num_special_tokens].data = self.original_embed_weights
            
        loss = self.common_step(batch)     
        self.log("training_loss", loss, on_step=True, on_epoch=True, batch_size=len(batch["input"]))
             
        return loss

    def validation_step(
        self, batch: Tuple[torch.Tensor, List[str]], batch_idx: int,
    ) -> torch.Tensor:
        loss = self.common_step(batch)
        self.log("validation_loss", loss, on_step=False, on_epoch=True, batch_size=len(batch["input"]))
        
        if batch_idx == 0:
            self.log_examples(batch)
            
    def log_examples(self, batch):
        good_grammar_batch = self.tokenizer(batch["output"], device=self.device)
        bad_grammar_batch = self.tokenizer(batch["input"], device=self.device)
        encoded_good_outputs = self.model.generate(**good_grammar_batch, **self.generation_kwargs)
        encoded_bad_outputs = self.model.generate(**bad_grammar_batch, **self.generation_kwargs)
        generated_good_sentences = self.tokenizer.batch_decode(encoded_good_outputs)
        generated_bad_sentences = self.tokenizer.batch_decode(encoded_bad_outputs)
        
        data = list(map(list, zip(batch["output"] + batch["input"], generated_good_sentences + generated_bad_sentences)))
        columns = ["Actual Sentence", "Generated Sentence"]
        #data = [[x, y.split(x)[1]] for x, y in data]
        table = wandb.Table(data=data, columns=columns)
        if self.logger is not None:
            self.table_logging += 1
            self.logger.experiment.log({f"epoch {self.table_logging} results": table})

    def configure_optimizers(self) -> torch.optim.Optimizer:
        caption_params = [
            {"params": self.model.transformer.ln_f.parameters() , "lr": self.lr},
            {"params": self.model.transformer.h[FREEZE_LAYERS:].parameters() , "lr": self.lr},
            {"params": self.model.transformer.wte.parameters() , "lr": self.lr},
        ]
        return adam.FusedAdam(caption_params)

In [29]:
#hide
generator_config = GeneratorConfig(
    bos_token_id=tokenizer.bos_token_id, 
    pad_token_id=tokenizer.pad_token_id, 
    eos_token_id=tokenizer.eos_token_id
)

lightning_module = LightningModule(
    language_model, 
    tokenizer, 
    generation_kwargs=generator_config.build_generator_kwargs(), 
    lr=LEARNING_RATE
)
is_interactive = True
logger = None if is_interactive else pl.loggers.WandbLogger(str(datetime.datetime.now().date()), LOG_PATH, project="Grammar_Correction")
trainer = pl.Trainer(
    max_epochs=EPOCHS,
    gpus=torch.cuda.device_count(),
    gradient_clip_val=1.0,
    precision=16,
    #save_weights_only=True,
    num_sanity_val_steps=0,
    logger=logger,
    enable_progress_bar=is_interactive,
    log_every_n_steps=200,
    limit_train_batches=20 if is_interactive else NUM_BATCHES,
    limit_val_batches=3 if is_interactive else 1.0,
    val_check_interval=UNFREEZE_BATCH_IDX if not is_interactive else 4,
)
trainer.fit(lightning_module, train_dl, valid_dl) #

TypeError: ignored

In [28]:
trainer.save()

AttributeError: ignored

In [30]:
model = ModelClass.load_from_checkpoint('/content/checkpoints/epoch=0-step=20.ckpt')

NameError: ignored

In [31]:

torch.save(
    model.input_embeddings.state_dict(),
    "input_embeddings.pt"
)
torch.save(model.mlp.state_dict(), "mlp.pt")

NameError: ignored