# Seaborn Chart Examples

This notebook creates various example charts using the seaborn library:
1. Mathematical function plots (sin, cos, tan, etc.)
2. Random walk time series simulation
3. Bar charts (simple, stacked, percentage stacked, and multiple)

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import math
from numpy.random import normal
from IPython.display import display, clear_output
import ipywidgets as widgets

# Set the aesthetic style of the plots
sns.set_theme(style="whitegrid")

## Mathematical Function Plots

In [None]:
def plot_math_functions():
    """Create line charts based on mathematical functions."""
    # Generate x values
    x = np.linspace(-2*np.pi, 2*np.pi, 1000)
    
    # Create a DataFrame with various math functions
    df = pd.DataFrame({
        'x': x,
        'sin(x)': np.sin(x),
        'cos(x)': np.cos(x),
        'tan(x)': np.tan(x),
        'x²': x**2,
        'log(|x|+1)': np.log(np.abs(x) + 1),
        'exp(x/4)': np.exp(x/4)
    })
    
    # Melt the DataFrame for easier plotting with seaborn
    df_melted = pd.melt(df, id_vars=['x'], var_name='function', value_name='y')
    
    # Create separate plots for each function to avoid scale issues
    functions = ['sin(x)', 'cos(x)', 'tan(x)', 'x²', 'log(|x|+1)', 'exp(x/4)']
    
    fig, axes = plt.subplots(3, 2, figsize=(15, 12))
    axes = axes.flatten()
    
    for i, func in enumerate(functions):
        subset = df_melted[df_melted['function'] == func]
        
        # For tan(x), limit the y-range to avoid extreme values
        if func == 'tan(x)':
            subset = subset[(subset['y'] > -10) & (subset['y'] < 10)]
        
        sns.lineplot(x='x', y='y', data=subset, ax=axes[i], linewidth=2.5)
        axes[i].set_title(f'{func}', fontsize=14)
        axes[i].axhline(y=0, color='gray', linestyle='-', alpha=0.3)
        axes[i].axvline(x=0, color='gray', linestyle='-', alpha=0.3)
        
        # Add pi markers on x-axis
        axes[i].set_xticks([-2*np.pi, -np.pi, 0, np.pi, 2*np.pi])
        axes[i].set_xticklabels(['-2π', '-π', '0', 'π', '2π'])
    
    plt.tight_layout()
    plt.close()
    return fig


## Random Walk Time Series

In [None]:
def random_walk(steps=1000, step_size=0.1):
    """Generate a random walk time series."""
    # Generate random steps with normal distribution
    steps = normal(loc=0, scale=step_size, size=steps)
    
    # Calculate the walk by taking the cumulative sum
    walk = np.cumsum(steps)
    
    # Create a time index
    time = np.arange(len(walk))
    
    return pd.DataFrame({'time': time, 'value': walk})

def plot_random_walk():
    """Create a line chart based on a random walk function."""
    # Generate multiple random walks
    walks = []
    for i in range(5):
        df = random_walk(steps=500, step_size=0.1)
        df['series'] = f'Series {i+1}'
        walks.append(df)
    
    # Combine all walks
    all_walks = pd.concat(walks)
    
    # Plot the random walks
    fig, ax = plt.subplots(figsize=(12, 6))
    sns.lineplot(x='time', y='value', hue='series', data=all_walks, linewidth=1.5)
    
    plt.title('Random Walk Time Series Simulation', fontsize=16)
    plt.xlabel('Time', fontsize=12)
    plt.ylabel('Value', fontsize=12)
    plt.grid(True, alpha=0.3)
    plt.legend(title='')
    
    plt.tight_layout()
    plt.close()
    return fig

## Bar Charts

