<a href="https://colab.research.google.com/github/deterministic-algorithms-lab/Jax-Journey/blob/main/basic_transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install git+https://github.com/deepmind/dm-haiku
!pip install transformers
!pip install git+git://github.com/deepmind/optax.git

In [None]:
import haiku as hk
import jax.numpy as jnp
import jax

from jax import jit
from jax.random import PRNGKey
import numpy as np

#Transformers-Classification Using pre-trained weights from RoBERTa

## Embedding Layers

In [None]:
from transformers import RobertaModel

class Embedding(hk.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

    def __call__(self, token_ids, training=False):
        """
        token_ids: ints of shape (batch, n_seq)
        """
        word_embeddings = self.config['pretrained']['embeddings/word_embeddings']
        
        flat_token_ids = jnp.reshape(token_ids, [-1])
        
        flat_token_embeddings = hk.Embed(vocab_size=word_embeddings.shape[0],
                                         embed_dim=word_embeddings.shape[1],
                                         w_init=hk.initializers.Constant(word_embeddings))(flat_token_ids)

        token_embeddings = jnp.reshape(flat_token_embeddings, [token_ids.shape[0], -1, word_embeddings.shape[1]])
        
        embeddings = token_embeddings + PositionEmbeddings(self.config)()

        embeddings = hk.LayerNorm(axis=-1,
                                  create_scale=True,
                                  create_offset=True,
                                  scale_init=hk.initializers.Constant(self.config['pretrained']['embeddings/LayerNorm/gamma']),
                                  offset_init=hk.initializers.Constant(self.config['pretrained']['embeddings/LayerNorm/beta']))(embeddings)
        if training:
            embeddings = hk.dropout(hk.next_rng_key(),
                                    rate=self.config['embed_dropout_rate'],
                                    x=embeddings)
        
        return embeddings

In [None]:
class PositionEmbeddings(hk.Module):
    """
    A position embedding of size [max_seq_leq, word_embedding_dim]
    """
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.offset = 2

    def __call__(self):
        pretrained_position_embeddings = self.config['pretrained']['embeddings/position_embeddings']

        position_weights = hk.get_parameter("position_embeddings",
                                            pretrained_position_embeddings.shape,
                                            init=hk.initializers.Constant(pretrained_position_embeddings))
        
        start = self.offset
        end = self.offset+self.config['max_length']
        
        return position_weights[start:end]

## Tokenizer and Utilities for Downloading and Extracting pre-trained weights

In [None]:
from io import BytesIO
from functools import lru_cache

import joblib
import requests

from transformers import RobertaModel, RobertaTokenizer

huggingface_roberta = RobertaModel.from_pretrained('roberta-base', output_hidden_states=True)

huggingface_tokenizer = RobertaTokenizer.from_pretrained('roberta-base')


In [None]:
def postprocess_key(key):
    key = key.replace('model/featurizer/bert/', '')
    key = key.replace(':0', '')
    key = key.replace('self/', '')
    return key

In [None]:
@lru_cache()
def get_pretrained_weights():
    # We'll use the weight dictionary from the Roberta encoder at 
    # https://github.com/IndicoDataSolutions/finetune
    remote_url = "https://bendropbox.s3.amazonaws.com/roberta/roberta-model-sm-v2.jl"
    weights = joblib.load(BytesIO(requests.get(remote_url).content))

    weights = {
        postprocess_key(key): value
        for key, value in weights.items()
    }

    input_embeddings = huggingface_roberta.get_input_embeddings()
    weights['embeddings/word_embeddings'] = input_embeddings.weight.detach().numpy()

    return weights


In [None]:
class Scope(object):
    """
    A tiny utility to help make looking up into our dictionary cleaner.
    There's no haiku magic here.
    """
    def __init__(self, weights, prefix):
        self.weights = weights
        self.prefix = prefix

    def __getitem__(self, key):
        return self.weights[self.prefix + key]

##Running the Embedding Layers

In [None]:
sample_text = "This was a flower of evil."


config = {'pretrained' : get_pretrained_weights(),
          'max_length' : 512,
          'embed_dropout_rate' : 0.1
          }

encoded = huggingface_tokenizer.batch_encode_plus([sample_text, sample_text],
                                                  padding='max_length',
                                                  max_length=config['max_length'])

sample_tokens = encoded['input_ids']

In [None]:


def embed_fn(tokens, training=False) :
    embedding = Embedding(config)(tokens)
    return embedding

rng = PRNGKey(42)
embed = hk.transform(embed_fn, apply_rng=True)
sample_tokens = np.asarray(sample_tokens)
params = embed.init(rng, sample_tokens, training=False)
embedded_tokens = jit(embed.apply)(params, rng, sample_tokens, training=False)

## Transformer Block

In [None]:
class TransformerBlock(hk.Module):

    def __init__(self, config, layer_num):
        super().__init__()
        self.config = config
        self.n = layer_num

    def __call__(self, x, mask, training = False):

        scope = Scope(
            self.config['pretrained'], f'encoder/layer_{self.n}/'
        )

        attention_output = MultiHeadAttention(self.config,
                                              self.n)(x, mask, training=training)
        
        residual = attention_output+x

        attention_output = hk.LayerNorm(axis=-1,
                                        create_scale=True,
                                        create_offset=True,
                                        scale_init=hk.initializers.Constant(scope['attention/output/LayerNorm/gamma']),
                                        offset_init=hk.initializers.Constant(scope['attention/output/LayerNorm/beta']),)(residual)

        mlp_output = TransformerMLP(self.config, self.n)(attention_output, training=training)

        output_residual = mlp_output+attention_output

        layer_output = hk.LayerNorm(axis=-1,
                                    create_scale=True,
                                    create_offset=True,
                                    scale_init=hk.initializers.Constant(scope['output/LayerNorm/gamma']),
                                    offset_init=hk.initializers.Constant(scope['output/LayerNorm/beta']))(output_residual)
        
        return layer_output

## Multi-Head Attention

In [None]:
class MultiHeadAttention(hk.Module):
    def __init__(self, config, layer_num):
        super().__init__()
        self.config = config
        self.n = layer_num

    def _split_into_heads(self, x):
        return jnp.reshape(x, [x.shape[0], x.shape[1], self.config['n_heads'], x.shape[2]//self.config['n_heads']])

    def __call__(self, x, mask, training=False):
        
        scope = Scope(self.config['pretrained'], f'encoder/layer_{self.n}/attention/')

        queries = hk.Linear(output_size=self.config['hidden_size'],
                            w_init=hk.initializers.Constant(scope['query/kernel']),
                            b_init=hk.initializers.Constant(scope['query/bias']))(x)
        
        keys = hk.Linear(output_size=self.config['hidden_size'],
                         w_init=hk.initializers.Constant(scope['key/kernel']),
                         b_init=hk.initializers.Constant(scope['key/bias']))(x)
        
        values = hk.Linear(output_size=self.config['hidden_size'],
                           w_init=hk.initializers.Constant(scope['value/kernel']),
                           b_init=hk.initializers.Constant(scope['value/bias']))(x)
        
        queries = self._split_into_heads(queries)
        keys = self._split_into_heads(keys)
        values = self._split_into_heads(values)

        attention_logits = jnp.einsum('bsnh,btnh->bnst', queries, keys)
        attention_logits /= np.sqrt(queries.shape[-1])

        attention_logits += jnp.reshape(mask*-2**32, [mask.shape[0],1,1,mask.shape[1]])
        attention_weights = jax.nn.softmax(attention_logits, axis=-1)
        per_head_attention_output = jnp.einsum('btnh,bnst->bsnh', values, attention_weights)
        attention_output = jnp.reshape(per_head_attention_output, [per_head_attention_output.shape[0], per_head_attention_output.shape[1], -1])

        attention_output = hk.Linear(output_size=self.config['hidden_size'],
                                     w_init=hk.initializers.Constant(scope['output/dense/kernel']),
                                     b_init=hk.initializers.Constant(scope['output/dense/bias']))(attention_output)
        
        if training:
            attention_output = hk.dropout(rng=hk.next_rng_key(),
                                          rate=self.config['attention_drop_rate'],
                                          x=attention_output)
        
        return attention_output

## Transformer MLP

In [None]:
def gelu(x):
    return x*0.5*(1.0+jax.scipy.special.erf(x / jnp.sqrt(2.0)))

class TransformerMLP(hk.Module):

    def __init__(self, config, layer_num):
        super().__init__()
        self.config = config
        self.n = layer_num

    def __call__(self, x, training=False):

        scope = Scope(self.config['pretrained'], f'encoder/layer_{self.n}/')

        intermediate_output = hk.Linear(output_size=self.config['intermediate_size'],
                                        w_init=hk.initializers.Constant(scope['intermediate/dense/kernel']),
                                        b_init=hk.initializers.Constant(scope['intermediate/dense/bias']))(x)

        intermediate_output = gelu(intermediate_output)

        output = hk.Linear(output_size=self.config['hidden_size'],
                           w_init=hk.initializers.Constant(scope['output/dense/kernel']),
                           b_init=hk.initializers.Constant(scope['output/dense/bias']))(intermediate_output)
        
        if training:
            output = hk.dropout(rng=hk.next_rng_key(),
                                rate=self.config['fully_connected_drop_rate'],
                                x=output)
        
        return output

## Confg and Getting Features from the model

In [None]:
class RobertaFeaturizer(hk.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

    def __call__(self, token_ids, training=False):
        x = Embedding(self.config)(token_ids, training=training)
        mask = (token_ids==self.config['mask_id']).astype(jnp.float32)
        for layer_num in range(self.config['n_layers']):
            x = TransformerBlock(config, layer_num=layer_num)(x,mask,training)
        return x

In [None]:
config = {
          'pretrained' : config['pretrained'], 
          'max_length' : config['max_length'], 
          'embed_dropout_rate' : 0.1,
          'fully_connected_drop_rate' : 0.1,
          'attention_drop_rate' : 0.1,
          'hidden_size' : 768,
          'intermediate_size' : 3072,
          'n_heads' : 12,
          'n_layers' : 12,
          'mask_id' : 1,
          'weight_stddev' : 0.02,
          
          'n_classes' : 2,
          'classifier_drop_rate' : 0.1,
          'learning_rate' : 1e-5,
          'max_grad_norm' : 1.0,
          'l2' : 0.1,
          'n_epochs' : 5,
          'batch_size' : 4
          }

def featurizer_fn(tokens, training=False):
    contextual_embeddings = RobertaFeaturizer(config)(tokens, training=training)
    return contextual_embeddings

rng = PRNGKey(42)
roberta = hk.transform(featurizer_fn)
sample_tokens = np.asarray(sample_tokens)
params = roberta.init(rng, sample_tokens, training=False)
contextual_embeddings = jit(roberta.apply)(params, rng, sample_tokens)
print(contextual_embeddings.shape)

(2, 512, 768)


## Getting Data

In [None]:
import tensorflow_datasets as tfds

def load_dataset(split, training, batch_size, n_epochs=1, n_examples=None):
    ds = tfds.load("imdb_reviews", 
                   split=f"{split}[:{n_examples}]").cache().repeat(n_epochs)
    
    if training:
        ds = ds.shuffle(10*batch_size, seed=0)
    
    ds = ds.batch(batch_size)

    return tfds.as_numpy(ds)

In [None]:
n_examples = 25000
train = load_dataset('train', training=True, batch_size=4, n_epochs=config['n_epochs'],n_examples=n_examples)

In [None]:
def encode_batch(batch_text):
    batch_text = [
                  text[:512].decode('utf-8') if isinstance(text, bytes) else text[:512]
                  for text in batch_text
    ]
    
    token_ids = huggingface_tokenizer.batch_encode_plus(batch_text,
                                                        padding='max_length',
                                                        max_length=config['max_length'],
                                                        )['input_ids']
    
    return np.asarray(token_ids)

## The classifier

In [None]:
class RobertaClassifier(hk.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

    def __call__(self, token_ids, training=False):
        sequence_features = RobertaFeaturizer(self.config)(token_ids=token_ids, training=training)

        clf_state = sequence_features[:,0,:]

        if training:
            clf_state = hk.dropout(rng=hk.next_rng_key(),
                                   rate=self.config['classifier_drop_rate'],
                                   x=clf_state)
        
        clf_logits = hk.Linear(output_size=self.config['n_classes'],
                               w_init=hk.initializers.TruncatedNormal(self.config['weight_stddev']))(clf_state)

        return clf_logits

## Running the Classifier

In [None]:
def roberta_classification_fn(batch_token_ids, training):
    logits = RobertaClassifier(config)(batch_token_ids, training=training)
    return logits

rng = jax.random.PRNGKey(42)
roberta_classifier = hk.transform(roberta_classification_fn)                    

params = roberta_classifier.init(rng, 
                                 batch_token_ids=encode_batch(['sample sentence', 'Another one!']),
                                 training=True)


```roberta_classifier.init()``` and ```roberta_classifier.apply()``` are pure functions now. So, can be composed to gether and used with other functions. 

In [None]:
def loss(params, rng, batch_token_ids, batch_labels):
    logits = roberta_classifier.apply(params, rng, batch_token_ids, training=True)
    labels = hk.one_hot(batch_labels, config['n_classes'])
    softmax_xent = -jnp.sum(labels*jax.nn.log_softmax(logits))
    softmax_xent /= labels.shape[0]
    return softmax_xent

@jax.jit
def accuracy(params, rng, batch_token_ids, batch_labels):
    logits = roberta_classifier.apply(params, rng, batch_token_ids, training=False)
    return jnp.mean(jnp.argmax(logits, axis=-1)==batch_labels)

@jax.jit
def update(params, rng, opt_state, batch_token_ids, batch_labels):
    batch_loss, grad = jax.value_and_grad(loss)(params, rng, batch_token_ids, batch_labels)
    updates, opt_state = opt.update(grad, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, opt_state, batch_loss

## Defining Learning rate scheduler and Optimizer

In [None]:
import optax

The below way of defining a functionality allows you to tie together namespaces with functions.(Or "wrap" a function in a namespace consisting of variables defined in the outer function).

Here, ```warmup_percentage``` and ```total_steps``` act as if they were variables in a class with a function ```lr_schedule()```. The ```lr_schedule()``` function can access them, freely. 

In [None]:
def make_lr_schedule(warmup_percentage, total_steps):
    
    def lr_schedule(step):
        percent_complete = step/total_steps
        
        #0 or 1 based on whether we are before peak
        before_peak = jax.lax.convert_element_type((percent_complete<=warmup_percentage),
                                                   np.float32)
        #Factor for scaling learning rate
        scale = ( before_peak*(percent_complete/warmup_percentage)
                + (1-before_peak) ) * (1-percent_complete)
        
        return scale
    
    return lr_schedule

In [None]:
total_steps = config['n_epochs']*(n_examples//config['batch_size'])

lr_schedule = make_lr_schedule(warmup_percentage=0.1, total_steps=total_steps)

In [None]:
opt = optax.chain(
    optax.clip_by_global_norm(config['max_grad_norm']),
    optax.adam(learning_rate=config['learning_rate']),
    optax.scale_by_schedule(lr_schedule),
)
opt_state = opt.init(params)

## Utility for Measuring Performance

In [None]:
def measure_current_performance(params, n_examples=None, splits=('train', 'test')):
    if 'train' in splits:
        train_eval = load_dataset('train', training=False, batch_size=25, n_examples=n_examples)

        train_accuracy = np.mean([accuracy(params, rng, 
                                          encode_batch(train_eval_batch['text']),
                                          train_eval_batch['label'])
                                for train_eval_batch in train_eval])
        
        print(f"\t Train validation acc: {train_accuracy:.3f}")

    if 'test' in splits:
        test_eval = load_dataset('test', training=False, batch_size=25, n_examples=n_examples)

        test_accuracy  = np.mean([accuracy(params, rng, 
                                           encode_batch(test_eval_batch['text']),
                                           test_eval_batch['label'])
                                  for test_eval_batch in test_eval])
    
    print(f"\t Test validation accuracy: {test_accuracy:.3f}")

## Training Loop

###For running on a different dataset : 

**In the cell below :**

* Change Line 1 to enumerate any data set returning batches of actual text(can have emojis too), with their integer labels. For example, ```train_batch['text']``` can be a list(or any other iterable) ```['My name is Jeevesh.', 'I live at your house.']``` with ```train_batch['labels']``` as another list ```[1,2]```.

* Change ```n_classes``` in config.

* Change tokenizer/provide vocabulary to add new tokens for additional languages, using ```huggingface_tokenizer.add_tokens(<list of new tokens>)``` .

* Rest remains same.

In [None]:
for step, train_batch in enumerate(train):
    
    if step%100==0:
        print(f'[Step {step}]')
    if step%1000==0 and step!=0:
        measure_current_performance(params, n_examples=100)
    print("Here")
    batch_token_ids = encode_batch(train_batch['text'])
    batch_labels = train_batch['label']
    params, opt_state, batch_loss = update(params, rng, opt_state, batch_token_ids, batch_labels)