In [37]:
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Load the dataset
df = pd.read_pickle("/content/sample_data/cleaned_pokemon_data.pkl")

# List of stats to analyze and their labels
stats = ["HP_min", "Attack_min", "Defense_min", "Sp. Atk_min", "Sp. Def_min", "Speed_min"]
stat_labels = ["HP", "Attack", "Defense", "Sp. Atk", "Sp. Def", "Speed"]

generation_colors = {
    1: 'blue',
    2: 'green',
    3: 'red',
    4: 'orange'
}

# Create subplots (2x3 facet grid)
fig = make_subplots(
    rows=2,
    cols=3,
    subplot_titles=stat_labels,
)

generations = sorted(df['Generation'].unique())

# Main loop for ridgeline subplots
for stat_idx, stat in enumerate(stats):
    row = (stat_idx // 3) + 1  # Determine row index
    col = (stat_idx % 3) + 1  # Determine column index
    # Iterate over generations to add ridgeline plots
    for i, (generation, group_df) in enumerate(df.groupby('Generation')):
        sub_df = group_df
        fig.add_trace(
            go.Violin(
                x=sub_df[stat],
                y0=i + 1,
                name=f"Gen {generation}",
                meanline_visible=True,
                line=dict(width=1, color=generation_colors.get(generation)),
                width=3,
                marker=dict(opacity=0.8),
                side="positive",
                legendgroup=f"Gen {generation}",
                showlegend=(stat_idx == 0),
            ),
            row=(stat_idx // 3) + 1,
            col=(stat_idx % 3) + 1,
        )

    # Update x-axis range for all subplots
    fig.update_xaxes(
        tickmode="array",
        tickvals=[0, 100, 200, 300, 400, 500],
        ticktext=[0, 100, 200, 300, 400, 500],
        range=[0, 500],
        dtick = 100,
        title_text="Value",
        row=row,
        col=col,
    )
    # Update y-axis (generation labels)
    fig.update_yaxes(
        title_text="Generation #",
        tickmode="array",
        tickvals=generations,
        ticktext=[f"Gen {gen}" for gen in generations],
        row=row,
        col=col,
    )

fig.update_layout(
    title_text="Pokémon Max Level Minimum Stats Ridgeline Plot",  # Main title
    title_x=0.5,  # Center the title
    title_font_size=20,
    showlegend=True,
    legend=dict(
        orientation="h",  # Horizontal legend
        x=0.5,
        xanchor="center",
        y=-0.2,
    ),
    height=900,
    width=1200,
)

# Display the plot
fig.show()