### Instructions to Run:

1. Upload sample_input_paragraph, question_answers, sample_input_question and sample_theme_interval csv.
2. Execute helper functions cells for installing relevant libraries and importing them.
3. Run the remaining cells as mentioned in the Sample_Eval notebook.

## Helper functions

### Load packages and import libraries

In [None]:
# Install packages
!pip install --upgrade --no-cache-dir gdown
!pip install -U sentence-transformers
!pip install -U faiss-cpu
!pip install transformers sentencepiece
!pip install optimum[onnxruntime]

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting gdown
  Downloading gdown-4.6.0-py3-none-any.whl (14 kB)
Installing collected packages: gdown
  Attempting uninstall: gdown
    Found existing installation: gdown 4.4.0
    Uninstalling gdown-4.4.0:
      Successfully uninstalled gdown-4.4.0
Successfully installed gdown-4.6.0
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting sentence-transformers
  Downloading sentence-transformers-2.2.2.tar.gz (85 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.0/86.0 KB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting transformers<5.0.0,>=4.6.0
  Downloading transformers-4.26.0-py3-none-any.whl (6.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.3/6.3 MB[0m [31m56.2 MB/s[0m eta [36m0:00:00[0m
Collecting sentencepiece
  D

In [None]:
# Import libraries
import gdown
import nltk
import faiss
import json
import os
import time
import numpy as np
import pandas as pd
import collections
import json
import re
import string
import timeit
import tarfile
import os

from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
from optimum.onnxruntime import ORTModelForQuestionAnswering, ORTOptimizer
from optimum.onnxruntime.configuration import OptimizationConfig
from optimum.pipelines import pipeline
from tqdm import tqdm
from ast import literal_eval
from zipfile import ZipFile

nltk.download('punkt')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


True

### Sentence Encoder

For a given theme, break its paragraphs into sentences and store their paragraph id. Load sentence encoder and calculate embeddings for the sentences from paragraphs and the queries.

In [None]:
def para_to_sentences(para):
    """Splits a paragraph into sentences."""
    para = para.replace('\n', ' ').replace('\t', ' ').replace('\x00', ' ')
    return nltk.sent_tokenize(para)

def load_sents_from_para(paras):
    """Splits a list of paragraphs into sentences and returns the sentences
    and their corresponding paragraph id"""
    sents = []
    para_id = []
    for i,p in enumerate(paras):
        new_sents = para_to_sentences(p['paragraph'])
        sents += new_sents
        para_id += [p['id']]*len(new_sents)
    return sents, para_id

In [None]:
def load_encoder():
    """Load mpnet-base-v2 Sentence Encoder"""
    # model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')
    gdown.download(
        "https://drive.google.com/file/d/137tZvp-iTMR2xIogasglSH4jTTLW4_Sf/view",
        fuzzy=True, use_cookies=False, quiet=True
    )
    with ZipFile('/content/finetuned_mpnet_triplet.zip') as zobj:
        zobj.extractall()
    model = SentenceTransformer('/content/kaggle/working/finetuned_mpnet_triplet')
    return model

def get_embeddings(sents, model):
    """Generates embeddings for each sentence in the list of 768 dimesions"""
    return model.encode(sents)

### Nearest Neighbour Search using FAISS

Based on the embeddings calculated, indexes them based on L2 distance and then applies nearest neighbour search to get top k closest sentences for each query

In [None]:
def save_index(source_embeds, output_path):
    """Creates and saves the faiss L2 Index using source_embeds"""
    index = faiss.IndexFlatL2(source_embeds.shape[1])
    index.add(np.array(source_embeds))
    faiss.write_index(index, output_path)

def load_index(path):
    """Loads faiss index from the disk"""
    index = faiss.read_index(path)
    return index

def get_k_nearest_neighbours(index, query_embeds, k = 10):
    """Returns k nearest neighbours of target_embeds in source_embeds"""
    return index.search(np.array(query_embeds), k)

def get_nearest_queries(ques_embed, theme):
    """Retrieve nearest already answered queries to the questions"""
    index = load_index(f'/content/indices/{theme}_ques_l2_index')
    return get_k_nearest_neighbours(index, ques_embed, 3)

def get_nearest_sentences(ques_embed, theme):
    """Retrieve nearest sentences to the questions"""
    index = load_index(f'/content/indices/{theme}_para_l2_index')
    return get_k_nearest_neighbours(index, ques_embed, k)

### Load Existing QA and paragraphs data

Load validation data for testing, based on missing data in the training data from squad 2.0 dataset. Round 1 data contains themes that are not present in training data. While, round 2 data contains themes that are present in training data.

In [None]:
def load_existing_data():
    """Load already answered questions and paragraphs, theme-wise.
    Also breaks the paragraphs into sentences"""
    paras, solved_ques = {}, {}
    paragraphs = json.loads(pd.read_csv("sample_input_paragraph.csv").to_json(orient="records"))
    questions = json.loads(pd.read_csv("question_answers.csv").to_json(orient="records"))
    theme_intervals = json.loads(pd.read_csv("sample_theme_interval.csv").to_json(orient="records"))
    
    for theme_interval in theme_intervals:
        theme = theme_interval["theme"]
        theme_paras = [p for p in paragraphs if p["theme"] == theme]
        sents, para_id = load_sents_from_para(theme_paras)
        paras[theme] = {
            'id': para_id,
            'sentences': sents
        }
    
    for i, ques in enumerate(questions):
        theme = ques['theme']
        if theme not in solved_ques:
            solved_ques[theme] = {
                'id': [],
                'question': [],
                'paragraph_id': [],
                'answers': []
            }
        solved_ques[theme]['id'].append(i)
        solved_ques[theme]['question'].append(ques["question"])
        solved_ques[theme]['paragraph_id'].append(ques["paragraph_id"])
        solved_ques[theme]['answers'].append(ques["answer"])
    return paras, solved_ques


def store_faiss_indices(paras, solved_ques, encoder):
    """Generates embeddings for paragraph sentences and queries. Then it creates
    and saves the faiss index using them into disk"""
    if not os.path.exists('/content/indices/'):
        os.mkdir('/content/indices/')
    for theme in paras:
        theme_paras = paras[theme]
        
        output_path = f'/content/indices/{theme}_para_l2_index'
        if not os.path.exists(output_path):
            para_embeds = get_embeddings(theme_paras['sentences'], encoder)
            save_index(para_embeds, output_path)
        
        output_path = f'/content/indices/{theme}_ques_l2_index'    
        if theme in solved_ques and not os.path.exists(output_path):
            theme_ques = solved_ques[theme]
            ques_embeds = get_embeddings(theme_ques['question'], encoder)
            save_index(ques_embeds, output_path)

### Search previously answered queries

In [None]:
def search_previously_answered_queries(q_id, dist, query_idx, solved_queries):
    """Search previously answered queries and return its answer if it exists"""
    if dist > query_threshold:
        return False, None
    ans = {
        "question_id": q_id,
        "answers": solved_queries['answers'][query_idx],
        "paragraph_id": solved_queries['paragraph_id'][query_idx]
    }
    return True, ans

### Context Generation

Generates a context for a given query and its nearest neighbours. Also provides a method to get the paragraph id given the start idx of the answer.

In [None]:
def get_context(sents, para_ids, nearest_neighbours, distances):
    """Generate the context for a given query and store the para_id for
    each sentence"""
    context = ""
    context_para_ids, sent_length = [], []
    for sent_id, dist in zip(nearest_neighbours, distances):
        if dist > distance_threshold*distances[0]:
            break
        context += sents[sent_id] + ' '
        context_para_ids.append(para_ids[sent_id])
        sent_length.append(len(sents[sent_id]))
        if len(context.split()) >= context_length_threshold:
            break
    sum = -1
    for i in range(len(sent_length)):
        sum += sent_length[i] + 1
        sent_length[i] = sum
    return context.strip(), context_para_ids, sent_length


def para_id_retriever(start_idx, sent_length, context_para_ids):
    """Given start index of the answer, return the id of the paragraph
    in which the answer belongs"""
    if start_idx == -1:
        return -1
    for j in range(len(sent_length)):
        if start_idx <= sent_length[j]:
            return context_para_ids[j]
    return context_para_ids[-1]

### Load fine-tuned QA models

Given a theme, load the corresponding fine-tuned QA model and load the QA pipeline 

In [None]:
def download_fine_tuned_models():
    """Download and unzip cluster-wise fine-tuned QA models"""
    urls = [
        ("1-7XfPhjfmUo8xz0iqmFHusbZ74q-SS3A", "zipped_0_11.tar.gz"),
        ("1-BIhfqK992YZW1eWiG5yOCLiX8vOyrNI", "zipped_12_22.tar.gz"),
        ("1-B8b2_s9i2pwTn7EPgzMg50nNM6Dp4B-", "zipped_23_34.tar.gz"),
        ("1-KDxa6wWMGqrDR7ZJq-bYSyWaa_Zsikq", "zipped_35_42.tar.gz")
    ]
    for url, filename in urls:
        if not os.path.exists(filename):
            link = f"https://drive.google.com/u/1/uc?id={url}&export=download"
            gdown.download(link, quiet=True, use_cookies=False)
            with tarfile.open(filename, 'r') as tar:
                tar.extractall()
            # os.remove(filename)

def download_generic_model():
    """Download and optimize electra base model using onnx"""
    model_id = 'PremalMatalia/electra-base-best-squad2'
    save_path = "/content/models/generic_model/"
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    ort_model = ORTModelForQuestionAnswering.from_pretrained(
        model_id, from_transformers=True
    )
    optimizer = ORTOptimizer.from_pretrained(ort_model)
    optimization_config = OptimizationConfig(optimization_level=99)
    optimizer.optimize(save_dir=save_path, optimization_config=optimization_config)

In [None]:
def load_models_mapping():
    """Loads map for checking cluster of a theme and vice versa"""
    theme_to_cluster = {}
    cluster_to_themes = {}
    if not os.path.exists("clusters.json"):
        file_url = "https://drive.google.com/file/d/1P6dp7f2m67-iPaUbaNZiDYTmTH7Mw9ec/view?usp=share_link"
        gdown.download(url=file_url, output='clusters.json', quiet=False, fuzzy=True)
    with open('clusters.json') as fo:
        map = json.load(fo)
    for cluster, themes in map.items():
        cluster = int(cluster)
        if cluster not in cluster_to_themes:
            cluster_to_themes[cluster] = []
        for theme in themes:
            theme_to_cluster[theme] = cluster
            cluster_to_themes[cluster].append(theme)
    return theme_to_cluster, cluster_to_themes


def load_qa_model_pipeline(model_path):
    """Load QA model pipeline for a given cluster"""
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    for i in range(5):
        try:
            model = ORTModelForQuestionAnswering.from_pretrained(
                model_path, file_name="model_optimized.onnx"
            )
        except:
            continue
        else:
            break
    optimum_qa = pipeline(
        task = 'question-answering', model=model,
        tokenizer=tokenizer, handle_impossible_answer=True
    )
    return optimum_qa

## Execution

In [None]:
# We are not using fine-tuned QA models as fine-tuning on one cluster
# takes 3 hours on CPU. So, we could not train for all 43 clusters
# in under 12 hours.

In [None]:
# Download sentence encoder model and fine-tuned QA models
# download_fine_tuned_models()
download_generic_model()
sentence_encoder = load_encoder()
theme_to_cluster, cluster_to_themes = load_models_mapping()

# Load existing QA pairs for themes and pre-process it
paras, solved_ques = load_existing_data()
store_faiss_indices(paras, solved_ques, sentence_encoder)

In [None]:
# Parameters for context generation
k = 10
query_threshold = 0.2
distance_threshold = 2
context_length_threshold = 205

In [None]:
def get_theme_model(theme):
    """Load theme model if available, otherwise use generic model"""
    if theme in theme_to_cluster:
        cluster = theme_to_cluster[theme]
        model_path = f'/content/models/electra-base-best-squad2-finetuned-squad-{cluster}'
        if os.path.exists(model_path):
            return load_qa_model_pipeline(model_path)
    model_path = f'/content/models/generic_model'
    return load_qa_model_pipeline(model_path)

In [None]:
def pred_theme_ans(questions, theme_model, pred_out):
    ann_inference_time, qna_inference_time = 0., 0.
    theme = questions[0]["theme"]
    solved_queries_exists = False
    if theme in solved_ques:
        solved_queries = solved_ques[theme]
        solved_queries_exists = True
    print(f'Theme: {theme}')

    # Nearest Neighbour Search
    start_time = time.time()
    ques_list = [q['question'] for q in questions]
    ques_embed = get_embeddings(ques_list, sentence_encoder)
    if solved_queries_exists:
        D_ques, I_ques = get_nearest_queries(ques_embed, theme)
    D_sents, I_sents = get_nearest_sentences(ques_embed, theme)
    ann_inference_time = (time.time() - start_time)*1000.

    # QA Model Prediction
    start_time = time.time()
    for i in tqdm(range(len(questions))):
        q = questions[i]
        # Check previously answered queries
        if solved_queries_exists:
            found, ans = search_previously_answered_queries(
                q["id"], D_ques[i,0], I_ques[i,0], solved_queries
            )
            if found:
                pred_out.append(ans)
                continue

        # Context Generation
        context, context_para_ids, sent_length = get_context(
            paras[theme]['sentences'], paras[theme]['id'], I_sents[i], D_sents[i]
        )
        # Answer Prediction and Paragraph Retrieval
        prediction = theme_model(question=q['question'], context=context)
        ans = {
            "question_id": q['id'],
            "answers": prediction['answer'],
            "paragraph_id": -1
        }
        if prediction['answer'] != "":
            ans["paragraph_id"] = para_id_retriever(
                prediction['start'], sent_length, context_para_ids
            )
        pred_out.append(ans)

    # Print Inference Time
    qna_inference_time = (time.time() - start_time)*1000.
    print(
        f'Avg. ANN IT = {round(ann_inference_time/len(questions), 2)} ms, ' +
        f'Avg. QnA IT = {round(qna_inference_time/len(questions),2)} ms\n'
    )

In [None]:
# NOT allowed to make changes. 

# All theme prediction.
questions = json.loads(pd.read_csv("sample_input_question.csv").to_json(orient="records"))
theme_intervals = json.loads(pd.read_csv("sample_theme_interval.csv").to_json(orient="records"))
pred_out = []
theme_inf_time = {}
for theme_interval in theme_intervals:
    theme_ques = questions[int(theme_interval["start"]) - 1: int(theme_interval["end"])]
    theme = theme_ques[0]["theme"]
    # Load model fine-tuned for this theme.
    theme_model = get_theme_model(theme)
    execution_time = timeit.timeit(lambda: pred_theme_ans(theme_ques, theme_model, pred_out), number=1)
    theme_inf_time[theme_interval["theme"]] = execution_time * 1000 # in milliseconds.
pred_df = pd.DataFrame.from_records(pred_out)
pred_df.fillna(value='', inplace=True)
# Write prediction to a CSV file. Teams are required to submit this csv file.
pred_df.to_csv('output_prediction.csv', index=False)

Theme: IPod


100%|██████████| 222/222 [00:07<00:00, 28.98it/s]


Avg. ANN IT = 59.25 ms, Avg. QnA IT = 34.54 ms

Theme: 2008_Sichuan_earthquake


100%|██████████| 192/192 [01:08<00:00,  2.82it/s]


Avg. ANN IT = 45.06 ms, Avg. QnA IT = 354.49 ms

Theme: Wayback_Machine


100%|██████████| 81/81 [01:01<00:00,  1.32it/s]


Avg. ANN IT = 61.12 ms, Avg. QnA IT = 758.44 ms

Theme: Canadian_Armed_Forces


100%|██████████| 133/133 [01:23<00:00,  1.60it/s]


Avg. ANN IT = 52.69 ms, Avg. QnA IT = 626.6 ms

Theme: Cardinal_(Catholicism)


100%|██████████| 110/110 [01:24<00:00,  1.29it/s]


Avg. ANN IT = 54.72 ms, Avg. QnA IT = 772.69 ms

Theme: Human_Development_Index


100%|██████████| 57/57 [00:28<00:00,  2.02it/s]


Avg. ANN IT = 58.73 ms, Avg. QnA IT = 495.55 ms

Theme: Heresy


100%|██████████| 68/68 [00:52<00:00,  1.30it/s]


Avg. ANN IT = 53.53 ms, Avg. QnA IT = 768.17 ms

Theme: Warsaw_Pact


100%|██████████| 46/46 [00:37<00:00,  1.24it/s]


Avg. ANN IT = 53.72 ms, Avg. QnA IT = 807.13 ms

Theme: Materialism


100%|██████████| 68/68 [00:53<00:00,  1.26it/s]


Avg. ANN IT = 62.13 ms, Avg. QnA IT = 791.61 ms

Theme: Pub


100%|██████████| 102/102 [00:36<00:00,  2.81it/s]


Avg. ANN IT = 69.53 ms, Avg. QnA IT = 355.33 ms

Theme: Web_browser


100%|██████████| 66/66 [00:44<00:00,  1.48it/s]


Avg. ANN IT = 69.52 ms, Avg. QnA IT = 676.55 ms

Theme: Catalan_language


100%|██████████| 110/110 [00:48<00:00,  2.27it/s]


Avg. ANN IT = 42.21 ms, Avg. QnA IT = 440.4 ms

Theme: Paper


100%|██████████| 117/117 [01:09<00:00,  1.69it/s]


Avg. ANN IT = 63.02 ms, Avg. QnA IT = 590.93 ms

Theme: Adult_contemporary_music


100%|██████████| 73/73 [00:32<00:00,  2.25it/s]


Avg. ANN IT = 83.14 ms, Avg. QnA IT = 445.0 ms

Theme: Nanjing


100%|██████████| 206/206 [02:45<00:00,  1.24it/s]


Avg. ANN IT = 55.1 ms, Avg. QnA IT = 805.52 ms

Theme: Dialect


100%|██████████| 251/251 [03:19<00:00,  1.26it/s]


Avg. ANN IT = 66.96 ms, Avg. QnA IT = 793.93 ms

Theme: Southampton


100%|██████████| 318/318 [03:43<00:00,  1.43it/s]


Avg. ANN IT = 60.45 ms, Avg. QnA IT = 701.35 ms

Theme: The_Times


100%|██████████| 141/141 [01:52<00:00,  1.26it/s]


Avg. ANN IT = 80.39 ms, Avg. QnA IT = 796.83 ms

Theme: Immunology


100%|██████████| 61/61 [00:53<00:00,  1.13it/s]


Avg. ANN IT = 50.56 ms, Avg. QnA IT = 882.69 ms

Theme: Imamah_(Shia_doctrine)


100%|██████████| 48/48 [00:42<00:00,  1.12it/s]


Avg. ANN IT = 68.92 ms, Avg. QnA IT = 889.05 ms

Theme: Grape


100%|██████████| 35/35 [00:26<00:00,  1.32it/s]


Avg. ANN IT = 77.58 ms, Avg. QnA IT = 758.1 ms

Theme: United_States_dollar


100%|██████████| 235/235 [03:21<00:00,  1.16it/s]


Avg. ANN IT = 50.92 ms, Avg. QnA IT = 858.71 ms

Theme: Everton_F.C.


100%|██████████| 158/158 [01:57<00:00,  1.35it/s]


Avg. ANN IT = 62.92 ms, Avg. QnA IT = 741.49 ms

Theme: Hard_rock


100%|██████████| 178/178 [02:46<00:00,  1.07it/s]


Avg. ANN IT = 56.31 ms, Avg. QnA IT = 936.52 ms

Theme: Great_Plains


100%|██████████| 76/76 [00:56<00:00,  1.34it/s]


Avg. ANN IT = 50.74 ms, Avg. QnA IT = 748.93 ms

Theme: Biodiversity


100%|██████████| 194/194 [02:12<00:00,  1.47it/s]


Avg. ANN IT = 53.86 ms, Avg. QnA IT = 681.35 ms

Theme: Federal_Bureau_of_Investigation


100%|██████████| 304/304 [05:31<00:00,  1.09s/it]


Avg. ANN IT = 52.56 ms, Avg. QnA IT = 1091.24 ms

Theme: Mary_(mother_of_Jesus)


100%|██████████| 247/247 [04:24<00:00,  1.07s/it]


Avg. ANN IT = 57.6 ms, Avg. QnA IT = 1072.16 ms

Theme: Unknown


100%|██████████| 11/11 [00:02<00:00,  3.68it/s]


Avg. ANN IT = 63.11 ms, Avg. QnA IT = 272.57 ms

Theme: DevRev


100%|██████████| 9/9 [00:03<00:00,  2.36it/s]

Avg. ANN IT = 38.15 ms, Avg. QnA IT = 423.82 ms






In [None]:
total_inf_time = 0.0
total_queries = 0
for theme_interval in theme_intervals:
    num_queries = int(theme_interval["end"]) - int(theme_interval["start"]) + 1
    exec_time = theme_inf_time[theme_interval["theme"]]
    total_queries += num_queries
    total_inf_time += exec_time
print(f'Average Execution Time: {total_inf_time / total_queries} ms')

Average Execution Time: 775.6525714352819
