## 🧑‍💻 __AI4Code DeBERTa Train & Infer__

---
### <a href='#hyperparameters'> ⚙️ Hyperparameters </a> | <a href='#data-factory'> ⚒ Data Factory </a>  | <a href='#model'> 🧠 Model </a>  | <a href='#training'> ⚡ Training </a> 


In [None]:
# Sync Notebook with VS Code #
import sys
sys.path.append('/kaggle/input/github-ai4code/fast-nlp')
sys.path.append('/kaggle/input/github-ai4code/ai4code')
sys.path.append('/kaggle/input/omegaconf')
sys.path.extend(['fast-nlp', 'ai4code'])
from src import *

import ai4c
import ai4c.process_df
from ai4c.submission import CELL_SEP, IS_INTERACTIVE

!touch __init__.py

# ⚙️ Hyperparameters ⚙️
---
### <a href='#data-factory'> ⚒ Data Factory </a>  | <a href='#model'> 🧠 Model </a>|  <a href='#training'> ⚡ Training </a> 

<a name='hyperparameters'>

In [None]:
%%hyperparameters

## Huggingface Backbone ##
backbone_name: 'microsoft/deberta-v3-large'
backbone_dir: 'deberta-backbones'

attention_probs_dropout_prob: 0.05
hidden_dropout_prob: 0.10

max_relative_positions: 512
max_seq_len: 1280

## Data Factory ##
max_markdown_seq_len: 1024
train_folds: [1, 2]
valid_fold: 0
max_tokens_per_cell: 256

## Model Training ##
num_train_epochs: 1
train_batch_size: 1
eval_batch_size: 4

gradient_accumulation_steps: 16
mixed_precision: True

## Loss Function ##
markdown_cell_loss_weight: 0.50

## Cosine Decay LR Scheduler ##
warmup_ratio: 0.125
learning_rate: 1e-4

## AdamW Optimizer ##
weight_decay: 1e-3
max_grad_norm: 1000000.0
adam_epsilon: 1e-6

## Logging ##
logging_freq: 100
checkpoint_freq: 1000

## Inference ##
hide_lb_score: False

In [None]:
backbone_code = 'deberta_v3_large'
backbone_dir = f'/kaggle/input/{HP.backbone_dir}/{backbone_code}'

backbone = AutoModel.from_pretrained(
    backbone_dir,
    attention_probs_dropout_prob=HP.attention_probs_dropout_prob,
    hidden_dropout_prob=HP.hidden_dropout_prob,
    # max_relative_positions=HP.max_relative_positions,
    type_vocab_size=2,
)
tokenizer = AutoTokenizer.from_pretrained(backbone_dir, use_fast=True)

In [None]:
notebooks_df = pd.read_csv('/kaggle/input/ai4code-dataframes/notebooks_df.csv')
train_df = notebooks_df[notebooks_df.notebook_fold.isin(HP.train_folds)]
valid_df = notebooks_df[notebooks_df.notebook_fold == HP.valid_fold]
if Path('/kaggle/input/AI4Code').exists():
    cell_df_test = ai4c.process_df.build_cell_df('/kaggle/input/AI4Code/test')
    test_df = ai4c.process_df.build_notebooks_df(cell_df_test)
else: 
    test_df = valid_df.drop(columns=['merged_cell_pct_ranks'])

valid_df = valid_df.sample(100)
if len(test_df) < 100:
    train_df = train_df.sample(300)

In [None]:
%%writefile data_module.py

from collections import defaultdict
from tqdm.auto import tqdm 
from pathlib import Path
from time import time
import pandas as pd

import transformers
import datasets 

import torch
import ai4c

CELL_SEP = '[CELL_SEP]'
tqdm.pandas()

