# __🧑‍💻 AI4Code TFRankEncoder Train (CV: 0.85+)__

---
### <a href='#hyperparameters'> ⚙️ Hyperparameters </a> | <a href='#data-factory'> ⚒ Data Factory </a>  | <a href='#training'> ⚡ Training </a> 

In [None]:
# Sync Notebook with VS Code #
!pip install -q transformers==4.10.0 datasets
import sys; sys.path.append('ai4code')

%run /kaggle/working/ai4code/ai4c/jupyter_setup.py
%run /kaggle/working/ai4code/ai4c/tensorflow_setup.py

## ⚙️ Hyperparameters ⚙️
---
### <a href='#data-factory'> ⚒ Data Factory </a>  | <a href='#model'> 🧠 Model </a>|  <a href='#training'> ⚡ Training </a> 

<a name='hyperparameters'>

In [None]:
%%hyperparameters

## Huggingface Backbone ##
backbone_name: 'facebook/muppet-roberta-large'
backbone_weights: null

attention_probs_dropout_prob: 0.10
hidden_dropout_prob: 0.10
max_seq_len: 512


## Tokenization ##
max_markdown_seq_len: 256
max_tokens_per_markdown_cell: 128
max_tokens_per_code_cell: 128


## Model Training ##
num_train_epochs: 10
train_batch_size: 256
eval_batch_size: 1024


## Loss Function ##
loss_fn: 'mse'
markdown_cell_weight: 0.50


## Cosine Decay LR Scheduler ##
warmup_ratio: 0.10
peak_lr: 3e-5
min_lr: 1e-8


## AdamW Optimizer ##
beta_1: 0.9
beta_2: 0.999
epsilon: 1e-8

weight_decay: 1e-4
max_grad_norm: 1.00
average_decay: 0.999


## Load From Cache: Tokenized Dataset ##
processed_dataset_folder: null # 'ai4code-tokenization-bigbird'
debug_notebooks: 10000


## Data Factory ##
valid_fold: 0

In [None]:
STRATEGY = tf_accelerator(bfloat16=True, jit_compile=True)
with STRATEGY.scope():
    backbone = TFAutoModel.from_pretrained(
        HP.backbone_name, 
        attention_probs_dropout_prob=HP.attention_probs_dropout_prob,
        hidden_dropout_prob=HP.hidden_dropout_prob,
        from_pt=True
    )
    load_tf_model_weights(backbone, HP.backbone_weights)

tokenizer = AutoTokenizer.from_pretrained(HP.backbone_name)

## ⚒️ Data Factory ⚒️

---
#### <a href='#prepare-huggingface-datasets'> 🤗 Huggingface Datasets </a> | <a href='#prepare-tensorflow-datasets'> Tensorflow Datasets </a> 


<a name='data-factory'>

In [None]:
processed_dataset_path = Path(f'/kaggle/input/{HP.processed_dataset_folder}')
if HP.processed_dataset_folder is not None:
    print(f'Loading dataframes from {processed_dataset_path}')
    notebooks_df = pd.read_csv('/kaggle/input/ai4code-dataframes/notebooks_df.csv')
    train_df = notebooks_df[notebooks_df.notebook_fold != HP.valid_fold]
    valid_df = notebooks_df[notebooks_df.notebook_fold == HP.valid_fold]
else:
    print(f'Loading {HP.debug_notebooks} notebooks for debugging.')
    train_df = valid_df = notebooks_df = pd.read_csv('/kaggle/input/ai4code-dataframes/notebooks_df.csv', nrows=HP.debug_notebooks)

### 🤗 Prepare Huggingface Datasets
---
#### <a href='#data-factory'> ⚒ Data Factory </a>  | <a href='#hyperparameters'> ⚙️ Hyperparameters </a>|  <a href='#training'> ⚡ Training </a> 

<a name='prepare-huggingface-datasets'>

In [None]:
%%writefile prepare_hf_dataset.py

from functools import partial
from tqdm.auto import tqdm
import pandas as pd
import numpy as np
import argparse
import gc

import transformers
import datasets

tqdm.pandas()
CELL_SEP = '[CELL_SEP]'

