## 🧑‍💻 __AI4Code: Longformer Train & Infer__

---
### <a href='#hyperparameters'> ⚙️ Hyperparameters </a> | <a href='#data-factory'> ⚒ Data Factory </a>  | <a href='#model'> 🧠 Model </a>  | <a href='#training'> ⚡ Training Loop </a> 


In [None]:
# TODO: Torch Profiler
# todo: Better batch selection and offline tokenixation
# TODO: MAKE SURE THAT TOKEN CELL INDICES IS POSITIVE

# 0.674 empty w. 4096

In [None]:
# Sync Notebook with VS Code #
import sys
sys.path.append('/kaggle/input/github-ai4code/ai4code')
sys.path.append('/kaggle/input/omegaconf')
!cp -r /kaggle/input/github-ai4code/ai4code /kaggle/working

# Run Setup Scripts #
%run /kaggle/working/ai4code/ai4c/jupyter_setup.py

# Imports #
import ai4c
import ai4c.process_df
import torch

## ⚙️ Hyperparameters ⚙️
---
### <a href='#data-factory'> ⚒ Data Factory </a>  | <a href='#model'> 🧠 Model </a>|  <a href='#training'> ⚡ Training Loop </a> 

<a name='hyperparameters'>

In [None]:
%%hyperparameters HP

## Huggingface Backbone ##
backbone_name: 'allenai/longformer-base-4096'
backbone_folder: 'longformer-backbones'

attention_probs_dropout_prob: 0.10
hidden_dropout_prob: 0.10

gradient_checkpointing: True


## Tokenization & Pre-processing ##
max_seq_len: 1024
max_markdown_seq_len: 512
max_tokens_per_cell: 512
max_global_tokens_per_notebook: 128


## Data Factory ##
train_folds: [1]
valid_fold: 0
num_validation_notebooks: 1000
sort_notebooks_by_input_tokens: True


## Model Training ##
num_train_epochs: 1
train_batch_size: 4
eval_batch_size: 4

gradient_accumulation_steps: 16
mixed_precision: True


## Loss Function ##
markdown_cell_loss_weight: 0.50
loss_fn_name: 'mse'


## Cosine Decay LR Scheduler ##
warmup_ratio: 0.0625
learning_rate: 3e-5


## AdamW Optimizer ##
weight_decay: 1e-4
max_grad_norm: 1e6
adam_epsilon: 1e-6


## Load From Cache: Tokenized Dataset ##
processed_dataset_folder: null # 'ai4code-flax-seq2seq-tokenization-2048'
debug_notebooks: 10000


## Logging ##
logging_frequency: 10

## Global Args ##
hide_lb_score: False

In [None]:
backbone_code = HP.backbone_name.replace('/', '_')
backbone_dir = f'/kaggle/input/{HP.backbone_folder}/{backbone_code}'
print(f'Loading backbone and tokenizer from {backbone_dir}')

backbone = transformers.AutoModel.from_pretrained(
    backbone_dir,
    attention_probs_dropout_prob=HP.attention_probs_dropout_prob,
    hidden_dropout_prob=HP.hidden_dropout_prob,
    gradient_checkpointing=HP.gradient_checkpointing,
)
tokenizer = transformers.AutoTokenizer.from_pretrained(backbone_dir, use_fast=True)

## ⚒️ Data Factory ⚒️

---
#### <a href='#prepare-huggingface-datasets'> 🤗 Huggingface Datasets </a> | <a href='#prepare-pytorch-datasets'> 🔥 PyTorch Data Module </a> 


<a name='data-factory'>

In [None]:
processed_dataset_path = Path(f'/kaggle/input/{HP.processed_dataset_folder}')
if HP.processed_dataset_folder is not None:
    print(f'Loading dataframes from {processed_dataset_path}')
    notebooks_df = pd.read_csv('/kaggle/input/ai4code-dataframes/notebooks_df.csv')
    train_df = notebooks_df[notebooks_df.notebook_fold != HP.valid_fold]
    valid_df = notebooks_df[notebooks_df.notebook_fold == HP.valid_fold]
    valid_df = valid_df.sample(HP.num_validation_notebooks)
else:
    print(f'Loading {HP.debug_notebooks} notebooks for debugging.')
    train_df = valid_df = notebooks_df = pd.read_csv('/kaggle/input/ai4code-dataframes/notebooks_df.csv', nrows=HP.debug_notebooks)

In [None]:
# Fast Submission #
if Path('/kaggle/input/AI4Code').exists():
    cell_df_test = ai4c.process_df.build_cell_df('/kaggle/input/AI4Code/test')
    test_df = ai4c.process_df.build_notebooks_df(cell_df_test)

if len(test_df) < 100:
    train_df = valid_df = train_df.sample(64)

### 🤗 Prepare Huggingface Datasets
---
#### <a href='#data-factory'> ⚒ Data Factory </a>  | <a href='#hyperparameters'> ⚙️ Hyperparameters </a>|  <a href='#training'> ⚡ Training </a> 

<a name='prepare-huggingface-datasets'>

In [None]:
%%writefile prepare_hf_dataset.py

from functools import partial
from tqdm.auto import tqdm
import pandas as pd
import numpy as np
import argparse

import transformers
import datasets