def prune_cell_tokens(cell_token_ids, max_seq_len):
    cell_token_counts = [len(token_ids) for token_ids in cell_token_ids]
    total_number_of_cells = len(cell_token_counts)
    total_tokens_to_prune = max(sum(cell_token_counts)-max_seq_len, 0)

    tokens_to_prune_per_cell = [0]*total_number_of_cells
    total_pruned_tokens = 0
    while total_tokens_to_prune > 0:
        cur_max_cell_token_count = max(cell_token_counts)
        second_max_cell_token_count = sorted(cell_token_counts)[-2]
        for cell_idx, cell_token_count in enumerate(cell_token_counts):
            if not cell_token_count == cur_max_cell_token_count: 
                continue
            
            num_tokens_to_pop = min(cell_token_count-second_max_cell_token_count+1, total_tokens_to_prune)
            tokens_to_prune_per_cell[cell_idx] += num_tokens_to_pop
            total_pruned_tokens += num_tokens_to_pop
            total_tokens_to_prune -= num_tokens_to_pop
            cell_token_counts[cell_idx] -= num_tokens_to_pop
            break
    
    # Prune the cell tokens
    pruned_cell_token_ids = []
    for cell_token_ids, num_tokens_to_pop in zip(cell_token_ids, tokens_to_prune_per_cell):
        if num_tokens_to_pop == 0:
            pruned_cell_token_ids.append(cell_token_ids)
            continue
        pruned_cell_token_ids.append(cell_token_ids[:-num_tokens_to_pop])
    return pruned_cell_token_ids


