In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
from climakitae.explore.amy import get_climate_profile

In [None]:
selection= {
    "variable": "Air Temperature at 2m",
    "resolution": "3 km",
    "warming_level": [2.0],
    "units": "degF",
    # cached_area=['Los Angeles County']
    "latitude": (37.704432 - 0.05, 37.704432 + 0.05),
    "longitude": (-121.898696 - 0.05, -121.898696 + 0.05),
    "q": 0.95,
}

In [None]:
profile = get_climate_profile(**selection)

In [None]:
"""
Visualization functions for climate profile data from amy.py

This module provides plotting functions for visualizing climate profiles
including heat maps, diurnal cycles, and seasonal patterns.
"""

import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import numpy as np
import pandas as pd
import seaborn as sns
from typing import Optional, Tuple, List, Union
from matplotlib.figure import Figure
from matplotlib.axes import Axes


def plot_climate_profile_heatmap(
    profile_df: pd.DataFrame,
    title: str = "Climate Profile Heatmap",
    cmap: str = "RdBu_r",
    figsize: Tuple[int, int] = (16, 10),
    vmin: Optional[float] = None,
    vmax: Optional[float] = None,
    center: float = 0,
    cbar_label: str = "Temperature Change (°F)",
    simulation: Optional[str] = None,
    warming_level: Optional[str] = None,
) -> Tuple[Figure, Axes]:
    """
    Create a heatmap visualization of climate profile data.
    
    Parameters
    ----------
    profile_df : pd.DataFrame
        Climate profile DataFrame from get_climate_profile or compute_profile
    title : str, optional
        Title for the plot
    cmap : str, optional
        Colormap to use for the heatmap
    figsize : tuple, optional
        Figure size (width, height)
    vmin : float, optional
        Minimum value for color scale
    vmax : float, optional
        Maximum value for color scale
    center : float, optional
        Center value for diverging colormap
    cbar_label : str, optional
        Label for the colorbar
    simulation : str, optional
        If MultiIndex, which simulation to plot
    warming_level : str, optional
        If MultiIndex, which warming level to plot
        
    Returns
    -------
    fig, ax : tuple
        Matplotlib figure and axes objects
    """
    fig, ax = plt.subplots(figsize=figsize)
    
    # Handle MultiIndex columns
    plot_data = profile_df.copy()
    
    if isinstance(profile_df.columns, pd.MultiIndex):
        if "Simulation" in profile_df.columns.names and simulation:
            # Filter to specific simulation
            plot_data = profile_df.xs(simulation, level="Simulation", axis=1)
            title = f"{title} - {simulation}"
        elif "Warming_Level" in profile_df.columns.names and warming_level:
            # Filter to specific warming level
            plot_data = profile_df.xs(warming_level, level="Warming_Level", axis=1)
            title = f"{title} - {warming_level}"
    
    # Convert hour labels back to numeric for better sorting
    hour_mapping = {
        '12am': 0, '1am': 1, '2am': 2, '3am': 3, '4am': 4, '5am': 5,
        '6am': 6, '7am': 7, '8am': 8, '9am': 9, '10am': 10, '11am': 11,
        '12pm': 12, '1pm': 13, '2pm': 14, '3pm': 15, '4pm': 16, '5pm': 17,
        '6pm': 18, '7pm': 19, '8pm': 20, '9pm': 21, '10pm': 22, '11pm': 23
    }
    
    # If columns are hour labels, reorder them
    if plot_data.columns[0] in hour_mapping:
        sorted_cols = sorted(plot_data.columns, key=lambda x: hour_mapping.get(x, 0))
        plot_data = plot_data[sorted_cols]
    
    # Create heatmap
    sns.heatmap(
        plot_data.T,  # Transpose so hours are on y-axis
        cmap=cmap,
        center=center,
        vmin=vmin,
        vmax=vmax,
        cbar_kws={"label": cbar_label},
        ax=ax,
        xticklabels=20,  # Show every 20th day label
        yticklabels=True,
    )
    
    ax.set_title(title, fontsize=14, fontweight='bold')
    ax.set_xlabel("Day of Year", fontsize=12)
    ax.set_ylabel("Hour of Day", fontsize=12)
    
    # Rotate x-axis labels for better readability
    plt.setp(ax.xaxis.get_majorticklabels(), rotation=45, ha='right')
    
    plt.tight_layout()
    return fig, ax


