# __🐥 AI4Code Train: Big Bird 🐥__

---
### <a href='#hyperparameters'> ⚙️ Hyperparameters </a> | <a href='#data-factory'> ⚒ Data Factory </a>  | <a href='#training-loop'> ⚡ Training Loop </a>

In [None]:
# %%time
# # TODO: Integrate WanDB
# !pip install git+https://github.com/google-research/t5x
# import t5x

In [None]:
# from flax.jax_utils import replicate, unreplicate
# from flax import jax_utils, struct, traverse_util

In [None]:
%%writefile flax_setup.py

import optax
import flax
import jax

import jax.numpy as jnp
import flax.linen as nn

from flax.training import train_state
from flax.training.common_utils import shard

import requests
import os

def kaggle_tpu_setup():
    if 'TPU_NAME' not in os.environ:
        print('TPU not found')
        return
    os.environ['TF_XLA_FLAGS'] = '--tf_xla_enable_xla_devices'
    url = 'http:' + os.environ['TPU_NAME'].split(':')[1] + ':8475/requestversion/tpu_driver_nightly'
    resp = requests.post(url)
    jax.config.FLAGS.jax_xla_backend = 'tpu_driver'
    jax.config.FLAGS.jax_backend_target = os.environ['TPU_NAME']
    jax.config.update('jax_default_matmul_precision', 'bfloat16')

kaggle_tpu_setup()

In [None]:
# Installations for Flax & BigBird #
!pip install --upgrade jaxlib jax flax optax -q
!pip install --upgrade transformers -q

# Sync Notebook with VS Code #
import sys; sys.path.append('ai4code')

from ai4c.jupyter_setup import *
from flax_setup import *
import ai4c

# ⚙️ Hyperparameters ⚙️
---
### <a href='#data-factory'> ⚒ Data Factory </a>  | <a href='#model'> 🧠 Model </a>|  <a href='#training-loop'> ⚡ Training Loop </a> 

<a name='hyperparameters'>

In [None]:
%%hyperparameters HP

## Huggingface Backbone ##
backbone_name: 'google/bigbird-roberta-base'
backbone_weights: null

attention_probs_dropout_prob: 0.10
hidden_dropout_prob: 0.10

num_random_blocks: 3
block_size: 64

gradient_checkpointing: False
max_seq_length: 1536


## Tokenization ##
max_markdown_seq_len: 1280
max_tokens_per_markdown_cell: 512
max_tokens_per_code_cell: 256


## Model Training ##
num_train_epochs: 3

per_device_train_batch_size: 2
per_device_eval_batch_size: 2


## Loss Function ##
loss_fn_name: 'mse'
markdown_cell_loss_weight: 0.50


## Cosine Decay LR Scheduler ##
warmup_ratio: 0.10
peak_lr: 1e-5
min_lr: 1e-8


## AdamW Optimizer ## 
weight_decay: 1e-6
max_grad_norm: 1.00
beta_2: 0.98
epsilon: 1e-6
ema_decay: 0.99


# Load From Cache: Tokenized Dataset #
processed_dataset_folder: 'ai4code-tokenization-bigbird'
debug_notebooks: null


# Data Factory #
valid_fold: 0
random_state: 69420
logging_freq: 100

## ⚒️ Data Factory ⚒️
---
#### <a href='#hyperparameters'> ⚙️ Hyperparameters </a>  | <a href='#model'> 🧠 Model </a>|  <a href='#training'> ⚡ Training </a> 

<a name='data-factory'>

In [None]:
tokenizer = transformers.AutoTokenizer.from_pretrained(HP.backbone_name)

processed_dataset_path = Path(f'/kaggle/input/{HP.processed_dataset_folder}')
if HP.processed_dataset_folder is not None:
    print(f'Loading dataframes from {processed_dataset_path}')
    notebooks_df = pd.read_csv('/kaggle/input/ai4code-dataframes/notebooks_df.csv')
    train_df = notebooks_df[notebooks_df.notebook_fold != HP.valid_fold]
    valid_df = notebooks_df[notebooks_df.notebook_fold == HP.valid_fold]
