# Train With Colbert Q & A

In [1]:
import os
project_name = "Science Exam - Kaggle"
run_name = "ensemble_with_270_wiki"

config = {
    'max_words': 100,
    'nbits': 1,
    'colbert_indexer_version': 6,
    'hf_model_id': 'microsoft/deberta-v3-large',
    'deberta_max_length': 512,
    'large_train_deduped': True, # if this is True the individual ones below are skipped
    'use_train_daniel': False,
    'use_osmu_sci_6k': False,
    'use_osmu_21k': False,
    'use_mgoksu_13k': False,
    'use_gigkpea_3k': False,
    'n_extra_test_rows' : 300,
    'lr_layer_factor': None, # set to an int to layer, None to not layer, need to be adapted to work with 8-bit Adam
    'learning_rate': 1e-5,
    'num_train_epochs': 1,
    '8_bit_adam': True
}

os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'garbage_collection_threshold:0.6,max_split_size_mb:128'

from transformers import TrainingArguments, IntervalStrategy
from pathlib import Path

output_path = Path('./checkpoints')
training_args = TrainingArguments(
    learning_rate=config['learning_rate'],
    num_train_epochs=config['num_train_epochs'],
    # fp16=True,
    # warmup_ratio=0.5,
    weight_decay=0,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=12,
    gradient_accumulation_steps=4, # number of batches to sum up gradient for
    gradient_checkpointing=True,
    # optim='adafactor',
    evaluation_strategy = IntervalStrategy.STEPS,
    logging_steps=10,
    eval_steps=100,
    save_strategy='no',
    # save_steps=20000,
    report_to='wandb',
    output_dir=str(output_path)
)

In [2]:
import numpy as np
import pandas as pd
from datasets import load_dataset, load_from_disk
from sklearn.feature_extraction.text import TfidfVectorizer
import pickle
import unicodedata

In [3]:
import os
import random
import glob
import math
import gc
import polars as pl
import polars.selectors as cs
import pandas as pd
import pyarrow as pa
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm
from dataclasses import dataclass
from typing import Optional, Union
from scipy.special import softmax
from functools import partial

from colbert.infra import Run, RunConfig, ColBERTConfig
from colbert.data import Queries
from colbert import Indexer, Searcher

import wandb

import torch
from torch import nn

import evaluate
from transformers import AutoTokenizer, AutoModelForMultipleChoice, Trainer, \
                        get_linear_schedule_with_warmup
from transformers.integrations import WandbCallback
from transformers.tokenization_utils_base import PreTrainedTokenizerBase, PaddingStrategy
from transformers.utils.notebook import NotebookProgressCallback
from transformers.trainer_pt_utils import get_parameter_names
from datasets import Dataset # HuggingFace
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR

import bitsandbytes as bnb

In [4]:
is_local = True

In [5]:
n_extra_test = config['n_extra_test_rows']
train_sets = []
test_sets = {}

test = pl.read_csv('data/train.csv')
test = test.rename({'prompt': 'question'})
test = test.drop(columns="id")
test_sets['test'] = test


def read_training_files(file_paths, test_name):
    if type(file_paths) == str:
        file_paths = [file_paths]
        
    if file_paths[0].endswith('.parquet'):
        train_raw = pl.concat(
            [pl.read_parquet(file) for file in file_paths]
        )
    elif file_paths[0].endswith('.csv'):
        train_raw = pl.concat(
            [pl.read_csv(file) for file in file_paths]
        )
    
    if 'prompt' in train_raw.columns:
        train_raw = train_raw.rename({'prompt': 'question'})
    train_cols = ['question', 'A', 'C', 'B', 'D', 'E', 'answer']
    train_raw = train_raw[train_cols]
    train_raw = train_raw.select(pl.all().fill_null('N/A'))
    
    test_split = train_raw[train_raw.shape[0] - n_extra_test:]
    train = train_raw[:train_raw.shape[0] - n_extra_test]
    
    train_sets.append(train)
    test_sets[test_name] = test_split

if config['large_train_deduped']:
    read_training_files('./data/large_train_deduped.parquet', 'combined')
else:        
    if config['use_train_daniel']:
        read_training_files('./data/train_dedupe/daniel.parquet', 'daniel')

    if config['use_osmu_sci_6k']:
        read_training_files('./data/train_dedupe/osmu_sci_6k.parquet', 'osmu_sci_6k')
    if config['use_osmu_21k']:
        read_training_files([
            './data/train_dedupe/osmu_15k.parquet',
            './data/train_dedupe/osmu_5_9k.parquet'
            ], 'osmu_21k')

    if config['use_mgoksu_13k']:
        read_training_files('./data/train_dedupe/mgoksu.parquet', 'mgoksu')

    if config['use_gigkpea_3k']:
        read_training_files('./data/train_dedupe/gigkpea.parquet', 'gigkpea')

