In [16]:
import pandas as pd
import plotly.graph_objects as go
import numpy as np
from plotly.subplots import make_subplots
import statsmodels.api as sm

# Load the dataset
file_path = "../../docs/assets/data/metrics.json"
df = pd.read_json(file_path)

# Select necessary columns
columns_to_plot = ["vote_average", "revenue", "ROI"]

# ------------------------ #
# Helper Function for Regression
# ------------------------ #
def add_trendline(x, y):
    """Fit an OLS trendline and return line endpoints."""
    x = np.array(x)
    y = np.array(y)
    x_clean = x[~np.isnan(y)]  # Clean NaN values
    y_clean = y[~np.isnan(y)]
    x_with_const = sm.add_constant(x_clean)  # Add intercept
    model = sm.OLS(y_clean, x_with_const).fit()
    y_pred = model.predict(x_with_const)
    return x_clean, y_pred

# ------------------------ #
# Create Pair Plot with Histograms
# ------------------------ #
n = len(columns_to_plot)
colors = ['#1f77b4', '#ff7f0e', '#2ca02c']  # Custom color palette

# Create a subplot grid
fig = make_subplots(
    rows=n, cols=n,
    shared_xaxes=False,
    shared_yaxes=False,
    vertical_spacing=0.05,
    horizontal_spacing=0.05
)

# Loop through rows and columns
for i, col_y in enumerate(columns_to_plot):
    for j, col_x in enumerate(columns_to_plot):
        row, col = i + 1, j + 1

        if i == j:
            # Diagonal: Histogram
            fig.add_trace(
                go.Histogram(
                    x=df[col_x],
                    nbinsx=20,
                    marker=dict(color=colors[i], line=dict(color="black", width=1)),
                    opacity=0.7,
                    name=f"{col_x} Distribution"
                ),
                row=row, col=col
            )
        else:
            # Off-diagonal: Scatter plot with trendline
            fig.add_trace(
                go.Scatter(
                    x=df[col_x],
                    y=df[col_y],
                    mode='markers',
                    marker=dict(size=5, color=colors[j], opacity=0.7),
                    name=f"{col_y} vs {col_x}"
                ),
                row=row, col=col
            )

            # Add trendline
            x_clean, y_pred = add_trendline(df[col_x], df[col_y])
            fig.add_trace(
                go.Scatter(
                    x=x_clean,
                    y=y_pred,
                    mode='lines',
                    line=dict(color='red', width=2),
                    name=f"Trendline: {col_y} ~ {col_x}"
                ),
                row=row, col=col
            )

# ------------------------ #
# Update Layout
# ------------------------ #
fig.update_layout(
    title_text="Distribution and Correlation Between Movie Ratings, Revenue, and ROI",
    height=900,
    width=1100,
    legend=dict(
        x=1.05, y=1,
        traceorder="normal",
        font=dict(size=10),
        bgcolor="rgba(255, 255, 255, 0.5)",
        bordercolor="gray",
        borderwidth=1
    ),
    template="plotly_white"  # Use seaborn style
)

# Update axis labels
for i, col in enumerate(columns_to_plot):
    fig.update_xaxes(title_text=col, row=n, col=i + 1)
    fig.update_yaxes(title_text=col, row=i + 1, col=1)

# ------------------------ #
# Save the Figure
# ------------------------ #
OUTPUT_PATH = "../../docs/_includes/plotly/"
fig.write_html(f'{OUTPUT_PATH}rq1_metrics_distribution.html', full_html=False, include_plotlyjs='cdn')
print(f"Pair plot saved as {OUTPUT_PATH}rq1_metrics_distribution.html")


Pair plot saved as ../../docs/_includes/plotly/rq1_metrics_distribution.html
