In [None]:
# Read jsonl file
import json

train = []
with open("train.jsonl", "r", encoding='utf-8') as f:
    for line in f:
        train.append(json.loads(line))

print(train[0]['question'], train[0]['points'], train[0]['article'])
print(len(train))


In [None]:
# Investigate proportion of testcases where answer can be found in the question
total = 0
found = 0
for set in train:
    total += 1
    if set['article'].lower() in set['question'].lower():
        found += 1

print(found,total,found/total)

In [None]:
# Not as many as I thought, but still a significant amount

# Perhaps the points are a good indicator of more trivial questions, where the answer is in the question?
# Investigate the distribution of points

from statistics import mean
import seaborn as sns
import matplotlib.pyplot as plt

trivial = []
nontrivial = []

for set in train:
    if set['article'].lower() in set['question'].lower():
        trivial.append(set['points'])
    else:
        nontrivial.append(set['points'])

all = trivial + nontrivial

print(mean(trivial), mean(nontrivial))
print(mean(all))

sns.displot(trivial)
sns.displot(nontrivial)
plt.show()

# There is a strong correlation, perhaps we can bias the model towards 
# retrieving the answer from the question if the points are low.

In [None]:
# I suspect most of the answers are nouns. Let's use NLP to check this.
!pip install spacy
!python -m spacy download en_core_web_sm


In [None]:
import spacy
nlp = spacy.load("en_core_web_sm")

is_noun = 0
not_noun = 0
not_noun_examples = []

docs = list(nlp.pipe([set['article'] for set in train]))
for doc in docs:
    if len(list(doc.noun_chunks)) >= 1:
        is_noun += 1
    else:
        not_noun += 1
        not_noun_examples.append(doc.text)

print(is_noun, not_noun, is_noun/(is_noun+not_noun))
print(not_noun_examples[:100])

In [None]:
# More than 85% of the articles are nouns, so we should prioritise nouns in our search.
# Many of the articles not classified as nouns are in fact nouns, many of them being years.

In [None]:
# Next, let's process the wikipedia dataset using parquet
!pip install pyarrow

In [None]:
import pyarrow.parquet as pq

wikipedia = pq.read_table('train-00000-of-00001.parquet').to_pandas()
wikipedia = wikipedia[:10000]
wikipedia = wikipedia[['text', 'title']]
print(wikipedia.tail())

In [None]:
# The best way to tackle this problem should be to use a vector database. Let's set up milvus for this.
# Milvus is being run in a docker container in the milvus folder.

In [None]:
# Connect to milvus server
# Credit to this tutorial by Stephen Collins for information on setting up milvus and text embedding
# https://dev.to/stephenc222/how-to-use-milvus-to-store-and-query-vector-embeddings-5hhl
from pymilvus import connections

def connect_to_milvus():
    try:
        connections.connect("default", host="localhost", port="19530")
        print("Connected to Milvus.")
    except Exception as e:
        print(f"Failed to connect to Milvus: {e}")
        raise

connect_to_milvus()

In [None]:
# Set up schema and create a collection
from pymilvus import FieldSchema, CollectionSchema, DataType, Collection

def create_collection(name, fields, description):
    schema = CollectionSchema(fields, description)
    collection = Collection(name, schema, consistency_level="Strong")
    return collection

def drop_collection(name):
    collection = Collection(name)
    collection.drop()
    
# Define fields for our collection
fields = [
    FieldSchema(name="pk", dtype=DataType.VARCHAR, is_primary=True, auto_id=False, max_length=100),
    FieldSchema(name="embeddings", dtype=DataType.FLOAT_VECTOR, dim=768),
    FieldSchema(name="title", dtype=DataType.VARCHAR, max_length=500),
]

drop_collection("wikipedia_simple")
collection = create_collection("wikipedia_simple", fields, "Text embeddings of the simple wikipedia dataset")

In [None]:
from embedding_util import generate_embeddings
# Generate embeddings for each article
for i, doc in enumerate(wikipedia['text']):
    embedding = generate_embeddings(doc)
    # Write into file
    with open("embeddings.txt", "a", encoding='utf-8') as f:
        f.write(f"{embedding}\n")
    print(f"{i}/{len(wikipedia)}")

In [None]:
# Read embeddings
with open("embeddings.txt", "r", encoding='utf-8') as f:
    embeddings = f.readlines()

embeddings = [[float(value) for value in embedding[1:-2].split(", ")] for embedding in embeddings]
print(embeddings[0])

In [None]:
# Write into milvus
entities = [
    [str(i) for i in range(len(wikipedia))],
    embeddings,
    [str(title) for title in wikipedia['title']],
]

insert_result = collection.insert(entities)
print(insert_result)

In [None]:
# Create index for embeddings
def create_index(collection, field_name, index_type, metric_type, params):
    index = {"index_type": index_type, "metric_type": metric_type, "params": params}
    collection.create_index(field_name, index)

create_index(collection, "embeddings", "IVF_FLAT", "L2", {"nlist": 128})

In [None]:
def search_and_query(collection, search_vectors, search_field, search_params):
    collection.load()
    result = collection.search(search_vectors, search_field, search_params, limit=3, output_fields=["title"])
    return result[0][0].entity.get("title")

# Test search
query = "how do living organisms in a natural environment respond to changes in weather or climate?"
query_vector = generate_embeddings(query)
search_and_query(collection, [query_vector], "embeddings", {"metric_type": "L2", "params": {"nprobe": 10}})

# Correctly returns "Environment"!

In [None]:
# Test the performance of our model
score = 0
totalScore = 0

for set in train:
    query = set['question']
    query_vector = generate_embeddings(query)
    result = search_and_query(collection, [query_vector], "embeddings", {"metric_type": "L2", "params": {"nprobe": 10}})
    if result.lower() is set['article'].lower():
        score += set['points']
    totalScore += set['points']

print(f"Our model scored {score}/{totalScore} points on the training set.")