<a href="https://colab.research.google.com/github/hsynj/AI-Special-HW/blob/main/Movie%20RAG%20Recommender/01_Movie_RAG_Recommender(Advanced).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Advanced RAG Pipeline for Movie Recommendations

This notebook takes the `simple_rag_exercise.ipynb` as a starting point and enhances it with several advanced RAG techniques to improve accuracy, context-awareness, and interactivity.

This project demonstrates a more robust, production-ready approach to RAG.

### Key Enhancements:
1.  **Richer Embeddings:**
    * Instead of embedding *only* the `overview`, this pipeline combines `title`, `genres`, `keywords`, and `overview` into a single, rich document for embedding. This allows for much more complex and accurate queries (e.g., searching by genre *and* story).

2.  **Re-Ranking (Cross-Encoder):**
    * To improve precision, the pipeline first retrieves the Top-K (e.g., Top 5) candidates from Qdrant.
    * It then uses a `Cross-Encoder` (Re-Ranker) to score the relevance of each candidate against the query *directly*.
    * This finds the *single best match* and solves the "Top-1 limitation" of a simple vector search.
3.  **Conversational Memory:**
    * The pipeline now includes a `chat_history` list.
    * This history is fed back into the LLM prompt, allowing the system to remember previous interactions and handle follow-up questions (e.g., "Thanks, find me another one...").

4.  **Flexible Generation:**
    * The generation step is configured to use Google's `gemini-pro` API.

In [None]:
# requirements
!pip install datasets qdrant-client sentence-transformers google-generativeai openai rich

In [None]:
from rich.console import Console

# Create a console (we will use the default theme)
console = Console()

In [None]:
from datasets import load_dataset

dataset = load_dataset('AiresPucrs/tmdb-5000-movies')

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [None]:
console.print(dataset)

In [None]:
from qdrant_client import models, QdrantClient
from sentence_transformers import SentenceTransformer

# create the vector database client
qdrant = QdrantClient(":memory:") # Create in-memory Qdrant instance

# Create the embedding encoder
encoder = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')

In [None]:
console.print(encoder)

In [None]:
# Create collection to store the wine rating data
collection_name="movies"