train = pl.concat(train_sets)
f'{train.shape[0]:,}'

'54,425'

In [6]:
df_chunk_size=5000
if is_local:
    paraphs_parsed_dataset = load_from_disk("./data/wiki-270k")
else:
    paraphs_parsed_dataset = load_from_disk("/kaggle/working/all-paraphs-parsed-expanded")
modified_texts_parsed = paraphs_parsed_dataset.map(lambda example:
                                                     {'temp_text':
                                                      f"{example['title']} {example['section']} {example['text']}".replace('\n'," ").replace("'","")},
                                                     num_proc=2)["temp_text"]

if is_local:
    cohere_dataset_filtered = load_from_disk("./data/stem-wiki-cohere")
else:
    cohere_dataset_filtered = load_from_disk("/kaggle/working/stem-wiki-cohere-no-emb")

modified_texts = cohere_dataset_filtered.map(lambda example:
                                         {'temp_text':
                                          unicodedata.normalize("NFKD", f"{example['title']} {example['text']}").replace('"',"")},
                                         num_proc=2)["temp_text"]

Setting TOKENIZERS_PARALLELISM=false for forked processes.


In [9]:
def SplitList(mylist, chunk_size):
    return [mylist[offs:offs+chunk_size] for offs in range(0, len(mylist), chunk_size)]

def get_relevant_documents_parsed(df_valid):
    all_articles_indices = []
    all_articles_values = []
    for idx in tqdm(range(0, df_valid.shape[0], df_chunk_size)):
        df_valid_ = df_valid.iloc[idx: idx+df_chunk_size]
    
        articles_indices, merged_top_scores = retrieval(df_valid_, modified_texts_parsed)
        all_articles_indices.append(articles_indices)
        all_articles_values.append(merged_top_scores)
        
    article_indices_array =  np.concatenate(all_articles_indices, axis=0)
    articles_values_array = np.concatenate(all_articles_values, axis=0).reshape(-1)
    
    top_per_query = article_indices_array.shape[1]
    articles_flatten = [(
                         articles_values_array[index],
                         paraphs_parsed_dataset[idx.item()]["title"],
                         paraphs_parsed_dataset[idx.item()]["text"],
                        )
                        for index,idx in enumerate(article_indices_array.reshape(-1))]
    retrieved_articles = SplitList(articles_flatten, top_per_query)
    return retrieved_articles



def get_relevant_documents(df_valid):
    all_articles_indices = []
    all_articles_values = []
    for idx in tqdm(range(0, df_valid.shape[0], df_chunk_size)):
        df_valid_ = df_valid.iloc[idx: idx+df_chunk_size]
    
        articles_indices, merged_top_scores = retrieval(df_valid_, modified_texts)
        all_articles_indices.append(articles_indices)
        all_articles_values.append(merged_top_scores)
        
    article_indices_array =  np.concatenate(all_articles_indices, axis=0)
    articles_values_array = np.concatenate(all_articles_values, axis=0).reshape(-1)
    
    top_per_query = article_indices_array.shape[1]
    articles_flatten = [(
                         articles_values_array[index],
                         cohere_dataset_filtered[idx.item()]["title"],
                         unicodedata.normalize("NFKD", cohere_dataset_filtered[idx.item()]["text"]),
                        )
                        for index,idx in enumerate(article_indices_array.reshape(-1))]
    retrieved_articles = SplitList(articles_flatten, top_per_query)
    return retrieved_articles