def plot_diurnal_cycle(
    profile_df: pd.DataFrame,
    months: Optional[List[str]] = None,
    title: str = "Average Diurnal Cycle",
    figsize: Tuple[int, int] = (12, 6),
    show_range: bool = True,
    simulation: Optional[str] = None,
) -> Tuple[Figure, Axes]:
    """
    Plot the average diurnal (daily) cycle for specified months.
    
    Parameters
    ----------
    profile_df : pd.DataFrame
        Climate profile DataFrame
    months : list of str, optional
        List of month abbreviations to include (e.g., ['Jan', 'Feb', 'Mar'])
        If None, plots annual average
    title : str, optional
        Title for the plot
    figsize : tuple, optional
        Figure size
    show_range : bool, optional
        Whether to show min/max range as shaded area
    simulation : str, optional
        If MultiIndex, which simulation to plot
        
    Returns
    -------
    fig, ax : tuple
        Matplotlib figure and axes objects
    """
    fig, ax = plt.subplots(figsize=figsize)
    
    # Handle MultiIndex
    plot_data = profile_df.copy()
    if isinstance(profile_df.columns, pd.MultiIndex) and simulation:
        plot_data = profile_df.xs(simulation, level="Simulation", axis=1)
        title = f"{title} - {simulation}"
    
    # Filter by months if specified
    if months:
        month_mask = plot_data.index.str[:3].isin(months)
        plot_data = plot_data[month_mask]
        month_str = ", ".join(months)
        title = f"{title} ({month_str})"
    
    # Calculate statistics across days
    mean_cycle = plot_data.mean(axis=0)
    
    if show_range:
        min_cycle = plot_data.min(axis=0)
        max_cycle = plot_data.max(axis=0)
        std_cycle = plot_data.std(axis=0)
    
    # Convert hour labels to numeric for plotting
    hours = np.arange(len(mean_cycle))
    
    # Plot mean line
    ax.plot(hours, mean_cycle, 'b-', linewidth=2, label='Mean')
    
    if show_range:
        # Show standard deviation range
        ax.fill_between(
            hours,
            mean_cycle - std_cycle,
            mean_cycle + std_cycle,
            alpha=0.3,
            color='blue',
            label='±1 std dev'
        )
        # Show min/max range
        ax.fill_between(
            hours,
            min_cycle,
            max_cycle,
            alpha=0.1,
            color='gray',
            label='Min/Max range'
        )
    
    ax.set_xlabel("Hour of Day", fontsize=12)
    ax.set_ylabel("Temperature (°F)", fontsize=12)
    ax.set_title(title, fontsize=14, fontweight='bold')
    ax.grid(True, alpha=0.3)
    ax.legend(loc='best')
    
    # Set x-axis labels
    ax.set_xticks(hours[::2])  # Show every other hour
    ax.set_xticklabels(plot_data.columns[::2], rotation=45, ha='right')
    
    plt.tight_layout()
    return fig, ax


