In [1]:
!pip install -U cohere astrapy datasets python-dotenv

Defaulting to user installation because normal site-packages is not writeable


In [2]:
import cohere
import os
from dotenv import load_dotenv
from astrapy.db import AstraDB, AstraDBCollection
from astrapy.ops import AstraDBOps
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
load_dotenv()

token = os.getenv("ASTRA_DB_APPLICATION_TOKEN")
api_endpoint = os.getenv("ASTRA_DB_API_ENDPOINT")
cohere_key = os.getenv("COHERE_API_KEY")

astra_db = AstraDB(token=token, api_endpoint=api_endpoint)
cohere = cohere.Client(cohere_key)

### Embedding Models
| Model Name                    | Embedding Dimensions |
|-------------------------------|----------------------|
| embed-english-v3.0            | 1024                 |
| embed-multilingual-v3.0       | 1024                 |
| embed-english-light-v3.0      | 384                  |
| embed-multilingual-light-v3.0 | 384                  |
| embed-english-v2.0            | 4096                 |
| embed-english-light-v2.0      | 1024                 |
| embed-multilingual-v2.0       | 768                  |

In [4]:
astra_db.create_collection(collection_name="cohere", dimension=1024)
collection = AstraDBCollection(
    collection_name="cohere", astra_db=astra_db
)

In [5]:
squad = load_dataset('squad', split='train[:2000]')

In [7]:
squad["question"][0:20]


['To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?',
 'What is in front of the Notre Dame Main Building?',
 'The Basilica of the Sacred heart at Notre Dame is beside to which structure?',
 'What is the Grotto at Notre Dame?',
 'What sits on top of the Main Building at Notre Dame?',
 'When did the Scholastic Magazine of Notre dame begin publishing?',
 "How often is Notre Dame's the Juggler published?",
 'What is the daily student paper at Notre Dame called?',
 'How many student news papers are found at Notre Dame?',
 'In what year did the student paper Common Sense begin publication at Notre Dame?',
 'Where is the headquarters of the Congregation of the Holy Cross?',
 'What is the primary seminary of the Congregation of the Holy Cross?',
 'What is the oldest structure at Notre Dame?',
 'What individuals live at Fatima House at Notre Dame?',
 'Which prize did Frederick Buechner create?',
 'How many BS level degrees are offered in the College of Engineering at Notre

In [8]:
embeddings = cohere.embed(
    texts=squad["question"],
    model="embed-english-v3.0",
    input_type='search_document',
    truncate='END'
    ).embeddings

len(embeddings[0])

1024

In [9]:
to_insert = []
for i in range(len(squad)):
    to_insert.append({**squad[i], "$vector":embeddings[i]})


To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?
What is in front of the Notre Dame Main Building?
The Basilica of the Sacred heart at Notre Dame is beside to which structure?
What is the Grotto at Notre Dame?
What sits on top of the Main Building at Notre Dame?
When did the Scholastic Magazine of Notre dame begin publishing?
How often is Notre Dame's the Juggler published?
What is the daily student paper at Notre Dame called?
How many student news papers are found at Notre Dame?
In what year did the student paper Common Sense begin publication at Notre Dame?
Where is the headquarters of the Congregation of the Holy Cross?
What is the primary seminary of the Congregation of the Holy Cross?
What is the oldest structure at Notre Dame?
What individuals live at Fatima House at Notre Dame?
Which prize did Frederick Buechner create?
How many BS level degrees are offered in the College of Engineering at Notre Dame?
In what year was the College of Engineering at Notre Da

In [10]:
batch_size = 20
i=0
while i<(len(to_insert)):
    res = collection.insert_many(documents=to_insert[i:(i+batch_size)])
    print(i)
    i=i+batch_size

0
20
40
60
80
100
120
140
160
180
200
220
240
260
280
300
320
340
360
380
400
420
440
460
480
500
520
540
560
580
600
620
640
660
680
700
720
740
760
780
800
820
840
860
880
900
920
940
960
980
1000
1020
1040
1060
1080
1100
1120
1140
1160
1180
1200
1220
1240
1260
1280
1300
1320
1340
1360
1380
1400
1420
1440
1460
1480
1500
1520
1540
1560
1580
1600
1620
1640
1660
1680
1700
1720
1740
1760
1780
1800
1820
1840
1860
1880
1900
1920
1940
1960
1980


In [11]:
user_query = "What's in front of Notre Dame?"
embedded_query = cohere.embed(
    texts=[user_query],
    model="embed-english-v3.0",
    input_type='search_query',
    truncate='END'
).embeddings[0]

In [12]:
results = collection.vector_find(embedded_query, limit=50)

In [13]:
print(f"Query: {user_query}")
print("Answers:")
for idx, answer in enumerate(results):
    print(f"\t Answer {idx}: {answer['answers']['text']}, Score: {answer['$similarity']}")

Query: What's in front of Notre Dame?
Answers:
	 Answer 0: ['a copper statue of Christ'], Score: 0.88295245
	 Answer 1: ['a golden statue of the Virgin Mary'], Score: 0.7785491
	 Answer 2: ['the City of South Bend'], Score: 0.7664312
	 Answer 3: ['the Main Building'], Score: 0.7474512
	 Answer 4: ['Old College'], Score: 0.730406
	 Answer 5: ['Basilica of the Sacred Heart'], Score: 0.726831
	 Answer 6: ['a Marian place of prayer and reflection'], Score: 0.7251403
	 Answer 7: ['the Virgin Mary'], Score: 0.72113997
	 Answer 8: ['Theodore M. Hesburgh Library'], Score: 0.7193226
	 Answer 9: ['Frank Eck Stadium'], Score: 0.71513194
	 Answer 10: ['two-story banner'], Score: 0.7143492
	 Answer 11: ['an early wind tunnel'], Score: 0.71360976
	 Answer 12: ['oldest university band in continuous existence in the United States'], Score: 0.712774
	 Answer 13: ['the first floor of Stanford Hall'], Score: 0.71262664
	 Answer 14: ['onward to victory'], Score: 0.71261424
	 Answer 15: ['Catholic research