qdrant.recreate_collection(
    collection_name=collection_name,
    vectors_config=models.VectorParams(
        size=encoder.get_sentence_embedding_dimension(), # Vector size is defined by used model
        distance=models.Distance.COSINE
    )
)

  qdrant.recreate_collection(


True

In [None]:
### [ADVANCED] Helper function to create a rich document for embedding.
import json

def create_movie_document(movie):
    try:
        genres_list = json.loads(movie['genres'])
    except (TypeError, json.JSONDecodeError):
        genres_list = movie['genres'] if isinstance(movie['genres'], list) else []

    try:
        keywords_list = json.loads(movie['keywords'])
    except (TypeError, json.JSONDecodeError):
        keywords_list = movie['keywords'] if isinstance(movie['keywords'], list) else []

    genres = ", ".join([g['name'] for g in genres_list if g and 'name' in g])
    keywords = ", ".join([k['name'] for k in keywords_list if k and 'name' in k])

    return f"Title: {movie['original_title']}. Genres: {genres}. Keywords: {keywords}. Overview: {movie['overview']}"

In [None]:
points_list = []
for i, movie in enumerate(dataset['train']):
  if movie['overview']:
    print(f"{i+1} - Processing ID: {movie['id']}, Title: {movie['original_title']}")
    points_list.append(
        models.PointStruct(
            id=movie['id'],
            vector=encoder.encode(create_movie_document(movie)).tolist(),
            payload=movie
        )
    )

print("\nStarting upload to Qdrant...")
qdrant.upload_points(
    collection_name=collection_name,
    points=points_list
)
print("Upload complete!")

1 - Processing ID: 5, Title: Four Rooms
2 - Processing ID: 11, Title: Star Wars
3 - Processing ID: 12, Title: Finding Nemo
4 - Processing ID: 13, Title: Forrest Gump
5 - Processing ID: 14, Title: American Beauty
6 - Processing ID: 16, Title: Dancer in the Dark
7 - Processing ID: 18, Title: The Fifth Element
8 - Processing ID: 19, Title: Metropolis
9 - Processing ID: 20, Title: My Life Without Me
10 - Processing ID: 22, Title: Pirates of the Caribbean: The Curse of the Black Pearl
11 - Processing ID: 24, Title: Kill Bill: Vol. 1
12 - Processing ID: 25, Title: Jarhead
13 - Processing ID: 28, Title: Apocalypse Now
14 - Processing ID: 33, Title: Unforgiven
15 - Processing ID: 35, Title: The Simpsons Movie
16 - Processing ID: 38, Title: Eternal Sunshine of the Spotless Mind
17 - Processing ID: 55, Title: Amores perros
18 - Processing ID: 58, Title: Pirates of the Caribbean: Dead Man's Chest
19 - Processing ID: 59, Title: A History of Violence
20 - Processing ID: 62, Title: 2001: A Space Ody

In [None]:
console.print(
    qdrant
    .get_collection(
        collection_name=collection_name
    )
)

In [None]:
user_prompt = "Thanks! Now find a comedy from the 2000s"

In [None]:
query_vector = encoder.encode(user_prompt).tolist()

In [None]:
from qdrant_client import models

query_filter= models.Filter(
  must=[
      models.FieldCondition(
          key='release_date',
          range=models.DatetimeRange(
              gte='1990-01-01T00:00:00Z',
              lte='1999-12-31T23:59:59Z'
          )
      )
  ]
)

In [None]:
# Search time for awesome wines!

hits = qdrant.search(
    collection_name=collection_name,
    query_vector=query_vector,
    limit=5,
    query_filter=query_filter,
)

  hits = qdrant.search(


In [None]:
from rich.text import Text
from rich.table import Table

table = Table(title="Retrieval Results", show_lines=True)

table.add_column("ID", style="#e0e0e0")
table.add_column("Original Title", style="#e0e0e0")
table.add_column("Overview", style="bright_red")
table.add_column("Score", style="#89ddff")

for hit in hits:
    table.add_row(
        str(hit.payload["id"]),
        hit.payload["original_title"],
        f'{hit.payload["overview"]}',
        f"{hit.score:.4f}"
    )

console.print(table)

In [None]:
### [ADVANCED] Re-rank the top-K retrieval results for higher precision.

from sentence_transformers.cross_encoder import CrossEncoder

cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')

query = user_prompt
pairs = [(query, hit.payload['overview']) for hit in hits]

scores = cross_encoder.predict(pairs)

best_hit = hits[scores.argmax()]

In [None]:
### [ADVANCED] Display the single best re-ranked result.
from rich.table import Table
from rich.console import Console

if 'best_hit' in locals() and best_hit is not None:

    table = Table(
        title="Best Hit Details",
        show_header=True,
        header_style="bold magenta",
        show_lines=True,
        padding=(0, 1)
    )

    table.add_column("Field", style="dim", width=18)
    table.add_column("Value", style="cyan")
    table.add_row("Qdrant ID", str(best_hit.id))
    table.add_row("Relevance Score", f"{best_hit.score:.4f}")
    payload = best_hit.payload

    if 'original_title' in payload:
        table.add_row("Original Title", payload['original_title'])

    if 'release_date' in payload:
        table.add_row("Release Date", payload['release_date'])

    if 'popularity' in payload:
        table.add_row("Popularity", f"{payload['popularity']:.2f}")

    if 'genres' in payload:
        table.add_row("Genres", str(payload['genres']))

    if 'overview' in payload:
        table.add_row("Overview", payload['overview'])

    console.print(table)

else:
    console.print("[bold red]Error:[/bold red] The variable 'best_hit' was not found or is empty.")

In [None]:
# define a variable to hold the search results with specific fields
search_results = [
    {
        'original_title': hit.payload['original_title'],
        'title': hit.payload['title'],
        'overview': hit.payload['overview'],
        'release_date': hit.payload['release_date'],
        'popularity': hit.payload['popularity']
    } for hit in hits]

In [None]:
console.print(search_results)

In [None]:
### [ADVANCED] Implement conversational memory (chat history) for follow-ups.
if 'chat_history' not in globals():
    chat_history = []

In [None]:
import google.generativeai as genai
from google.colab import userdata
from rich.panel import Panel
from rich.text import Text

try:
    GOOGLE_API_KEY = userdata.get('GOOGLE_API_KEY')
    genai.configure(api_key=GOOGLE_API_KEY)
except Exception as e:
    console.print(f"[bold red]Error:[/bold red] Could not configure Gemini. Make sure you have stored 'GOOGLE_API_KEY' in Colab Secrets.")
    raise e

history_string = "\n".join([f"User: {turn['user']}\nModel: {turn['model']}" for turn in chat_history])
system_prompt = "You are a helpful movie recommendation assistant..."
user_request = f"The user is asking for: '{user_prompt}'."
context = f"Based on this, I found the following movie context: {search_results}."
full_prompt = f"Previous conversation:\n{history_string}\n\n{system_prompt}\n\n{context}\n\n{user_request}\n\nProvide a recommendation for the user based *only* on the new context and be aware of the previous conversation."


model = genai.GenerativeModel('gemini-2.5-flash')
completion = model.generate_content(full_prompt)

response_text = Text(completion.text)
styled_panel = Panel(
    response_text,
    title="Movie Recommendation with Retrieval",
    expand=False,
    border_style="bright_yellow",
    padding=(1, 1)
)

response_text = completion.text
chat_history.append({"user": user_prompt, "model": response_text})

console.print(styled_panel)

In [None]:
### [ADVANCED] Display the full conversational chat history log.
from rich.console import Console
from rich.table import Table

if 'chat_history' not in locals() or not chat_history:
    console.print("[yellow]Alert: 'chat_history' not found. Using sample data for demonstration.[/yellow]\n")
    chat_history = [
        {"user": "Love story between an Asian king and European teacher",
         "model": "Based on your request, I recommend 'Anna and the King' (1999). It's a romance about the King of Siam and a British school teacher..."},
        {"user": "Thanks! Can you find me another one like that, but a comedy?",
         "model": "Based on the new context, I recommend 'Crazy Rich Asians'. It's a modern romantic comedy..."}
    ]

if 'chat_history' in locals() and chat_history:
    table = Table(
        title="💬 Chat History Log",
        show_header=True,
        header_style="bold magenta",
        show_lines=True
    )
    table.add_column("Role", style="dim", width=15)
    table.add_column("Message", style="cyan")

    for i, turn in enumerate(chat_history):
        table.add_row(f"User (Turn {i+1})", turn['user'])
        table.add_row(f"Model (Turn {i+1})", turn['model'], end_section=True)
    console.print(table)

else:
    console.print("[bold red]Chat history is empty.[/bold red]")