tqdm.pandas()
CELL_SEP = '[CELL_SEP]'

def prune_cell_tokens(cell_token_ids, max_seq_len):
    """
    Prunes cells that take too many tokens to fit in max_seq_len.
    """
    cell_token_counts = [len(token_ids) for token_ids in cell_token_ids]
    total_number_of_cells = len(cell_token_counts)
    total_tokens_to_prune = max(sum(cell_token_counts)-max_seq_len, 0)

    tokens_to_prune_per_cell = [0]*total_number_of_cells
    total_pruned_tokens = 0
    while total_tokens_to_prune > 0:
        cur_max_cell_token_count = max(cell_token_counts)
        second_max_cell_token_count = sorted(cell_token_counts)[-2]
        for cell_idx, cell_token_count in enumerate(cell_token_counts):
            if not cell_token_count == cur_max_cell_token_count: 
                continue
            
            num_tokens_to_pop = min(cell_token_count-second_max_cell_token_count+1, total_tokens_to_prune)
            tokens_to_prune_per_cell[cell_idx] += num_tokens_to_pop
            total_pruned_tokens += num_tokens_to_pop
            total_tokens_to_prune -= num_tokens_to_pop
            cell_token_counts[cell_idx] -= num_tokens_to_pop
            break
    
    # Prune the cell tokens
    pruned_cell_token_ids = []
    for cell_token_ids, num_tokens_to_pop in zip(cell_token_ids, tokens_to_prune_per_cell):
        if num_tokens_to_pop == 0:
            pruned_cell_token_ids.append(cell_token_ids)
            continue
        pruned_cell_token_ids.append(cell_token_ids[:-num_tokens_to_pop])
    return pruned_cell_token_ids