def plot_seasonal_profiles(
    profile_df: pd.DataFrame,
    title: str = "Seasonal Climate Profiles",
    figsize: Tuple[int, int] = (14, 10),
    simulation: Optional[str] = None,
) -> Tuple[Figure, np.ndarray]:
    """
    Create a 2x2 subplot showing seasonal diurnal cycles.
    
    Parameters
    ----------
    profile_df : pd.DataFrame
        Climate profile DataFrame
    title : str, optional
        Main title for the figure
    figsize : tuple, optional
        Figure size
    simulation : str, optional
        If MultiIndex, which simulation to plot
        
    Returns
    -------
    fig, axes : tuple
        Matplotlib figure and array of axes
    """
    seasons = {
        'Winter (DJF)': ['Dec', 'Jan', 'Feb'],
        'Spring (MAM)': ['Mar', 'Apr', 'May'],
        'Summer (JJA)': ['Jun', 'Jul', 'Aug'],
        'Fall (SON)': ['Sep', 'Oct', 'Nov']
    }
    
    fig, axes = plt.subplots(2, 2, figsize=figsize)
    axes = axes.flatten()
    
    # Handle MultiIndex
    plot_data = profile_df.copy()
    if isinstance(profile_df.columns, pd.MultiIndex) and simulation:
        plot_data = profile_df.xs(simulation, level="Simulation", axis=1)
        title = f"{title} - {simulation}"
    
    for idx, (season_name, months) in enumerate(seasons.items()):
        ax = axes[idx]
        
        # Filter by season months
        month_mask = plot_data.index.str[:3].isin(months)
        season_data = plot_data[month_mask]
        
        # Calculate statistics
        mean_cycle = season_data.mean(axis=0)
        std_cycle = season_data.std(axis=0)
        
        hours = np.arange(len(mean_cycle))
        
        # Plot
        ax.plot(hours, mean_cycle, linewidth=2, label='Mean')
        ax.fill_between(
            hours,
            mean_cycle - std_cycle,
            mean_cycle + std_cycle,
            alpha=0.3,
            label='±1 std dev'
        )
        
        ax.set_title(season_name, fontsize=12, fontweight='bold')
        ax.set_xlabel("Hour of Day", fontsize=10)
        ax.set_ylabel("Temperature (°F)", fontsize=10)
        ax.grid(True, alpha=0.3)
        ax.legend(loc='best', fontsize=8)
        
        # Set x-axis labels
        ax.set_xticks(hours[::4])  # Show every 4th hour
        ax.set_xticklabels(plot_data.columns[::4], rotation=45, ha='right')
    
    fig.suptitle(title, fontsize=14, fontweight='bold')
    plt.tight_layout()
    return fig, axes


def plot_comparison_profiles(
    profiles: dict,
    month: Optional[str] = None,
    title: str = "Climate Profile Comparison",
    figsize: Tuple[int, int] = (12, 6),
    colors: Optional[List[str]] = None,
) -> Tuple[Figure, Axes]:
    """
    Compare multiple climate profiles on the same plot.
    
    Parameters
    ----------
    profiles : dict
        Dictionary with keys as labels and values as profile DataFrames
    month : str, optional
        Month abbreviation to filter (e.g., 'Jul')
    title : str, optional
        Plot title
    figsize : tuple, optional
        Figure size
    colors : list, optional
        List of colors to use for each profile
        
    Returns
    -------
    fig, ax : tuple
        Matplotlib figure and axes objects
    """
    fig, ax = plt.subplots(figsize=figsize)
    
    if colors is None:
        colors = plt.cm.tab10(np.linspace(0, 1, len(profiles)))
    
    for idx, (label, profile_df) in enumerate(profiles.items()):
        plot_data = profile_df.copy()
        
        # Filter by month if specified
        if month:
            month_mask = plot_data.index.str[:3] == month
            plot_data = plot_data[month_mask]
            if month not in title:
                title = f"{title} ({month})"
        
        # Calculate mean diurnal cycle
        mean_cycle = plot_data.mean(axis=0)
        hours = np.arange(len(mean_cycle))
        
        # Plot
        ax.plot(
            hours,
            mean_cycle,
            linewidth=2,
            label=label,
            color=colors[idx]
        )
    
    ax.set_xlabel("Hour of Day", fontsize=12)
    ax.set_ylabel("Temperature Change (°F)", fontsize=12)
    ax.set_title(title, fontsize=14, fontweight='bold')
    ax.grid(True, alpha=0.3)
    ax.legend(loc='best')
    
    # Set x-axis labels (assuming all profiles have same structure)
    first_profile = list(profiles.values())[0]
    ax.set_xticks(hours[::2])
    ax.set_xticklabels(first_profile.columns[::2], rotation=45, ha='right')
    
    plt.tight_layout()
    return fig, ax