def convert_to_features_torch_rankencoder(notebook, tokenizer, max_seq_len, max_markdown_seq_len, max_tokens_per_cell): 
    '''Tokenizer the notebook and convert to features for the model.'''

    markdown_cell_sources = notebook['merged_markdown_cell_sources'].split(CELL_SEP)
    code_cell_sources = notebook['merged_code_cell_sources'].split(CELL_SEP)
    
    if 'merged_markdown_cell_pct_ranks' in notebook:
        markdown_cell_pct_ranks = [float(rank) for rank in notebook['merged_markdown_cell_pct_ranks'].split(CELL_SEP)]
        code_cell_pct_ranks = [float(rank) for rank in notebook['merged_code_cell_pct_ranks'].split(CELL_SEP)]
    else:
        markdown_cell_pct_ranks = [0]*len(markdown_cell_sources)
        code_cell_pct_ranks = [0]*len(code_cell_sources)

    # Remove cells from the end of the notebook so that all cells have at least one representative token
    max_markdown_cells = max_markdown_seq_len//2
    max_code_cells = (max_seq_len-max_markdown_seq_len)//2
    if len(markdown_cell_sources) > max_markdown_cells:
        markdown_cell_sources = markdown_cell_sources[:max_markdown_cells]
        markdown_cell_pct_ranks = markdown_cell_pct_ranks[:max_markdown_cells]
    if len(code_cell_sources) > max_code_cells:
        code_cell_sources = code_cell_sources[:max_code_cells]
        code_cell_pct_ranks = code_cell_pct_ranks[:max_code_cells]
    markdown_cell_count = len(markdown_cell_sources)
    code_cell_count = len(code_cell_sources)

    max_tokens_per_markdown_cell = max(max_tokens_per_cell, max_markdown_seq_len//markdown_cell_count)
    markdown_cell_token_ids = tokenizer(
        markdown_cell_sources,
        max_length=max_tokens_per_markdown_cell,
        truncation=True,
    )['input_ids']
    markdown_cell_token_ids = prune_cell_tokens(markdown_cell_token_ids, max_markdown_seq_len)
    total_markdown_cell_tokens = sum([len(token_ids) for token_ids in markdown_cell_token_ids])

    max_code_seq_len = max_seq_len - total_markdown_cell_tokens
    max_tokens_per_code_cell = max(max_tokens_per_cell, max_code_seq_len//code_cell_count)
    code_cell_token_ids = tokenizer(
        code_cell_sources, 
        max_length=max_tokens_per_code_cell, 
        truncation=True, 
    )['input_ids']
    code_cell_token_ids = prune_cell_tokens(code_cell_token_ids, max_seq_len-total_markdown_cell_tokens)

    # Merge the tokenized cells and create the model features
    cell_token_ids = markdown_cell_token_ids + code_cell_token_ids
    notebook_cell_count = len(cell_token_ids)

    # Create the model features
    if 'merged_cell_pct_ranks' in notebook:
        cell_pct_ranks = markdown_cell_pct_ranks + code_cell_pct_ranks
    else: 
        cell_pct_ranks = [None]*notebook_cell_count
    
    input_ids = []
    token_cell_ids, token_type_ids = [], []
    token_weights, token_labels = [], []
    for cur_cell_idx, cell_token_ids in enumerate(cell_token_ids):
        token_count_for_cell = len(cell_token_ids)
        if cur_cell_idx < markdown_cell_count:
            token_cell_id = 1.0
        else: 
            token_cell_id = 2.0
        
        input_ids += cell_token_ids
        token_cell_ids += [token_cell_id] * token_count_for_cell
        token_type_id = 0 if cur_cell_idx % 2 == 0 else 1
        token_type_ids += [token_type_id] * token_count_for_cell

        token_labels += [cell_pct_ranks[cur_cell_idx]] * token_count_for_cell
        token_weights += [1/token_count_for_cell] * token_count_for_cell
    
    # Build the feature dict for the input 
    notebook_features = {
        'input_ids': input_ids, 
        'token_type_ids': token_type_ids,
        'token_cell_ids': token_cell_ids,
    }
    if 'merged_cell_pct_ranks' in notebook:
        notebook_features['token_labels'] = token_labels
        notebook_features['token_weights'] = token_weights
    return notebook_features

class AI4CodeDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        df,
        tokenizer,
        max_seq_len,
        max_markdown_seq_len,
        max_tokens_per_cell,
    ):
        self.df = df
        self.tokenizer = tokenizer
        self.max_seq_len = max_seq_len
        self.max_markdown_seq_len = max_markdown_seq_len
        self.max_tokens_per_cell = max_tokens_per_cell
        self.is_test = 'merged_cell_pct_ranks' not in df.columns
    
    def __getitem__(self, idx):
        notebook_row = self.df.iloc[idx]
        notebook_features = convert_to_features_torch_rankencoder(
            notebook=notebook_row,
            tokenizer=self.tokenizer,
            max_seq_len=self.max_seq_len,
            max_markdown_seq_len=self.max_markdown_seq_len,
            max_tokens_per_cell=self.max_tokens_per_cell,
        )
        inputs = {
            'input_ids': notebook_features['input_ids'],
            'token_type_ids': notebook_features['token_type_ids'],
            'token_cell_ids': notebook_features['token_cell_ids'],
        }
        if self.is_test:
            return inputs
        
        labels = {
            'token_labels': notebook_features['token_labels'],
            'token_weights': notebook_features['token_weights'],
        }
        return inputs, labels

    def __len__(self):
        return len(self.df)
    
    
def train_collate_fn(batch, tokenizer):
    batch_input_ids = [{'input_ids': inputs['input_ids']} for inputs, labels in batch]
    padded_batch = tokenizer.pad(
        batch_input_ids,
        pad_to_multiple_of=8,
        padding=True, 
    )
    feature_to_pad_token_id = {
        'token_type_ids': -100.0,
        'token_cell_ids': -100.0,
        'token_labels': -100.0,
        'token_weights': 0.0,
    }
    
    batch_inputs = {
        'input_ids': padded_batch['input_ids'],
        'attention_mask': padded_batch['attention_mask'],
        'token_type_ids': [],
        'token_cell_ids': [],
    }
    batch_labels = {
        'token_weights': [], 
        'token_labels': [],
    }
    
    batch_sequence_length = torch.tensor(padded_batch['input_ids']).shape[1]
    for example_inputs, example_labels in batch:
        for feature, value in example_inputs.items():
            if feature not in feature_to_pad_token_id:
                continue
            num_tokens_to_pad = batch_sequence_length-len(value)
            pad_token_id = feature_to_pad_token_id[feature]
            batch_inputs[feature].append(list(value) + [pad_token_id]*num_tokens_to_pad)
        
        for feature, value in example_labels.items():
            if feature not in feature_to_pad_token_id:
                continue
            num_tokens_to_pad = batch_sequence_length-len(value)
            pad_token_id = feature_to_pad_token_id[feature]
            batch_labels[feature].append(list(value) + [pad_token_id]*num_tokens_to_pad)
    
    batch_inputs = {
        'input_ids': torch.tensor(batch_inputs['input_ids'], dtype=torch.long),
        'attention_mask': torch.tensor(batch_inputs['attention_mask'], dtype=torch.long),
        'token_type_ids': torch.tensor(batch_inputs['token_type_ids'], dtype=torch.long),
        'token_cell_ids': torch.tensor(batch_inputs['token_cell_ids'], dtype=torch.long),
    }
    batch_labels = {
        'token_weights': torch.tensor(batch_labels['token_weights'], dtype=torch.float),
        'token_labels': torch.tensor(batch_labels['token_labels'], dtype=torch.float),
    }
    return batch_inputs, batch_labels


def test_collate_fn(batch, tokenizer):
    batch_input_ids = [{'input_ids': inputs['input_ids']} for inputs in batch]
    padded_batch = tokenizer.pad(
        batch_input_ids,
        pad_to_multiple_of=8,
        padding=True, 
    )
    feature_to_pad_token_id = {
        'token_type_ids': -100.0,
        'token_cell_ids': -100.0,
        'token_weights': 0.0,
    }
    batch_inputs = {
        'input_ids': padded_batch['input_ids'],
        'attention_mask': padded_batch['attention_mask'],
        'token_type_ids': [],
        'token_cell_ids': [],
    }
    
    batch_sequence_length = torch.tensor(padded_batch['input_ids']).shape[1]
    for example_inputs in batch:
        for feature, value in example_inputs.items():
            if feature not in feature_to_pad_token_id:
                continue
            num_tokens_to_pad = batch_sequence_length-len(value)
            pad_token_id = feature_to_pad_token_id[feature]
            batch_inputs[feature].append(list(value) + [pad_token_id]*num_tokens_to_pad)
    
    batch_inputs = {
        'input_ids': torch.tensor(batch_inputs['input_ids'], dtype=torch.long),
        'attention_mask': torch.tensor(batch_inputs['attention_mask'], dtype=torch.long),
        'token_type_ids': torch.tensor(batch_inputs['token_type_ids'], dtype=torch.long),
        'token_cell_ids': torch.tensor(batch_inputs['token_cell_ids'], dtype=torch.long),
    }
    return batch_inputs

In [None]:
import data_module

train_dataset = data_module.AI4CodeDataset(
    df=train_df,
    tokenizer=tokenizer,
    max_seq_len=HP.max_seq_len,
    max_markdown_seq_len=HP.max_markdown_seq_len,
    max_tokens_per_cell=HP.max_tokens_per_cell,
)

eval_dataset = data_module.AI4CodeDataset(
    df=valid_df,
    tokenizer=tokenizer,
    max_seq_len=HP.max_seq_len,
    max_markdown_seq_len=HP.max_markdown_seq_len,
    max_tokens_per_cell=HP.max_tokens_per_cell,
)

train_collate_fn = partial(data_module.train_collate_fn, tokenizer=tokenizer)
train_dataloader = torch.utils.data.DataLoader(
    dataset=train_dataset,
    batch_size=HP.train_batch_size,
    shuffle=True,
    num_workers=2,
    collate_fn=train_collate_fn,
    pin_memory=True,
    drop_last=True,
    prefetch_factor=4,
)

eval_dataloader = torch.utils.data.DataLoader(
    dataset=eval_dataset,
    batch_size=HP.eval_batch_size,
    shuffle=False,
    num_workers=2,
    collate_fn=train_collate_fn,
    pin_memory=True,
    drop_last=False,
    prefetch_factor=4,
)

batch = next(iter(train_dataloader))

In [None]:
%%writefile torch_model.py

import transformers
import torch.nn as nn
import torch

class AI4CodeModel(nn.Module):
    def __init__(self, backbone):
        super().__init__()
        self.backbone = backbone
        self.ranker = nn.Linear(backbone.config.hidden_size, 1)
    
    def forward(self, input_ids, attention_mask, token_type_ids):
        backbone_outputs = self.backbone(
            input_ids=input_ids,
            # attention_mask=attention_mask,
            # token_type_ids=token_type_ids,
        )
        token_preds = self.ranker(backbone_outputs.last_hidden_state)
        seq_len = token_preds.size(1)
        return token_preds.view((-1, seq_len))

def get_optimizer_grouped_parameters(model, weight_decay):
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": weight_decay,
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]
    return optimizer_grouped_parameters