In [None]:
def create_sample_data():
    """Create sample datasets for bar charts."""
    # Sample data for simple bar chart
    categories = ['Category A', 'Category B', 'Category C', 'Category D', 'Category E']
    values = [25, 40, 30, 55, 15]
    
    simple_data = pd.DataFrame({
        'Category': categories,
        'Value': values
    })
    
    # Sample data for grouped and stacked bar charts
    groups = ['Group 1', 'Group 2', 'Group 3', 'Group 4']
    products = ['Product X', 'Product Y', 'Product Z']
    
    # Create a more complex dataset with multiple variables
    data = []
    np.random.seed(42)  # For reproducibility
    
    for group in groups:
        for product in products:
            sales = np.random.randint(10, 100)
            profit = np.random.randint(5, 30)
            returns = np.random.randint(1, 10)
            
            data.append({
                'Group': group,
                'Product': product,
                'Sales': sales,
                'Profit': profit,
                'Returns': returns
            })
    
    complex_data = pd.DataFrame(data)
    
    # Sample data for percentage stacked bar chart
    regions = ['North', 'South', 'East', 'West']
    segments = ['Segment A', 'Segment B', 'Segment C']
    
    pct_data = []
    for region in regions:
        # Ensure percentages will sum to 100
        pcts = np.random.randint(10, 50, size=len(segments))
        pcts = (pcts / pcts.sum() * 100).astype(int)
        
        # Adjust to ensure sum is 100
        pcts[-1] = 100 - pcts[:-1].sum()
        
        for i, segment in enumerate(segments):
            pct_data.append({
                'Region': region,
                'Segment': segment,
                'Percentage': pcts[i]
            })
    
    percentage_data = pd.DataFrame(pct_data)
    
    return simple_data, complex_data, percentage_data

In [None]:
def plot_simple_bar_chart(data):
    """Create a simple bar chart."""
    fig, ax = plt.subplots(figsize=(10, 6))
    
    # Create the bar chart - handle both old and new seaborn API
    try:
        # New seaborn API (v0.12+)
        sns.barplot(x='Category', y='Value', data=data, palette='viridis', errorbar=None)
    except TypeError:
        # Old seaborn API
        sns.barplot(x='Category', y='Value', data=data, palette='viridis')
    
    plt.title('Simple Bar Chart', fontsize=16)
    plt.xlabel('Category', fontsize=12)
    plt.ylabel('Value', fontsize=12)
    plt.grid(axis='y', alpha=0.3)
    
    # Add value labels on top of bars
    for i, v in enumerate(data['Value']):
        plt.text(i, v + 1, str(v), ha='center', fontsize=10)
    
    plt.tight_layout()
    plt.close()
    return fig

def plot_grouped_bar_chart(data):
    """Create a grouped bar chart."""
    fig, ax = plt.subplots(figsize=(12, 7))
    
    # Create the grouped bar chart - handle both old and new seaborn API
    try:
        # New seaborn API (v0.12+)
        sns.barplot(x='Group', y='Sales', hue='Product', data=data, palette='Set2', errorbar=None)
    except TypeError:
        # Old seaborn API
        sns.barplot(x='Group', y='Sales', hue='Product', data=data, palette='Set2')
    
    plt.title('Grouped Bar Chart - Sales by Group and Product', fontsize=16)
    plt.xlabel('Group', fontsize=12)
    plt.ylabel('Sales', fontsize=12)
    plt.grid(axis='y', alpha=0.3)
    plt.legend(title='Product', loc='upper right')
    
    plt.tight_layout()
    plt.close()
    return fig

def plot_stacked_bar_chart(data):
    """Create a stacked bar chart."""
    # Pivot the data for stacking
    pivot_data = data.pivot_table(
        index='Group', 
        columns='Product', 
        values='Sales',
        aggfunc='sum'
    )
    
    # Plot the stacked bar chart
    fig, ax = plt.subplots(figsize=(12, 7))
    pivot_data.plot(kind='bar', stacked=True, figsize=(12, 7), colormap='Set3', ax=ax)
    
    plt.title('Stacked Bar Chart - Sales by Group and Product', fontsize=16)
    plt.xlabel('Group', fontsize=12)
    plt.ylabel('Sales', fontsize=12)
    plt.grid(axis='y', alpha=0.3)
    plt.legend(title='Product', loc='upper right')
    
    # Add total labels on top of stacked bars
    for i, total in enumerate(pivot_data.sum(axis=1)):
        plt.text(i, total + 1, f'Total: {total}', ha='center', fontsize=10)
    
    plt.tight_layout()
    plt.close()
    return fig

