In [7]:
import configparser
from datasets import load_dataset
import torch
import cohere
import pandas as pd

config = configparser.ConfigParser()
config.read("./config.ini")
api_key = config.get('cohere', 'api_key')

# Add your cohere API key from www.cohere.com
co = cohere.Client(api_key)  

#Load at max 2000 documents + embeddings
max_docs = 2000
docs_stream = load_dataset(f"Cohere/wikipedia-22-12-simple-embeddings", split="train", streaming=True)

docs = []

for doc in docs_stream:
    docs.append(doc)
    # titles.append(doc['title'])
    # text.append(doc['text'])
    # doc_embeddings.append(doc['emb'])
    if len(docs) >= max_docs:
        break

# doc_embeddings = torch.tensor(doc_embeddings)


In [9]:
# turns docs array of json into a pandas dataframe. format:
"""
{'id': 0,
  'title': '24-hour clock',
  'text': 'The 24-hour clock is a way of telling the time in which the day runs from midnight to midnight and is divided into 24 hours, numbered from 0 to 23. It does not use a.m. or p.m. This system is also referred to (only in the US and the English speaking parts of Canada) as military time or (only in the United Kingdom and now very rarely) as continental time. In some parts of the world, it is called railway time. Also, the international standard notation of time (ISO 8601) is based on this format.',
  'url': 'https://simple.wikipedia.org/wiki?curid=9985',
  'wiki_id': 9985,
  'views': 2450.62548828125,
  'paragraph_id': 0,
  'langs': 30,
  'emb': [0.07711287587881088,
"""

df = pd.DataFrame(docs)
df

Unnamed: 0,id,title,text,url,wiki_id,views,paragraph_id,langs,emb
0,0,24-hour clock,The 24-hour clock is a way of telling the time...,https://simple.wikipedia.org/wiki?curid=9985,9985,2450.625488,0,30,"[0.07711287587881088, 0.3197174072265625, -0.2..."
1,1,24-hour clock,A time in the 24-hour clock is written in the ...,https://simple.wikipedia.org/wiki?curid=9985,9985,2450.625488,1,30,"[0.19612890481948853, 0.5142669677734375, 0.03..."
2,2,24-hour clock,"However, the US military prefers not to say 24...",https://simple.wikipedia.org/wiki?curid=9985,9985,2450.625488,2,30,"[0.1391918957233429, 0.17759686708450317, -0.1..."
3,3,24-hour clock,"24-hour clock time is used in computers, milit...",https://simple.wikipedia.org/wiki?curid=9985,9985,2450.625488,3,30,"[0.1279686838388443, 0.06708071380853653, -0.0..."
4,4,24-hour clock,"In railway timetables 24:00 means the ""end"" of...",https://simple.wikipedia.org/wiki?curid=9985,9985,2450.625488,4,30,"[0.0753360167145729, 0.3530837893486023, -0.08..."
...,...,...,...,...,...,...,...,...,...
1995,1995,Christmas,"Many towns hold Christmas parades, street ente...",https://simple.wikipedia.org/wiki?curid=3317,3317,745.546997,33,212,"[-0.026805002242326736, 0.2564919888973236, -0..."
1996,1996,Christmas,A traditional part of Christmas is the theatre...,https://simple.wikipedia.org/wiki?curid=3317,3317,745.546997,34,212,"[0.1477964222431183, 0.26868319511413574, -0.2..."
1997,1997,Christmas,"Because many people feel very lonely, hungry a...",https://simple.wikipedia.org/wiki?curid=3317,3317,745.546997,35,212,"[0.338507741689682, 0.043504685163497925, -0.1..."
1998,1998,Christmas,Family celebrations are often very different f...,https://simple.wikipedia.org/wiki?curid=3317,3317,745.546997,36,212,"[0.09053482860326767, 0.0269009992480278, -0.1..."


In [None]:

# Get the query, then embed it
query = 'Who founded Wikipedia'
response = co.embed(texts=[query], model='multilingual-22-12')
query_embedding = response.embeddings 
query_embedding = torch.tensor(query_embedding)

# Compute dot score between query embedding and document embeddings
dot_scores = torch.mm(query_embedding, doc_embeddings.transpose(0, 1))
top_k = torch.topk(dot_scores, k=3)

# Print results
print("Query:", query)
for doc_id in top_k.indices[0].tolist():
    print(docs[doc_id]['title'])
    print(docs[doc_id]['text'], "\n")