In [None]:
import torch_model

model = torch_model.AI4CodeModel(backbone)
optimizer = torch.optim.AdamW(
    params=torch_model.get_optimizer_grouped_parameters(model.backbone, HP.weight_decay), 
    eps=HP.adam_epsilon,
    lr=HP.learning_rate,
)
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / HP.gradient_accumulation_steps)
total_train_steps = HP.num_train_epochs * num_update_steps_per_epoch

lr_scheduler = transformers.get_cosine_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=int(HP.warmup_ratio*total_train_steps),
    num_training_steps=total_train_steps,
)

_ = model.cuda()
torch.cuda.empty_cache()
gc.collect()

# ⚡ Training ⚡
---
### <a href='#hyperparameters'> ⚙️ Hyperparameters </a>  | <a href='#model'> 🧠 Model </a>

<a name='training'>

In [None]:
def mse_loss(token_labels, token_preds, token_weights, mask):
    """
    Compute the mean squared error loss.
    Args:
        token_labels: tensor of shape (batch_size, seq_len)
        token_preds: tensor of shape (batch_size, seq_len)
        token_weights: tensor of shape (batch_size, seq_len)
        mask: tensor of shape (batch_size, seq_len)
    Returns:
        mse_loss: scalar
    """
    token_labels, token_preds, token_weights = token_labels*mask, token_preds*mask, token_weights*mask
    sum_weights = torch.sum(token_weights, dim=-1)
    notebook_losses = torch.sum(((token_labels - token_preds)**2 * token_weights), dim=-1) / (sum_weights + 1e-8)
    return torch.mean(notebook_losses)

