In [None]:
import pandas as pd
import json

# ------------------------ #
# 1. Load the Dataset
# ------------------------ #

df = pd.read_csv('../../data/cmu_tmdb.csv')
print("Initial DataFrame shape:", df.shape)

# ------------------------ #
# 2. Data Cleaning
# ------------------------ #

# 2.1 Remove rows where vote_average is 0
df = df[df['vote_count'] > 0]
print("\nAfter removing rows with vote_count = 0:", df.shape)
df = df[df['vote_average'] > 0]
print("\nAfter removing rows with vote_average = 0:", df.shape)

# 2.2 Remove rows where revenue is 0
df = df[df["revenue"] > 0]
print("After removing rows with revenue = 0:", df.shape)

# 2.3 Remove duplicate movies based on 'id'
df = df.drop_duplicates(subset="id")
print("After removing duplicate movies:", df.shape)

# ------------------------ #
# 3. Select Relevant Columns
# ------------------------ #

relevant_cols = ["id", "title", "vote_average", "vote_count", "revenue", "budget"]
df = df[relevant_cols].copy()
print("\nSelected relevant columns:", df.shape)

# ------------------------ #
# 4. Compute Additional Metrics
# ------------------------ #

# Compute profit
df["profit"] = df["revenue"] - df["budget"]

# Remove movies with budget <= 0 to avoid division by zero or negative ROI
df = df[df["budget"] > 0]
print("\nAfter removing movies with budget <= 0:", df.shape)

# Compute ROI
df["ROI"] = df["profit"] / df["budget"]

# ------------------------ #
# 5. Remove Outliers
# ------------------------ #

def remove_outliers_iqr(data, col):
    Q1 = data[col].quantile(0.25)
    Q3 = data[col].quantile(0.75)
    IQR = Q3 - Q1
    lower_bound = Q1 - 1.5 * IQR
    upper_bound = Q3 + 1.5 * IQR
    return data[(data[col] >= lower_bound) & (data[col] <= upper_bound)]


# Remove outliers for relevant columns
columns_to_check = ["vote_average", "revenue", "ROI"]
for col in columns_to_check:
    before = df.shape[0]
    df = remove_outliers_iqr(df, col)
    after = df.shape[0]
    print(
        f"\nAfter removing outliers in '{col}': {before - after} rows removed, remaining {after} rows"
    )

# ------------------------ #
# 6. Randomly Select Subset of Movies
# ------------------------ #

# Set sample size (e.g., 200 movies) and random state for reproducibility
sample_size = 200
df_sample = df.sample(n=sample_size, random_state=42)

print(f"\nRandomly selected {sample_size} movies for analysis.")

# ------------------------ #
# 7. Save Processed Data as JSON
# ------------------------ #

# Convert sampled DataFrame to dictionary
processed_data = df_sample.to_dict(orient='records')

# Save to 'metrics.json'
with open("../../docs/assets/data/metrics.json", "w") as f:
    json.dump(processed_data, f, indent=2)

print("\nProcessing complete. 'metrics.json' has been created with the selected subset.")


In [3]:
from plot_settings import COMMON_LAYOUT

In [None]:
import pandas as pd
import plotly.graph_objects as go
import numpy as np
from plotly.subplots import make_subplots
import statsmodels.api as sm
import os

# Load the dataset
file_path = "../../docs/assets/data/metrics.json"
try:
    df = pd.read_json(file_path)
except ValueError:
    raise ValueError(f"Could not read the file at {file_path}. Ensure the path and format are correct.")

# Select necessary columns
columns_to_plot = ["vote_average", "revenue", "ROI"]

# ------------------------ #
# Helper Function for Regression
# ------------------------ #
def add_trendline(x, y):
    """Fit an OLS trendline and return line endpoints."""
    x = np.array(x)
    y = np.array(y)
    # Handle NaN values
    x_clean = x[~np.isnan(y)]
    y_clean = y[~np.isnan(y)]
    x_with_const = sm.add_constant(x_clean)  # Add intercept
    model = sm.OLS(y_clean, x_with_const).fit()
    y_pred = model.predict(x_with_const)
    return x_clean, y_pred

# ------------------------ #
# Create Pair Plot with Histograms
# ------------------------ #
n = len(columns_to_plot)
colors = ['#1f77b4', '#ff7f0e', '#2ca02c']  # Custom color palette

# Create a subplot grid
fig = make_subplots(
    rows=n, cols=n,
    shared_xaxes=False,
    shared_yaxes=False,
    vertical_spacing=0.05,
    horizontal_spacing=0.05
)

# Loop through rows and columns
for i, col_y in enumerate(columns_to_plot):
    for j, col_x in enumerate(columns_to_plot):
        row, col = i + 1, j + 1

        if i == j:
            # Diagonal: Histogram
            fig.add_trace(
                go.Histogram(
                    x=df[col_x],
                    nbinsx=20,
                    marker=dict(color=colors[i], line=dict(color="black", width=1)),
                    opacity=0.7,
                    name=f"{col_x} Distribution"
                ),
                row=row, col=col
            )
        else:
            # Off-diagonal: Scatter plot with trendline
            fig.add_trace(
                go.Scatter(
                    x=df[col_x],
                    y=df[col_y],
                    mode='markers',
                    marker=dict(size=5, color=colors[j], opacity=0.7),
                    name=f"{col_y} vs {col_x}"
                ),
                row=row, col=col
            )

            # Add trendline
            x_clean, y_pred = add_trendline(df[col_x], df[col_y])
            fig.add_trace(
                go.Scatter(
                    x=x_clean,
                    y=y_pred,
                    mode='lines',
                    line=dict(color='red', width=2),
                    name=f"Trendline: {col_y} ~ {col_x}"
                ),
                row=row, col=col
            )

# ------------------------ #
# Update Layout
# ------------------------ #
fig.update_layout(
    **COMMON_LAYOUT,
    title_text="Distribution and Correlation Between Movie Ratings, Revenue, and ROI",
    # width=900,
    legend=dict(
        x=1.05, y=1,
        traceorder="normal",
        font=dict(size=10),
        bgcolor="rgba(255, 255, 255, 0.5)",
        bordercolor="gray",
        borderwidth=1
    ),
    template="plotly_white"  # Use seaborn style
)

# Update axis labels
for i, col in enumerate(columns_to_plot):
    fig.update_xaxes(title_text=col, row=n, col=i + 1)
    fig.update_yaxes(title_text=col, row=i + 1, col=1)

# ------------------------ #
# Save the Figure
# ------------------------ #
OUTPUT_PATH = "../../docs/_includes/plotly/"
if not os.path.exists(OUTPUT_PATH):
    os.makedirs(OUTPUT_PATH)

output_file = os.path.join(OUTPUT_PATH, "rq1_metrics_distribution.html")
fig.write_html(output_file, full_html=False, include_plotlyjs='cdn')
print(f"Pair plot saved as {output_file}")