def convert_to_features_longformer(
    notebook_dict,
    tokenizer,
    max_seq_len,
    max_markdown_seq_len,
    max_tokens_per_cell,
    max_global_tokens_per_notebook,
):
    '''Tokenize the notebook and convert to features for the model'''

    markdown_cell_sources = notebook_dict['merged_markdown_cell_sources'].split(CELL_SEP)
    markdown_cell_pct_ranks = [float(rank) for rank in notebook_dict['merged_markdown_cell_pct_ranks'].split(CELL_SEP)]
    markdown_cell_ids = notebook_dict['merged_markdown_cell_ids'].split(CELL_SEP)

    code_cell_sources = notebook_dict['merged_code_cell_sources'].split(CELL_SEP)
    code_cell_pct_ranks = [float(rank) for rank in notebook_dict['merged_code_cell_pct_ranks'].split(CELL_SEP)]
    code_cell_ids = notebook_dict['merged_code_cell_ids'].split(CELL_SEP)

    # Remove cells from the end of the notebook so that all cells have at least one representative token
    max_markdown_cells = max_markdown_seq_len//2
    max_code_cells = (max_seq_len-max_markdown_seq_len)//2
    if len(markdown_cell_sources) > max_markdown_cells:
        markdown_cell_sources = markdown_cell_sources[:max_markdown_cells]
        markdown_cell_pct_ranks = markdown_cell_pct_ranks[:max_markdown_cells]
        markdown_cell_ids = markdown_cell_ids[:max_markdown_cells]
    if len(code_cell_sources) > max_code_cells:
        code_cell_sources = code_cell_sources[:max_code_cells]
        code_cell_pct_ranks = code_cell_pct_ranks[:max_code_cells]
        code_cell_ids = code_cell_ids[:max_code_cells]
    
    markdown_cell_count = len(markdown_cell_sources)
    code_cell_count = len(code_cell_sources)

    max_tokens_per_markdown_cell = max(max_tokens_per_cell, max_markdown_seq_len//markdown_cell_count)
    markdown_cell_token_ids = tokenizer(
        markdown_cell_sources,
        max_length=max_tokens_per_markdown_cell,
        truncation=True,
    )['input_ids']
    markdown_cell_token_ids = prune_cell_tokens(markdown_cell_token_ids, max_markdown_seq_len)
    total_markdown_cell_tokens = sum([len(token_ids) for token_ids in markdown_cell_token_ids])

    max_code_seq_len = max_seq_len - total_markdown_cell_tokens
    max_tokens_per_code_cell = max(max_tokens_per_cell, max_code_seq_len//code_cell_count)
    code_cell_token_ids = tokenizer(
        code_cell_sources, 
        max_length=max_tokens_per_code_cell, 
        truncation=True, 
    )['input_ids']
    code_cell_token_ids = prune_cell_tokens(code_cell_token_ids, max_seq_len-total_markdown_cell_tokens)

    # Merge the tokenized cells and create the model features
    cell_token_ids = markdown_cell_token_ids + code_cell_token_ids
    cell_pct_ranks = markdown_cell_pct_ranks + code_cell_pct_ranks
    
    input_ids, markdown_token_mask, code_token_mask = [], [], []
    global_attention_mask = []
    token_weights, token_labels = [], []
    token_cell_indices = []
    
    for cur_cell_idx, cell_token_ids in enumerate(cell_token_ids):
        token_count_for_cell = len(cell_token_ids)
        if cur_cell_idx < markdown_cell_count:
            markdown_token_mask += [1]*token_count_for_cell
            code_token_mask += [0]*token_count_for_cell
        else: 
            markdown_token_mask += [0]*token_count_for_cell
            code_token_mask += [1]*token_count_for_cell
        
        if sum(global_attention_mask) < max_global_tokens_per_notebook:
            global_attention_mask += [1] + [0]*(token_count_for_cell-1)
        else: 
            global_attention_mask += [0]*token_count_for_cell
        input_ids += cell_token_ids
        token_cell_indices += [cur_cell_idx] * token_count_for_cell
        token_labels += [cell_pct_ranks[cur_cell_idx]] * token_count_for_cell
        token_weights += [1/token_count_for_cell] * token_count_for_cell
    
    # Pad to max_seq_len for efficient storage 
    num_pad_tokens = max_seq_len - len(input_ids)
    attention_mask = [1]*len(input_ids) + [0]*num_pad_tokens
    input_ids += [0]*num_pad_tokens
    global_attention_mask += [0]*num_pad_tokens
    markdown_token_mask += [0]*num_pad_tokens
    code_token_mask += [0]*num_pad_tokens
    token_cell_indices += [-100]*num_pad_tokens
    token_labels += [-100]*num_pad_tokens
    token_weights += [0]*num_pad_tokens

    # Build the feature dict for the input 
    notebook_features = {
        'input_ids': input_ids, 
        'attention_mask': attention_mask,
        'global_attention_mask': global_attention_mask,
        'markdown_token_mask': markdown_token_mask,
        'code_token_mask': code_token_mask,
        'token_cell_indices': token_cell_indices,
        
        'token_labels': token_labels,
        'token_weights': token_weights,
        
        'notebook_id': notebook_dict['notebook_id'],
        'num_pad_tokens': num_pad_tokens,
    }
    return notebook_features


def build_hf_dataset(
    df, 
    tokenizer, 
    max_seq_len,
    max_markdown_seq_len,
    max_tokens_per_cell,
    max_global_tokens_per_notebook,
    ):
    '''Builds the huggingface dataset for training the model.'''
    convert_to_features = partial(
        convert_to_features_longformer, 
        tokenizer=tokenizer,
        max_seq_len=max_seq_len,
        max_markdown_seq_len=max_markdown_seq_len,
        max_tokens_per_cell=max_tokens_per_cell,
        max_global_tokens_per_notebook=max_global_tokens_per_notebook,
    )
    raw_dataset = datasets.Dataset.from_pandas(df)
    processed_dataset = raw_dataset.map(
        convert_to_features, 
        remove_columns=raw_dataset.column_names, 
        desc='Running tokenizer on raw dataset'
    )
    processed_dataset.set_format(type='numpy')
    empty_sentences = (np.array(processed_dataset['attention_mask'])[:, -1] == 0).sum()
    print('Empty sentences ratio:', empty_sentences/len(processed_dataset))
    return processed_dataset

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--tokenizer_name', default='google/bigbird-roberta-large', type=str, help='The tokenizer name')
    parser.add_argument('--max_seq_len', default=512, type=int, help='The maximum sequence length')
    parser.add_argument('--max_markdown_seq_len', default=512, type=int, help='The maximum sequence length for markdown cells')
    parser.add_argument('--max_tokens_per_cell', default=256, type=int, help='The maximum number of tokens per cell')
    parser.add_argument('--max_global_tokens_per_notebook', default=128, type=int, help='The maximum number of global tokens per notebook')
    parser.add_argument('--notebooks_df_path', default='notebooks_df.csv', type=str, help='Path to notebooks.csv')

    args = parser.parse_args()
    
    tokenizer = transformers.AutoTokenizer.from_pretrained(args.tokenizer_name)
    notebooks_df = pd.read_csv(args.notebooks_df_path)
    print('Total number of notebooks:', len(notebooks_df))
    
    for fold in tqdm(range(8), desc='Tokenizing notebooks for each fold'):
        fold_df = notebooks_df[notebooks_df.notebook_fold == fold]
        fold_dataset = build_hf_dataset(
            df=fold_df,
            tokenizer=tokenizer,
            max_seq_len=args.max_seq_len,
            max_markdown_seq_len=args.max_markdown_seq_len,
            max_tokens_per_cell=args.max_tokens_per_cell,
            max_global_tokens_per_notebook=args.max_global_tokens_per_notebook,
        )
        fold_dataset.save_to_disk(f'hf_dataset_fold_{fold}')
    print('Done!')

In [None]:
import prepare_hf_dataset

if HP.processed_dataset_folder is not None:
    print(f'Loading tokenized datasets from {processed_dataset_path}')
    valid_hf_dataset = datasets.load_from_disk(processed_dataset_path/f'hf_dataset_fold_{HP.valid_fold}')
    train_hf_dataset = datasets.concatenate_datasets([
        datasets.load_from_disk(processed_dataset_path/f'hf_dataset_fold_{fold}')
        for fold in tqdm(range(8), desc='Loading training dataset')
        if fold != HP.valid_fold
    ])
    train_hf_dataset.save_to_disk('train_hf_dataset')
    train_hf_dataset = datasets.load_from_disk('train_hf_dataset')
else:
    train_hf_dataset = valid_hf_dataset = prepare_hf_dataset.build_hf_dataset(
        df=train_df, 
        tokenizer=tokenizer, 
        max_seq_len=HP.max_seq_len,
        max_markdown_seq_len=HP.max_markdown_seq_len,
        max_global_tokens_per_notebook=HP.max_global_tokens_per_notebook,
        max_tokens_per_cell=HP.max_tokens_per_cell,
    )

if HP.sort_notebooks_by_input_tokens:
    train_hf_dataset = train_hf_dataset.sort('num_pad_tokens', reverse=True)
    valid_hf_dataset = valid_hf_dataset.sort('num_pad_tokens', reverse=True)

### 🔥 Prepare PyTorch Data Module
---
#### <a href='#data-factory'> ⚒ Data Factory </a>  | <a href='#hyperparameters'> ⚙️ Hyperparameters </a>|  <a href='#training'> ⚡ Training </a> 

<a name='prepare-tensorflow-datasets'>

In [None]:
def collate_fn(batch):
    min_pad_tokens_in_batch = min(inputs['num_pad_tokens'] for inputs in batch)
    max_seq_len = len(batch[0]['input_ids'])
    batch_seq_len = max_seq_len - min_pad_tokens_in_batch

    # Prune the elements in the batch to batch_seq_len
    model_inputs_batch = {
        'input_ids': torch.tensor([inputs['input_ids'][:batch_seq_len] for inputs in batch], dtype=torch.long),
        'attention_mask': torch.tensor([inputs['attention_mask'][:batch_seq_len] for inputs in batch], dtype=torch.long),
        'global_attention_mask': torch.tensor([inputs['global_attention_mask'][:batch_seq_len] for inputs in batch], dtype=torch.long),
        'markdown_token_mask': torch.tensor([inputs['markdown_token_mask'][:batch_seq_len] for inputs in batch], dtype=torch.long),
        'code_token_mask': torch.tensor([inputs['code_token_mask'][:batch_seq_len] for inputs in batch], dtype=torch.long),
    }

    model_labels_batch = {
        'token_labels': torch.tensor([inputs['token_labels'][:batch_seq_len] for inputs in batch], dtype=torch.long),
        'token_weights': torch.tensor([inputs['token_weights'][:batch_seq_len] for inputs in batch], dtype=torch.float),
    }
    return model_inputs_batch, model_labels_batch

shuffle_train_dataset = not HP.sort_notebooks_by_input_tokens
train_dataloader = torch.utils.data.DataLoader(
    dataset=train_hf_dataset,
    batch_size=HP.train_batch_size,
    shuffle=shuffle_train_dataset,
    num_workers=2,
    collate_fn=collate_fn,
    pin_memory=True,
    drop_last=True,
    prefetch_factor=4,
)

eval_dataloader = torch.utils.data.DataLoader(
    dataset=valid_hf_dataset,
    batch_size=HP.eval_batch_size,
    shuffle=False,
    num_workers=2,
    collate_fn=collate_fn,
    pin_memory=True,
    drop_last=False,
    prefetch_factor=4,
)

batch = next(iter(train_dataloader))

In [None]:
%%writefile torch_model.py

import transformers
import torch.nn as nn
import torch

class AI4CodeModel(nn.Module):
    def __init__(self, backbone):
        super().__init__()
        self.backbone = backbone
        self.ranker = nn.Linear(backbone.config.hidden_size, 1)
    
    def forward(self, input_ids, global_attention_mask):
        backbone_outputs = self.backbone(
            input_ids=input_ids,
            global_attention_mask=global_attention_mask,
        )
        token_preds = self.ranker(backbone_outputs.last_hidden_state)
        batch_seq_len = token_preds.size(1)
        return token_preds.view((-1, batch_seq_len))

def get_optimizer_grouped_parameters(model, weight_decay):
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": weight_decay,
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]
    return optimizer_grouped_parameters

In [None]:
import torch_model
import gc

model = torch_model.AI4CodeModel(backbone)
optimizer = torch.optim.AdamW(
    params=torch_model.get_optimizer_grouped_parameters(model.backbone, HP.weight_decay), 
    eps=HP.adam_epsilon,
    lr=HP.learning_rate,
)
train_steps_per_epoch = math.ceil(len(train_dataloader) / HP.gradient_accumulation_steps)
total_train_steps = HP.num_train_epochs * train_steps_per_epoch

lr_scheduler = transformers.get_cosine_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=int(HP.warmup_ratio*total_train_steps),
    num_training_steps=total_train_steps,
)
_ = model.cuda()