def plot_percentage_stacked_bar(data):
    """Create a percentage stacked bar chart."""
    # Pivot the data
    pivot_data = data.pivot_table(
        index='Region', 
        columns='Segment', 
        values='Percentage',
        aggfunc='sum'
    )
    
    # Plot the percentage stacked bar chart
    fig, ax = plt.subplots(figsize=(12, 7))
    pivot_data.plot(kind='bar', stacked=True, figsize=(12, 7), colormap='tab10', ax=ax)
    
    plt.title('Percentage Stacked Bar Chart - Market Segments by Region', fontsize=16)
    plt.xlabel('Region', fontsize=12)
    plt.ylabel('Percentage', fontsize=12)
    plt.grid(axis='y', alpha=0.3)
    plt.legend(title='Segment', loc='upper right')
    
    # Add percentage labels in the middle of each segment
    for i, (idx, row) in enumerate(pivot_data.iterrows()):
        cumulative = 0
        for col, val in row.items():
            # Position the text in the middle of each segment
            y_pos = cumulative + val/2
            plt.text(i, y_pos, f'{val}%', ha='center', va='center', fontsize=10)
            cumulative += val
    
    plt.tight_layout()
    plt.close()
    return fig

def plot_multiple_bar_charts(data):
    """Create multiple bar charts in a single figure."""
    # Create a figure with subplots
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    
    # Plot Sales by Group for each subplot (different products)
    for i, product in enumerate(['Product X', 'Product Y', 'Product Z']):
        product_data = data[data['Product'] == product]
        
        # Handle both old and new seaborn API
        # Use predefined color palettes instead of dynamic names
        palette_choices = ['Blues', 'Greens', 'Oranges']
        
        try:
            # New seaborn API (v0.12+)
            sns.barplot(x='Group', y='Sales', data=product_data, ax=axes[i], 
                        palette=palette_choices[i], errorbar=None)
        except TypeError:
            # Old seaborn API
            sns.barplot(x='Group', y='Sales', data=product_data, ax=axes[i], 
                        palette=palette_choices[i])
        
        axes[i].set_title(f'Sales for {product}', fontsize=14)
        axes[i].set_xlabel('Group', fontsize=12)
        axes[i].set_ylabel('Sales', fontsize=12)
        axes[i].grid(axis='y', alpha=0.3)
        
        # Add value labels
        for j, v in enumerate(product_data['Sales']):
            axes[i].text(j, v + 1, str(v), ha='center', fontsize=9)
    
    plt.tight_layout()
    plt.close()
    return fig

## More Sample Data

In [None]:
def create_more_sample_data():
    """Create sample data for various chart types"""
    # Create a DataFrame with multiple variables for different chart types
    np.random.seed(42)
    
    # For scatterplots
    n = 100
    x = np.random.normal(size=n)
    y = x + np.random.normal(size=n, scale=0.5)
    categories = np.random.choice(['A', 'B', 'C', 'D'], size=n)
    sizes = np.random.uniform(10, 200, size=n)
    
    scatter_df = pd.DataFrame({
        'x': x,
        'y': y,
        'category': categories,
        'size': sizes
    })
    
    # For boxplots
    box_data = pd.DataFrame({
        'group': np.repeat(['A', 'B', 'C', 'D', 'E'], 30),
        'value': np.concatenate([
            np.random.normal(0, 1, 30),
            np.random.normal(2, 1.5, 30),
            np.random.normal(4, 1, 30),
            np.random.normal(1.5, 2, 30),
            np.random.normal(3, 1, 30)
        ]),
        'subgroup': np.tile(np.repeat(['X', 'Y', 'Z'], 10), 5)
    })
    
    # For candlestick data
    dates = pd.date_range(start='2023-01-01', periods=30, freq='B')
    
    candlestick_data = pd.DataFrame({
        'date': dates,
        'open': np.random.uniform(100, 150, size=30),
        'close': np.random.uniform(100, 150, size=30),
        'high': np.zeros(30),
        'low': np.zeros(30)
    })
    
    # Ensure high is always the highest and low is always the lowest
    for i in range(len(candlestick_data)):
        op = candlestick_data.loc[i, 'open']
        cl = candlestick_data.loc[i, 'close']
        candlestick_data.loc[i, 'high'] = max(op, cl) + np.random.uniform(1, 10)
        candlestick_data.loc[i, 'low'] = min(op, cl) - np.random.uniform(1, 10)
    
    return scatter_df, box_data, candlestick_data


