# Module 4: Data Visualization

**Author:** Chinmay Nadgir  
**Date:** October 2025  
**Purpose:** Demonstrate professional data visualization techniques for effective communication

---

## Table of Contents
1. [Introduction](#intro)
2. [Setup & Data Loading](#setup)
3. [Univariate Analysis](#univariate)
4. [Bivariate Analysis](#bivariate)
5. [Multivariate Analysis](#multivariate)
6. [Time Series Visualization](#timeseries)
7. [Advanced Visualizations](#advanced)
8. [Interactive Visualizations](#interactive)
9. [Visualization Best Practices](#bestpractices)
10. [Summary](#summary)

<a id='intro'></a>
## 1. Introduction

Data visualization transforms complex data into accessible visual insights. Effective visualizations communicate patterns, trends, and relationships clearly.

**Learning Objectives:**
- Create professional univariate, bivariate, and multivariate visualizations
- Apply appropriate chart types for different data types
- Design clear, interpretable visualizations
- Build interactive visualizations with Plotly
- Follow visualization best practices

<a id='setup'></a>
## 2. Setup & Data Loading

In [None]:
# Standard library imports
import warnings
from pathlib import Path

# Third-party imports
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Configuration
warnings.filterwarnings('ignore')
pd.set_option('display.max_columns', None)

# Matplotlib/Seaborn styling
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette('Set2')
plt.rcParams['figure.dpi'] = 100
plt.rcParams['font.size'] = 10

print(f"matplotlib: {plt.matplotlib.__version__}")
print(f"seaborn: {sns.__version__}")
print(f"plotly: {px.__version__}")

In [None]:
# Load cleaned data
data_dir = Path('data')

try:
    df = pd.read_csv(data_dir / 'cleaned_data.csv', parse_dates=['date'])
    print(f"Loaded cleaned data: {df.shape}")
except FileNotFoundError:
    print("Creating sample dataset...")
    data_dir.mkdir(exist_ok=True)
    np.random.seed(42)
    df = pd.DataFrame({
        'customer_id': range(1, 101),
        'age': np.random.randint(18, 70, 100),
        'purchase_amount': np.random.uniform(10, 1000, 100).round(2),
        'category': np.random.choice(['Electronics', 'Clothing', 'Food', 'Books'], 100),
        'date': pd.date_range('2024-01-01', periods=100, freq='D'),
        'loyalty_member': np.random.choice([True, False], 100)
    })

display(df.head())

<a id='univariate'></a>
## 3. Univariate Analysis

Univariate analysis examines single variables to understand their distributions.

In [None]:
# Histogram with KDE overlay for continuous variables
numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
numeric_cols = [col for col in numeric_cols if col != 'customer_id']

if numeric_cols:
    fig, axes = plt.subplots(1, len(numeric_cols), figsize=(14, 5))
    if len(numeric_cols) == 1:
        axes = [axes]
    
    for idx, col in enumerate(numeric_cols):
        axes[idx].hist(df[col], bins=20, color='skyblue', edgecolor='black', alpha=0.7)
        axes[idx].set_title(f'Distribution: {col}', fontweight='bold')
        axes[idx].set_xlabel(col)
        axes[idx].set_ylabel('Frequency')
        axes[idx].grid(axis='y', alpha=0.3)
        
        # Add KDE
        ax2 = axes[idx].twinx()
        df[col].plot(kind='kde', ax=ax2, color='red', linewidth=2)
        ax2.set_ylabel('Density', color='red')
        ax2.tick_params(axis='y', labelcolor='red')
    
    plt.suptitle('Histograms with KDE Overlay', fontsize=14, fontweight='bold', y=1.02)
    plt.tight_layout()
    plt.show()

In [None]:
# Box plots for outlier detection
if numeric_cols:
    fig, axes = plt.subplots(1, len(numeric_cols), figsize=(14, 5))
    if len(numeric_cols) == 1:
        axes = [axes]
    
    for idx, col in enumerate(numeric_cols):
        bp = axes[idx].boxplot(df[col].dropna(), vert=True, patch_artist=True,
                               boxprops=dict(facecolor='lightcoral', alpha=0.7),
                               medianprops=dict(color='darkred', linewidth=2))
        axes[idx].set_title(f'Box Plot: {col}', fontweight='bold')
        axes[idx].set_ylabel(col)
        axes[idx].grid(axis='y', alpha=0.3)
    
    plt.suptitle('Box Plots for Outlier Detection', fontsize=14, fontweight='bold', y=1.02)
    plt.tight_layout()
    plt.show()

In [None]:
# Bar charts for categorical variables
categorical_cols = df.select_dtypes(include=['object', 'category']).columns.tolist()

if categorical_cols:
    n_cats = len(categorical_cols)
    fig, axes = plt.subplots(1, n_cats, figsize=(6*n_cats, 5))
    if n_cats == 1:
        axes = [axes]
    
    for idx, col in enumerate(categorical_cols):
        value_counts = df[col].value_counts().sort_values(ascending=False)
        bars = axes[idx].bar(range(len(value_counts)), value_counts.values, 
                            color='steelblue', edgecolor='black', alpha=0.7)
        axes[idx].set_xticks(range(len(value_counts)))
        axes[idx].set_xticklabels(value_counts.index, rotation=45, ha='right')
        axes[idx].set_title(f'Count by {col}', fontweight='bold')
        axes[idx].set_xlabel(col)
        axes[idx].set_ylabel('Count')
        axes[idx].grid(axis='y', alpha=0.3)
        
        # Add percentage annotations
        total = value_counts.sum()
        for bar in bars:
            height = bar.get_height()
            pct = height / total * 100
            axes[idx].text(bar.get_x() + bar.get_width()/2., height,
                          f'{pct:.1f}%', ha='center', va='bottom', fontsize=9)
    
    plt.suptitle('Categorical Variable Distributions', fontsize=14, fontweight='bold', y=1.02)
    plt.tight_layout()
    plt.show()

<a id='bivariate'></a>
## 4. Bivariate Analysis

Bivariate analysis explores relationships between two variables.

In [None]:
# Scatter plot with regression line
if len(numeric_cols) >= 2:
    fig, ax = plt.subplots(figsize=(10, 6))
    
    x_col, y_col = numeric_cols[0], numeric_cols[1]
    ax.scatter(df[x_col], df[y_col], alpha=0.6, s=50, color='steelblue', edgecolors='black')
    
    # Add regression line
    z = np.polyfit(df[x_col], df[y_col], 1)
    p = np.poly1d(z)
    ax.plot(df[x_col], p(df[x_col]), "r-", linewidth=2, label=f'y={z[0]:.2f}x+{z[1]:.2f}')
    
    ax.set_xlabel(x_col, fontweight='bold', fontsize=12)
    ax.set_ylabel(y_col, fontweight='bold', fontsize=12)
    ax.set_title(f'{y_col} vs {x_col}', fontweight='bold', fontsize=14)
    ax.legend()
    ax.grid(alpha=0.3)
    plt.tight_layout()
    plt.show()

In [None]:
# Scatter plot with categorical color-coding
if len(numeric_cols) >= 2 and categorical_cols:
    fig, ax = plt.subplots(figsize=(10, 6))
    
    cat_col = categorical_cols[0]
    for category in df[cat_col].unique():
        subset = df[df[cat_col] == category]
        ax.scatter(subset[x_col], subset[y_col], label=category, alpha=0.6, s=50, edgecolors='black')
    
    ax.set_xlabel(x_col, fontweight='bold', fontsize=12)
    ax.set_ylabel(y_col, fontweight='bold', fontsize=12)
    ax.set_title(f'{y_col} vs {x_col} by {cat_col}', fontweight='bold', fontsize=14)
    ax.legend(title=cat_col)
    ax.grid(alpha=0.3)
    plt.tight_layout()
    plt.show()

In [None]:
# Grouped bar chart
if categorical_cols and numeric_cols:
    cat_col = categorical_cols[0]
    num_col = numeric_cols[0]
    
    if 'loyalty_member' in df.columns:
        grouped_data = df.groupby([cat_col, 'loyalty_member'])[num_col].mean().unstack()
        
        fig, ax = plt.subplots(figsize=(10, 6))
        grouped_data.plot(kind='bar', ax=ax, color=['steelblue', 'coral'], 
                         edgecolor='black', alpha=0.7)
        ax.set_title(f'Average {num_col} by {cat_col} and Loyalty Status', 
                    fontweight='bold', fontsize=14)
        ax.set_xlabel(cat_col, fontweight='bold')
        ax.set_ylabel(f'Average {num_col}', fontweight='bold')
        ax.legend(title='Loyalty Member', labels=['No', 'Yes'])
        ax.grid(axis='y', alpha=0.3)
        plt.xticks(rotation=45, ha='right')
        plt.tight_layout()
        plt.show()

In [None]:
# Violin plots for distribution comparison
if categorical_cols and numeric_cols:
    fig, ax = plt.subplots(figsize=(10, 6))
    
    cat_col = categorical_cols[0]
    num_col = numeric_cols[1] if len(numeric_cols) > 1 else numeric_cols[0]
    
    sns.violinplot(data=df, x=cat_col, y=num_col, ax=ax, palette='Set2')
    ax.set_title(f'Distribution of {num_col} by {cat_col}', fontweight='bold', fontsize=14)
    ax.set_xlabel(cat_col, fontweight='bold')
    ax.set_ylabel(num_col, fontweight='bold')
    ax.grid(axis='y', alpha=0.3)
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    plt.show()

<a id='multivariate'></a>
## 5. Multivariate Analysis

Multivariate analysis examines relationships among three or more variables.

In [None]:
# Pair plot for numeric variables
if len(numeric_cols) >= 2:
    pairplot_data = df[numeric_cols].copy()
    if categorical_cols:
        pairplot_data[categorical_cols[0]] = df[categorical_cols[0]]
        sns.pairplot(pairplot_data, hue=categorical_cols[0], palette='Set2', 
                    diag_kind='kde', plot_kws={'alpha': 0.6, 's': 30, 'edgecolor': 'k'})
    else:
        sns.pairplot(pairplot_data, diag_kind='kde', plot_kws={'alpha': 0.6, 's': 30})
    
    plt.suptitle('Pair Plot: Multivariate Relationships', y=1.01, fontweight='bold', fontsize=14)
    plt.tight_layout()
    plt.show()

In [None]:
# Correlation heatmap
if len(numeric_cols) >= 2:
    correlation_matrix = df[numeric_cols].corr()
    
    fig, ax = plt.subplots(figsize=(10, 8))
    sns.heatmap(correlation_matrix, annot=True, fmt='.3f', cmap='coolwarm', 
                center=0, square=True, linewidths=1, cbar_kws={"shrink": 0.8}, ax=ax)
    ax.set_title('Correlation Heatmap', fontweight='bold', fontsize=14, pad=20)
    plt.tight_layout()
    plt.show()

In [None]:
# Bubble chart (3 variables: x, y, size)
if len(numeric_cols) >= 3:
    fig, ax = plt.subplots(figsize=(10, 6))
    
    x_col, y_col, size_col = numeric_cols[0], numeric_cols[1], numeric_cols[2]
    sizes = (df[size_col] - df[size_col].min()) / (df[size_col].max() - df[size_col].min()) * 500 + 50
    
    scatter = ax.scatter(df[x_col], df[y_col], s=sizes, alpha=0.5, 
                        c=df[size_col], cmap='viridis', edgecolors='black')
    
    ax.set_xlabel(x_col, fontweight='bold', fontsize=12)
    ax.set_ylabel(y_col, fontweight='bold', fontsize=12)
    ax.set_title(f'Bubble Chart: {y_col} vs {x_col} (size: {size_col})', 
                fontweight='bold', fontsize=14)
    
    cbar = plt.colorbar(scatter, ax=ax)
    cbar.set_label(size_col, fontweight='bold')
    ax.grid(alpha=0.3)
    plt.tight_layout()
    plt.show()

<a id='timeseries'></a>
## 6. Time Series Visualization

Time series plots reveal trends, seasonality, and patterns over time.

In [None]:
# Line plot for time series
if 'date' in df.columns and numeric_cols:
    fig, ax = plt.subplots(figsize=(12, 6))
    
    num_col = numeric_cols[1] if len(numeric_cols) > 1 else numeric_cols[0]
    ax.plot(df['date'], df[num_col], marker='o', linestyle='-', 
           color='steelblue', markersize=3, linewidth=1.5)
    
    ax.set_xlabel('Date', fontweight='bold', fontsize=12)
    ax.set_ylabel(num_col, fontweight='bold', fontsize=12)
    ax.set_title(f'{num_col} Over Time', fontweight='bold', fontsize=14)
    ax.grid(alpha=0.3)
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show()

In [None]:
# Area chart for cumulative trends
if 'date' in df.columns and numeric_cols:
    df_sorted = df.sort_values('date')
    df_sorted['cumulative'] = df_sorted[num_col].cumsum()
    
    fig, ax = plt.subplots(figsize=(12, 6))
    ax.fill_between(df_sorted['date'], df_sorted['cumulative'], 
                    alpha=0.5, color='coral', edgecolor='darkred', linewidth=2)
    ax.plot(df_sorted['date'], df_sorted['cumulative'], 
           color='darkred', linewidth=2, label='Cumulative')
    
    ax.set_xlabel('Date', fontweight='bold', fontsize=12)
    ax.set_ylabel(f'Cumulative {num_col}', fontweight='bold', fontsize=12)
    ax.set_title(f'Cumulative {num_col} Over Time', fontweight='bold', fontsize=14)
    ax.legend()
    ax.grid(alpha=0.3)
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show()

<a id='advanced'></a>
## 7. Advanced Visualizations

Specialized plots for specific analytical needs.

In [None]:
# Joint plot (scatter with marginal distributions)
if len(numeric_cols) >= 2:
    x_col, y_col = numeric_cols[0], numeric_cols[1]
    g = sns.jointplot(data=df, x=x_col, y=y_col, kind='scatter', 
                     color='steelblue', alpha=0.6, marginal_kws={'bins': 20, 'color': 'coral'})
    g.fig.suptitle(f'Joint Plot: {y_col} vs {x_col}', fontweight='bold', y=1.02)
    plt.tight_layout()
    plt.show()

In [None]:
# Heatmap for pivot table
if categorical_cols and numeric_cols:
    if len(categorical_cols) >= 2:
        pivot_table = df.pivot_table(values=numeric_cols[0], 
                                     index=categorical_cols[0], 
                                     columns=categorical_cols[1], 
                                     aggfunc='mean')
    elif 'loyalty_member' in df.columns:
        pivot_table = df.pivot_table(values=numeric_cols[0], 
                                     index=categorical_cols[0], 
                                     columns='loyalty_member', 
                                     aggfunc='mean')
    else:
        pivot_table = None
    
    if pivot_table is not None:
        fig, ax = plt.subplots(figsize=(10, 6))
        sns.heatmap(pivot_table, annot=True, fmt='.1f', cmap='YlOrRd', 
                   linewidths=1, cbar_kws={"shrink": 0.8}, ax=ax)
        ax.set_title(f'Heatmap: Average {numeric_cols[0]}', fontweight='bold', fontsize=14)
        plt.tight_layout()
        plt.show()

<a id='interactive'></a>
## 8. Interactive Visualizations with Plotly

Interactive charts enhance exploration and engagement.

In [None]:
# Interactive scatter plot
if len(numeric_cols) >= 2:
    x_col, y_col = numeric_cols[0], numeric_cols[1]
    color_col = categorical_cols[0] if categorical_cols else None
    
    fig = px.scatter(df, x=x_col, y=y_col, color=color_col, 
                    hover_data=df.columns.tolist(),
                    title=f'Interactive Scatter: {y_col} vs {x_col}',
                    template='plotly_white')
    fig.update_traces(marker=dict(size=8, opacity=0.7, line=dict(width=1, color='DarkSlateGrey')))
    fig.update_layout(font=dict(size=12), title_font_size=16)
    fig.show()

In [None]:
# Interactive line chart with range slider
if 'date' in df.columns and numeric_cols:
    num_col = numeric_cols[1] if len(numeric_cols) > 1 else numeric_cols[0]
    
    fig = px.line(df.sort_values('date'), x='date', y=num_col,
                 title=f'Interactive Time Series: {num_col}',
                 template='plotly_white')
    fig.update_xaxes(rangeslider_visible=True)
    fig.update_layout(font=dict(size=12), title_font_size=16)
    fig.show()

In [None]:
# Interactive bar chart
if categorical_cols and numeric_cols:
    cat_col = categorical_cols[0]
    num_col = numeric_cols[0]
    
    summary = df.groupby(cat_col)[num_col].mean().reset_index()
    summary.columns = [cat_col, f'Average {num_col}']
    
    fig = px.bar(summary, x=cat_col, y=f'Average {num_col}',
                title=f'Average {num_col} by {cat_col}',
                template='plotly_white', color=cat_col)
    fig.update_layout(font=dict(size=12), title_font_size=16, showlegend=False)
    fig.show()

In [None]:
# Interactive correlation heatmap
if len(numeric_cols) >= 2:
    corr_matrix = df[numeric_cols].corr()
    
    fig = px.imshow(corr_matrix, text_auto='.2f', aspect='auto',
                   color_continuous_scale='RdBu_r', zmin=-1, zmax=1,
                   title='Interactive Correlation Heatmap',
                   template='plotly_white')
    fig.update_layout(font=dict(size=12), title_font_size=16)
    fig.show()

<a id='bestpractices'></a>
## 9. Visualization Best Practices

### Design Principles

1. **Clarity:** Every chart should have a clear message
2. **Simplicity:** Remove unnecessary elements (chart junk)
3. **Accuracy:** Do not distort data (use appropriate scales)
4. **Accessibility:** Use colorblind-friendly palettes
5. **Context:** Always label axes, add titles, and include units

### Chart Selection Guide

| Data Type | Recommended Chart | Purpose |
|-----------|------------------|----------|
| One continuous | Histogram, Box plot | Distribution |
| One categorical | Bar chart, Pie chart | Frequency |
| Two continuous | Scatter plot, Line chart | Relationship, trend |
| Two categorical | Grouped bar, Heatmap | Comparison |
| Continuous + categorical | Box plot, Violin plot | Distribution by group |
| Time series | Line chart, Area chart | Trends over time |
| Three+ variables | Scatter with color/size, Pair plot | Multivariate relationships |

### Color Guidelines

- Use sequential colors for ordered data (light to dark)
- Use diverging colors for data with a midpoint (blue-white-red)
- Use qualitative colors for categorical data (distinct hues)
- Test visualizations for colorblind accessibility
- Limit palette to 5-7 colors for clarity

<a id='summary'></a>
## 10. Summary

### Visualizations Created

1. **Univariate:** Histograms with KDE, box plots, bar charts
2. **Bivariate:** Scatter plots, grouped bar charts, violin plots
3. **Multivariate:** Pair plots, correlation heatmaps, bubble charts
4. **Time Series:** Line charts, area charts
5. **Advanced:** Joint plots, pivot table heatmaps
6. **Interactive:** Plotly scatter, line, bar, and heatmap visualizations

### Key Takeaways

- Choose the right chart type for your data and message
- Always label axes, add titles, and provide context
- Use color strategically and ensure accessibility
- Interactive visualizations enhance exploration
- Less is more: remove unnecessary elements
- Test visualizations with your target audience

### Next Steps

Apply these visualization techniques in:
- Comprehensive Exploratory Data Analysis (Module 5)
- Presentation slides and reports
- Interactive dashboards