def prune_code_tokens(code_token_ids, max_seq_len):
    """
    Prunes cells that take too many tokens to fit in max_seq_len.
    """
    code_token_counts = [len(token_ids) for token_ids in code_token_ids]
    total_number_of_cells = len(code_token_counts)
    total_tokens_to_prune = max(sum(code_token_counts)-max_seq_len, 0)

    tokens_to_prune_per_cell = [0]*total_number_of_cells
    total_pruned_tokens = 0
    while total_tokens_to_prune > 0:
        cur_max_code_token_count = max(code_token_counts)
        second_max_code_token_count = sorted(code_token_counts)[-2]
        for cell_idx, code_token_count in enumerate(code_token_counts):
            if not code_token_count == cur_max_code_token_count: 
                continue
            
            num_tokens_to_pop = min(code_token_count-second_max_code_token_count+1, total_tokens_to_prune)
            tokens_to_prune_per_cell[cell_idx] += num_tokens_to_pop
            total_pruned_tokens += num_tokens_to_pop
            total_tokens_to_prune -= num_tokens_to_pop
            code_token_counts[cell_idx] -= num_tokens_to_pop
            break
    
    # Prune the cell tokens
    pruned_code_token_ids = []
    for code_token_ids, num_tokens_to_pop in zip(code_token_ids, tokens_to_prune_per_cell):
        if num_tokens_to_pop == 0:
            pruned_code_token_ids.append(code_token_ids)
            continue
        pruned_code_token_ids.append(code_token_ids[:-num_tokens_to_pop])
    return pruned_code_token_ids