def plot_monthly_summary(
    profile_df: pd.DataFrame,
    title: str = "Monthly Temperature Summary",
    figsize: Tuple[int, int] = (12, 6),
    show_variability: bool = True,
) -> Tuple[Figure, Axes]:
    """
    Plot monthly average temperatures with variability bands.
    
    Parameters
    ----------
    profile_df : pd.DataFrame
        Climate profile DataFrame
    title : str, optional
        Plot title
    figsize : tuple, optional
        Figure size
    show_variability : bool, optional
        Whether to show variability bands
        
    Returns
    -------
    fig, ax : tuple
        Matplotlib figure and axes objects
    """
    fig, ax = plt.subplots(figsize=figsize)
    
    # Group by month
    months_order = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun',
                    'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']
    
    monthly_stats = []
    for month in months_order:
        month_mask = profile_df.index.str[:3] == month
        month_data = profile_df[month_mask]
        
        if len(month_data) > 0:
            monthly_stats.append({
                'month': month,
                'mean': month_data.values.mean(),
                'std': month_data.values.std(),
                'min': month_data.values.min(),
                'max': month_data.values.max(),
                'q25': np.percentile(month_data.values, 25),
                'q75': np.percentile(month_data.values, 75)
            })
    
    monthly_df = pd.DataFrame(monthly_stats)
    x_pos = np.arange(len(monthly_df))
    
    # Plot mean line
    ax.plot(x_pos, monthly_df['mean'], 'b-o', linewidth=2, markersize=8, label='Mean')
    
    if show_variability:
        # Show interquartile range
        ax.fill_between(
            x_pos,
            monthly_df['q25'],
            monthly_df['q75'],
            alpha=0.3,
            color='blue',
            label='IQR (25-75%)'
        )
        # Show full range
        ax.fill_between(
            x_pos,
            monthly_df['min'],
            monthly_df['max'],
            alpha=0.1,
            color='gray',
            label='Min-Max'
        )
    
    ax.set_xticks(x_pos)
    ax.set_xticklabels(monthly_df['month'])
    ax.set_xlabel("Month", fontsize=12)
    ax.set_ylabel("Temperature (°F)", fontsize=12)
    ax.set_title(title, fontsize=14, fontweight='bold')
    ax.grid(True, alpha=0.3, axis='y')
    ax.legend(loc='best')
    
    plt.tight_layout()
    return fig, ax


# Example usage function
def create_climate_profile_report(
    profile_df: pd.DataFrame,
    output_dir: Optional[str] = None,
    profile_name: str = "climate_profile"
) -> None:
    """
    Create a comprehensive report with multiple climate profile visualizations.
    
    Parameters
    ----------
    profile_df : pd.DataFrame
        Climate profile DataFrame from get_climate_profile
    output_dir : str, optional
        Directory to save plots. If None, displays plots
    profile_name : str, optional
        Base name for saved files
    """
    # Create all visualizations
    figs = []
    
    # 1. Heatmap
    fig1, _ = plot_climate_profile_heatmap(
        profile_df,
        title=f"{profile_name} - Annual Heatmap"
    )
    figs.append(('heatmap', fig1))
    
    # 2. Annual diurnal cycle
    fig2, _ = plot_diurnal_cycle(
        profile_df,
        title=f"{profile_name} - Annual Average Diurnal Cycle"
    )
    figs.append(('diurnal_annual', fig2))
    
    # 3. Seasonal profiles
    fig3, _ = plot_seasonal_profiles(
        profile_df,
        title=f"{profile_name} - Seasonal Profiles"
    )
    figs.append(('seasonal', fig3))
    
    # 4. Monthly summary
    fig4, _ = plot_monthly_summary(
        profile_df,
        title=f"{profile_name} - Monthly Summary"
    )
    figs.append(('monthly', fig4))
    
    # Save or show figures
    if output_dir:
        import os
        os.makedirs(output_dir, exist_ok=True)
        for name, fig in figs:
            filepath = os.path.join(output_dir, f"{profile_name}_{name}.png")
            fig.savefig(filepath, dpi=150, bbox_inches='tight')
            print(f"Saved: {filepath}")
            plt.close(fig)
    else:
        plt.show()


In [None]:
create_climate_profile_report(profile, output_dir="./", profile_name="my_profile")