progress_bar = tqdm(range(total_train_steps), desc="Training Progress") 
completed_batches = 0
_ = model.train()
scaler = torch.cuda.amp.GradScaler(enabled=HP.mixed_precision)

epoch_metrics = {
    'loss': 0.0, 
    'markdown_cell_loss': 0.0, 
    'code_cell_loss': 0.0
}
batch_metrics = {
    'loss': 0.0,
    'markdown_cell_loss': 0.0,
    'code_cell_loss': 0.0,
    'num_markdown_cell_tokens': 0.0, 
    'num_code_cell_tokens': 0.0,
    'num_pad_tokens': 0.0,
}

for step, batch in enumerate(train_dataloader):
    inputs, labels = batch
    for k, v in inputs.items(): 
        inputs[k] = v.cuda()
    for k, v in labels.items():
        labels[k] = v.cuda()
    
    with torch.cuda.amp.autocast(enabled=HP.mixed_precision):
        token_preds = model(
            input_ids=inputs['input_ids'], 
            attention_mask=inputs['attention_mask'],
            token_type_ids=inputs['token_type_ids'],
        )
        
        markdown_cell_mask = torch.where(inputs['token_cell_ids']==1.0, 1.0, 0.0)
        markdown_cell_loss = mse_loss(
            token_labels=labels['token_labels'],
            token_preds=token_preds,
            token_weights=labels['token_weights'],
            mask=markdown_cell_mask,
        )

        code_cell_mask = torch.where(inputs['token_cell_ids']==2.0, 1.0, 0.0)
        code_cell_loss = mse_loss(
            token_labels=labels['token_labels'],
            token_preds=token_preds,
            token_weights=labels['token_weights'],
            mask=code_cell_mask,
        )

        loss = HP.markdown_cell_loss_weight*markdown_cell_loss + (1-HP.markdown_cell_loss_weight) * code_cell_loss
    
    epoch_metrics['loss'] += loss.item()
    epoch_metrics['markdown_cell_loss'] += markdown_cell_loss.item()
    epoch_metrics['code_cell_loss'] += code_cell_loss.item()

    batch_metrics['loss'] += loss.item()
    batch_metrics['markdown_cell_loss'] += markdown_cell_loss.item()
    batch_metrics['code_cell_loss'] += code_cell_loss.item()
    batch_metrics['num_markdown_cell_tokens'] += torch.mean(torch.sum(markdown_cell_mask, dim=-1)).item()
    batch_metrics['num_code_cell_tokens'] += torch.mean(torch.sum(code_cell_mask, dim=-1)).item()
    batch_metrics['num_pad_tokens'] += torch.mean(torch.sum(1.0-inputs['attention_mask'], dim=-1)).item()

    scaler.scale(loss).backward()
    if (step+1) % HP.gradient_accumulation_steps == 0:
        grad_norm = torch.nn.utils.clip_grad_norm_(model.backbone.parameters(), HP.max_grad_norm)
        scaler.step(optimizer)
        scaler.update()
        lr_scheduler.step()
        optimizer.zero_grad()
        completed_batches += 1
        
        progress_bar.set_postfix(
            loss=f"{batch_metrics['loss']/HP.gradient_accumulation_steps:.4f}", 
            markdown_cell_loss=f"{batch_metrics['markdown_cell_loss']/HP.gradient_accumulation_steps:.4f}", 
            code_cell_loss=f"{batch_metrics['code_cell_loss']/HP.gradient_accumulation_steps:.4f}",
            grad_norm=f"{grad_norm:.4f}",
        )
        _ = progress_bar.update(1)

        if completed_batches % HP.logging_freq == 0:
            print(f"Batch #{completed_batches} | Pred shape: {token_preds.shape}")
            print(f"Loss: {loss.item()} | Markdown Cell Loss: {markdown_cell_loss.item()} | Code Cell Loss: {code_cell_loss.item()}")
            print(f"Gradient Norm: {grad_norm.item()}")
            print(f"Learning Rate: {lr_scheduler.get_lr()[0]}")
            print("-"*50)
            batch_metrics = {key: 0 for key in batch_metrics}