def convert_to_features_bigbird(
    notebook_dict,
    tokenizer,
    max_seq_len,
    max_markdown_seq_len,
    max_tokens_per_markdown_cell,
    max_tokens_per_code_cell,
):
    '''Tokenize the notebook and convert to features for the model'''

    markdown_cell_sources = notebook_dict['merged_markdown_cell_sources'].split(CELL_SEP)
    markdown_cell_pct_ranks = [float(rank) for rank in notebook_dict['merged_markdown_cell_pct_ranks'].split(CELL_SEP)]
    markdown_cell_ids = notebook_dict['merged_markdown_cell_ids'].split(CELL_SEP)

    code_cell_sources = notebook_dict['merged_code_cell_sources'].split(CELL_SEP)
    code_cell_pct_ranks = [float(rank) for rank in notebook_dict['merged_code_cell_pct_ranks'].split(CELL_SEP)]
    code_cell_ids = notebook_dict['merged_code_cell_ids'].split(CELL_SEP)

    # Remove cells from the end of the notebook so that all cells have at least one representative token
    max_markdown_cells = max_markdown_seq_len//2
    max_code_cells = (max_seq_len-max_markdown_seq_len)//2
    if len(markdown_cell_sources) > max_markdown_cells:
        markdown_cell_sources = markdown_cell_sources[:max_markdown_cells]
        markdown_cell_pct_ranks = markdown_cell_pct_ranks[:max_markdown_cells]
        markdown_cell_ids = markdown_cell_ids[:max_markdown_cells]
    if len(code_cell_sources) > max_code_cells:
        code_cell_sources = code_cell_sources[:max_code_cells]
        code_cell_pct_ranks = code_cell_pct_ranks[:max_code_cells]
        code_cell_ids = code_cell_ids[:max_code_cells]
    
    markdown_cell_count = len(markdown_cell_sources)
    code_cell_count = len(code_cell_sources)

    max_tokens_per_markdown_cell = max(max_tokens_per_markdown_cell, max_markdown_seq_len//markdown_cell_count)
    markdown_code_token_ids = tokenizer(
        markdown_cell_sources,
        max_length=max_tokens_per_markdown_cell,
        truncation=True,
    )['input_ids']
    markdown_code_token_ids = prune_code_tokens(markdown_code_token_ids, max_markdown_seq_len)
    total_markdown_code_tokens = sum([len(token_ids) for token_ids in markdown_code_token_ids])

    max_code_seq_len = max_seq_len - total_markdown_code_tokens
    max_tokens_per_code_cell = max(max_tokens_per_code_cell, max_code_seq_len//code_cell_count)
    code_code_token_ids = tokenizer(
        code_cell_sources, 
        max_length=max_tokens_per_code_cell, 
        truncation=True, 
    )['input_ids']
    code_code_token_ids = prune_code_tokens(code_code_token_ids, max_seq_len-total_markdown_code_tokens)

    # Merge the tokenized cells and create the model features
    code_token_ids = markdown_code_token_ids + code_code_token_ids
    cell_ids = markdown_cell_ids + code_cell_ids
    notebook_cell_count = len(code_token_ids)

    # Create the model features
    if 'merged_cell_pct_ranks' in notebook_dict:
        cell_pct_ranks = markdown_cell_pct_ranks + code_cell_pct_ranks
    else:
        cell_pct_ranks = [-1]*notebook_cell_count
    
    input_ids, markdown_token_mask, code_token_mask = [], [], []
    token_weights, token_labels = [], []
    token_cell_indices = []
    
    for cur_cell_idx, code_token_ids in enumerate(code_token_ids):
        token_count_for_cell = len(code_token_ids)
        if cur_cell_idx < markdown_cell_count:
            markdown_token_mask += [1]*token_count_for_cell
            code_token_mask += [0]*token_count_for_cell
        else: 
            markdown_token_mask += [0]*token_count_for_cell
            code_token_mask += [1]*token_count_for_cell
        
        input_ids += code_token_ids
        token_cell_indices += [cur_cell_idx] * token_count_for_cell
        token_labels += [cell_pct_ranks[cur_cell_idx]] * token_count_for_cell
        token_weights += [1/token_count_for_cell] * token_count_for_cell

    
    # Pad the features to match max_seq_len #
    num_pad_tokens = max_seq_len-len(input_ids)
    token_labels += [0]*num_pad_tokens
    token_weights += [0]*num_pad_tokens
    token_cell_indices += [-1]*num_pad_tokens
    markdown_token_mask += [0]*num_pad_tokens
    code_token_mask += [0]*num_pad_tokens
    attention_mask = [1]*len(input_ids) + [0]*num_pad_tokens
    input_ids += [0]*num_pad_tokens
    
    # Check for bugs
    assert len(input_ids) == max_seq_len
    assert len(token_labels) == max_seq_len
    
    # Build the feature dict for the input 
    notebook_features = {
        'input_ids': input_ids, 
        'attention_mask': attention_mask,
        'markdown_token_mask': markdown_token_mask,
        'code_token_mask': code_token_mask,
        'token_cell_indices': token_cell_indices,
        'notebook_id': notebook_dict['notebook_id'],
    }
    if 'merged_cell_pct_ranks' in notebook_dict:
        notebook_features['token_labels'] = token_labels
        notebook_features['token_weights'] = token_weights
    return notebook_features


def build_hf_dataset(
    df, 
    tokenizer, 
    max_seq_len, 
    max_markdown_seq_len, 
    max_tokens_per_markdown_cell,
    max_tokens_per_code_cell,
    ):
    '''Builds the huggingface dataset for training the model.'''
    convert_to_features = partial(
        convert_to_features_bigbird, 
        tokenizer=tokenizer,
        max_seq_len=max_seq_len,
        max_markdown_seq_len=max_markdown_seq_len,
        max_tokens_per_markdown_cell=max_tokens_per_markdown_cell,
        max_tokens_per_code_cell=max_tokens_per_code_cell,
    )
    raw_dataset = datasets.Dataset.from_pandas(df)
    processed_dataset = raw_dataset.map(
        convert_to_features, 
        remove_columns=raw_dataset.column_names, 
        desc='Running tokenizer on raw dataset'
    )
    processed_dataset.set_format(type='numpy')
    empty_sentences = (np.array(processed_dataset['attention_mask'])[:, -1] == 0).sum()
    print('Empty sentences ratio:', empty_sentences/len(processed_dataset))
    return processed_dataset

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--tokenizer_name', default='google/bigbird-roberta-large', type=str, help='The tokenizer name')
    parser.add_argument('--max_seq_length', default=1024, type=int, help='The max sequence length')
    parser.add_argument('--max_markdown_seq_length', default=512, type=int, help='The max markdown sequence length')
    parser.add_argument('--max_tokens_per_markdown_cell', default=128, type=int, help='The max tokens per markdown cell')
    parser.add_argument('--max_tokens_per_code_cell', default=128, type=int, help='The max tokens per code cell')
    parser.add_argument('--notebooks_df_path', default='notebooks_df.csv', type=str, help='Path to notebooks.csv')
    parser.add_argument('--valid_fold', default=0, type=int, help='Validation Fold')

    args = parser.parse_args()
    
    tokenizer = transformers.AutoTokenizer.from_pretrained(args.tokenizer_name)
    notebooks_df = pd.read_csv(args.notebooks_df_path)
    print('Total number of notebooks:', len(notebooks_df))

    train_df = notebooks_df[notebooks_df.notebook_fold != args.valid_fold]
    valid_df = notebooks_df[notebooks_df.notebook_fold == args.valid_fold]
    
    train_dataset = build_hf_dataset(
        df=train_df,
        tokenizer=tokenizer,
        max_seq_len=args.max_seq_length,
        max_markdown_seq_len=args.max_markdown_seq_length,
        max_tokens_per_markdown_cell=args.max_tokens_per_markdown_cell,
        max_tokens_per_code_cell=args.max_tokens_per_code_cell,
    )
    valid_dataset = build_hf_dataset(
        df=valid_df,
        tokenizer=tokenizer,
        max_seq_len=args.max_seq_length,
        max_markdown_seq_len=args.max_markdown_seq_length,
        max_tokens_per_markdown_cell=args.max_tokens_per_markdown_cell,
        max_tokens_per_code_cell=args.max_tokens_per_code_cell,
    )
    print('Train dataset size:', len(train_dataset))
    print('Valid dataset size:', len(valid_dataset))

    train_dataset.save_to_disk('train_dataset')
    valid_dataset.save_to_disk('valid_dataset')
    train_df.to_csv('train_df.csv', index=False)
    valid_df.to_csv('valid_df.csv', index=False)
    print('Done!')

In [None]:
import prepare_hf_dataset

if HP.processed_dataset_folder is not None:
    print(f'Loading tokenized datasets from {processed_dataset_path}')
    valid_hf_dataset = datasets.load_from_disk(processed_dataset_path/f'fold{HP.valid_fold}_dataset')
    train_hf_dataset = datasets.concatenate_datasets([
        datasets.load_from_disk(processed_dataset_path/f'fold{fold}_dataset')
        for fold in tqdm(range(8), desc='Loading training dataset')
        if fold != HP.valid_fold
    ])
    train_hf_dataset.save_to_disk('train_hf_dataset')
    train_hf_dataset = datasets.load_from_disk('train_hf_dataset')
else:
    train_hf_dataset = valid_hf_dataset = prepare_hf_dataset.build_hf_dataset(
        df=train_df, 
        tokenizer=tokenizer, 
        max_seq_len=HP.max_seq_length, 
        max_markdown_seq_len=HP.max_markdown_seq_len,
        max_tokens_per_markdown_cell=HP.max_tokens_per_markdown_cell, 
        max_tokens_per_code_cell=HP.max_tokens_per_code_cell,
    )

### Prepare Tensorflow Datasets
---
#### <a href='#data-factory'> ⚒ Data Factory </a>  | <a href='#hyperparameters'> ⚙️ Hyperparameters </a>|  <a href='#training'> ⚡ Training </a> 

<a name='prepare-tensorflow-datasets'>

In [None]:
%%writefile prepare_tf_data.py

def convert_hf_dataset_to_tfds(hf_dataset):
    hf_dataset.set_format(type='numpy')

    model_inputs = {
        'input_ids': hf_dataset['input_ids'].astype(np.int32),
        'attention_mask': hf_dataset['attention_mask'].astype(np.int32),
        'markdown_token_mask': hf_dataset['markdown_token_mask'].astype(np.int32),
        'code_token_mask': hf_dataset['code_token_mask'].astype(np.int32),
    }
    input_ds = tf.data.Dataset.from_tensor_slices(model_inputs)

    model_outputs = {
        'token_labels': hf_dataset['token_labels'].astype(np.float32),
    }
    output_ds = tf.data.Dataset.from_tensor_slices(model_outputs)

    ds = tf.data.Dataset.zip((input_ds, output_ds))
    return ds

def hf_dataset_to_tfds(hf_dataset, dataset_type, batch_size): 
    ds = convert_hf_dataset_to_tfds(hf_dataset)
    if dataset_type == 'train':
        ds = ds.shuffle(len(hf_dataset), reshuffle_each_iteration=True).repeat()
    elif dataset_type == 'valid': 
        ds = ds.cache()
    ds = ds.batch(batch_size)
    steps = len(hf_dataset)//batch_size + 1
    return ds.prefetch(tf.data.AUTOTUNE), steps

def convert_hf_dataset_to_test_ds(hf_dataset, batch_size):
    input_ids_ds = tf.data.Dataset.from_tensor_slices(hf_dataset['input_id'].astype(np.int32))
    mask_ds = tf.data.Dataset.from_tensor_slices(hf_dataset['attention_mask'].astype(np.int32))
    input_ds = tf.data.Dataset.zip((input_ids_ds, mask_ds))
    ds = tf.data.Dataset.zip((input_ds, input_ds))
    return ds.batch(batch_size).prefetch(tf.data.AUTOTUNE)

In [None]:
import prepare_tf_data

train_ds, train_steps_per_epoch = prepare_tf_data.hf_dataset_to_tfds(
    hf_dataset=train_hf_dataset,
    dataset_type='train',
    batch_size=HP.train_batch_size
)
valid_ds, valid_steps_per_epoch = prepare_tf_data.hf_dataset_to_tfds(
    hf_dataset=valid_hf_dataset,
    dataset_type='valid',
    batch_size=HP.eval_batch_size,
)

eval_ds = prepare_tf_data.convert_hf_dataset_to_test_ds(valid_hf_dataset, HP.eval_batch_size)

## 🧠 Model Factory
---
#### <a href='#training'> ⚡ Training </a>

<a name='model-factory'>

In [None]:
%%writefile tf_rankencoder_model.py
import tensorflow as tf

class AI4CodeTFRankEncoder(tf.keras.Model):
    def __init__(
        self, 
        model_inputs, 
        model_outputs,
        markdown_cell_loss_weight=0.50,
        loss_fn_name='mse',
    ):
        super().__init__(inputs=model_inputs, outputs=model_outputs)
        self.metrics_tracker = {
            'total_loss': tf.keras.metrics.Mean(name='loss'),
            'markdown_cell_loss': tf.keras.metrics.Mean(name='markdown_cell_loss'),
            'code_cell_loss': tf.keras.metrics.Mean(name='code_cell_loss'),
            'gradient_norm': tf.keras.metrics.Mean(name='gradient_norm'),
        }
        self.markdown_cell_loss_weight = markdown_cell_loss_weight
        self.code_cell_loss_weight = 1.0 - markdown_cell_loss_weight

        self.loss_fn = {
            'mse': self.mse_loss_fn,
            'mae': self.mae_loss_fn,
        }[loss_fn_name]
    
    def mse_loss_fn(self, y_true, y_pred, token_mask):
        diff = (y_true-y_pred)*token_mask
        return tf.math.reduce_sum(diff**2) / tf.math.reduce_sum(token_mask)
    
    def mae_loss_fn(self, y_true, y_pred, token_mask):
        diff = (y_true-y_pred)*token_mask
        return tf.math.reduce_sum(tf.math.abs(diff)) / tf.math.reduce_sum(token_mask)
    
    @tf.function
    def train_step(self, data):
        x, y = data
        with tf.GradientTape() as tape: 
            model_inputs = (x['input_ids'], x['attention_mask'])
            token_labels_pred = self(model_inputs)
            
            token_labels_pred = tf.cast(token_labels_pred, tf.float32)
            token_labels_true = tf.cast(y['token_labels'], tf.float32)
            markdown_cell_mask = tf.cast(x['markdown_cell_mask'], tf.float32)
            code_cell_mask = tf.cast(x['code_cell_mask'], tf.float32)

            markdown_cell_loss = self.loss_fn(token_labels_true, token_labels_pred, markdown_cell_mask)
            code_cell_loss = self.loss_fn(token_labels_true, token_labels_pred, code_cell_mask)
            
            loss = self.markdown_cell_weight*markdown_cell_loss + self.code_cell_weight*code_cell_loss
        
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)
        gradient_norm = tf.linalg.global_norm(gradients)
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        
        self.metrics_tracker['total_loss'].update_state(loss)
        self.metrics_tracker['code_cell_loss'].update_state(code_cell_loss)
        self.metrics_tracker['markdown_cell_loss'].update_state(markdown_cell_loss)
        self.metrics_tracker['gradient_norm'].update_state(gradient_norm)
        return {m.name: m.result() for m in self.metrics}
    
    def test_step(self, data):
        x, y = data
        model_inputs = (x['input_ids'], x['attention_mask'])
        token_labels_pred = self(model_inputs)
        
        token_labels_pred = tf.cast(token_labels_pred, tf.float32)
        token_labels_true = tf.cast(y['token_labels'], tf.float32)
        markdown_cell_mask = tf.cast(x['markdown_cell_mask'], tf.float32)
        code_cell_mask = tf.cast(x['code_cell_mask'], tf.float32)

        markdown_cell_loss = self.loss_fn(token_labels_true, token_labels_pred, markdown_cell_mask)
        code_cell_loss = self.loss_fn(token_labels_true, token_labels_pred, code_cell_mask)
        
        loss = self.markdown_cell_weight*markdown_cell_loss + self.code_cell_weight*code_cell_loss
        
        self.metrics_tracker['total_loss'].update_state(loss)
        self.metrics_tracker['code_cell_loss'].update_state(code_cell_loss)
        self.metrics_tracker['markdown_cell_loss'].update_state(markdown_cell_loss)
        return {m.name: m.result() for m in self.metrics}
    
    @property
    def metrics(self):
        return list(self.metrics_tracker.values())

def build_tfrankencoder_model(
    backbone, 
    max_seq_len,
    markdown_cell_loss_weight=0.50,
    loss_fn_name='mse',
    ):

    input_ids = tf.keras.Input(shape=(max_seq_len,), dtype=tf.int32, name='input_ids')
    attention_mask = tf.keras.Input(shape=(max_seq_len,), dtype=tf.float32, name='attention_mask')
    model_inputs = [input_ids, attention_mask]

    token_ranker_layer = tf.keras.Sequential([
        tf.keras.layers.Dense(1), 
        tf.keras.layers.Reshape((max_seq_len,))
    ], name='token_labels')

    backbone_outputs = backbone(input_ids=input_ids, attention_mask=attention_mask)
    x = backbone_outputs.last_hidden_state
    return AI4CodeTFRankEncoder(
        model_inputs=model_inputs,
        model_outputs=token_ranker_layer(x),
        markdown_cell_loss_weight=markdown_cell_loss_weight,
        loss_fn_name=loss_fn_name,
    )

In [None]:
def adamw_optimizer_factory(lr_scheduler): 
    optimizer = tfa.optimizers.AdamW(
        beta_1=HP.beta_1, 
        beta_2=HP.beta_2, 
        epsilon=HP.epsilon, 
        weight_decay=HP.max_weight_decay, 
        clipnorm=HP.max_grad_norm,
        learning_rate=lr_scheduler,
    )
    if HP.average_decay > 0: 
        print(f"Using EMA with decay {colored(HP.average_decay, 'blue')}")
        optimizer = tfa.optimizers.MovingAverage(
            optimizer, 
            average_decay=HP.average_decay, 
            dynamic_decay=True, 
        )
    return optimizer

In [None]:
import tf_rankencoder_model

total_train_steps = train_steps_per_epoch * HP.num_train_epochs
warmup_steps = int(total_train_steps * HP.warmup_ratio)
decay_steps = total_train_steps - warmup_steps

decay_schedule_fn = tf.keras.optimizers.schedules.CosineDecay(
    initial_learning_rate=HP.peak_lr, 
    decay_steps=decay_steps, 
    alpha=HP.min_lr,
)
lr_scheduler = transformers.WarmUp(
    initial_learning_rate=HP.min_lr,
    decay_schedule_fn=decay_schedule_fn,
    warmup_steps=warmup_steps,
)

with STRATEGY.scope():
    model = tf_rankencoder_model.build_tfrankencoder_model(
        backbone=backbone,
        max_seq_len=HP.max_seq_len,
        markdown_cell_loss_weight=HP.markdown_cell_loss_weight,
        loss_fn_name=HP.loss_fn_name,               
    )

In [None]:
from collections import defaultdict
from termcolor import colored
from tqdm.auto import tqdm
import numpy as np
import scipy.stats
import os

import tensorflow as tf

CELL_SEP = '[CELL_SEP]'

class KendallTauCallback(tf.keras.callbacks.Callback):
    """
    Evaluates Kendall Tau for the notebooks at the end of each epoch
    """

    def __init__(self, eval_df, eval_dataset, eval_ds, strategy):
        self.eval_df = eval_df
        self.eval_dataset = eval_dataset
        self.eval_ds = eval_ds
        self.strategy = strategy

        self.saved_weights = []
        os.makedirs('/kaggle/tmp', exist_ok=True)
    
    def evaluate_model(self, model):
        with self.strategy.scope():
            all_notebook_token_preds = np.array(model.predict(self.eval_ds, verbose=True)).astype(np.float32)
        all_notebook_token_cell_indices = self.eval_dataset['token_cell_indices'].values

        notebook_kendall_taus = []
        for notebook_idx, (token_preds, token_cell_indices) in tqdm(
            enumerate(zip(all_notebook_token_preds, all_notebook_token_cell_indices)), 
            total=len(self.eval_dataset)
        ):
            cell_idx_to_pred = defaultdict(list)
            for cell_idx, token_pred in zip(token_cell_indices, token_preds):
                cell_idx_to_pred[cell_idx].append(token_pred)
            cell_idx_to_pred = {cell_idx: np.mean(preds) for cell_idx, preds in cell_idx_to_pred.items()}

            notebook_cell_pct_ranks = [float(rank) for rank in self.eval_df.iloc[notebook_idx].merged_cell_pct_ranks.split(CELL_SEP)]
            notebook_cell_preds = [cell_idx_to_pred[cell_idx] for cell_idx in range(len(notebook_cell_pct_ranks))]          
            notebook_tau = scipy.stats.kendalltau(notebook_cell_pct_ranks, notebook_cell_preds, method='asymptotic')[0]
            notebook_kendall_taus.append(notebook_tau)
    
        self.eval_df['kendall_tau'] = notebook_kendall_taus
        for cell_cutoff in [4, 16, 64]:
            cutoff_df = self.eval_df[self.eval_df.markdown_cell_count > cell_cutoff]
            tau = np.mean(cutoff_df.kendall_tau.values)
            print(f"Average Kendall Tau for notebooks with {cell_cutoff}+ markdown cells: {colored(tau, 'red')}")
        
        avg_kendall_tau = np.mean(notebook_kendall_taus)
        print('Average Kendall Tau:', colored(avg_kendall_tau, 'red'))
        return avg_kendall_tau

    def on_epoch_end(self, epoch, logs=None):
        epoch_kendall_tau = self.evaluate_model(self.model)
        weights_file = f'/kaggle/tmp/weights_epoch_{epoch}_tau{int(epoch_kendall_tau*1e6)}.h5'
        print('Saving weights at', colored(weights_file, 'blue'))
        self.model.save_weights(weights_file)
        self.saved_weights.append(weights_file)

In [None]:
kendall_tau_callback = KendallTauCallback(
    eval_df=valid_df,
    eval_dataset=valid_dataset,
    eval_ds=eval_ds,
    strategy=STRATEGY,
)

## ⚡ Training 
---
### <a href='#hyperparameters'> ⚙️ Hyperparameters </a> | <a href='#model'> 🧠 Model </a>

<a name='training'>

In [None]:
HP.multi_steps_per_execution = None
with STRATEGY.scope():
    model = tf_rankencoder_model.build_tfrankencoder_model(
        backbone=backbone,
        max_seq_len=HP.max_seq_len,
        markdown_cell_loss_weight=HP.markdown_cell_loss_weight,
        loss_fn_name=HP.loss_fn_name,               
    )
    optimizer = adamw_optimizer_factory(lr_scheduler)
    model.compile(
        optimizer=optimizer, 
        steps_per_execution=HP.multi_steps_per_execution,
    )

history = model.fit(
    train_ds, steps_per_epoch=train_steps_per_epoch, epochs=HP.num_train_epochs,
    validation_data=valid_ds, validation_steps=valid_steps_per_epoch,
    callbacks=[kendall_tau_callback]
)

In [None]:
HP.multi_steps_per_execution = False
with STRATEGY.scope():
    model = build_model(backbone)
    model = get_freeze_compiled_model()

history = model.fit(train_ds, steps_per_epoch=train_steps)
tau_callback.evaluate_model(model)
model.save_weights('inital_model_weights.h5')

In [None]:
HP.multi_steps_per_execution = True
with STRATEGY.scope():
    backbone.trainable = True
    model = build_model(backbone)
    model = get_compiled_model()
model.load_weights('inital_model_weights.h5')

# 653s + 421s/epoch
history = model.fit(
    train_ds, steps_per_epoch=train_steps*4, epochs=HP.num_epochs,
    validation_data=valid_ds, validation_steps=valid_steps,
    callbacks=[tau_callback],
)
tf.keras.backend.clear_session()
gc.collect()

In [None]:
model_weight_files = sorted(tau_callback.saved_weights_files)[-HP.num_swa_models:]
print('Taking average of model weights:', model_weight_files)
with STRATEGY.scope():
    model = get_model_average(model, model_weight_files)

In [None]:
# wandb.init(project='ai4code_tfrankencoder_fold0')

# Save the model weights #
tau = tau_callback.evaluate_model(model)
weights_file = HP.wandb_weights_save_format.format(fold=HP.valid_fold, backbone_code=backbone_code, tau=int(tau*1000000))
print(blue(weights_file))
model.save_weights(weights_file)
wandb.save(weights_file)

# 👽 Save the train and validation file #
train.to_csv('train.csv', index=False)
valid.to_csv('valid.csv', index=False)

sleep(180)

## 👽 Paranoid Validation
---

<a name='paranoid-validation'>

In [None]:
cell_df = pd.read_csv('df.csv').set_index('notebook_id')

In [None]:
i = 910
row = valid.iloc[i]
notebook_id = row.notebook_id

print('notebook id:', notebook_id)
print('notebook tau:', row.tau)

cell_pct_preds = [float(pred) for pred in row.cell_pct_preds.split(CELL_SEP)]
display(cell_df.loc[notebook_id].sort_values(by='cell_pct_rank'))

input_ids, token_labels = valid_dataset[i]['input_ids'], valid_dataset[i]['token_labels']
print(tokenizer.decode(input_ids))

In [None]:
cell_idx, total_loss = 0, 0.0
for token_idx, (token_id, token_label) in enumerate(zip(input_ids, token_labels)):
    if token_label == -100:
        continue
    print(red('-')*100)
    print('Label:'+yellow(token_label))
    print('Prediction:'+blue(cell_pct_preds[cell_idx]))
    print()
    for next_token_idx, (token_id, next_token_label) in enumerate(zip(input_ids, token_labels)):
        if next_token_label == -100 or next_token_idx <= token_idx:
            continue
        break
    print(tokenizer.decode(input_ids[token_idx:next_token_idx]))
    total_loss += (cell_pct_preds[cell_idx]-token_label)**2
    cell_idx += 1

print('Tau for the notebook:', row.tau)
print('MSE loss for notebook:', red(total_loss/cell_idx))

In [None]:
# # Basic Validation 
# valid, valid_dataset = train, train_dataset
# eval_ds = convert_dataset_to_test_ds(valid_dataset)
# tau_callback.evaluate_model(model)