<a href="https://colab.research.google.com/github/deterministic-algorithms-lab/NLP-Journey/blob/main/LanguageModelling/CLM_MLM_TLM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
!git clone https://github.com/deterministic-algorithms-lab/NLP-Journey
%cd NLP-Journey
!pip install -r requirements.txt

In [None]:
import jax
import jax.numpy as jnp
import haiku as hk
import optax

import numpy as np
from functools import partial

In [None]:
import src.DataLoaders.tfds as tfdl
from src.Tokenizers.hf_tokenizer import LM_Tokenizer
from src.model.transformer import LogitsTransformer
from src.optimizers.adam import get_adam_opt
from src.Tokenizers.masking_utils import mask_batch_mlm

## Setting Up Config

In [None]:
config = {
          #Data Parameters
          'max_length' : 512, 
          'batch_size' : 4,

          #Model Parameters
          'intermediate_size' : 3072,
          'n_heads' : 12,
          'n_layers' : 12,
          'hidden_size' : 768,
          'd_model' : 768,                                                      #same as hidden_size
          
          #Embeddings Parameters
          'embed_dropout_rate' : 0.1,
          'lang2id' : {'en' : 1, 'ne' : 2},
          
          #MHA parameters
          'attention_drop_rate' : 0.1,
          
          #MLP parameters
          'fully_connected_drop_rate' : 0.1,
          
          #Training Parameters
          'learning_rate' : 1e-5,
          'max_grad_norm' : 1.0,
          'l2' : 0.1,
          'n_epochs' : 5,
          'n_examples' : 25000,

          #Task no.
          'mlm' : 0,
          'clm' : 1,
          }


## Getting Data

In [None]:
imdb_ds = tfdl.load_tf_dataset(config, training=True, split='train', n_epochs=3, n_examples=25000)                                  #For MLM, CLM
flores_neen = tfdl.load_tf_dataset(config, training=True, split='test', n_epochs=50, n_examples=-1, name='flores/neen_plain_text')   #For TLM

## Training Tokenizer


In [None]:
def enne_iter():
    for elem1, elem2 in zip(flores_neen, imdb_ds):
        yield elem1['en']
        yield elem1['ne']
        yield elem1['ne']
        yield elem2['text']

In [None]:
lm_tokeniser = LM_Tokenizer(config)
lm_tokeniser.train_tokenizer(binary_iterator=enne_iter())

In [None]:
print(lm_tokeniser.tokenizer.get_vocab())

### Updating Config

In [None]:
config['vocab_size'] = lm_tokeniser.tokenizer.get_vocab_size()

#Tokenization ids  
config['mask_id'] = lm_tokeniser.tokenizer.token_to_id("<mask>")
config['pad_id'] = lm_tokeniser.tokenizer.token_to_id("<pad>")
config['sos_id'] = lm_tokeniser.tokenizer.token_to_id("<s>")
config['eos_id'] = lm_tokeniser.tokenizer.token_to_id("</s>")
config = hk.data_structures.to_immutable_dict(config)

## Language Model

In [None]:
def logits_fn(masked_token_ids, lang_ids=None, training=True, task=config['mlm']):
     logits = LogitsTransformer(config)(masked_token_ids, lang_ids, 
                                       training=training, 
                                       is_autoregressive=(task==config['clm']))
     return logits

key, subkey = jax.random.split( jax.random.PRNGKey(42) )
pure_logits_fn = hk.transform(logits_fn)

token_encoding = lm_tokeniser.batch_encode_plus(['sample sentence', 'Another one!', "we need to make", "this equal to batch size"])

token_ids = np.asarray(lm_tokeniser.get_token_ids(token_encoding), dtype=np.int16)
lang_ids = np.asarray(lm_tokeniser.get_lang_ids(token_encoding), dtype=np.int16)

masked_token_ids, original_batch = mask_batch_mlm(subkey, config, token_ids)

key, subkey = jax.random.split(key)
params = pure_logits_fn.init(subkey, masked_token_ids, lang_ids=lang_ids)

In [None]:
def loss(params, key, original_batch, masked_token_ids, lang_ids=None, task=config['mlm']) :
    
    logits = pure_logits_fn.apply(params, key, 
                                  masked_token_ids, lang_ids=lang_ids,
                                  training=True, task=task)
    
    logits_mask = (original_batch!=config['pad_id'])

    if task==config['clm']:
        logits = logits[:,:-1,:]
        original_batch = original_batch[:,1:]
        logits_mask = logits_mask[:,1:]    
    else :
        logits_mask = jnp.bitwise_or( logits_mask,
                                      (masked_token_ids!=config['mask_id']) )
        
    logits = jax.vmap(jnp.multiply, (None,2), 2)(logits_mask,logits)
    labels = hk.one_hot(original_batch, config['vocab_size'])
    softmax_xent = -jnp.sum(labels*jax.nn.log_softmax(logits))
    
    total_masks = jnp.sum(logits_mask)
    if total_masks == 0:
        return jnp.zeros(())
    softmax_xent /= total_masks
    return softmax_xent

