In [1]:
!pip install beir

Collecting beir
  Downloading beir-0.2.3.tar.gz (52 kB)
[K     |████████████████████████████████| 52 kB 1.6 MB/s 
[?25hCollecting sentence-transformers
  Downloading sentence-transformers-2.1.0.tar.gz (78 kB)
[K     |████████████████████████████████| 78 kB 5.7 MB/s 
[?25hCollecting pytrec_eval
  Downloading pytrec_eval-0.5.tar.gz (15 kB)
Collecting faiss_cpu
  Downloading faiss_cpu-1.7.1.post2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (8.4 MB)
[K     |████████████████████████████████| 8.4 MB 23.1 MB/s 
[?25hCollecting elasticsearch
  Downloading elasticsearch-7.15.2-py2.py3-none-any.whl (379 kB)
[K     |████████████████████████████████| 379 kB 78.4 MB/s 
Collecting tensorflow-text
  Downloading tensorflow_text-2.7.3-cp37-cp37m-manylinux2010_x86_64.whl (4.9 MB)
[K     |████████████████████████████████| 4.9 MB 58.0 MB/s 
Collecting transformers<5.0.0,>=4.6.0
  Downloading transformers-4.12.5-py3-none-any.whl (3.1 MB)
[K     |████████████████████████████████| 3.1 

In [2]:
!pip install sentence_transformers



In [3]:
!pip install datasets

Collecting datasets
  Downloading datasets-1.15.1-py3-none-any.whl (290 kB)
[K     |████████████████████████████████| 290 kB 7.3 MB/s 
Collecting fsspec[http]>=2021.05.0
  Downloading fsspec-2021.11.0-py3-none-any.whl (132 kB)
[K     |████████████████████████████████| 132 kB 87.6 MB/s 
Collecting aiohttp
  Downloading aiohttp-3.8.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.1 MB)
[K     |████████████████████████████████| 1.1 MB 66.5 MB/s 
Collecting xxhash
  Downloading xxhash-2.0.2-cp37-cp37m-manylinux2010_x86_64.whl (243 kB)
[K     |████████████████████████████████| 243 kB 90.5 MB/s 
Collecting yarl<2.0,>=1.0
  Downloading yarl-1.7.2-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (271 kB)
[K     |████████████████████████████████| 271 kB 80.1 MB/s 
[?25hCollecting frozenlist>=1.1.1
  Downloading frozenlist-1.2.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_

In [4]:
!pip install hnswlib

Collecting hnswlib
  Downloading hnswlib-0.5.2.tar.gz (29 kB)
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
Building wheels for collected packages: hnswlib
  Building wheel for hnswlib (PEP 517) ... [?25l[?25hdone
  Created wheel for hnswlib: filename=hnswlib-0.5.2-cp37-cp37m-linux_x86_64.whl size=1326164 sha256=634cd53460402250b9fcd01a06a75691ad0beef380915195e7afd72db849a4e8
  Stored in directory: /root/.cache/pip/wheels/b4/11/b3/337c4a361b31217d62c3b420ad66fe20d381f1ebb29b046095
Successfully built hnswlib
Installing collected packages: hnswlib
Successfully installed hnswlib-0.5.2


In [5]:
import pandas as pd
import numpy as np

from beir import util
from beir.datasets.data_loader import GenericDataLoader

from datasets import load_dataset, load_metric

from typing import List
from typing import Dict
from typing import Tuple

from sentence_transformers import SentenceTransformer
import sentence_transformers.util

import hnswlib
import pickle

from operator import itemgetter

import time

  from tqdm.autonotebook import tqdm


In [6]:
def create_the_whole_context_dataset(model_name):

    # Load squadv2
    squad_v2 = False
    train_dataset = load_dataset("squad_v2" if squad_v2 else "squad", split='train[:10%]')

    # Load DBpedia
    dataset = "dbpedia-entity"
    url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/dbpedia-entity.zip".format(dataset)
    data_path = util.download_and_unzip(url, "datasets")
    corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")

    # Create the list of all context
    list_context = []

    ## Add from squad_v2
    for elem in train_dataset:
        context = elem["context"]
        if (not context in list_context):
            list_context.append(context)

    ## Add from dbpedia
    corpus_list = list(corpus)
    random_element = np.random.choice(corpus_list, 10000 - len(list_context) + 1500)

    random_element = [corpus[elem]["text"] for elem in random_element]

    random_element = [elem for elem in random_element if len(elem) >= 50]

    random_element = np.unique(random_element)

    for elem in random_element:
        if (not elem in list_context):
            list_context.append(elem)


    # Create the list with formatedcontext
    model = SentenceTransformer(model_name)

    list_formated_context = model.encode(list_context, device='cuda', show_progress_bar=True)
    list_formated_context = [[list_context[i], list_formated_context[i], i] for i in range(len(list_formated_context))]

    return list_formated_context

In [7]:
def create_and_train_model(param1, param2, context_list):

    dim = len(context_list[0][1])
    num_elements = len(context_list)

    p = hnswlib.Index(space = 'ip', dim = dim)
    p.init_index(max_elements = num_elements, ef_construction = param1, M = param2)

    datas = [elem[1] for elem in context_list]
    indexs = [elem[2] for elem in context_list]

    p.add_items(datas, indexs)

    return p

In [8]:
def find_best_context(model_name, knn, question, n, formated_context):

    model = SentenceTransformer(model_name)
    formated_question = model.encode(question, device='cuda', show_progress_bar=False)

    labels, distances = knn.knn_query(formated_question, k = n)

    context_list = []

    for label in labels[0]:
      context_list.append(formated_context[label][0])

    return context_list