torch.cuda.empty_cache()
gc.collect()

# ⚡ Training Loop ⚡
---
### <a href='#hyperparameters'> ⚙️ Hyperparameters </a>  | <a href='#model'> 🧠 Model </a>

<a name='training'>

In [None]:
def mse_loss(token_labels, token_preds, token_weights, token_mask):
    token_labels, token_preds, token_weights = token_labels*token_mask, token_preds*token_mask, token_weights*token_mask
    sum_weights = torch.sum(token_weights, dim=-1)
    notebook_losses = torch.sum(((token_labels - token_preds)**2 * token_weights), dim=-1) / sum_weights
    return torch.mean(notebook_losses)

_ = model.train()
scaler = torch.cuda.amp.GradScaler(enabled=HP.mixed_precision)

running_metrics = defaultdict(int)
steps_progress_bar = tqdm(range(total_train_steps), desc="Training Progress") 

for step, batch in enumerate(train_dataloader):
    inputs, labels = batch
    for k, v in inputs.items(): 
        inputs[k] = v.cuda()
    for k, v in labels.items():
        labels[k] = v.cuda()
    
    with torch.cuda.amp.autocast(enabled=HP.mixed_precision):
        token_preds = model(
            input_ids=inputs['input_ids'], 
            global_attention_mask=inputs['global_attention_mask'],
        )
        
        markdown_cell_loss = mse_loss(
            token_labels=labels['token_labels'],
            token_preds=token_preds,
            token_weights=labels['token_weights'],
            token_mask=inputs['markdown_token_mask'],
        )
        code_cell_loss = mse_loss(
            token_labels=labels['token_labels'],
            token_preds=token_preds,
            token_weights=labels['token_weights'],
            token_mask=inputs['code_token_mask'],
        )
        loss = HP.markdown_cell_loss_weight*markdown_cell_loss + (1-HP.markdown_cell_loss_weight) * code_cell_loss
    
    grad_norm = torch.nn.utils.clip_grad_norm_(model.backbone.parameters(), HP.max_grad_norm)
    
    running_metrics['total_loss'] += loss.item()
    running_metrics['markdown_cell_loss'] += markdown_cell_loss.item()
    running_metrics['code_cell_loss'] += code_cell_loss.item()
    running_metrics['learning_rate'] += lr_scheduler.get_lr()[0]
    running_metrics['num_input_tokens'] += torch.sum(inputs['attention_mask']).item()/HP.train_batch_size
    running_metrics['gradient_norm'] += grad_norm.item()

    steps_progress_bar.set_postfix(
        loss=f"{loss.item()/HP.gradient_accumulation_steps:.4f}", 
        markdown_cell_loss=f"{markdown_cell_loss.item()/HP.gradient_accumulation_steps:.4f}", 
        code_cell_loss=f"{code_cell_loss.item()/HP.gradient_accumulation_steps:.4f}",
        grad_norm=f"{grad_norm.item():.2f}",
    )

    scaler.scale(loss).backward()
    if (step+1) % HP.gradient_accumulation_steps == 0:
        scaler.step(optimizer)
        scaler.update()
        lr_scheduler.step()
        optimizer.zero_grad()
        _ = steps_progress_bar.update(1)
    
    logging_width = (HP.logging_frequency * HP.gradient_accumulation_steps)
    if (step+1) % logging_width == 0:
        print('-'*50)
        print(f"Step {step-logging_width}-{step} out of {total_train_steps}")
        for k, v in running_metrics.items():
            print(colored(k, 'blue'), ':', colored(v / logging_width, 'red'))
        running_metrics = defaultdict(int)
        print()

