# Color Pallete

Implement a color palette class

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib import gridspec

class ColorPalette:
    """
    A class to handle matplotlib color palettes and extract hex colors.
    
    Attributes:
        palette: matplotlib colormap or palette name
        num_colors: number of colors to extract from the palette
        data_min: minimum data value
        data_max: maximum data value
    """
    
    def __init__(self, palette, num_colors, data_min=None, data_max=None):
        """
        Initialize the ColorPalette class.
        
        Parameters:
        -----------
        palette : str or matplotlib.colors.Colormap
            The color palette/colormap name (e.g., 'viridis', 'tab10') or colormap object
        num_colors : int
            Number of colors to extract from the palette
        data_min : float, optional
            Minimum data value (e.g., minimum wind speed). If None, defaults to 0
        data_max : float, optional
            Maximum data value (e.g., maximum wind speed). If None, defaults to num_colors
        """
        self.palette = palette
        self.num_colors = num_colors
        self.data_min = data_min if data_min is not None else 0
        self.data_max = data_max if data_max is not None else num_colors
        self._colormap = self._get_colormap()
        
    def _get_colormap(self):
        """Get the matplotlib colormap object."""
        if isinstance(self.palette, str):
            try:
                return plt.get_cmap(self.palette)
            except ValueError:
                raise ValueError(f"Unknown colormap name: {self.palette}")
        else:
            return self.palette
    
    def get_hex_colors(self):
        """
        Extract hex colors from the palette.
        
        Returns:
        --------
        list: Array of hex color strings with length num_colors
        """
        # Create linspace from 0 to 1 with num_colors points
        color_indices = np.linspace(0, 1, self.num_colors)
        
        # Sample colors from the colormap
        colors_rgba = self._colormap(color_indices)
        
        # Convert RGBA to hex
        hex_colors = []
        for rgba in colors_rgba:
            rgb = rgba[:3]
            hex_color = mcolors.to_hex(rgb)
            hex_colors.append(hex_color)
        
        return hex_colors
    
    def get_color_for_value(self, value):
        """
        Get the hex color corresponding to a specific data value.
        
        Parameters:
        -----------
        value : float
            The data value to map to a color
            
        Returns:
        --------
        str: Hex color string corresponding to the value
        """
        # Normalize the value to [0, 1] range
        if self.data_max == self.data_min:
            normalized_value = 0.5
        else:
            normalized_value = (value - self.data_min) / (self.data_max - self.data_min)
        
        # Clamp to [0, 1] range
        normalized_value = max(0, min(1, normalized_value))
        
        # Get color from colormap
        rgba = self._colormap(normalized_value)
        return mcolors.to_hex(rgba[:3])
    
    def plot_palette(self, orientation='horizontal', figsize=None, title=None, label_size=20):
        """
        Create a colorbar visualization of the palette with data range axis.
        
        Parameters:
        -----------
        orientation : str, optional
            'horizontal' or 'vertical'. Default is 'horizontal'
        figsize : tuple, optional
            Figure size (width, height). Auto-calculated if None
        title : str, optional
            Title for the plot. If None, uses palette name
        label_size : int, optional
            Font size for the colorbar label. Default is 12
        
        Returns:
        --------
        matplotlib.figure.Figure: The created figure object
        """
        # Set default figure size based on orientation
        if figsize is None:
            figsize = (10, 2) if orientation == 'horizontal' else (2, 8)
        
        # Create figure
        fig, ax = plt.subplots(figsize=figsize, layout='constrained')
        
        # Create normalization from data_min to data_max
        norm = mcolors.Normalize(vmin=self.data_min, vmax=self.data_max)
        
        # Create ScalarMappable
        scalar_mappable = plt.cm.ScalarMappable(norm=norm, cmap=self._colormap)
        
        # Set title
        #if title is None:
            #title = f'{self.palette} Palette | Range: {self.data_min} - {self.data_max}'
        
        # Create colorbar with data range as axis
        cbar = fig.colorbar(scalar_mappable,
                            cax=ax,
                            orientation=orientation,
                            extend='both',
                           #aspect=1,
                           shrink=0.1)
        cbar.set_label(title, fontsize=label_size)
        cbar.ax.tick_params(labelsize=label_size)
        
        # Set axis limits to match data range
        if orientation == 'horizontal':
            ax.set_xlim(self.data_min, self.data_max)
        else:
            ax.set_ylim(self.data_min, self.data_max)
        
        return fig

    
    def plot_palette(self,figsize=None, title=None,
                      title_size=16,
                      label_size=20,
                     width_ratio=0.8):
        
        if figsize is None:
            figsize = (2, 8)
    
        fig = plt.figure(figsize=figsize)

        width_ratio_aux = (1-width_ratio)/2
        # left and right columns are spacers
        gs = gridspec.GridSpec(1, 3, width_ratios=[width_ratio_aux, width_ratio, width_ratio_aux]) 
        cax = fig.add_subplot(gs[1])
    
        norm = mcolors.Normalize(vmin=self.data_min, vmax=self.data_max)
        scalar_mappable = plt.cm.ScalarMappable(norm=norm, cmap=self._colormap)
    
        cbar = fig.colorbar(scalar_mappable,
                            cax=cax,
                            orientation='vertical',
                            extend='both')
    
        cbar.set_label(title, fontsize=title_size)
        cbar.ax.tick_params(labelsize=label_size)
    
        return fig
        
    def save_plot(self, filename='colorbar', orientation='vertical', figsize=None, title=None, 
                  label_size=12, dpi=300, bbox_inches='tight'):
        """
        Create and save a colorbar visualization of the palette.
        
        Parameters:
        -----------
        filename : str
            Output filename (e.g., 'palette.png', 'colorbar.pdf')
        orientation : str, optional
            'horizontal' or 'vertical'. Default is 'horizontal'
        figsize : tuple, optional
            Figure size (width, height). Auto-calculated if None
        title : str, optional
            Title for the plot. If None, uses palette name
        label_size : int, optional
            Font size for the colorbar label. Default is 12
        dpi : int, optional
            Resolution for output image. Default is 300
        bbox_inches : str, optional
            Bounding box in inches. Default is 'tight'
        """
        fig = self.plot_palette(orientation=orientation, figsize=figsize, 
                               title=title, label_size=label_size)
        fig.savefig(filename, dpi=dpi, bbox_inches=bbox_inches)
        plt.close(fig)  # Close figure to free memory
        print(f"Palette saved as: {filename}")
    
    def __repr__(self):
        """String representation of the ColorPalette object."""
        return f"ColorPalette(palette='{self.palette}', num_colors={self.num_colors}, data_range=[{self.data_min}, {self.data_max}])"

In [None]:

discrete_palette = ColorPalette('viridis', 5, 0, 15)
print(f"Hex colors: {discrete_palette.get_hex_colors()}")
print(discrete_palette)

In [None]:
fig = discrete_palette.plot_palette3(orientation='vertical',
                                   label_size=16)