for k, v in epoch_metrics.items():
    print(f"{colored(k, 'blue')}: {colored(v/(step+1), 'red')}")

## 🛒 Validation
---
### <a href='#hyperparameters'> ⚙️ Hyperparameters </a>  | <a href='#model'> 🧠 Model </a>

<a name='validation'>

In [None]:
torch.cuda.empty_cache()
gc.collect()
_ = model.eval()

epoch_metrics = {
    'loss': 0.0, 
    'markdown_cell_loss': 0.0, 
    'code_cell_loss': 0.0
}
cell_id_to_rank_preds = defaultdict(list)
notebook_idx = 0
for step, (inputs, labels) in tqdm(enumerate(eval_dataloader), total=len(eval_dataloader)):
    for k, v in inputs.items():
        inputs[k] = v.cuda()
    for k, v in labels.items():
        labels[k] = v.cuda()

    with torch.cuda.amp.autocast(enabled=HP.mixed_precision):
        with torch.no_grad():
            token_preds = model(
                input_ids=inputs['input_ids'],
                attention_mask=inputs['attention_mask'],
                token_type_ids=inputs['token_type_ids'],
            )
            markdown_cell_mask = torch.where(inputs['token_cell_ids']==1.0, 1.0, 0.0)
            markdown_cell_loss = mse_loss(
                token_labels=labels['token_labels'],
                token_preds=token_preds,
                token_weights=labels['token_weights'],
                mask=markdown_cell_mask,
            )

            code_cell_mask = torch.where(inputs['token_cell_ids']==2.0, 1.0, 0.0)
            code_cell_loss = mse_loss(
                token_labels=labels['token_labels'],
                token_preds=token_preds,
                token_weights=labels['token_weights'],
                mask=code_cell_mask,
            )
            loss = HP.markdown_cell_loss_weight*markdown_cell_loss + (1-HP.markdown_cell_loss_weight) * code_cell_loss
            
            epoch_metrics['loss'] += loss.item()
            epoch_metrics['markdown_cell_loss'] += markdown_cell_loss.item()
            epoch_metrics['code_cell_loss'] += code_cell_loss.item()

            token_preds = token_preds.detach().cpu().numpy()
        
        batch_size = inputs['input_ids'].shape[0]
        batch_notebook_cell_ids = valid_df.iloc[notebook_idx:notebook_idx+batch_size].merged_cell_ids.values
        for i in range(batch_size): 
            notebook_token_type_ids = inputs['token_type_ids'][i]
            notebook_cell_ids = valid_df.iloc[notebook_idx+i].merged_cell_ids.split(CELL_SEP)
            notebook_token_preds = token_preds[i]

            cur_cell_idx = 0
            prev_token_type_id = 0
            for token_pred, token_type_id in zip(notebook_token_preds, notebook_token_type_ids):
                if token_type_id == -100:
                    continue

                if token_type_id == prev_token_type_id:
                    cur_cell_id = notebook_cell_ids[cur_cell_idx]
                    cell_id_to_rank_preds[cur_cell_id].append(token_pred)
                else: 
                    cur_cell_idx += 1
                prev_token_type_id = token_type_id
        notebook_idx += batch_size