def retrieval(df_valid, modified_texts):
    
    corpus_df_valid = df_valid.apply(lambda row:
                                     f'{row["question"]}\n{row["question"]}\n{row["question"]}\n{row["A"]}\n{row["B"]}\n{row["C"]}\n{row["D"]}\n{row["E"]}',
                                     axis=1).values
    vectorizer1 = TfidfVectorizer(ngram_range=(1,2),
                                 token_pattern=r"(?u)\b[\w/.-]+\b|!|/|\?|\"|\'",
                                 stop_words=stop_words)
    vectorizer1.fit(corpus_df_valid)
    vocab_df_valid = vectorizer1.get_feature_names_out()
    vectorizer = TfidfVectorizer(ngram_range=(1,2),
                                 token_pattern=r"(?u)\b[\w/.-]+\b|!|/|\?|\"|\'",
                                 stop_words=stop_words,
                                 vocabulary=vocab_df_valid)
    vectorizer.fit(modified_texts[:500000])
    corpus_tf_idf = vectorizer.transform(corpus_df_valid)
    
    print(f"length of vectorizer vocab is {len(vectorizer.get_feature_names_out())}")

    chunk_size = 100
    top_per_chunk = 3
    top_per_query = 3

    all_chunk_top_indices = []
    all_chunk_top_values = []

    for idx in tqdm(range(0, len(modified_texts), chunk_size)):
        wiki_vectors = vectorizer.transform(modified_texts[idx: idx+chunk_size])
        temp_scores = (corpus_tf_idf * wiki_vectors.T).toarray()
        chunk_top_indices = temp_scores.argpartition(-top_per_chunk, axis=1)[:, -top_per_chunk:]
        chunk_top_values = temp_scores[np.arange(temp_scores.shape[0])[:, np.newaxis], chunk_top_indices]

        all_chunk_top_indices.append(chunk_top_indices + idx)
        all_chunk_top_values.append(chunk_top_values)

    top_indices_array = np.concatenate(all_chunk_top_indices, axis=1)
    top_values_array = np.concatenate(all_chunk_top_values, axis=1)
    
    merged_top_scores = np.sort(top_values_array, axis=1)[:,-top_per_query:]
    merged_top_indices = top_values_array.argsort(axis=1)[:,-top_per_query:]
    articles_indices = top_indices_array[np.arange(top_indices_array.shape[0])[:, np.newaxis], merged_top_indices]
    
    return articles_indices, merged_top_scores


def prepare_answering_input(
        tokenizer, 
        question,  
        options,   
        context,   
        max_seq_length=4096,
    ):
    c_plus_q   = context + ' ' + tokenizer.bos_token + ' ' + question
    c_plus_q_4 = [c_plus_q] * len(options)
    tokenized_examples = tokenizer(
        c_plus_q_4, options,
        max_length=max_seq_length,
        padding="longest",
        truncation=False,
        return_tensors="pt",
    )
    input_ids = tokenized_examples['input_ids'].unsqueeze(0)
    attention_mask = tokenized_examples['attention_mask'].unsqueeze(0)
    example_encoded = {
        "input_ids": input_ids.to(model.device.index),
        "attention_mask": attention_mask.to(model.device.index),
    }
    return example_encoded


stop_words = ['each', 'you', 'the', 'use', 'used',
                  'where', 'themselves', 'nor', "it's", 'how', "don't", 'just', 'your',
                  'about', 'himself', 'with', "weren't", 'hers', "wouldn't", 'more', 'its', 'were',
                  'his', 'their', 'then', 'been', 'myself', 're', 'not',
                  'ours', 'will', 'needn', 'which', 'here', 'hadn', 'it', 'our', 'there', 'than',
                  'most', "couldn't", 'both', 'some', 'for', 'up', 'couldn', "that'll",
                  "she's", 'over', 'this', 'now', 'until', 'these', 'few', 'haven',
                  'of', 'wouldn', 'into', 'too', 'to', 'very', 'shan', 'before', 'the', 'they',
                  'between', "doesn't", 'are', 'was', 'out', 'we', 'me',
                  'after', 'has', "isn't", 'have', 'such', 'should', 'yourselves', 'or', 'during', 'herself',
                  'doing', 'in', "shouldn't", "won't", 'when', 'do', 'through', 'she',
                  'having', 'him', "haven't", 'against', 'itself', 'that',
                  'did', 'theirs', 'can', 'those',
                  'own', 'so', 'and', 'who', "you've", 'yourself', 'her', 'he', 'only',
                  'what', 'ourselves', 'again', 'had', "you'd", 'is', 'other',
                  'why', 'while', 'from', 'them', 'if', 'above', 'does', 'whom',
                  'yours', 'but', 'being', "wasn't", 'be']

In [10]:
test_name = 'test'
test_data = test_sets[test_name].to_pandas()
test_wiki_tfidf_parsed = get_relevant_documents_parsed(test_data)
train_wiki_ifidf_parsed = get_relevant_documents_parsed(train.to_pandas())
# test_wiki_tfidf = get_relevant_documents(test_data)

  0%|          | 0/2 [00:00<?, ?it/s]



length of vectorizer vocab is 6286


  0%|          | 0/21013 [00:00<?, ?it/s]

