In [16]:
import sqlite3
import numpy as np
from tqdm import tqdm
from scipy.spatial.distance import cosine
from angle_emb import AnglE
import os
import configparser
config = configparser.ConfigParser()
config.read('config.ini')
OPENAI_API_KEY = config['API_KEYS']['OPENAI_API_KEY']

angle = AnglE.from_pretrained('WhereIsAI/UAE-Large-V1', pooling_strategy='cls')

vector_db_name = "raw/full_card_vector_database.db"

In [18]:
from openai import OpenAI
def semantic_search(query, vector_db_name, number_chunks = 5):
    conn = sqlite3.connect(vector_db_name)
    c = conn.cursor()

    query_embedding = angle.encode(query, to_numpy=True).flatten()

    c.execute("SELECT id, name, card_text, vector FROM vectordb")
    rows = c.fetchall()

    similarities = []
    for row in rows :
        id_, name, card_text, vector_bytes = row
        stored_embedding = np.frombuffer(vector_bytes, dtype=np.float32).flatten()
        sim = 1 - cosine(query_embedding, stored_embedding)
        similarities.append((id_, card_text, sim))

    similarities.sort(key=lambda x: x[2], reverse=True)  # Sort by similarity scores in descending order

    top_matches = similarities[:number_chunks]

    conn.close()

    return [(match[1], match[2]) for match in top_matches]

def rag_query(query, RAG=True):
    prompt_rag = ""
    if RAG:
        chunks = semantic_search(query, vector_db_name, 3)

        prompt_rag = ""
        for chunk in chunks:
            prompt_rag += chunk[0]+"\n\n"

        prompt = f"""[INST]
        Given the following card data, provide me with the exact text of {query} in the format of :
        \nname: \nmana_cost: \ncmc: \ntype_line: \noracle_text: \npower: \ntoughness: \ncolors: \ncolor_identity: \nkeywords:

        \n{prompt_rag}

        Use only the data in the provided chunks above. [/INST]
        """
    else:
        prompt = f"""[INST]
        Provide me with the exact text of {query} in the format of :
        \nname: \nmana_cost: \ncmc: \ntype_line: \noracle_text: \npower: \ntoughness: \ncolors: \ncolor_identity: \nkeywords: [/INST]
        """

    
    client = OpenAI(api_key = OPENAI_API_KEY)

    response = client.chat.completions.create(
      model="gpt-3.5-turbo",
      messages=[
        {"role": "user", "content": prompt},
      ]
    )
    return response.choices[0].message.content

In [28]:
query = "Braids, Arisen Nightmare"
RAG = True

print(prompt)

[INST]
    Given the following card data, provide me with the exact text of Braids, Arisen Nightmare in the format of :
    
name: 
mana_cost: 
cmc: 
type_line: 
oracle_text: 
power: 
toughness: 
colors: 
color_identity: 
keywords:

    
name: Wrenn and Realmbreaker
mana_cost: {1}{G}{G}
cmc: 3.0
type_line: Legendary Planeswalker — Wrenn
oracle_text: Lands you control have "{T}: Add one mana of any color."
+1: Up to one target land you control becomes a 3/3 Elemental creature with vigilance, hexproof, and haste until your next turn. It's still a land.
−2: Mill three cards. You may put a permanent card from among the milled cards into your hand.
−7: You get an emblem with "You may play lands and cast permanent spells from your graveyard."
power: nan
toughness: nan
colors: ['G']
color_identity: ['G']
keywords: ['Mill']

name: Braids's Frightful Return
mana_cost: {2}{B}
cmc: 3.0
type_line: Enchantment — Saga
oracle_text: Read ahead (Choose a chapter and start with that many lore counters. 