for k, v in epoch_metrics.items():
    print(f"{colored(k, 'blue')}: {colored(v/(step+1), 'red')}")

# Compute Kendall Tau for the predictions with the ground truth
cell_id_to_pred_rank = {cell_id: float(sum(preds)/len(preds)) for cell_id, preds in cell_id_to_rank_preds.items()}
all_notebook_cell_pct_ranks = valid_df.merged_cell_pct_ranks.values
all_notebook_cell_ids = valid_df.merged_cell_ids.values
all_notebook_kendall_taus, all_notebook_cell_order_preds = [], []
for notebook_idx in range(len(valid_df)):
    true_cell_ranks = [float(rank) for rank in all_notebook_cell_pct_ranks[notebook_idx].split(CELL_SEP)]
    cell_ids = all_notebook_cell_ids[notebook_idx].split(CELL_SEP)
    pred_cell_ranks = [cell_id_to_pred_rank.get(cell_id, cell_idx/len(cell_ids)) for cell_idx, cell_id in enumerate(cell_ids)]

    notebook_tau = scipy.stats.kendalltau(true_cell_ranks, pred_cell_ranks, method='asymptotic')[0]
    notebook_preds = CELL_SEP.join([str(round(rank, 4)) for rank in pred_cell_ranks])
    all_notebook_kendall_taus.append(notebook_tau)
    all_notebook_cell_order_preds.append(notebook_preds)
all_notebook_kendall_taus = np.array(all_notebook_kendall_taus)

valid_df['kendall_tau'] = all_notebook_kendall_taus 
avg_tau = all_notebook_kendall_taus.mean()
print(f"{colored('Kendall Tau', 'blue')}: {colored(avg_tau, 'red')}")

for cutoff in [4, 16, 64]:
    tau = valid_df[valid_df.markdown_cell_count>cutoff].kendall_tau.mean()
    print(f"Kendall Tau for notebooks with {colored(cutoff, 'yellow')}+ markdown cells:", colored(tau, 'red'))

valid_df.to_csv(f'valid_df.csv', index=False)

In [None]:
tau_32 = valid_df[valid_df.markdown_cell_count>32].kendall_tau.mean()
model_save_path = f"{backbone_code}_tau{int(avg_tau*10000)}_tau32_{int(tau_32*1000)}.pt"
torch.save(model.state_dict(), model_save_path)
print(f'Model saved at {model_save_path}')