length of vectorizer vocab is 5933


  0%|          | 0/21013 [00:00<?, ?it/s]

  0%|          | 0/545 [00:00<?, ?it/s]



length of vectorizer vocab is 6118


  0%|          | 0/21013 [00:00<?, ?it/s]

length of vectorizer vocab is 6245


  0%|          | 0/21013 [00:00<?, ?it/s]

length of vectorizer vocab is 5971


  0%|          | 0/21013 [00:00<?, ?it/s]

length of vectorizer vocab is 6857


  0%|          | 0/21013 [00:00<?, ?it/s]

length of vectorizer vocab is 6528


  0%|          | 0/21013 [00:00<?, ?it/s]

length of vectorizer vocab is 5952


  0%|          | 0/21013 [00:00<?, ?it/s]



length of vectorizer vocab is 5928


  0%|          | 0/21013 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
def preprocess_articles(example):
    articles = example['retrieved_articles_parsed']
    first_sentence = [f"""[CLS] {articles[-1]} [SEP] {articles[-2]} [SEP] {articles[-3]}"""] * 5
    second_sentences = []
    for option in 'ABCDE':
        answer = example[option]
        if answer is not None:
            second_sentences.append(" #### " + example["question"] + " [SEP] " + example[option] + " [SEP]")
        else:
            second_sentences.append('N/A')
    try:
        tokenized_example = tokenizer(first_sentence, second_sentences, truncation='longest_first', max_length=max_length)
    except:
        print(first_sentence, second_sentences)
        raise
    if 'answer' in example.keys():
        tokenized_example['label'] = option_to_index[example['answer']]
    
    return tokenized_example

def tokenized_articles(data):
    columns_to_keep = set(['input_ids', 'token_type_ids', 'attention_mask', 'label'])
    dataset = Dataset.from_pandas(data.to_pandas(), preserve_index=False)
    col_to_remove = set(dataset.map(preprocess_articles).features.keys()) - columns_to_keep
    tokenized = dataset.map(preprocess_articles, remove_columns=col_to_remove)
    return tokenized

In [None]:
a

In [None]:
import torch

from transformers import AutoTokenizer, AutoModelForMultipleChoice, Trainer, TrainingArguments, IntervalStrategy, get_linear_schedule_with_warmup
from transformers.tokenization_utils_base import PreTrainedTokenizerBase, PaddingStrategy
from datasets import Dataset # HuggingFace
from torch.optim import AdamW

In [None]:
fine_tuned_path = '/kaggle/input/deberta-for-colbert'
tokenizer = AutoTokenizer.from_pretrained(fine_tuned_path)

In [None]:
max_length = 450

option_to_index = {option: idx for idx, option in enumerate('ABCDE')}
index_to_option = {v: k for k,v in option_to_index.items()}

@dataclass
class DataCollatorForMultipleChoice:
    tokenizer: PreTrainedTokenizerBase
    padding: Union[bool, str, PaddingStrategy] = True
    max_length: Optional[int] = None

    
    def __call__(self, input_batch):
        # input_batch is list of samples, choices, tokens
        additional_cols = set(input_batch[0].keys()) - set(['input_ids', 'token_type_ids', 'attention_mask', 'label'])
        if len(additional_cols) > 0:
            print(f'{additional_cols=}')
        
        if 'label' in input_batch[0].keys():
            label_name = 'label' 
            labels = [feature.pop(label_name) for feature in input_batch]
        batch_size = len(input_batch)
        num_choices = len(input_batch[0]['input_ids'])
        flattened_input = [
            [{k: v[i] for k, v in sample.items()} for i in range(num_choices)] for sample in input_batch
        ]
        flattened_input = sum(flattened_input, [])
        
        batch = self.tokenizer.pad(
            flattened_input,
            padding=self.padding,
            max_length=self.max_length,
            return_tensors='pt',
        )
        
        # batch.shape = (n_samples, n_choices, n_tokens)
        batch = {k: v.view(batch_size, num_choices, -1) for k, v in batch.items()}
        if 'label' in input_batch[0].keys():
            batch['labels'] = torch.tensor(labels, dtype=torch.int64)
        #print(np.array(batch['input_ids']).shape)
        return batch

In [None]:
tokenized_train_articles = tokenized_articles(train)
tokenized_test_articles = tokenized_articles(test)

## Train DeBERTa

In [None]:
def precision_at_k(predictions, actuals, k=3):        
    if isinstance(actuals, list):
        actuals = np.array(actuals)
        
    found_at = np.where(predictions == actuals.reshape(-1, 1))
    # found_at is a tuple with the array of found indices in the second position
    score = 1 / (1 + found_at[1])
    score[score < 1/k] = 0
    return score