##  Validation
---
### <a href='#hyperparameters'> ⚙️ Hyperparameters </a>  | <a href='#model'> 🧠 Model </a>

<a name='validation'>

In [None]:
from collections import defaultdict

CELL_SEP = '[CELL_SEP]'
torch.cuda.empty_cache()
gc.collect()
_ = model.eval()

epoch_metrics = defaultdict(int)
cell_id_to_pred_rank = {}
for step, (inputs, labels) in tqdm(enumerate(eval_dataloader), total=len(eval_dataloader)):
    for k, v in inputs.items():
        inputs[k] = v.cuda()
    for k, v in labels.items():
        labels[k] = v.cuda()

    with torch.cuda.amp.autocast(enabled=HP.mixed_precision):
        with torch.no_grad():
            token_preds = model(
                input_ids=inputs['input_ids'],
                global_attention_mask=inputs['global_attention_mask'],
            )
            markdown_cell_loss = mse_loss(
                token_labels=labels['token_labels'],
                token_preds=token_preds,
                token_weights=labels['token_weights'],
                token_mask=inputs['markdown_token_mask'],
            )
            code_cell_loss = mse_loss(
                token_labels=labels['token_labels'],
                token_preds=token_preds,
                token_weights=labels['token_weights'],
                token_mask=inputs['code_token_mask'],
            )
            loss = HP.markdown_cell_loss_weight*markdown_cell_loss + (1-HP.markdown_cell_loss_weight) * code_cell_loss
            
            epoch_metrics['loss'] += loss.item()
            epoch_metrics['markdown_cell_loss'] += markdown_cell_loss.item()
            epoch_metrics['code_cell_loss'] += code_cell_loss.item()

            token_preds = token_preds.detach().cpu().numpy()
        
        batch_size = inputs['input_ids'].shape[0]
        notebook_ids = valid_hf_dataset['notebook_id'][step*batch_size: (step+1)*batch_size]
        token_cell_indices = valid_hf_dataset['token_cell_indices'][step*batch_size: (step+1)*batch_size]
        for example_idx in range(batch_size):
            # POTENTIAL BUG
            
            notebook_id = notebook_ids[example_idx]
            notebook_row = valid_df[valid_df.notebook_id==notebook_id].iloc[0]

            cell_ids = notebook_row.merged_cell_ids.split(CELL_SEP)
            example_token_preds = token_preds[example_idx]
            example_token_cell_indices = token_cell_indices[example_idx]

            cell_idx_to_sum_preds = defaultdict(int)
            cell_idx_to_num_preds = defaultdict(int)
            for token_pred, cell_idx in zip(example_token_preds, example_token_cell_indices):
                if cell_idx < 0:
                    break
                cell_idx_to_sum_preds[cell_idx] += token_pred
                cell_idx_to_num_preds[cell_idx] += 1
            
            for cell_idx, sum_pred in cell_idx_to_sum_preds.items():
                num_pred = cell_idx_to_num_preds[cell_idx]
                cell_id_to_pred_rank[cell_ids[cell_idx]] = sum_pred / num_pred
                