else:
    print(f'Loading {HP.debug_notebooks} notebooks for debugging.')
    train_df = valid_df = notebooks_df = pd.read_csv('/kaggle/input/ai4code-dataframes/notebooks_df.csv', nrows=HP.debug_notebooks)

### 🤗 Huggingface Dataset
---

#### <a href='#hyperparameters'> ⚙️ Hyperparameters </a>  | <a href='#model'> 🧠 Model </a>|  <a href='#training-loop'> ⚡ Training Loop </a> 

In [None]:
%%writefile prepare_hf_dataset.py

from functools import partial
from tqdm.auto import tqdm
import pandas as pd
import numpy as np
import argparse
import gc

import transformers
import datasets

tqdm.pandas()
CELL_SEP = '[CELL_SEP]'

def prune_code_tokens(code_token_ids, max_seq_len):
    """
    Prunes cells that take too many tokens to fit in max_seq_len.
    """
    code_token_counts = [len(token_ids) for token_ids in code_token_ids]
    total_number_of_cells = len(code_token_counts)
    total_tokens_to_prune = max(sum(code_token_counts)-max_seq_len, 0)

    tokens_to_prune_per_cell = [0]*total_number_of_cells
    total_pruned_tokens = 0
    while total_tokens_to_prune > 0:
        cur_max_code_token_count = max(code_token_counts)
        second_max_code_token_count = sorted(code_token_counts)[-2]
        for cell_idx, code_token_count in enumerate(code_token_counts):
            if not code_token_count == cur_max_code_token_count: 
                continue
            
            num_tokens_to_pop = min(code_token_count-second_max_code_token_count+1, total_tokens_to_prune)
            tokens_to_prune_per_cell[cell_idx] += num_tokens_to_pop
            total_pruned_tokens += num_tokens_to_pop
            total_tokens_to_prune -= num_tokens_to_pop
            code_token_counts[cell_idx] -= num_tokens_to_pop
            break
    
    # Prune the cell tokens
    pruned_code_token_ids = []
    for code_token_ids, num_tokens_to_pop in zip(code_token_ids, tokens_to_prune_per_cell):
        if num_tokens_to_pop == 0:
            pruned_code_token_ids.append(code_token_ids)
            continue
        pruned_code_token_ids.append(code_token_ids[:-num_tokens_to_pop])
    return pruned_code_token_ids


