In [17]:
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np

# Load your dataset
df = pd.read_pickle("/content/sample_data/cleaned_pokemon_data.pkl")  # Update with the actual dataset path

# Define consistent color mapping for Pokémon types
type_colors = {
    "Bug": "#A8B820", "Dark": "#705848", "Dragon": "#7038F8", "Electric": "#F8D030",
    "Fairy": "#EE99AC", "Fighting": "#C03028", "Fire": "#F08030", "Flying": "#A890F0",
    "Ghost": "#705898", "Grass": "#78C850", "Ground": "#E0C068", "Ice": "#98D8D8",
    "Normal": "#A8A878", "Poison": "#A040A0", "Psychic": "#F85888", "Rock": "#B8A038",
    "Steel": "#B8B8D0", "Water": "#6890F0"
}

# Map Pokémon types to colors
df["Color"] = df["Primary Type"].map(type_colors)

# Define a 7x7 subplot layout
specs = [
    [{'rowspan': 3, 'colspan': 3}, None, None, None, {'rowspan': 3, 'colspan': 3}, None, None],
    [None, None, None, None, None, None, None],
    [None, None, None, None, None, None, None],
    [None, None, None, {'type': 'domain'}, None, None, None],
    [{'rowspan': 3, 'colspan': 3}, None, None, None, {'rowspan': 3, 'colspan': 3}, None, None],
    [None, None, None, None, None, None, None],
    [None, None, None, None, None, None, None],
]

# Initialize subplots
fig = make_subplots(
    rows=7,
    cols=7,
    specs=specs,
    subplot_titles=[
        "Pokémon Total Stat Value over Primary Types",
        "Base Exp Requirements vs Pokémon Strength",
        "Pokémon Type Distribution",
        "Pokémon Height vs Pokémon Weight",
        "Pokémon Primary Type - Base Exp Heatmap",
    ],
)

# -------------------------------------------
# 1. Top-left: Ridgeline plot for Total Stat Value over Primary Types
# -------------------------------------------
primary_types = df["Primary Type"].unique()
offset = 0
for idx, primary_type in enumerate(primary_types):
    stats = df[df["Primary Type"] == primary_type]["Total"]
    stats = pd.to_numeric(stats, errors='coerce').dropna()
    hist, bins = np.histogram(stats, bins=20, density=True)

    fig.add_trace(
        go.Scatter(
            x=bins[:-1],
            y=hist + offset,
            fill="tonexty",
            name=primary_type,
            line=dict(color=type_colors[primary_type], width=2),
            mode="lines",
        ),
        row=1,
        col=1,
    )
    offset += 0.3  # Offset for the ridgeline effect

# -------------------------------------------
# 2. Top-right: Scatter plot of Base Exp vs Total Stat Value
# -------------------------------------------
fig.add_trace(
    go.Scatter(
        x=df["Base Exp."], y=df["Total"],
        mode="markers",
        marker=dict(size=8, color=df["Color"], opacity=0.7),
        name="Base Exp vs Total Stat",
    ),
    row=1,
    col=5,
)

# Add axis labels
fig.update_xaxes(title_text="Base Experience", row=1, col=5)
fig.update_yaxes(title_text="Total Stat Value", row=1, col=5)

# -------------------------------------------
# 3. Middle-center: Pie chart of Primary Type distribution
# -------------------------------------------
primary_type_distribution = df["Primary Type"].value_counts(normalize=True) * 100
fig.add_trace(
    go.Pie(
        labels=primary_type_distribution.index,
        values=primary_type_distribution.values,
        textinfo="percent+label",
        marker=dict(colors=[type_colors[t] for t in primary_type_distribution.index]),
    ),
    row=4,
    col=4,
)

# -------------------------------------------
# 4. Bottom-left: Scatter plot of Height vs Weight
# -------------------------------------------
fig.add_trace(
    go.Scatter(
        x=df["Height_m"], y=df["Weight_kg"],
        mode="markers",
        marker=dict(size=8, color=df["Color"], opacity=0.7),
        name="Height vs Weight",
    ),
    row=5,
    col=1,
)

fig.update_xaxes(title_text="Height (m)", row=5, col=1)
fig.update_yaxes(title_text="Weight (kg)", row=5, col=1)

# -------------------------------------------
# 5. Bottom-right: Heatmap of BaseExp vs Primary Type
# -------------------------------------------
df["Base Exp (Binned)"] = pd.cut(df["Base Exp."], bins=10)
heatmap_data = pd.crosstab(df["Primary Type"], df["Base Exp (Binned)"])

fig.add_trace(
    go.Heatmap(
        z=heatmap_data.values,
        x=heatmap_data.columns.astype(str),
        y=heatmap_data.index,
        colorscale="Purples",
        name="BaseExp Heatmap",
        colorbar=dict(title="Count"),
        showscale=False,
    ),
    row=5,
    col=5,
)

# -------------------------------------------
# Final Layout Update
# -------------------------------------------
fig.update_layout(
    height=1000,
    width=1400,
    title_text="Comprehensive Pokémon Data Visualization",
    title_x=0.5,
    showlegend=True,
    font=dict(family="Arial", size=12, color="black"),
    margin=dict(l=50, r=50, t=80, b=50),
)

# Display the figure
fig.show()
