In [1]:
import pandas as pd
import weaviate
from sentence_transformers import SentenceTransformer
import numpy as np
from tqdm import tqdm

  from tqdm.autonotebook import tqdm, trange


In [3]:
# Initialize Weaviate client
client = weaviate.Client("http://localhost:8080")


Python client v3 `weaviate.Client(...)` connections and methods are deprecated and will
            be removed by 2024-11-30.

            Upgrade your code to use Python client v4 `weaviate.WeaviateClient` connections and methods.
                - For Python Client v4 usage, see: https://weaviate.io/developers/weaviate/client-libraries/python
                - For code migration, see: https://weaviate.io/developers/weaviate/client-libraries/python/v3_v4_migration

            If you have to use v3 code, install the v3 client and pin the v3 dependency in your requirements file: `weaviate-client>=3.26.7;<4.0.0`
  client = weaviate.Client("http://localhost:8080")


In [8]:

# Define the schema
game_class = {
    "class": "Game",
    "description": "A class representing a game",
    "vectorizer": "none",
    "properties": [
        {"name": "gameId", "dataType": ["text"], "description": "Unique identifier for the game"},
        {"name": "gameName", "dataType": ["text"], "description": "Name of the game"},
        {"name": "alternateNames", "dataType": ["text"], "description": "Alternative names for the game"},
        {"name": "subcategory", "dataType": ["text"], "description": "Subcategory of the game"},
        {"name": "level", "dataType": ["text"], "description": "Difficulty or experience level required"},
        {"name": "description", "dataType": ["text"], "description": "Detailed description of the game"},
        {"name": "playersMax", "dataType": ["int"], "description": "Maximum number of players"},
        {"name": "ageRange", "dataType": ["text"], "description": "Suitable age range for players"},
        {"name": "duration", "dataType": ["text"], "description": "Duration of the game"},
        {"name": "equipmentNeeded", "dataType": ["text"], "description": "Equipment required to play the game"},
        {"name": "objective", "dataType": ["text"], "description": "The main objective of the game"},
        {"name": "skillsDeveloped", "dataType": ["text"], "description": "Skills that players develop"},
        {"name": "setupTime", "dataType": ["text"], "description": "Time required to set up the game"},
        {"name": "place", "dataType": ["text"], "description": "Specific place or setting for the game"},
        {"name": "physicalIntensityLevel", "dataType": ["text"], "description": "Physical intensity level"},
        {"name": "educationalBenefits", "dataType": ["text"], "description": "Educational benefits"},
        {"name": "category", "dataType": ["text"], "description": "Main category of the game"},
    ]
}

# Delete the class if it already exists
if client.schema.exists("Game"):
    client.schema.delete_class("Game")

# Create the schema
client.schema.create_class(game_class)



In [9]:
# Load the game data
df_games = pd.read_csv('../data/game-dataset.csv')
df_games = df_games.fillna('')

# Function to combine fields
def combine_fields(row):
    fields = [
        'gameName',
        'alternateNames',
        'subcategory',
        'level',
        'description',
        'playersMax',
        'ageRange',
        'duration',
        'equipmentNeeded',
        'objective',
        'skillsDeveloped',
        'setupTime',
        'place',
        'physicalIntensityLevel',
        'educationalBenefits',
        'category'
    ]
    return ' '.join(str(row[field]) for field in fields if row[field])

# Load the embedding model
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')

# Function to import data
def import_data(df):
    with client.batch as batch:
        batch.batch_size = 100
        for index, row in tqdm(df.iterrows(), total=len(df)):
            properties = {
                "gameId": str(row["gameId"]),  # Changed from 'id' to 'gameId'
                "gameName": row["gameName"],
                "alternateNames": row["alternateNames"],
                "subcategory": row["subcategory"],
                "level": row["level"],
                "description": row["description"],
                "playersMax": int(row["playersMax"]) if str(row["playersMax"]).isdigit() else 0,
                "ageRange": row["ageRange"],
                "duration": row["duration"],
                "equipmentNeeded": row["equipmentNeeded"],
                "objective": row["objective"],
                "skillsDeveloped": row["skillsDeveloped"],
                "setupTime": row["setupTime"],
                "place": row["place"],
                "physicalIntensityLevel": row["physicalIntensityLevel"],
                "educationalBenefits": row["educationalBenefits"],
                "category": row["category"],
            }

            # Generate embedding for the 'combined_text' field
            combined_text = combine_fields(row)
            embedding = embedding_model.encode(combined_text).astype('float32')

            client.batch.add_data_object(
                data_object=properties,
                class_name="Game",
                vector=embedding
            )

# Import the data
import_data(df_games)

# Load the ground truth data
df_ground_truth = pd.read_csv('../data/ground-truth-retrieval.csv')
df_ground_truth = df_ground_truth.fillna('')

# Adjust column names if necessary
# Assume columns 'query' and 'relevant_doc_id' exist
queries = df_ground_truth['question'].tolist()
ground_truth_ids = df_ground_truth['q_id'].astype(str).tolist()

# Number of top results to consider
k = 10

# Initialize metrics
hit_count = 0
reciprocal_ranks = []

# Iterate over queries
for query_text, true_id in tqdm(zip(queries, ground_truth_ids), total=len(queries)):
    # Generate embedding for the query
    query_embedding = embedding_model.encode([query_text])[0].astype('float32')

    # Perform vector search
    response = client.query.get(
        class_name="Game",
        properties=["gameId"]  # Changed from 'id' to 'gameId'
    ).with_near_vector({
        "vector": query_embedding.tolist(),
    }).with_limit(k).do()

    # Extract the retrieved IDs
    results = response["data"]["Get"]["Game"]
    retrieved_ids = [res['gameId'] for res in results]

    # Check if the true ID is in the top k results
    if true_id in retrieved_ids:
        hit_count += 1
        # Calculate reciprocal rank
        rank = retrieved_ids.index(true_id) + 1
        reciprocal_rank = 1 / rank
        reciprocal_ranks.append(reciprocal_rank)
    else:
        reciprocal_ranks.append(0)

# Calculate metrics
hit_rate = hit_count / len(queries)
mrr = np.mean(reciprocal_ranks)

print(f"Hit Rate@{k}: {hit_rate:.4f}")
print(f"MRR@{k}: {mrr:.4f}")


100%|██████████| 590/590 [01:05<00:00,  9.00it/s]
100%|██████████| 2950/2950 [02:21<00:00, 20.87it/s]

Hit Rate@10: 0.9515
MRR@10: 0.7799



