<a href="https://colab.research.google.com/github/lkarjun/malayalam-language-model/blob/main/Malayalam-Language-Model/malayalam-language-model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Malayalam Language Model

### Imports

In [None]:
!pip install -qq ohmeow-blurr datasets

In [46]:
from fastai.text.all import *

from transformers import (BertLMHeadModel,
                          BertConfig,
                          PreTrainedTokenizerFast)

from blurr.modeling.language_modeling import (HF_LMBeforeBatchTransform, 
                                              HF_CausalLMInput, 
                                              HF_MLMInput,
                                              BertMLMStrategy,
                                              CausalLMStrategy)

from blurr.modeling.core import (HF_BaseModelWrapper, 
                                 HF_TextBlock,
                                 HF_PreCalculatedLoss, 
                                 hf_splitter,
                                 hf_splitter,HF_BaseModelCallback,

                                 )

### Dataset Loading

In [4]:
from datasets import load_dataset

In [5]:
dset = load_dataset("lkarjun/Malayalam-Articles")

Using custom data configuration lkarjun--Malayalam-Articles-d44c52244000c266


Downloading and preparing dataset csv/lkarjun--Malayalam-Articles to /root/.cache/huggingface/datasets/csv/lkarjun--Malayalam-Articles-d44c52244000c266/0.0.0/6b9057d9e23d9d8a2f05b985917a0da84d70c5dae3d22ddd8a3f22fb01c69d9e...


  0%|          | 0/2 [00:00<?, ?it/s]

Downloading:   0%|          | 0.00/252M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/81.3M [00:00<?, ?B/s]

  0%|          | 0/2 [00:00<?, ?it/s]

Dataset csv downloaded and prepared to /root/.cache/huggingface/datasets/csv/lkarjun--Malayalam-Articles-d44c52244000c266/0.0.0/6b9057d9e23d9d8a2f05b985917a0da84d70c5dae3d22ddd8a3f22fb01c69d9e. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]

In [6]:
train = dset['train'].to_pandas()[:1000]
train.dropna(inplace=True)

### Tokenizer

In [8]:
MalayalamTokenizer = PreTrainedTokenizerFast.from_pretrained("lkarjun/malayalam-language-model")

Downloading:   0%|          | 0.00/158 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.27M [00:00<?, ?B/s]

### Model Config

Default Bert Config 

( vocab_size = 30522 hidden_size = 768 num_hidden_layers = 12 num_attention_heads = 12 intermediate_size = 3072 hidden_act = 'gelu' hidden_dropout_prob = 0.1 attention_probs_dropout_prob = 0.1 max_position_embeddings = 512 type_vocab_size = 2 initializer_range = 0.02 layer_norm_eps = 1e-12 pad_token_id = 0 position_embedding_type = 'absolute'use_cache = Truec lassifier_dropout = None **kwargs )

In [25]:
CONFIG = BertConfig(
            num_attention_heads = 8, 
            intermediate_size = 2072, 
            hidden_size = 568,
            vocab_size=MalayalamTokenizer.vocab_size,
            pad_token_id=MalayalamTokenizer.pad_token_id,
            is_decoder=True,
            name_or_path = "lkarjun/malayalam-language-model",
            )


In [26]:
LModel = BertLMHeadModel(CONFIG)

### Dataloder

In [27]:
train_bs, val_bs, train_sl, val_sl = 200, 256, 250, 300

In [28]:
splits = RandomSplitter(valid_pct=.1, seed=7)(train)
splits

((#897) [527,870,283,292,430,133,54,42,21,646...],
 (#99) [843,493,393,713,391,671,743,621,396,47...])

In [30]:
before_batch_tfm = HF_LMBeforeBatchTransform(
                            hf_arch = 'bert',
                            hf_config = CONFIG,
                            hf_tokenizer = MalayalamTokenizer,
                            hf_model = LModel,
                            lm_strategy_cls=BertMLMStrategy
                    )

In [31]:
block = HF_TextBlock(before_batch_tfm = before_batch_tfm, 
                     input_return_type = HF_MLMInput), noop
          

In [32]:
Mdblock = DataBlock(
          blocks = block,
          get_x = ColReader('content'),
          splitter = RandomSplitter(valid_pct=.1, seed=7)
      )

In [51]:
dls = Mdblock.dataloaders(train, bs=3, max_seq_len = 500)

In [None]:
dls.show_batch(dataloaders = dls, max_n = 2, trunc_at=100)

In [52]:
b = dls.one_batch()
b[0]['input_ids'].shape, b[0]['labels'].shape, b[1].shape

(torch.Size([3, 6682]), torch.Size([3, 6682]), torch.Size([3, 6682]))

### Training

In [37]:
from fastai.text.all import *

In [53]:
learn = Learner(dls, 
                HF_BaseModelWrapper(LModel),
                opt_func=partial(Adam, decouple_wd=True),
                loss_func= HF_PreCalculatedLoss(),
                cbs=[HF_BaseModelCallback],
                metrics=Perplexity(),
                splitter=hf_splitter).to_fp16()