In [1]:
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import sqlite3
from transformers import AutoTokenizer
access_token = "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"

def semantic_search(query, database='raw/tokenized_data_llama2.db'):
    conn = sqlite3.connect(database)
    cursor = conn.cursor()

    # Retrieve all chunk texts from the database
    cursor.execute("SELECT chunk_text FROM tokenized_chunks")
    rows = cursor.fetchall()
    chunk_texts = [row[0] for row in rows]

    conn.close()

    # Initialize tokenizer
    tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf", token = access_token)


    # Vectorize query and chunk texts using TF-IDF
    vectorizer = TfidfVectorizer()
    tfidf_matrix = vectorizer.fit_transform([query] + chunk_texts)

    cosine_similarities = cosine_similarity(tfidf_matrix[0:1], tfidf_matrix[1:]).flatten()
    sorted_chunk_indices = cosine_similarities.argsort()[::-1]
    top_chunk_indices = sorted_chunk_indices[:2]

    top_matching_chunks = []
    for i in top_chunk_indices:
        # Tokenize the chunk text
        tokens = tokenizer.tokenize(chunk_texts[i])
        # Count the number of tokens
        num_tokens = len(tokens)
        # Store chunk text along with cosine similarity and number of tokens
        top_matching_chunks.append((cosine_similarities[i], chunk_texts[i], num_tokens))

    conn.close()
    return top_matching_chunks[0]

# query = "Elesh Norn, Grand Cenobite"
# matching_chunks = semantic_search(query)

# for i, (cosine_similarity_score, chunk, num_tokens) in enumerate(matching_chunks):
#     # Remove <s> tags from the chunk text
#     clean_chunk = chunk.replace('<s>', '')
#     print(f"Matching Chunk {i + 1}: (Cosine Similarity: {cosine_similarity_score}, Tokens: {num_tokens})")
#     print(clean_chunk)
#     print("-o"*60)

In [2]:
import pandas as pd

RAG_eval_data_raw = "RAG_MTG_Test.csv"
RAG_eval_data = pd.read_csv(RAG_eval_data_raw)
CARD_INFO_DF = pd.read_csv("raw/filtered_oracle_database.csv")
CARD_EVAL_DF = pd.read_csv("RAG_MTG_Test.csv")

In [3]:
FORMATTED_DATA = []

for name in CARD_EVAL_DF["Card Name"].tolist():
    card_info = CARD_INFO_DF[CARD_INFO_DF["name"] == name]

    formatted_info = ""
    for index, row in card_info.iterrows():
        formatted_row = ""
        for column_name, value in row.items():
            formatted_row += f"{column_name}: {value}\n"
        formatted_info += formatted_row.strip() + "\n"
        
        FORMATTED_DATA.append(formatted_info)

In [4]:
import torch
from transformers import AutoModelForCausalLM

# Load pre-trained model and tokenizer
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", token = access_token)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf", token = access_token)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [11]:
def run_card_model_inference(model, tokenizer, card_name, card_information):
    
    raw_text_prompt = f"I want you to tell me the following information about {card_name} in this exact same format:\nname:\nmana_cost:\ncmc:\ntype_line:\noracle_text:\npower:\ntoughness:\ncolors:\ncolor_identity:\nkeywords: \nDo not provide any other context for this, all I want is to know the listed parameters for this card."
    
    raw_chunk_data = semantic_search(card_name)[1].replace('<s>', '')
    
    RAG_text_prompt = raw_chunk_data + "\n" + raw_text_prompt
    input_ids = tokenizer.encode(raw_text_prompt, return_tensors="pt", max_length=1024, truncation=True)
    
    # Generate text
    output = model.generate(input_ids, min_new_tokens = 70, max_length=2048, pad_token_id=tokenizer.eos_token_id)

    # Decode and display the generated text
    generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
    return generated_text

In [12]:
test_card_name = CARD_EVAL_DF["Card Name"].iloc[0]
test_card_info = FORMATTED_DATA[0]
generated_card_info = run_card_model_inference(model, tokenizer, test_card_name, test_card_info)

In [13]:
print(generated_card_info)

I want you to tell me the following information about Goldnight Redeemer in this exact same format:
name:
mana_cost:
cmc:
type_line:
oracle_text:
power:
toughness:
colors:
color_identity:
keywords:
rarity:

name: Goldnight Redeemer

mana_cost: 4GG

cmc: 4

type_line: Creature - Human Cleric

oracle_text: Whenever Goldnight Redeemer deals combat damage to a player, you may return that player's card to their hand. If you do, create a 2/2 white Human Cleric creature token.

power: 3

toughness: 3

colors: White

color_identity: White

keywords: Return to Hand, Creature

rarity: Rare

I hope that helps! Let me know if you have any other questions.


In [None]:
"""name: Goldnight Redeemer
mana_cost: 3UU
cmc: 3
type_line: Creature - Angel
oracle_text: Whenever Goldnight Redeemer deals combat damage to a player, that player discards a card.
power: 4
toughness: 4
colors: U, U/B
color_identity: Blue
keywords: Flying, Lifelink"""