## Scatterplots

### Basic Scatterplot

In [None]:
def plot_basic_scatterplot(data):
    """Create a basic scatterplot."""
    fig, ax = plt.subplots(figsize=(10, 6))
    
    # Create the scatterplot
    sns.scatterplot(x='x', y='y', data=data, ax=ax, color='blue', alpha=0.7)
    
    plt.title('Basic Scatterplot', fontsize=16)
    plt.xlabel('X-axis', fontsize=12)
    plt.ylabel('Y-axis', fontsize=12)
    plt.grid(alpha=0.3)
    
    plt.tight_layout()
    plt.close()
    return fig

## Categorical Scatterplot

In [None]:
def plot_categorical_scatterplot(data):
    """Create a scatterplot with categorical variables."""
    fig, ax = plt.subplots(figsize=(10, 6))
    
    # Create the scatterplot
    sns.scatterplot(x='x', y='y', hue='category', data=data, ax=ax, palette='Set2', alpha=0.8)
    
    plt.title('Scatterplot with Categorical Variables', fontsize=16)
    plt.xlabel('X-axis', fontsize=12)
    plt.ylabel('Y-axis', fontsize=12)
    plt.legend(title='Category', loc='upper right')
    plt.grid(alpha=0.3)
    
    plt.tight_layout()
    plt.close()
    return fig

### Bubble Chart

In [None]:
def plot_bubble_chart(data):
    """Create a bubble chart."""
    fig, ax = plt.subplots(figsize=(10, 6))
    
    # Create the bubble chart
    sns.scatterplot(x='x', y='y', size='size', hue='category', data=data, ax=ax, 
                    sizes=(20, 200), palette='coolwarm', alpha=0.8)
    
    plt.title('Bubble Chart', fontsize=16)
    plt.xlabel('X-axis', fontsize=12)
    plt.ylabel('Y-axis', fontsize=12)
    plt.legend(title='Category', loc='upper right', bbox_to_anchor=(1.2, 1))
    plt.grid(alpha=0.3)
    
    plt.tight_layout()
    plt.close()
    return fig

## Boxplots

### Basic

In [None]:
def plot_horizontal_boxplot(data):
    """Create a horizontal boxplot"""
    fig, ax = plt.subplots(figsize=(12, 6))
    sns.boxplot(x='value', y='group', data=data, orient='h')
    plt.title('Horizontal Boxplot')
    plt.tight_layout()
    plt.close()
    return fig



### Grouped Boxplot

In [None]:
def plot_grouped_horizontal_boxplot(data):
    """Create a grouped horizontal boxplot."""
    fig, ax = plt.subplots(figsize=(10, 6))
    
    # Create the horizontal boxplot
    sns.boxplot(y='group', x='value', hue='subgroup', data=data, ax=ax, palette='Set2')
    
    plt.title('Grouped Horizontal Boxplot', fontsize=16)
    plt.xlabel('Value', fontsize=12)
    plt.ylabel('Group', fontsize=12)
    plt.legend(title='Subgroup', loc='upper right')
    plt.grid(axis='x', alpha=0.3)
    
    plt.tight_layout()
    plt.close()
    return fig

## Violin Boxplot

In [None]:
def plot_horizontal_violin_boxplot(data):
    """Create a horizontal violin plot with boxplot inside."""
    fig, ax = plt.subplots(figsize=(12, 8))
    
    # Create the horizontal violin plot with boxplot inside
    sns.violinplot(x='value', y='group', data=data, inner='box', orient='h', palette='muted')
    
    plt.title('Horizontal Violin Plot with Boxplot', fontsize=16)
    plt.xlabel('Value', fontsize=12)
    plt.ylabel('Group', fontsize=12)
    plt.grid(axis='x', alpha=0.3)
    
    plt.tight_layout()
    plt.close()
    return fig

## Candlestick