for k, v in epoch_metrics.items():
    print(f"{colored(k, 'blue')}: {colored(v/(step+1), 'red')}")

# Compute Kendall Tau for the predictions with the ground truth
all_notebook_cell_pct_ranks = valid_df.merged_cell_pct_ranks.values
all_notebook_cell_ids = valid_df.merged_cell_ids.values
all_notebook_kendall_taus, all_notebook_cell_order_preds = [], []
for notebook_idx in range(len(valid_df)):
    true_cell_ranks = [float(rank) for rank in all_notebook_cell_pct_ranks[notebook_idx].split(CELL_SEP)]
    cell_ids = all_notebook_cell_ids[notebook_idx].split(CELL_SEP)
    pred_cell_ranks = [cell_id_to_pred_rank.get(cell_id, cell_idx/len(cell_ids)) for cell_idx, cell_id in enumerate(cell_ids)]

    notebook_tau = scipy.stats.kendalltau(true_cell_ranks, pred_cell_ranks, method='asymptotic')[0]
    notebook_preds = CELL_SEP.join([str(round(rank, 4)) for rank in pred_cell_ranks])
    all_notebook_kendall_taus.append(notebook_tau)
    all_notebook_cell_order_preds.append(notebook_preds)
all_notebook_kendall_taus = np.array(all_notebook_kendall_taus)

valid_df['kendall_tau'] = all_notebook_kendall_taus 
avg_tau = all_notebook_kendall_taus.mean()
print(f"{colored('Kendall Tau', 'blue')}: {colored(avg_tau, 'red')}")

for cutoff in [4, 16, 64]:
    cutoff_df = valid_df[valid_df.markdown_cell_count>cutoff]
    tau = cutoff_df.kendall_tau.mean()
    print(f"Kendall Tau for {colored(len(cutoff_df), 'blue')} notebooks with {colored(cutoff, 'yellow')}+ markdown cells:", \
          colored(tau, 'red'))

valid_df.to_csv(f'valid_df.csv', index=False)

In [None]:
# tau_32 = valid_df[valid_df.markdown_cell_count>32].kendall_tau.mean()
# model_save_path = f"{backbone_code}_tau{int(avg_tau*10000)}_tau32_{int(tau_32*1000)}.pt"
# torch.save(model.state_dict(), model_save_path)
# print(f'Model saved at {model_save_path}')

## 🎯 Inference
---
### <a href='#hyperparameters'> ⚙️ Hyperparameters </a>  | <a href='#training'> ⚡ Training </a>

<a name='inference'>

In [None]:
del train_df, valid_df, train_hf_dataset, valid_hf_dataset
torch.cuda.empty_cache()

gc.collect()

In [None]:
if Path('/kaggle/input/AI4Code').exists():
    cell_df_test = ai4c.process_df.build_cell_df('/kaggle/input/AI4Code/test')
    test_df = ai4c.process_df.build_notebooks_df(cell_df_test)

In [None]:
%%writefile prepare_hf_dataset.py

from functools import partial
from tqdm.auto import tqdm
import pandas as pd
import numpy as np
import argparse

import transformers
import datasets

tqdm.pandas()
CELL_SEP = '[CELL_SEP]'

def prune_cell_tokens(cell_token_ids, max_seq_len):
    """
    Prunes cells that take too many tokens to fit in max_seq_len.
    """
    cell_token_counts = [len(token_ids) for token_ids in cell_token_ids]
    total_number_of_cells = len(cell_token_counts)
    total_tokens_to_prune = max(sum(cell_token_counts)-max_seq_len, 0)

    tokens_to_prune_per_cell = [0]*total_number_of_cells
    total_pruned_tokens = 0
    while total_tokens_to_prune > 0:
        cur_max_cell_token_count = max(cell_token_counts)
        second_max_cell_token_count = sorted(cell_token_counts)[-2]
        for cell_idx, cell_token_count in enumerate(cell_token_counts):
            if not cell_token_count == cur_max_cell_token_count: 
                continue
            
            num_tokens_to_pop = min(cell_token_count-second_max_cell_token_count+1, total_tokens_to_prune)
            tokens_to_prune_per_cell[cell_idx] += num_tokens_to_pop
            total_pruned_tokens += num_tokens_to_pop
            total_tokens_to_prune -= num_tokens_to_pop
            cell_token_counts[cell_idx] -= num_tokens_to_pop
            break
    
    # Prune the cell tokens
    pruned_cell_token_ids = []
    for cell_token_ids, num_tokens_to_pop in zip(cell_token_ids, tokens_to_prune_per_cell):
        if num_tokens_to_pop == 0:
            pruned_cell_token_ids.append(cell_token_ids)
            continue
        pruned_cell_token_ids.append(cell_token_ids[:-num_tokens_to_pop])
    return pruned_cell_token_ids


