# Link to the database 
https://www.kaggle.com/datasets/ernestitus/2024-olympics-medals-vs-gdp

# Get your api key from here : 
https://aistudio.google.com/app/apikey

# Data consultation

In [32]:
import pandas as pd

In [33]:
data = pd.read_csv('olympics.csv')

In [4]:
data.isnull().sum()

country         0
country_code    0
region          0
gold            0
silver          0
bronze          0
total           0
gdp             0
gdp_year        0
population      0
dtype: int64

In [5]:
data.head()

Unnamed: 0,country,country_code,region,gold,silver,bronze,total,gdp,gdp_year,population
0,United States,USA,North America,40,44,42,126,81695.19,2023,334.9
1,China,CHN,Asia,40,27,24,91,12614.06,2023,1410.7
2,Japan,JPN,Asia,20,12,13,45,33834.39,2023,124.5
3,Australia,AUS,Oceania,18,19,16,53,64711.77,2023,26.6
4,France,FRA,Europe,16,26,22,64,44460.82,2023,68.2


# The import needed

In [None]:
from typing import List, Dict
from langchain_google_genai import GoogleGenerativeAI, GoogleGenerativeAIEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
from langchain.schema import Document
from tabulate import tabulate

# Initialisation

In [None]:
# Initialize Google Embeddings
initialize_embeddings = lambda api_key: GoogleGenerativeAIEmbeddings(
    model="models/embedding-001",
    google_api_key=api_key,
    task_type="retrieval_query"
)

# Initialize LLM (Gemini Pro)
initialize_llm = lambda api_key: GoogleGenerativeAI(
    model="gemini-pro",
    google_api_key=api_key,
    temperature=0.3,  # Lower temperature for more factual responses
    top_p=0.9,
    top_k=40,
    max_output_tokens=2048,
)


# To convert the dataset to text

In [None]:
# Create text representation for each Olympic record
create_text_representation = lambda row: (
    f"Country: {row['country']} ({row['country_code']}) in {row['region']} "
    f"won {row['gold']} gold medals, {row['silver']} silver medals, and {row['bronze']} bronze medals "
    f"in the Olympics, with a total of {row['total']} medals. "
    f"The country's GDP is ${row['gdp']} trillion (as of {row['gdp_year']}) "
    f"with a population of {row['population']} million people."
)

# To load the text and split it to chuncks

In [None]:
# Load and process Olympics data
def load_and_process_olympics_data(df: pd.DataFrame) -> List[Document]:
    documents = [Document(page_content=create_text_representation(row), metadata={"country": row["country"], "region": row["region"], "total_medals": row["total"]}) for _, row in df.iterrows()]
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100, length_function=len)
    return text_splitter.split_documents(documents)

# Vectore store

In [None]:
# Create vector store
create_vector_store = lambda documents, embeddings: FAISS.from_documents(documents, embeddings)

# Save vector store to disk
save_vector_store = lambda vector_store, path: vector_store.save_local(path)

# Load vector store from disk
load_vector_store = lambda path, embeddings: FAISS.load_local(path, embeddings)

# QA retrieval chain and the query function

In [None]:
# Set up retrieval QA chain
def setup_retrieval_qa(llm, vector_store) -> RetrievalQA:
    prompt_template = """You are an Olympics data expert. Use the following contextual information about Olympic medals, 
    GDP, and population to answer the question. Be precise with numbers and always mention the source country when relevant.
    If you don't have enough information to answer accurately, please say so.
    
    Context: {context}
    
    Question: {question}
    
    Answer: """
    PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
    chain_type_kwargs = {"prompt": PROMPT}
    return RetrievalQA.from_chain_type(
        llm=llm,
        chain_type="stuff",
        retriever=vector_store.as_retriever(search_kwargs={"k": 3}),
        chain_type_kwargs=chain_type_kwargs,
        return_source_documents=True
    )

# Query function
def query_olympics(llm, vector_store, question: str) -> Dict:
    if not vector_store:
        raise ValueError("Vector store not initialized. Please create or load a vector store first.")
    qa_chain = setup_retrieval_qa(llm, vector_store)
    result = qa_chain.invoke({"query": question})
    return {"answer": result["result"], "source_documents": result["source_documents"]}

# Processing 

In [30]:
# Your Olympics DataFrame
data = pd.read_csv('olympics.csv')
# Initialize RAG system
GOOGLE_API_KEY = "YOUR_API_KEY"  # Replace with your API key
# Process documents and embeddings
embeddings = initialize_embeddings(GOOGLE_API_KEY)
llm = initialize_llm(GOOGLE_API_KEY)
documents = load_and_process_olympics_data(data)
# Create vector store
vector_store = create_vector_store(documents, embeddings)
# Optional: Save vector store
save_vector_store(vector_store, "olympics_vector_store")

# Testing :

In [31]:
# Function to format and print the results with full source documents
def print_query_results(result: Dict):
    # Extract and print the answer
    print("\n\033[1mAnswer:\033[0m")
    print(f"{result['answer']}\n")
    
    # Extract source documents
    sources = result['source_documents']
    
    # Create a table with metadata from source documents
    source_data = []
    for doc in sources:
        source_data.append({
            "Country": doc.metadata.get("country"),
            "Region": doc.metadata.get("region"),
            "Total Medals": doc.metadata.get("total_medals")
        })
    
    # Convert the source data into a DataFrame for better visualization
    df_sources = pd.DataFrame(source_data)
    
    # Display source document metadata in a table
    print("\033[1mSource Documents Metadata:\033[0m")
    print(tabulate(df_sources, headers="keys", tablefmt="grid"))
    
    # Display the full content of each source document
    for i, doc in enumerate(sources, 1):
        print(f"\n\033[1mSource Document {i}:\033[0m")
        print(f"\033[3mCountry:\033[0m {doc.metadata.get('country')}")
        print(f"\033[3mRegion:\033[0m {doc.metadata.get('region')}")
        print(f"\033[3mTotal Medals:\033[0m {doc.metadata.get('total_medals')}")
        print(f"\033[3mDocument Content:\033[0m\n{doc.page_content}\n")
    
# Example query to test
question = "wich countries from europe won a medal  ?"
result = query_olympics(llm, vector_store, question)

# Print the result in a modern, visual way
print_query_results(result)



[1mAnswer:[0m
- France (FRA)
- Belgium (BEL)
- Germany (DEU)

[1mSource Documents Metadata:[0m
+----+-----------+----------+----------------+
|    | Country   | Region   |   Total Medals |
|  0 | France    | Europe   |             64 |
+----+-----------+----------+----------------+
|  1 | Belgium   | Europe   |             10 |
+----+-----------+----------+----------------+
|  2 | Germany   | Europe   |             33 |
+----+-----------+----------+----------------+

[1mSource Document 1:[0m
[3mCountry:[0m France
[3mRegion:[0m Europe
[3mTotal Medals:[0m 64
[3mDocument Content:[0m
Country: France (FRA) in Europe won 16 gold medals, 26 silver medals, and 22 bronze medals in the Olympics, with a total of 64 medals. The country's GDP is $44460.82 trillion (as of 2023) with a population of 68.2 million people.


[1mSource Document 2:[0m
[3mCountry:[0m Belgium
[3mRegion:[0m Europe
[3mTotal Medals:[0m 10
[3mDocument Content:[0m
Country: Belgium (BEL) in Europe won 3 g