def mean_avg_precision_at_k(predictions, actual, k=3):
    n = predictions.shape[0]
    row_precision = precision_at_k(predictions, actual)
    return row_precision.sum()/n

acc_metric = evaluate.load("accuracy")
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.flip(predictions.argsort(axis=1), axis=1)
    accuracy = acc_metric.compute(predictions=predictions[:,0], references=labels)['accuracy']
    map_at_3 = mean_avg_precision_at_k(predictions, labels)
    return {
        'accuracy': accuracy,
        'map_at_3': round(map_at_3, 3)
    }

In [None]:
retrain = True

if not output_path.exists() or retrain:
    wandb.init(
        project=project_name,
        name=run_name,
        job_type='train',
        config=config
        # group="bert"
    )
       
    torch.cuda.empty_cache()
    model = AutoModelForMultipleChoice.from_pretrained(config['hf_model_id'])

    total_steps = len(tokenized_train) * training_args.num_train_epochs // (training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps) 
    
    if config['8_bit_adam']:
        decay_parameters = get_parameter_names(model, [nn.LayerNorm])
        decay_parameters = [name for name in decay_parameters if "bias" not in name]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in model.named_parameters() if n in decay_parameters],
                "weight_decay": training_args.weight_decay,
            },
            {
                "params": [p for n, p in model.named_parameters() if n not in decay_parameters],
                "weight_decay": 0.0,
            },
        ]
        optimizer_kwargs = {
            "betas": (training_args.adam_beta1, training_args.adam_beta2),
            "eps": training_args.adam_epsilon,
        }
        optimizer_kwargs["lr"] = training_args.learning_rate
        optimizer = bnb.optim.Adam8bit(
            optimizer_grouped_parameters,
            betas=(training_args.adam_beta1, training_args.adam_beta2),
            eps=training_args.adam_epsilon,
            lr=training_args.learning_rate,
        )
        max_lr=training_args.learning_rate
    elif type(config['lr_layer_factor']) == int:
        factor = config['lr_layer_factor']
        base_lr = training_args.learning_rate
        
        embedding_lr = base_lr / factor**4
        early_layers_lr = base_lr / factor**3
        middle_layers_lr = base_lr / factor**2
        late_layers_lr = base_lr / factor
        classifier_lr = base_lr

        optimizer_grouped_parameters = [
            {'params': model.deberta.embeddings.parameters()},
            {'params': model.deberta.encoder.layer[:8].parameters()},
            {'params': model.deberta.encoder.layer[8:16].parameters()},
            {'params': model.deberta.encoder.layer[16:].parameters()},
            {'params': model.classifier.parameters()},
        ]
        optimizer = AdamW(optimizer_grouped_parameters,
                          weight_decay=training_args.weight_decay)
        max_lr = [base_lr / config['lr_layer_factor']**i for i in range(5,0,-1)]
        
    else:
        optimizer = AdamW(model.parameters(),
                          lr=training_args.learning_rate,
                          weight_decay=training_args.weight_decay)
        max_lr=training_args.learning_rate
    
    scheduler = OneCycleLR(optimizer, max_lr=max_lr, total_steps=total_steps)
      
    #warmup_steps = int(total_steps * training_args.warmup_ratio)
    #scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)
    
    trainer = Trainer(
        model=model,
        args=training_args,
        tokenizer=tokenizer,
        data_collator=DataCollatorForMultipleChoice(tokenizer=tokenizer, max_length=600),
        train_dataset=tokenized_train,
        eval_dataset=eval_datasets,
        compute_metrics=compute_metrics,
        optimizers=(optimizer, scheduler)
    )

    # needed when there are multiple eval datasets
    trainer.remove_callback(NotebookProgressCallback)
    trainer.train()
    # wandb.config.update(config)
    wandb.finish()
    trainer.save_model(output_path/run_name)
else:
    model = AutoModelForMultipleChoice.from_pretrained(output_path/run_name)

## Predict with DeBERTa

In [None]:
model = AutoModelForMultipleChoice.from_pretrained(fine_tuned_path)
trainer = Trainer(
    model=model,
    data_collator=DataCollatorForMultipleChoice(tokenizer=tokenizer, max_length=600),
)

In [None]:
test_logits = trainer.predict(tokenized_test).predictions

In [None]:
test_article_logits = trainer.predict(tokenized_tokenized_articles).predictions