In [None]:
def plot_candlestick_chart(data):
    """Create a candlestick chart using matplotlib"""
    fig, ax = plt.subplots(figsize=(14, 8))
    
    # Format the x-axis to show dates nicely
    ax.xaxis.set_major_formatter(plt.matplotlib.dates.DateFormatter('%Y-%m-%d'))
    plt.xticks(rotation=45)
    
    # Width of the candlesticks
    width = 0.6
    width2 = 0.1
    
    # Define up and down colors
    up_color = 'green'
    down_color = 'red'
    
    # Plot the candlesticks
    for i, row in data.iterrows():
        # Use the right color depending on if the stock closed higher or lower
        color = up_color if row['close'] >= row['open'] else down_color
        
        # Plot the price range line (high to low)
        ax.plot([row['date'], row['date']], [row['low'], row['high']], 
                color=color, linewidth=1)
        
        # Plot the open-close body
        ax.bar(row['date'], height=abs(row['close'] - row['open']), 
               bottom=min(row['open'], row['close']), width=width, 
               color=color, alpha=0.7)
    
    ax.set_title('Candlestick Chart')
    ax.set_xlabel('Date')
    ax.set_ylabel('Price')
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.close()
    return fig


## Interactive Chart Selector

In [None]:
# Create sample data for bar charts
simple_data, complex_data, percentage_data = create_sample_data()
# Create the rest of the sample data
scatter_data, box_data, candlestick_data = create_more_sample_data()

# Define a function to display the selected chart
def display_chart(chart_type):
    plt.close('all')  # Close any existing plots
    clear_output(wait=True)
    
    if chart_type == 'Math Functions':
        fig = plot_math_functions()
    elif chart_type == 'Random Walk':
        fig = plot_random_walk()
    elif chart_type == 'Simple Bar Chart':
        fig = plot_simple_bar_chart(simple_data)
    elif chart_type == 'Grouped Bar Chart':
        fig = plot_grouped_bar_chart(complex_data)
    elif chart_type == 'Stacked Bar Chart':
        fig = plot_stacked_bar_chart(complex_data)
    elif chart_type == 'Percentage Stacked Bar':
        fig = plot_percentage_stacked_bar(percentage_data)
    elif chart_type == 'Multiple Bar Charts':
        fig = plot_multiple_bar_charts(complex_data)
    elif chart_type == 'Basic Scatterplot':
        fig = plot_basic_scatterplot(scatter_data)
    elif chart_type == 'Categorical Scatterplot':
        fig = plot_categorical_scatterplot(scatter_data)
    elif chart_type == 'Bubble Chart':
        fig = plot_bubble_chart(scatter_data)
    elif chart_type == 'Horizontal Boxplot':
        fig = plot_horizontal_boxplot(box_data)
    elif chart_type == 'Grouped Horizontal Boxplot':
        fig = plot_grouped_horizontal_boxplot(box_data)
    elif chart_type == 'Horizontal Violin Boxplot':
        fig = plot_horizontal_violin_boxplot(box_data)
    elif chart_type == 'Candlestick Chart':
        fig = plot_candlestick_chart(candlestick_data)
    
    #display(fig)
    return fig

# Create a dropdown widget
chart_dropdown = widgets.Dropdown(
    options=[
        'Math Functions',
        'Random Walk',
        'Simple Bar Chart',
        'Grouped Bar Chart',
        'Stacked Bar Chart',
        'Percentage Stacked Bar',
        'Multiple Bar Charts',
        'Basic Scatterplot',
        'Categorical Scatterplot',
        'Bubble Chart',
        'Horizontal Boxplot',
        'Grouped Horizontal Boxplot',
        'Horizontal Violin Boxplot',
        'Candlestick Chart'
    ],
    value='Math Functions',
    description='Chart Type:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(width='50%')
)

# Create an output widget to display the chart
output = widgets.Output()

# Define the callback function for the dropdown
def on_change(change):
    with output:
        display(display_chart(change.new))

# Register the callback
chart_dropdown.observe(on_change, names='value')

# Display the initial chart
with output:
    display(display_chart(chart_dropdown.value))

# Display the widget and output
display(widgets.VBox([chart_dropdown, output]))