In [4]:
# cell 1
import sys
import os

# Add the project root to the system path so we can import from 'src'
# This assumes the notebook is inside a 'notebooks' folder
sys.path.append(os.path.abspath('..'))

%load_ext autoreload
%autoreload 2

# Import our custom modules
from src.utils import get_spark_session
from src.config import API_KEY, BASE_URL, MOVIE_IDS
from src.ingestion import fetch_movie_data
from src.cleaning import clean_movie_data
from src.analysis import get_ranked_movies, analyze_franchises

# Visualization libraries
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

# Initialize Spark
spark = get_spark_session("TMDB_Analysis_Lab")
print(f"Spark Version: {spark.version}")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Spark Version: 3.5.0


In [5]:
# cell 2
# 1. Fetch raw data using Python (Driver Node)
print(f"Fetching data for {len(MOVIE_IDS)} movies...")
raw_data_list = fetch_movie_data(MOVIE_IDS, API_KEY, BASE_URL)

# 2. Convert to Spark DataFrame
# Spark automatically infers the schema from the list of dicts
df_raw = spark.createDataFrame(raw_data_list)

print("Raw Schema:")
df_raw.printSchema()

Fetching data for 19 movies...
Fetching 1/19: ID 0...
Fetching 2/19: ID 299534...
Fetching 3/19: ID 19995...
Fetching 4/19: ID 140607...
Fetching 5/19: ID 299536...
Fetching 6/19: ID 597...
Fetching 7/19: ID 135397...
Fetching 8/19: ID 420818...
Fetching 9/19: ID 24428...
Fetching 10/19: ID 168259...
Fetching 11/19: ID 99861...
Fetching 12/19: ID 284054...
Fetching 13/19: ID 12445...
Fetching 14/19: ID 181808...
Fetching 15/19: ID 330457...
Fetching 16/19: ID 351286...
Fetching 17/19: ID 109445...
Fetching 18/19: ID 321612...
Fetching 19/19: ID 260513...
Raw Schema:
root
 |-- adult: boolean (nullable = true)
 |-- backdrop_path: string (nullable = true)
 |-- belongs_to_collection: map (nullable = true)
 |    |-- key: string
 |    |-- value: long (valueContainsNull = true)
 |-- budget: long (nullable = true)
 |-- genres: array (nullable = true)
 |    |-- element: map (containsNull = true)
 |    |    |-- key: string
 |    |    |-- value: long (valueContainsNull = true)
 |-- homepage: stri

In [6]:
# cell 3
# Apply our Spark cleaning pipeline
df_clean = clean_movie_data(df_raw)

# Cache the result since we will use this dataframe for multiple analyses
df_clean.cache()

print(f"Cleaned Row Count: {df_clean.count()}")
df_clean.select("title", "release_date", "revenue_musd", "roi").show(5, truncate=False)

AnalysisException: [DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE] Cannot resolve "genres[name]" due to data type mismatch: Parameter 2 requires the "INTEGRAL" type, however "name" has the type "STRING".;
'Project [backdrop_path#53, belongs_to_collection#125L, budget#55L, array_join(genres#56[name], |, None) AS genres#148, id#58L, origin_country#60, original_language#61, overview#63, popularity#64, poster_path#65, production_companies#66, production_countries#67, release_date#68, revenue#69L, runtime#70L, spoken_languages#71, status#72, tagline#73, title#74, vote_average#76, vote_count#77L]
+- Project [backdrop_path#53, belongs_to_collection#54[name] AS belongs_to_collection#125L, budget#55L, genres#56, id#58L, origin_country#60, original_language#61, overview#63, popularity#64, poster_path#65, production_companies#66, production_countries#67, release_date#68, revenue#69L, runtime#70L, spoken_languages#71, status#72, tagline#73, title#74, vote_average#76, vote_count#77L]
   +- Project [backdrop_path#53, belongs_to_collection#54, budget#55L, genres#56, id#58L, origin_country#60, original_language#61, overview#63, popularity#64, poster_path#65, production_companies#66, production_countries#67, release_date#68, revenue#69L, runtime#70L, spoken_languages#71, status#72, tagline#73, title#74, vote_average#76, vote_count#77L]
      +- LogicalRDD [adult#52, backdrop_path#53, belongs_to_collection#54, budget#55L, genres#56, homepage#57, id#58L, imdb_id#59, origin_country#60, original_language#61, original_title#62, overview#63, popularity#64, poster_path#65, production_companies#66, production_countries#67, release_date#68, revenue#69L, runtime#70L, spoken_languages#71, status#72, tagline#73, title#74, video#75, ... 2 more fields], false


In [None]:
# cell 4
print("--- Top 5 Highest Revenue ---")
top_rev = get_ranked_movies(df_clean, "revenue_musd", ascending=False)
top_rev.select("title", "revenue_musd", "budget_musd").show()

print("--- Top 5 Highest ROI ---")
top_roi = get_ranked_movies(df_clean, "roi", ascending=False)
top_roi.select("title", "roi", "revenue_musd").show()

print("--- Worst 5 Flops (Lowest ROI) ---")
# Filter for significant budget first to avoid divide-by-zero anomalies on micro-films
flop_roi = get_ranked_movies(df_clean.filter("budget_musd > 10"), "roi", ascending=True)
flop_roi.select("title", "roi", "budget_musd", "revenue_musd").show()

In [None]:
# cell 5
# Compare Franchises vs Standalone
franchise_stats = analyze_franchises(df_clean)

# Collect to Pandas for display
pdf_franchise = franchise_stats.toPandas()
display(pdf_franchise)

In [None]:
# cell 6
# Collect necessary data to Pandas (Drivers only)
# Warning: Only do this on aggregated or filtered data in production!
pdf_plot = df_clean.select("title", "budget_musd", "revenue_musd", "roi", "popularity").toPandas()

# Set plot style
sns.set_theme(style="whitegrid")

# 1. Budget vs Revenue Scatter Plot
plt.figure(figsize=(10, 6))
sns.scatterplot(data=pdf_plot, x="budget_musd", y="revenue_musd", size="roi", hue="roi", sizes=(20, 200), palette="viridis")

plt.title("Movie Budget vs. Revenue (Size = ROI)")
plt.xlabel("Budget (Million USD)")
plt.ylabel("Revenue (Million USD)")
plt.axline((0, 0), slope=1, color=".5", linestyle="--", label="Break-even") # Break-even line
plt.legend(bbox_to_anchor=(1.05, 1), loc=2)

# Label top movies
for i in range(pdf_plot.shape[0]):
    if pdf_plot.revenue_musd[i] > 500 or pdf_plot.roi[i] > 5: # Only label hits
        plt.text(pdf_plot.budget_musd[i]+2, pdf_plot.revenue_musd[i], 
                 pdf_plot.title[i], fontsize=9)

plt.show()

In [None]:
# cell 7
# We need to 'explode' the genre string back into rows for this plot
# "Action|Sci-Fi" -> 
# Row 1: Action
# Row 2: Sci-Fi
from pyspark.sql.functions import explode, split

df_exploded = df_clean.withColumn("genre", explode(split("genres", "\|")))
genre_roi = df_exploded.groupBy("genre").mean("roi").toPandas().sort_values("avg(roi)", ascending=False)

plt.figure(figsize=(12, 6))
sns.barplot(data=genre_roi, x="avg(roi)", y="genre", palette="magma")
plt.title("Average ROI by Genre")
plt.xlabel("Average ROI")
plt.show()