## 🎯 Inference
---
### <a href='#hyperparameters'> ⚙️ Hyperparameters </a>  | <a href='#training'> ⚡ Training </a>

<a name='inference'>

In [None]:
import data_module
test_dataset = data_module.AI4CodeDataset(
    df=test_df,
    tokenizer=tokenizer,
    max_seq_len=HP.max_seq_len,
    max_markdown_seq_len=HP.max_markdown_seq_len,
    max_tokens_per_cell=HP.max_tokens_per_cell,
)

test_collate_fn = partial(data_module.test_collate_fn, tokenizer=tokenizer)
test_dataloader = torch.utils.data.DataLoader(
    dataset=test_dataset,
    batch_size=HP.eval_batch_size,
    shuffle=False,
    num_workers=2,
    collate_fn=test_collate_fn,
    pin_memory=True,
    drop_last=False,
    prefetch_factor=4,
)

In [None]:
torch.cuda.empty_cache()
gc.collect()

_ = model.eval()
cell_id_to_rank_preds = defaultdict(list)
notebook_idx = 0
for step, inputs in tqdm(enumerate(test_dataloader), total=len(test_dataloader)):
    for k, v in inputs.items():
        inputs[k] = v.cuda()
    
    with torch.cuda.amp.autocast(enabled=HP.mixed_precision):
        with torch.no_grad():
            token_preds = model(
                input_ids=inputs['input_ids'],
                attention_mask=inputs['attention_mask'],
                token_type_ids=inputs['token_type_ids'],
            )
            token_preds = token_preds.detach().cpu().numpy()
        
        batch_size = inputs['input_ids'].shape[0]
        batch_notebook_cell_ids = test_df.iloc[notebook_idx:notebook_idx+batch_size].merged_cell_ids.values
        for i in range(batch_size): 
            notebook_token_type_ids = inputs['token_type_ids'][i]
            notebook_cell_ids = test_df.iloc[notebook_idx+i].merged_cell_ids.split(CELL_SEP)
            notebook_token_preds = token_preds[i]

            cur_cell_idx = 0
            prev_token_type_id = 0
            for token_pred, token_type_id in zip(notebook_token_preds, notebook_token_type_ids):
                if token_type_id == -100:
                    continue
                if token_type_id == prev_token_type_id:
                    cur_cell_id = notebook_cell_ids[cur_cell_idx]
                    cell_id_to_rank_preds[cur_cell_id].append(token_pred)
                else: 
                    cur_cell_idx += 1
                prev_token_type_id = token_type_id
        notebook_idx += batch_size
cell_id_to_pred_rank = {cell_id: float(sum(preds)/len(preds)) for cell_id, preds in cell_id_to_rank_preds.items()}

In [None]:
sample_sub = pd.read_csv('/kaggle/input/AI4Code/sample_submission.csv')
predicted_cell_orders = []
for notebook_idx in range(len(test_df)):
    cell_ids = test_df.iloc[notebook_idx].merged_cell_ids.split(CELL_SEP)
    pred_cell_ranks = [
        cell_id_to_pred_rank.get(cell_id, cell_idx/len(cell_ids)) 
        for cell_idx, cell_id in enumerate(cell_ids)
    ]
    ordered_cell_ids_list = [cell_id for cell_pred, cell_id in sorted(zip(pred_cell_ranks, cell_ids), key=lambda pairs: pairs[0])]
    ordered_cell_ids = ' '.join(ordered_cell_ids_list)
    predicted_cell_orders.append(ordered_cell_ids)
test_df['cell_order'] = predicted_cell_orders
if HP.hide_lb_score: 
    test_df.cell_order = test_df.cell_order.apply(lambda cell_order: ' '.join(cell_order.split()[::-1]))
test_df['id'] = test_df.notebook_id
test_df[['id', 'cell_order']].to_csv('submission.csv', index=False)