@partial(jax.jit, static_argnums=(5,))
def update(params, rng, opt_state, original_batch, masked_token_ids, task, lang_ids=None):
    batch_loss, grad = jax.value_and_grad(loss)(params, rng, original_batch, masked_token_ids, 
                                                lang_ids=lang_ids, task=task)
    updates, opt_state = opt.update(grad, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, opt_state, batch_loss

@partial(jax.jit, static_argnums=(5,))
def accuracy(params, rng, original_batch, masked_token_ids, task, lang_ids=None):
    logits = LogitsTransformer(config)(masked_token_ids, lang_ids, 
                                       training=True, 
                                       is_autoregressive=(task==config['clm']))
    if task=='clm':
        logits = logits[:,:-1,:]
        original_batch = original_batch[:,1:]
        logits_mask = jnp.ones_like(original_batch)
    
    else :
        logits_mask = (masked_token_ids==config['mask_id'])
    total_masks = jnp.sum(logits_mask)

    if total_masks: softmax_xent /= total_masks
    return jnp.sum((jnp.argmax(logits, axis=-1)==original_batch)*logits_mask)/total_masks


## Optimizer

In [None]:
opt = get_adam_opt(config)
opt_state = opt.init(params)

## Training Loops

### MLM

In [None]:
losses = []
for step, train_batch in enumerate(imdb_ds):
    if step%100==0:
        print(f'[Step {step}]')
    
    token_encoding = lm_tokeniser.batch_encode_plus(train_batch['text'])
    token_ids = np.asarray(lm_tokeniser.get_token_ids(token_encoding), dtype=np.int16)
    lang_ids = np.asarray(lm_tokeniser.get_lang_ids(token_encoding), dtype=np.int16)

    key, subkey = jax.random.split(key)
    masked_token_ids, original_batch = mask_batch_mlm(subkey, config, token_ids)
    
    key, subkey = jax.random.split(key)
    params, opt_state, batch_loss = update(params, subkey, opt_state,
                                           original_batch, masked_token_ids, 
                                           config['mlm'], lang_ids=lang_ids)
    losses.append(batch_loss)

    if step%100==0 and step!=0:
        print(sum(losses)/100)
        losses = []

### TLM

In [None]:
losses = []
for step, train_batch in enumerate(flores_neen):
    if step%100==0:
        print(f'[Step {step}]')
    
    token_encoding = lm_tokeniser.batch_encode_plus(train_batch['en'], train_batch['ne'])
    token_ids = np.asarray(lm_tokeniser.get_token_ids(token_encoding), dtype=np.int16)
    lang_ids = np.asarray(lm_tokeniser.get_lang_ids(token_encoding), dtype=np.int16)

    key, subkey = jax.random.split(key)
    masked_token_ids, original_batch = mask_batch_mlm(subkey, config, token_ids)

    key, subkey = jax.random.split(key)
    params, opt_state, batch_loss = update(params, subkey, opt_state, 
                                           original_batch, masked_token_ids,
                                           config['mlm'], lang_ids=lang_ids,)
    losses.append(batch_loss)
    
    if step%100==0 and step!=0:
        print(sum(losses)/100)
        losses = []

### CLM

In [None]:
losses = []
for step, train_batch in enumerate(imdb_ds):
    if step%100==0:
        print(f'[Step {step}]')
    
    token_encoding = lm_tokeniser.batch_encode_plus(train_batch['text'])
    token_ids = np.asarray(lm_tokeniser.get_token_ids(token_encoding), dtype=np.int16)
    lang_ids = np.asarray(lm_tokeniser.get_lang_ids(token_encoding), dtype=np.int16)

    
    key, subkey = jax.random.split(key)
    masked_token_ids, original_batch = mask_batch_mlm(subkey, config, token_ids)

    key, subkey = jax.random.split(key)
    params, opt_state, batch_loss = update(params, subkey, opt_state,
                                           original_batch, masked_token_ids, 
                                           config['clm'], lang_ids=lang_ids)
    losses.append(batch_loss)

    if step%100==0 and step!=0:
        print(sum(losses)/100)
        losses = []