def convert_to_features_bigbird(
    notebook_dict,
    tokenizer,
    max_seq_len,
    max_markdown_seq_len,
    max_tokens_per_markdown_cell,
    max_tokens_per_code_cell,
):
    '''Tokenize the notebook and convert to features for the model'''

    markdown_cell_sources = notebook_dict['merged_markdown_cell_sources'].split(CELL_SEP)
    markdown_cell_pct_ranks = [float(rank) for rank in notebook_dict['merged_markdown_cell_pct_ranks'].split(CELL_SEP)]
    markdown_cell_ids = notebook_dict['merged_markdown_cell_ids'].split(CELL_SEP)

    code_cell_sources = notebook_dict['merged_code_cell_sources'].split(CELL_SEP)
    code_cell_pct_ranks = [float(rank) for rank in notebook_dict['merged_code_cell_pct_ranks'].split(CELL_SEP)]
    code_cell_ids = notebook_dict['merged_code_cell_ids'].split(CELL_SEP)

    # Remove cells from the end of the notebook so that all cells have at least one representative token
    max_markdown_cells = max_markdown_seq_len//2
    max_code_cells = (max_seq_len-max_markdown_seq_len)//2
    if len(markdown_cell_sources) > max_markdown_cells:
        markdown_cell_sources = markdown_cell_sources[:max_markdown_cells]
        markdown_cell_pct_ranks = markdown_cell_pct_ranks[:max_markdown_cells]
        markdown_cell_ids = markdown_cell_ids[:max_markdown_cells]
    if len(code_cell_sources) > max_code_cells:
        code_cell_sources = code_cell_sources[:max_code_cells]
        code_cell_pct_ranks = code_cell_pct_ranks[:max_code_cells]
        code_cell_ids = code_cell_ids[:max_code_cells]
    
    markdown_cell_count = len(markdown_cell_sources)
    code_cell_count = len(code_cell_sources)

    max_tokens_per_markdown_cell = max(max_tokens_per_markdown_cell, max_markdown_seq_len//markdown_cell_count)
    markdown_code_token_ids = tokenizer(
        markdown_cell_sources,
        max_length=max_tokens_per_markdown_cell,
        truncation=True,
    )['input_ids']
    markdown_code_token_ids = prune_code_tokens(markdown_code_token_ids, max_markdown_seq_len)
    total_markdown_code_tokens = sum([len(token_ids) for token_ids in markdown_code_token_ids])

    max_code_seq_len = max_seq_len - total_markdown_code_tokens
    max_tokens_per_code_cell = max(max_tokens_per_code_cell, max_code_seq_len//code_cell_count)
    code_code_token_ids = tokenizer(
        code_cell_sources, 
        max_length=max_tokens_per_code_cell, 
        truncation=True, 
    )['input_ids']
    code_code_token_ids = prune_code_tokens(code_code_token_ids, max_seq_len-total_markdown_code_tokens)

    # Merge the tokenized cells and create the model features
    code_token_ids = markdown_code_token_ids + code_code_token_ids
    cell_ids = markdown_cell_ids + code_cell_ids
    notebook_cell_count = len(code_token_ids)

    # Create the model features
    if 'merged_cell_pct_ranks' in notebook_dict:
        cell_pct_ranks = markdown_cell_pct_ranks + code_cell_pct_ranks
    else:
        cell_pct_ranks = [-1]*notebook_cell_count
    
    input_ids, markdown_token_mask, code_token_mask = [], [], []
    token_weights, token_labels = [], []
    token_cell_indices = []
    
    for cur_cell_idx, code_token_ids in enumerate(code_token_ids):
        token_count_for_cell = len(code_token_ids)
        if cur_cell_idx < markdown_cell_count:
            markdown_token_mask += [1]*token_count_for_cell
            code_token_mask += [0]*token_count_for_cell
        else: 
            markdown_token_mask += [0]*token_count_for_cell
            code_token_mask += [1]*token_count_for_cell
        
        input_ids += code_token_ids
        token_cell_indices += [cur_cell_idx] * token_count_for_cell
        token_labels += [cell_pct_ranks[cur_cell_idx]] * token_count_for_cell
        token_weights += [1/token_count_for_cell] * token_count_for_cell

    
    # Pad the features to match max_seq_len #
    num_pad_tokens = max_seq_len-len(input_ids)
    token_labels += [0]*num_pad_tokens
    token_weights += [0]*num_pad_tokens
    token_cell_indices += [-1]*num_pad_tokens
    markdown_token_mask += [0]*num_pad_tokens
    code_token_mask += [0]*num_pad_tokens
    attention_mask = [1]*len(input_ids) + [0]*num_pad_tokens
    input_ids += [0]*num_pad_tokens
    
    # Check for bugs
    assert len(input_ids) == max_seq_len
    assert len(token_labels) == max_seq_len
    
    # Build the feature dict for the input 
    notebook_features = {
        'input_ids': input_ids, 
        'attention_mask': attention_mask,
        'markdown_token_mask': markdown_token_mask,
        'code_token_mask': code_token_mask,
        'token_cell_indices': token_cell_indices,
        'notebook_id': notebook_dict['notebook_id'],
    }
    if 'merged_cell_pct_ranks' in notebook_dict:
        notebook_features['token_labels'] = token_labels
        notebook_features['token_weights'] = token_weights
    return notebook_features


def build_hf_dataset(
    df, 
    tokenizer, 
    max_seq_len, 
    max_markdown_seq_len, 
    max_tokens_per_markdown_cell,
    max_tokens_per_code_cell,
    ):
    '''Builds the huggingface dataset for training the model.'''
    convert_to_features = partial(
        convert_to_features_bigbird, 
        tokenizer=tokenizer,
        max_seq_len=max_seq_len,
        max_markdown_seq_len=max_markdown_seq_len,
        max_tokens_per_markdown_cell=max_tokens_per_markdown_cell,
        max_tokens_per_code_cell=max_tokens_per_code_cell,
    )
    raw_dataset = datasets.Dataset.from_pandas(df)
    processed_dataset = raw_dataset.map(
        convert_to_features, 
        remove_columns=raw_dataset.column_names, 
        desc='Running tokenizer on raw dataset'
    )
    processed_dataset.set_format(type='numpy')
    empty_sentences = (np.array(processed_dataset['attention_mask'])[:, -1] == 0).sum()
    print('Empty sentences ratio:', empty_sentences/len(processed_dataset))
    return processed_dataset

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--tokenizer_name', default='google/bigbird-roberta-large', type=str, help='The tokenizer name')
    parser.add_argument('--max_seq_length', default=1024, type=int, help='The max sequence length')
    parser.add_argument('--max_markdown_seq_length', default=512, type=int, help='The max markdown sequence length')
    parser.add_argument('--max_tokens_per_markdown_cell', default=128, type=int, help='The max tokens per markdown cell')
    parser.add_argument('--max_tokens_per_code_cell', default=128, type=int, help='The max tokens per code cell')
    parser.add_argument('--notebooks_df_path', default='notebooks_df.csv', type=str, help='Path to notebooks.csv')
    parser.add_argument('--valid_fold', default=0, type=int, help='Validation Fold')

    args = parser.parse_args()
    
    tokenizer = transformers.AutoTokenizer.from_pretrained(args.tokenizer_name)
    notebooks_df = pd.read_csv(args.notebooks_df_path)
    print('Total number of notebooks:', len(notebooks_df))

    train_df = notebooks_df[notebooks_df.notebook_fold != args.valid_fold]
    valid_df = notebooks_df[notebooks_df.notebook_fold == args.valid_fold]
    
    train_dataset = build_hf_dataset(
        df=train_df,
        tokenizer=tokenizer,
        max_seq_len=args.max_seq_length,
        max_markdown_seq_len=args.max_markdown_seq_length,
        max_tokens_per_markdown_cell=args.max_tokens_per_markdown_cell,
        max_tokens_per_code_cell=args.max_tokens_per_code_cell,
    )
    valid_dataset = build_hf_dataset(
        df=valid_df,
        tokenizer=tokenizer,
        max_seq_len=args.max_seq_length,
        max_markdown_seq_len=args.max_markdown_seq_length,
        max_tokens_per_markdown_cell=args.max_tokens_per_markdown_cell,
        max_tokens_per_code_cell=args.max_tokens_per_code_cell,
    )
    print('Train dataset size:', len(train_dataset))
    print('Valid dataset size:', len(valid_dataset))

    train_dataset.save_to_disk('train_dataset')
    valid_dataset.save_to_disk('valid_dataset')
    train_df.to_csv('train_df.csv', index=False)
    valid_df.to_csv('valid_df.csv', index=False)
    print('Done!')

In [None]:
import prepare_hf_dataset

if HP.processed_dataset_folder is not None:
    print(f'Loading tokenized datasets from {processed_dataset_path}')
    valid_hf_dataset = datasets.load_from_disk(processed_dataset_path/f'fold{HP.valid_fold}_dataset')
    train_hf_dataset = datasets.concatenate_datasets([
        datasets.load_from_disk(processed_dataset_path/f'fold{fold}_dataset')
        for fold in tqdm(range(8), desc='Loading training dataset')
        if fold != HP.valid_fold
    ])
    train_hf_dataset.save_to_disk('train_hf_dataset')
    train_hf_dataset = datasets.load_from_disk('train_hf_dataset')
else:
    train_hf_dataset = valid_hf_dataset = prepare_hf_dataset.build_hf_dataset(
        df=train_df, 
        tokenizer=tokenizer, 
        max_seq_len=HP.max_seq_length, 
        max_markdown_seq_len=HP.max_markdown_seq_len,
        max_tokens_per_markdown_cell=HP.max_tokens_per_markdown_cell, 
        max_tokens_per_code_cell=HP.max_tokens_per_code_cell,
    )

### 🌱 Flax Dataloader
---

#### <a href='#hyperparameters'> ⚙️ Hyperparameters </a>  | <a href='#model'> 🧠 Model </a> | <a href='#training-loop'> ⚡ Training Loop </a> 


In [None]:
def collate_fn(hf_batch):
    jax_batch = {
        'input_ids': jnp.array(hf_batch['input_ids'], dtype=jnp.int32),
        'attention_mask': jnp.array(hf_batch['attention_mask'], dtype=jnp.int32),
        'markdown_token_mask': jnp.array(hf_batch['markdown_token_mask'], dtype=jnp.int32),
        'code_token_mask': jnp.array(hf_batch['code_token_mask'], dtype=jnp.int32),
        'token_labels': jnp.array(hf_batch['token_labels'], dtype=jnp.float32),
        'token_weights': jnp.array(hf_batch['token_weights'], dtype=jnp.float32),
    }
    jax_batch = jax.tree_map(shard, jax_batch)
    return jax_batch

def get_dataloader(hf_dataset, batch_size, shuffle_seed=None):
    if shuffle_seed is not None:
        hf_dataset = hf_dataset.shuffle(seed=shuffle_seed)
    total_batches = len(hf_dataset)//batch_size
    print(f'{total_batches} batches from {len(hf_dataset)} examples for dataloader')
    for batch_idx in range(total_batches):
        batch = hf_dataset[batch_idx * batch_size : (batch_idx + 1) * batch_size]
        yield dict(batch)

# 🐥 BigBird Model 🐥
---
### <a href='#hyperparameters'> ⚙️ Hyperparameters </a>  | <a href='#training-loop'> ⚡ Training Loop </a> 

<a name='model'>


In [None]:
model = transformers.FlaxBigBirdForTokenClassification.from_pretrained(
    HP.backbone_name,
    num_labels=1,
    #attention_probs_dropout_prob=HP.attention_probs_dropout_prob,
    #hidden_dropout_prob=HP.hidden_dropout_prob,
    #gradient_checkpointing=HP.gradient_checkpointing,
    #num_random_blocks=HP.num_random_blocks,
    #block_size=HP.block_size,
    #dtype=jax.numpy.bfloat16,
)

## 🏭 Optimizer Factory
---
#### <a href='#hyperparameters'> ⚙️ Hyperparameters </a>  | <a href='#model'> 🧠 Model </a> | <a href='#training-loop'> ⚡ Training Loop </a> 

<a name='optimizer-factory'>

In [None]:
%%writefile optimizer_factory.py

import optax
import flax

def build_lr_scheduler(peak_lr, min_lr, warmup_ratio, total_train_steps):
    warmup_steps = int(warmup_ratio * total_train_steps)
    decay_steps = total_train_steps - warmup_steps
    print(f'Warmup Steps: {warmup_steps} | Decay Steps: {decay_steps}')
    
    lr_scheduler = optax.warmup_cosine_decay_schedule(
        init_value=min_lr,
        peak_value=peak_lr,
        warmup_steps=warmup_steps,
        decay_steps=decay_steps,
        end_value=min_lr,
    )
    return lr_scheduler

def build_tx(
    lr_scheduler,
    adam_beta_2,
    adam_epsilon,
    weight_decay,
    max_grad_norm,
    ema_decay,
):
    def weight_decay_mask(params):
        params = flax.traverse_util.flatten_dict(params)
        mask = {k: (v[-1] != 'bias' and v[-2:] != ('LayerNorm', 'scale')) for k, v in params.items()}
        return flax.traverse_util.unflatten_dict(mask)

    tx = optax.adamw(
        learning_rate=lr_scheduler, 
        b1=0.9, 
        b2=adam_beta_2, 
        eps=adam_epsilon, 
        weight_decay=weight_decay, 
        mask=weight_decay_mask,
    )

    if max_grad_norm is not None:
        tx = optax.chain(tx, optax.clip_by_global_norm(max_grad_norm))
    if ema_decay is not None:
        tx = optax.chain(tx, optax.ema(decay=ema_decay)) 
    return tx

In [None]:
import optimizer_factory

train_batch_size = HP.per_device_train_batch_size * jax.device_count()
eval_batch_size = HP.per_device_eval_batch_size * jax.device_count()
total_train_steps = HP.num_train_epochs * (len(train_hf_dataset) // train_batch_size)

lr_scheduler = optimizer_factory.build_lr_scheduler(
    peak_lr=HP.peak_lr,
    min_lr=HP.min_lr,
    warmup_ratio=HP.warmup_ratio,
    total_train_steps=total_train_steps,
)

tx = optimizer_factory.build_tx(
    lr_scheduler=lr_scheduler,
    adam_beta_2=HP.beta_2,
    adam_epsilon=HP.epsilon,
    weight_decay=HP.weight_decay,
    max_grad_norm=HP.max_grad_norm,
    ema_decay=HP.ema_decay,
)

## ⚡ Training Loop
---
### <a href='#hyperparameters'> ⚙️ Hyperparameters </a>  | <a href='#model'> 🐥 BigBird Model </a> 

<a name='training-loop'>

In [None]:
from typing import Callable

def compute_loss_per_cell_type(token_labels, token_preds, token_weights, token_mask): 
    token_labels, token_preds, token_weights = token_labels*token_mask, token_preds*token_mask, token_weights*token_mask

    diff = jnp.abs(token_labels - token_preds)
    notebook_token_weights_sum = jnp.sum(token_weights, axis=-1)

    if HP.loss_fn_name == 'mse':
        notebook_losses = jnp.sum(diff**2 * token_weights, axis=-1)
        loss = jnp.sum(notebook_losses / notebook_token_weights_sum)
    elif HP.loss_fn_name == 'rmse':
        notebook_losses = jnp.sum(diff**2 * token_weights, axis=-1)
        loss = jnp.sqrt(jnp.sum(notebook_losses / notebook_token_weights_sum))
    elif HP.loss_fn_name == 'mae':
        notebook_losses = jnp.sum(diff * token_weights, axis=-1)
        loss = jnp.sum(notebook_losses / notebook_token_weights_sum)

    return loss

def loss_fn(token_labels, token_preds, token_weights, markdown_token_mask, code_token_mask):
    
    markdown_cell_loss = compute_loss_per_cell_type(
        token_labels=token_labels,
        token_preds=token_preds,
        token_weights=token_weights,
        token_mask=markdown_token_mask,
    )
    code_cell_loss = compute_loss_per_cell_type(
        token_labels=token_labels,
        token_preds=token_preds,
        token_weights=token_weights,
        token_mask=code_token_mask,
    )
    return markdown_cell_loss, code_cell_loss

class TrainState(flax.training.train_state.TrainState):
    loss_fn: Callable = flax.struct.field(pytree_node=False)
    logits_fn: Callable = flax.struct.field(pytree_node=False)

def create_state(model, tx):
    params = model.params
    state = TrainState.create(
        apply_fn=model.__call__,
        params=params,
        tx=tx,
        loss_fn=loss_fn,
        logits_fn=lambda x: x,
    )
    state = flax.jax_utils.replicate(state)
    return state

@partial(jax.pmap, axis_name='batch')
def train_step(state, drp_rng, **model_inputs):
    def loss_fn(params):
        outputs = state.apply_fn(
            input_ids=model_inputs['input_ids'],
            attention_mask=model_inputs['attention_mask'],
            params=params,
            dropout_rng=drp_rng,
            train=True,
        )
        token_preds = outputs.logits
        
        token_preds = jnp.squeeze(token_preds)
        token_labels = jnp.squeeze(model_inputs['token_labels'])
        token_weights = jnp.squeeze(model_inputs['token_weights'])
        markdown_token_mask = jnp.squeeze(model_inputs['markdown_token_mask'])
        code_token_mask = jnp.squeeze(model_inputs['code_token_mask'])

        markdown_cell_loss, code_cell_loss = state.loss_fn(
            token_labels=token_labels,
            token_preds=token_preds,
            token_weights=token_weights,
            markdown_token_mask=markdown_token_mask,
            code_token_mask=code_token_mask,
        )
        
        markdown_cell_loss_weight = HP.markdown_cell_loss_weight
        code_cell_loss_weight = 1 - HP.markdown_cell_loss_weight
        loss = markdown_cell_loss_weight * markdown_cell_loss + code_cell_loss_weight * code_cell_loss
        
        return loss, markdown_cell_loss

    drp_rng, new_drop_rng = jax.random.split(drp_rng)
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, markdown_cell_loss), grads = grad_fn(state.params)
    
    metrics = {
        'loss': loss,
        'markdown_cell_loss': markdown_cell_loss,
        'learning_rate': lr_scheduler(state.step),
    }
    metrics = jax.lax.pmean(metrics, axis_name='batch')
    grads = jax.lax.pmean(grads, 'batch')

    state = state.apply_gradients(grads=grads)
    return state, metrics, new_drop_rng


print('---------------- Available devices ----------------')
print(jax.devices())
print('---------------------------------------------------')

train_steps_per_epoch = len(train_hf_dataset)//train_batch_size
total_train_steps = train_steps_per_epoch * HP.num_train_epochs
rng = jax.random.PRNGKey(HP.random_state)
drp_rng = jax.random.split(rng, jax.device_count())
state = create_state(model, tx)

for epoch in range(HP.num_train_epochs):
    train_dataloader = get_dataloader(train_hf_dataset, train_batch_size, shuffle_seed=epoch)
    running_metrics = defaultdict(int)
    steps_progress_bar = tqdm(enumerate(train_dataloader), total=train_steps_per_epoch, desc=f'Epoch #{epoch+1}/{HP.num_train_epochs}')
    
    for step, hf_batch in steps_progress_bar:
        batch = collate_fn(hf_batch)
        state, step_metrics, drp_rng = train_step(
            state=state,
            drp_rng=drp_rng,
            **batch
        )
        
        # Update progress bar and running metrics
        step_metrics = flax.jax_utils.unreplicate(step_metrics)
        steps_progress_bar.set_postfix(**step_metrics)
        for k, v in step_metrics.items():
            running_metrics[k] += step_metrics[k]

        # Log metrics every `logging_freq` steps
        if (step + 1) % HP.logging_freq == 0:
            global_step = flax.jax_utils.unreplicate(state.step)
            print('-'*50)
            print(f"Step {global_step-HP.logging_freq}-{global_step} out of {total_train_steps}")
            for k, v in running_metrics.items():
                print(colored(k, 'blue'), ':', colored(v / HP.logging_freq, 'red'))
            running_metrics = defaultdict(int)
            print()

    print(colored(f'Epoch #{epoch+1} completed.'))
    print(colored('-'*100, 'red'))
    print('\n\n')

## 🎯 Inference
---

<a name='inference'>

In [None]:
# CELL_SEP = '[CELL_SEP]'

# @partial(jax.pmap, axis_name='batch')
# def predict_step(state, batch):
#     input_ids = jnp.squeeze(batch['input_ids'])
#     attention_mask = jnp.squeeze(batch['attention_mask'])
    
#     outputs = state.apply_fn(
#         input_ids=input_ids,
#         attention_mask=attention_mask,
#         params=state.params,
#         train=False,
#     )
    
#     token_preds = jnp.squeeze(outputs.logits)
#     token_labels = jnp.squeeze(batch['token_labels'])
#     token_weights = jnp.squeeze(batch['token_weights'])
#     markdown_token_mask = jnp.squeeze(batch['markdown_token_mask'])
#     code_token_mask = jnp.squeeze(batch['code_token_mask'])
#     markdown_cell_loss, code_cell_loss = state.loss_fn(
#         token_labels=token_labels,
#         token_preds=token_preds,
#         token_weights=token_weights,
#         markdown_token_mask=markdown_token_mask,
#         code_token_mask=code_token_mask,
#     )
    
#     metrics = {
#         'markdown_cell_loss': markdown_cell_loss, 
#         'code_cell_loss': code_cell_loss
#     }
#     metrics = jax.lax.pmean(metrics, axis_name="batch")
#     return metrics

# valid_dataloader = get_dataloader(valid_hf_dataset, eval_batch_size, shuffle_seed=epoch)
# agg_metrics = defaultdict(list)
# for hf_batch in tqdm(valid_dataloader, total=len(valid_hf_dataset)//eval_batch_size):
#     batch = collate_fn(hf_batch)
#     step_metrics = predict_step(state, batch)
#     for k, v in step_metrics.items():
#         agg_metrics[k].append(v)
# for k, v in agg_metrics.items():
#     v = np.mean(np.array(v))
#     print(colored(k, 'red'), ':', colored(v, 'blue'))

In [None]:
# CELL_SEP = '[CELL_SEP]'

# @partial(jax.pmap, axis_name='batch')
# def predict_step(state, batch):
2
#     return state.logits_fn(token_preds)
    
#     #return jnp.asarray(token_preds*10, dtype=jnp.int8)

# def compute_kendall_tau(valid_df, cell_id_to_pred_rank):
#     all_notebook_cell_pct_ranks = valid_df.merged_cell_pct_ranks.values
#     all_notebook_cell_ids = valid_df.merged_cell_ids.values
#     all_notebook_kendall_taus = []
#     for notebook_idx in range(len(valid_df)):
#         true_cell_ranks = [float(rank) for rank in all_notebook_cell_pct_ranks[notebook_idx].split(CELL_SEP)]
#         cell_ids = all_notebook_cell_ids[notebook_idx].split(CELL_SEP)
#         pred_cell_ranks = [cell_id_to_pred_rank.get(cell_id, -1) for cell_id in cell_ids]

#         notebook_kendall_tau = scipy.stats.kendalltau(true_cell_ranks, pred_cell_ranks, method='asymptotic')[0]
#         all_notebook_kendall_taus.append(notebook_kendall_tau)

#     valid_df['kendall_tau'] = all_notebook_kendall_taus
#     print('Average Kendall Tau:', colored(np.mean(all_notebook_kendall_taus), 'red'))

#     for cell_cutoff in [4, 16, 64]:
#         tau = np.mean(valid_df[valid_df.markdown_cell_count >= cell_cutoff].kendall_tau)
#         print(f"Kendall Tau for notebooks with {cell_cutoff}+ markdown cells:", colored(tau, 'red'))

# cell_id_to_token_preds = defaultdict(list)
# valid_dataloader = get_dataloader(valid_hf_dataset, eval_batch_size, shuffle_seed=epoch)
# for hf_batch in tqdm(valid_dataloader, total=len(valid_hf_dataset)//eval_batch_size):
#     batch = collate_fn(hf_batch)
#     token_preds = predict_step(state, batch)
#     token_preds = np.array([pred for pred in itertools.chain(*token_preds)])
    
#     token_cell_indices = hf_batch['token_cell_indices']
#     notebook_ids = hf_batch['notebook_id']
#     for example_idx, notebook_id in enumerate(notebook_ids):
#         cell_ids = valid_df[valid_df.notebook_id==notebook_id].iloc[0].merged_cell_ids.split(CELL_SEP)
#         for cell_idx, token_pred in zip(token_cell_indices[example_idx], token_preds[example_idx]):
#             cell_id = cell_ids[cell_idx]
#             cell_id_to_token_preds[cell_id].append(token_pred)
# cell_id_to_pred_rank = {cell_id: np.mean(preds) for cell_id, preds in cell_id_to_token_preds.items()}
# compute_kendall_tau(valid_df, cell_id_to_pred_rank)
# print('\n\n')

In [None]:
# avg_tau = valid_df.kendall_tau.mean()
# tau_16 = np.mean(valid_df[valid_df.markdown_cell_count >= 16].kendall_tau)
# tau_64 = np.mean(valid_df[valid_df.markdown_cell_count >= 64].kendall_tau)

# WEIGHTS_SAVE_FORMAT = "{backbone_code}-tau{avg_tau}-tau16_{tau_16}-tau64_{tau_64}.msgpack"
# weights_file = WEIGHTS_SAVE_FORMAT.format(
#     backbone_code=backbone_code,
#     avg_tau=int(avg_tau*1e6),
#     tau_16=int(tau_16*1e4),
#     tau_64=int(tau_64*1e4),
# )
# print(f'Saving weights at {weights_file}')
# if jax.process_index() == 0:
#     params = jax.device_get(flax.jax_utils.unreplicate(state.params))
    
#     with open(f'/kaggle/working/{weights_file}', 'wb') as f:
#         model_bytes = flax.serialization.to_bytes(params)
#         f.write(model_bytes)

In [None]:
weights_file = 'bigbird-seq1536_md1024-loss146.msgpack'
print(weights_file)
if jax.process_index() == 0:
    params = jax.device_get(flax.jax_utils.unreplicate(state.params))
    with open(f'/kaggle/working/{weights_file}', 'wb') as f:
        model_bytes = flax.serialization.to_bytes(params)
        f.write(model_bytes)

In [None]:
wandb.init(project='bigbird_dev')
wandb.save(weights_file)
sleep(180)