def convert_to_features_longformer(
    notebook_dict,
    tokenizer,
    max_seq_len,
    max_markdown_seq_len,
    max_tokens_per_cell,
    max_global_tokens_per_notebook,
):
    '''Tokenize the notebook and convert to features for the model'''

    markdown_cell_sources = notebook_dict['merged_markdown_cell_sources'].split(CELL_SEP)
    markdown_cell_ids = notebook_dict['merged_markdown_cell_ids'].split(CELL_SEP)

    code_cell_sources = notebook_dict['merged_code_cell_sources'].split(CELL_SEP)
    code_cell_ids = notebook_dict['merged_code_cell_ids'].split(CELL_SEP)

    # Remove cells from the end of the notebook so that all cells have at least one representative token
    max_markdown_cells = max_markdown_seq_len//2
    max_code_cells = (max_seq_len-max_markdown_seq_len)//2
    if len(markdown_cell_sources) > max_markdown_cells:
        markdown_cell_sources = markdown_cell_sources[:max_markdown_cells]
        markdown_cell_ids = markdown_cell_ids[:max_markdown_cells]
    if len(code_cell_sources) > max_code_cells:
        code_cell_sources = code_cell_sources[:max_code_cells]
        code_cell_ids = code_cell_ids[:max_code_cells]
    
    markdown_cell_count = len(markdown_cell_sources)
    code_cell_count = len(code_cell_sources)

    max_tokens_per_markdown_cell = max(max_tokens_per_cell, max_markdown_seq_len//markdown_cell_count)
    markdown_cell_token_ids = tokenizer(
        markdown_cell_sources,
        max_length=max_tokens_per_markdown_cell,
        truncation=True,
    )['input_ids']
    markdown_cell_token_ids = prune_cell_tokens(markdown_cell_token_ids, max_markdown_seq_len)
    total_markdown_cell_tokens = sum([len(token_ids) for token_ids in markdown_cell_token_ids])

    max_code_seq_len = max_seq_len - total_markdown_cell_tokens
    max_tokens_per_code_cell = max(max_tokens_per_cell, max_code_seq_len//code_cell_count)
    code_cell_token_ids = tokenizer(
        code_cell_sources, 
        max_length=max_tokens_per_code_cell, 
        truncation=True, 
    )['input_ids']
    code_cell_token_ids = prune_cell_tokens(code_cell_token_ids, max_seq_len-total_markdown_cell_tokens)

    # Merge the tokenized cells and create the model features
    cell_token_ids = markdown_cell_token_ids + code_cell_token_ids
    
    input_ids, markdown_token_mask, code_token_mask = [], [], []
    global_attention_mask = []
    token_cell_indices = []
    
    for cur_cell_idx, cell_token_ids in enumerate(cell_token_ids):
        token_count_for_cell = len(cell_token_ids)
        if cur_cell_idx < markdown_cell_count:
            markdown_token_mask += [1]*token_count_for_cell
            code_token_mask += [0]*token_count_for_cell
        else: 
            markdown_token_mask += [0]*token_count_for_cell
            code_token_mask += [1]*token_count_for_cell
        
        if sum(global_attention_mask) < max_global_tokens_per_notebook:
            global_attention_mask += [1] + [0]*(token_count_for_cell-1)
        else: 
            global_attention_mask += [0]*token_count_for_cell
        input_ids += cell_token_ids
        token_cell_indices += [cur_cell_idx] * token_count_for_cell
    
    # Pad to max_seq_len for efficient storage 
    num_pad_tokens = max_seq_len - len(input_ids)
    attention_mask = [1]*len(input_ids) + [0]*num_pad_tokens
    input_ids += [0]*num_pad_tokens
    global_attention_mask += [0]*num_pad_tokens
    markdown_token_mask += [0]*num_pad_tokens
    code_token_mask += [0]*num_pad_tokens
    token_cell_indices += [-100]*num_pad_tokens

    # Build the feature dict for the input 
    notebook_features = {
        'input_ids': input_ids, 
        'attention_mask': attention_mask,
        'global_attention_mask': global_attention_mask,
        'markdown_token_mask': markdown_token_mask,
        'code_token_mask': code_token_mask,
        'token_cell_indices': token_cell_indices,
        
        'notebook_id': notebook_dict['notebook_id'],
        'num_pad_tokens': num_pad_tokens,
    }
    return notebook_features


def build_hf_dataset(
    df, 
    tokenizer, 
    max_seq_len,
    max_markdown_seq_len,
    max_tokens_per_cell,
    max_global_tokens_per_notebook,
    ):
    '''Builds the huggingface dataset for training the model.'''
    convert_to_features = partial(
        convert_to_features_longformer, 
        tokenizer=tokenizer,
        max_seq_len=max_seq_len,
        max_markdown_seq_len=max_markdown_seq_len,
        max_tokens_per_cell=max_tokens_per_cell,
        max_global_tokens_per_notebook=max_global_tokens_per_notebook,
    )
    raw_dataset = datasets.Dataset.from_pandas(df)
    processed_dataset = raw_dataset.map(
        convert_to_features, 
        remove_columns=raw_dataset.column_names, 
        desc='Running tokenizer on raw dataset'
    )
    processed_dataset.set_format(type='numpy')
    empty_sentences = (np.array(processed_dataset['attention_mask'])[:, -1] == 0).sum()
    print('Empty sentences ratio:', empty_sentences/len(processed_dataset))
    return processed_dataset

In [None]:
test_hf_dataset = prepare_hf_dataset.build_hf_dataset(
    df=test_df, 
    tokenizer=tokenizer, 
    max_seq_len=HP.max_seq_len,
    max_markdown_seq_len=HP.max_markdown_seq_len,
    max_global_tokens_per_notebook=HP.max_global_tokens_per_notebook,
    max_tokens_per_cell=HP.max_tokens_per_cell,
)

In [None]:
def collate_fn(batch):
    min_pad_tokens_in_batch = min(inputs['num_pad_tokens'] for inputs in batch)
    max_seq_len = len(batch[0]['input_ids'])
    batch_seq_len = max_seq_len - min_pad_tokens_in_batch

    # Prune the elements in the batch to batch_seq_len
    model_inputs_batch = {
        'input_ids': torch.tensor([inputs['input_ids'][:batch_seq_len] for inputs in batch], dtype=torch.long),
        'attention_mask': torch.tensor([inputs['attention_mask'][:batch_seq_len] for inputs in batch], dtype=torch.long),
        'global_attention_mask': torch.tensor([inputs['global_attention_mask'][:batch_seq_len] for inputs in batch], dtype=torch.long),
        'markdown_token_mask': torch.tensor([inputs['markdown_token_mask'][:batch_seq_len] for inputs in batch], dtype=torch.long),
        'code_token_mask': torch.tensor([inputs['code_token_mask'][:batch_seq_len] for inputs in batch], dtype=torch.long),
    }
    return model_inputs_batch

test_dataloader = torch.utils.data.DataLoader(
    dataset=test_hf_dataset,
    batch_size=HP.eval_batch_size,
    shuffle=False,
    num_workers=2,
    collate_fn=collate_fn,
    pin_memory=True,
    drop_last=False,
    prefetch_factor=4,
)

batch = next(iter(test_dataloader))

In [None]:
_ = model.eval()
CELL_SEP = '[CELL_SEP]'

cell_id_to_pred_rank = {}
for step, inputs in tqdm(enumerate(test_dataloader), total=len(test_dataloader)):
    for k, v in inputs.items():
        inputs[k] = v.cuda()
    
    with torch.cuda.amp.autocast(enabled=HP.mixed_precision):
        with torch.no_grad():
            token_preds = model(
                input_ids=inputs['input_ids'],
                global_attention_mask=inputs['global_attention_mask'],
            )
            token_preds = token_preds.detach().cpu().numpy()
        
        batch_size = inputs['input_ids'].shape[0]
        start = step*batch_size
        end = min(start+batch_size, len(test_hf_dataset))
        
        notebook_ids = test_hf_dataset['notebook_id'][start: end]
        token_cell_indices = test_hf_dataset['token_cell_indices'][start: end]
        for example_idx, notebook_id in enumerate(notebook_ids):
            notebook_row = test_df[test_df.notebook_id==notebook_id].iloc[0]

            cell_ids = notebook_row.merged_cell_ids.split(CELL_SEP)
            example_token_preds = token_preds[example_idx]
            example_token_cell_indices = token_cell_indices[example_idx]

            cell_idx_to_sum_preds = defaultdict(int)
            cell_idx_to_num_preds = defaultdict(int)
            for token_pred, cell_idx in zip(example_token_preds, example_token_cell_indices):
                if cell_idx < 0:
                    break
                cell_idx_to_sum_preds[cell_idx] += token_pred
                cell_idx_to_num_preds[cell_idx] += 1
            
            for cell_idx, sum_pred in cell_idx_to_sum_preds.items():
                num_pred = cell_idx_to_num_preds[cell_idx]
                cell_id_to_pred_rank[cell_ids[cell_idx]] = sum_pred / num_pred

In [None]:
sample_sub = pd.read_csv('/kaggle/input/AI4Code/sample_submission.csv')
predicted_cell_orders = []
for notebook_idx in range(len(test_df)):
    cell_ids = test_df.iloc[notebook_idx].merged_cell_ids.split(CELL_SEP)
    pred_cell_ranks = [
        cell_id_to_pred_rank.get(cell_id, cell_idx/len(cell_ids)) 
        for cell_idx, cell_id in enumerate(cell_ids)
    ]
    ordered_cell_ids_list = [cell_id for cell_pred, cell_id in sorted(zip(pred_cell_ranks, cell_ids), key=lambda pairs: pairs[0])]
    ordered_cell_ids = ' '.join(ordered_cell_ids_list)
    predicted_cell_orders.append(ordered_cell_ids)
test_df['cell_order'] = predicted_cell_orders
if HP.hide_lb_score: 
    test_df.cell_order = test_df.cell_order.apply(lambda cell_order: ' '.join(cell_order.split()[::-1]))
test_df['id'] = test_df.notebook_id
test_df[['id', 'cell_order']].to_csv('submission.csv', index=False)

In [None]:
test_df