# Module 8: Data Visualization

## Topics Covered
1. Introduction to Matplotlib
2. Line Plots and Customization
3. Bar Charts and Histograms
4. Scatter Plots
5. Subplots and Figure Layouts
6. Introduction to Seaborn
7. Distribution Plots
8. Categorical Plots
9. Heatmaps and Correlation Matrices
10. Pair Plots and Joint Plots
11. Styling and Themes
12. Saving Figures

## Learning Objectives

By the end of this module, you will be able to:
- Create publication-quality visualizations with Matplotlib
- Build informative statistical graphics with Seaborn
- Customize plots with titles, labels, legends, and colors
- Choose appropriate chart types for different data
- Create multi-panel figures and dashboards
- Export visualizations for reports and presentations

---

---
# Section 1: Introduction to Matplotlib
---

## What is Matplotlib?

Matplotlib is Python's foundational plotting library. It provides:

- **Complete control** over every aspect of a figure
- **Multiple interfaces**: pyplot (simple) and object-oriented (flexible)
- **Wide variety** of plot types
- **Publication quality** output

### Why This Matters in Data Science

Visualization is essential for:
- Exploring data to find patterns
- Communicating findings to stakeholders
- Validating model results
- Creating reports and dashboards

In [None]:
# Import libraries
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

# Display plots inline (for Jupyter)
%matplotlib inline

print("Libraries imported successfully")

In [None]:
# Example: Your first plot

x = [1, 2, 3, 4, 5]
y = [2, 4, 6, 8, 10]

plt.plot(x, y)
plt.title('My First Plot')
plt.xlabel('X axis')
plt.ylabel('Y axis')
plt.show()

In [None]:
# Example: Figure and Axes (object-oriented approach)

# This approach gives more control
fig, ax = plt.subplots(figsize=(8, 5))

x = np.linspace(0, 10, 100)
y = np.sin(x)

ax.plot(x, y)
ax.set_title('Sine Wave')
ax.set_xlabel('X')
ax.set_ylabel('sin(x)')
ax.grid(True)

plt.show()

## Figure Anatomy

Understanding Matplotlib's structure:

- **Figure**: The entire window/page
- **Axes**: The actual plot area (a figure can have multiple axes)
- **Axis**: The x or y axis with ticks and labels
- **Artist**: Everything visible on the figure (lines, text, etc.)

In [None]:
# Example: Figure components

fig, ax = plt.subplots(figsize=(10, 6))

# Plot data
x = np.arange(1, 11)
y = x ** 2

ax.plot(x, y, color='blue', linewidth=2, marker='o', label='y = x^2')

# Customize
ax.set_title('Understanding Figure Components', fontsize=14, fontweight='bold')
ax.set_xlabel('X Axis Label', fontsize=12)
ax.set_ylabel('Y Axis Label', fontsize=12)
ax.legend(loc='upper left')
ax.grid(True, linestyle='--', alpha=0.7)

# Set axis limits
ax.set_xlim(0, 12)
ax.set_ylim(0, 120)

plt.show()

---
# Section 2: Line Plots and Customization
---

Line plots are ideal for showing trends over time or continuous data.

In [None]:
# Example: Multiple lines on one plot

x = np.linspace(0, 2 * np.pi, 100)

fig, ax = plt.subplots(figsize=(10, 6))

ax.plot(x, np.sin(x), label='sin(x)')
ax.plot(x, np.cos(x), label='cos(x)')
ax.plot(x, np.sin(x) + np.cos(x), label='sin(x) + cos(x)')

ax.set_title('Trigonometric Functions')
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.legend()
ax.grid(True, alpha=0.3)

plt.show()

In [None]:
# Example: Line styles and colors

x = np.arange(1, 11)

fig, ax = plt.subplots(figsize=(10, 6))

# Different line styles
ax.plot(x, x, 'r-', label='solid red', linewidth=2)
ax.plot(x, x + 2, 'g--', label='dashed green', linewidth=2)
ax.plot(x, x + 4, 'b:', label='dotted blue', linewidth=2)
ax.plot(x, x + 6, 'm-.', label='dash-dot magenta', linewidth=2)

ax.set_title('Line Styles')
ax.legend()
ax.grid(True, alpha=0.3)

plt.show()

In [None]:
# Example: Markers

x = np.arange(1, 8)

fig, ax = plt.subplots(figsize=(10, 6))

ax.plot(x, x, 'o-', label='circle', markersize=10)
ax.plot(x, x + 1, 's-', label='square', markersize=10)
ax.plot(x, x + 2, '^-', label='triangle', markersize=10)
ax.plot(x, x + 3, 'D-', label='diamond', markersize=10)
ax.plot(x, x + 4, '*-', label='star', markersize=12)

ax.set_title('Marker Styles')
ax.legend()
ax.grid(True, alpha=0.3)

plt.show()

In [None]:
# Example: Real data - Sales trend

# Load and prepare data
sales = pd.read_csv('assets/datasets/sales_data.csv', parse_dates=['date'])

# Monthly sales
monthly = sales.groupby(sales['date'].dt.to_period('M'))['total_amount'].sum()
monthly.index = monthly.index.to_timestamp()

fig, ax = plt.subplots(figsize=(12, 6))

ax.plot(monthly.index, monthly.values, 'b-o', linewidth=2, markersize=6)

ax.set_title('Monthly Sales Trend', fontsize=14, fontweight='bold')
ax.set_xlabel('Month')
ax.set_ylabel('Total Sales ($)')
ax.grid(True, alpha=0.3)

# Rotate x-axis labels
plt.xticks(rotation=45)
plt.tight_layout()

plt.show()

## Practice Exercise 2.1

**Task:** Create a line plot showing the cumulative sales over time. Add a horizontal line showing the average daily sales.

Hints:
- Use `cumsum()` for cumulative sum
- Use `ax.axhline()` for horizontal line

In [None]:
# Your code here


In [None]:
# Solution 2.1

sales = pd.read_csv('assets/datasets/sales_data.csv', parse_dates=['date'])
daily = sales.groupby('date')['total_amount'].sum().sort_index()

fig, ax = plt.subplots(figsize=(12, 6))

# Cumulative sales
cumulative = daily.cumsum()
ax.plot(cumulative.index, cumulative.values, 'b-', linewidth=2, label='Cumulative Sales')

# Average line (as reference on secondary axis would be better, but for simplicity)
avg_daily = daily.mean()

ax.set_title('Cumulative Sales Over Time', fontsize=14, fontweight='bold')
ax.set_xlabel('Date')
ax.set_ylabel('Cumulative Sales ($)')
ax.legend()
ax.grid(True, alpha=0.3)

plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

print(f"Average daily sales: ${avg_daily:,.2f}")

---
# Section 3: Bar Charts and Histograms
---

Bar charts compare categorical data. Histograms show the distribution of continuous data.

In [None]:
# Example: Basic bar chart

categories = ['Electronics', 'Furniture', 'Office Supplies']
values = [45000, 32000, 28000]

fig, ax = plt.subplots(figsize=(8, 6))

bars = ax.bar(categories, values, color=['steelblue', 'coral', 'seagreen'])

ax.set_title('Sales by Category')
ax.set_xlabel('Category')
ax.set_ylabel('Sales ($)')

# Add value labels on bars
for bar, val in zip(bars, values):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 500, 
            f'${val:,}', ha='center', va='bottom')

plt.show()

In [None]:
# Example: Horizontal bar chart

sales = pd.read_csv('assets/datasets/sales_data.csv')
product_sales = sales.groupby('product')['total_amount'].sum().nlargest(10)

fig, ax = plt.subplots(figsize=(10, 6))

ax.barh(product_sales.index, product_sales.values, color='steelblue')

ax.set_title('Top 10 Products by Sales', fontsize=14, fontweight='bold')
ax.set_xlabel('Total Sales ($)')

plt.tight_layout()
plt.show()

In [None]:
# Example: Grouped bar chart

sales = pd.read_csv('assets/datasets/sales_data.csv')

# Sales by category and region (top 3 regions)
pivot = sales.pivot_table(values='total_amount', index='category', 
                          columns='region', aggfunc='sum')

fig, ax = plt.subplots(figsize=(12, 6))

x = np.arange(len(pivot.index))
width = 0.15
colors = ['steelblue', 'coral', 'seagreen', 'gold', 'purple']

for i, (region, color) in enumerate(zip(pivot.columns, colors)):
    ax.bar(x + i*width, pivot[region], width, label=region, color=color)

ax.set_title('Sales by Category and Region', fontsize=14, fontweight='bold')
ax.set_xlabel('Category')
ax.set_ylabel('Sales ($)')
ax.set_xticks(x + width * 2)
ax.set_xticklabels(pivot.index)
ax.legend(title='Region')

plt.tight_layout()
plt.show()

In [None]:
# Example: Stacked bar chart

pivot = sales.pivot_table(values='total_amount', index='region', 
                          columns='category', aggfunc='sum')

fig, ax = plt.subplots(figsize=(10, 6))

pivot.plot(kind='bar', stacked=True, ax=ax, 
           color=['steelblue', 'coral', 'seagreen'])

ax.set_title('Sales by Region (Stacked by Category)', fontsize=14, fontweight='bold')
ax.set_xlabel('Region')
ax.set_ylabel('Sales ($)')
ax.legend(title='Category')
plt.xticks(rotation=0)

plt.tight_layout()
plt.show()

In [None]:
# Example: Histogram

fig, ax = plt.subplots(figsize=(10, 6))

ax.hist(sales['total_amount'], bins=30, color='steelblue', edgecolor='white', alpha=0.7)

ax.set_title('Distribution of Transaction Amounts', fontsize=14, fontweight='bold')
ax.set_xlabel('Transaction Amount ($)')
ax.set_ylabel('Frequency')

# Add mean line
mean_val = sales['total_amount'].mean()
ax.axvline(mean_val, color='red', linestyle='--', linewidth=2, label=f'Mean: ${mean_val:.2f}')
ax.legend()

plt.show()

In [None]:
# Example: Multiple histograms (overlapping)

fig, ax = plt.subplots(figsize=(10, 6))

for category in sales['category'].unique():
    data = sales[sales['category'] == category]['total_amount']
    ax.hist(data, bins=20, alpha=0.5, label=category)

ax.set_title('Transaction Amount Distribution by Category')
ax.set_xlabel('Transaction Amount ($)')
ax.set_ylabel('Frequency')
ax.legend()

plt.show()

---
# Section 4: Scatter Plots
---

Scatter plots show relationships between two continuous variables.

In [None]:
# Example: Basic scatter plot

employees = pd.read_csv('assets/datasets/employees.csv')

# Calculate years of experience
employees['hire_date'] = pd.to_datetime(employees['hire_date'])
employees['years_exp'] = (pd.Timestamp.now() - employees['hire_date']).dt.days / 365

fig, ax = plt.subplots(figsize=(10, 6))

ax.scatter(employees['years_exp'], employees['salary'], alpha=0.6, color='steelblue')

ax.set_title('Salary vs Years of Experience', fontsize=14, fontweight='bold')
ax.set_xlabel('Years of Experience')
ax.set_ylabel('Salary ($)')
ax.grid(True, alpha=0.3)

plt.show()

In [None]:
# Example: Scatter with color coding by category

fig, ax = plt.subplots(figsize=(10, 6))

departments = employees['department'].unique()
colors = plt.cm.Set2(np.linspace(0, 1, len(departments)))

for dept, color in zip(departments, colors):
    subset = employees[employees['department'] == dept]
    ax.scatter(subset['years_exp'], subset['salary'], 
               alpha=0.6, label=dept, color=color, s=60)

ax.set_title('Salary vs Experience by Department', fontsize=14, fontweight='bold')
ax.set_xlabel('Years of Experience')
ax.set_ylabel('Salary ($)')
ax.legend(title='Department', bbox_to_anchor=(1.02, 1))
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# Example: Scatter with size encoding

sales = pd.read_csv('assets/datasets/sales_data.csv')

# Aggregate by product
product_stats = sales.groupby('product').agg(
    total_sales=('total_amount', 'sum'),
    avg_transaction=('total_amount', 'mean'),
    num_transactions=('transaction_id', 'count')
).reset_index()

fig, ax = plt.subplots(figsize=(12, 8))

scatter = ax.scatter(product_stats['num_transactions'], 
                     product_stats['avg_transaction'],
                     s=product_stats['total_sales'] / 100,  # Size by total sales
                     alpha=0.6, c='steelblue')

ax.set_title('Products: Transactions vs Avg Amount (size = total sales)', fontsize=14)
ax.set_xlabel('Number of Transactions')
ax.set_ylabel('Average Transaction ($)')
ax.grid(True, alpha=0.3)

plt.show()

---
# Section 5: Subplots and Figure Layouts
---

Creating multi-panel figures allows comparing different views of data.

In [None]:
# Example: Simple 2x2 grid

fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# Plot 1: Line plot
x = np.linspace(0, 10, 100)
axes[0, 0].plot(x, np.sin(x))
axes[0, 0].set_title('Sine Wave')

# Plot 2: Bar chart
categories = ['A', 'B', 'C', 'D']
values = [25, 40, 30, 55]
axes[0, 1].bar(categories, values, color='coral')
axes[0, 1].set_title('Bar Chart')

# Plot 3: Scatter
x = np.random.randn(100)
y = x + np.random.randn(100) * 0.5
axes[1, 0].scatter(x, y, alpha=0.5)
axes[1, 0].set_title('Scatter Plot')

# Plot 4: Histogram
data = np.random.randn(1000)
axes[1, 1].hist(data, bins=30, color='seagreen', edgecolor='white')
axes[1, 1].set_title('Histogram')

plt.tight_layout()
plt.show()

In [None]:
# Example: Sales dashboard

sales = pd.read_csv('assets/datasets/sales_data.csv', parse_dates=['date'])

fig, axes = plt.subplots(2, 2, figsize=(14, 10))
fig.suptitle('Sales Dashboard', fontsize=16, fontweight='bold')

# 1. Monthly trend
monthly = sales.groupby(sales['date'].dt.to_period('M'))['total_amount'].sum()
monthly.index = monthly.index.to_timestamp()
axes[0, 0].plot(monthly.index, monthly.values, 'b-o')
axes[0, 0].set_title('Monthly Sales Trend')
axes[0, 0].tick_params(axis='x', rotation=45)
axes[0, 0].grid(True, alpha=0.3)

# 2. Sales by category
category_sales = sales.groupby('category')['total_amount'].sum()
axes[0, 1].pie(category_sales, labels=category_sales.index, autopct='%1.1f%%',
               colors=['steelblue', 'coral', 'seagreen'])
axes[0, 1].set_title('Sales by Category')

# 3. Sales by region
region_sales = sales.groupby('region')['total_amount'].sum().sort_values()
axes[1, 0].barh(region_sales.index, region_sales.values, color='steelblue')
axes[1, 0].set_title('Sales by Region')
axes[1, 0].set_xlabel('Sales ($)')

# 4. Transaction amount distribution
axes[1, 1].hist(sales['total_amount'], bins=30, color='coral', edgecolor='white')
axes[1, 1].set_title('Transaction Amount Distribution')
axes[1, 1].set_xlabel('Amount ($)')
axes[1, 1].set_ylabel('Frequency')

plt.tight_layout()
plt.show()

In [None]:
# Example: Unequal subplot sizes with GridSpec

from matplotlib.gridspec import GridSpec

fig = plt.figure(figsize=(12, 8))
gs = GridSpec(2, 3, figure=fig)

# Large plot on left
ax1 = fig.add_subplot(gs[:, 0:2])  # All rows, first 2 columns
ax1.plot(np.random.randn(100).cumsum())
ax1.set_title('Main Plot')

# Two smaller plots on right
ax2 = fig.add_subplot(gs[0, 2])
ax2.bar(['A', 'B', 'C'], [3, 7, 5])
ax2.set_title('Top Right')

ax3 = fig.add_subplot(gs[1, 2])
ax3.scatter(np.random.randn(50), np.random.randn(50))
ax3.set_title('Bottom Right')

plt.tight_layout()
plt.show()

---
# Section 6: Introduction to Seaborn
---

Seaborn is a statistical visualization library built on Matplotlib. It provides:

- **Beautiful default styles**
- **Statistical plotting functions**
- **Automatic handling of DataFrames**
- **Built-in themes and color palettes**

In [None]:
import seaborn as sns

# Set default style
sns.set_style('whitegrid')
sns.set_palette('Set2')

print(f"Seaborn version: {sns.__version__}")

In [None]:
# Example: Seaborn vs Matplotlib comparison

sales = pd.read_csv('assets/datasets/sales_data.csv')

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Matplotlib
category_sales = sales.groupby('category')['total_amount'].sum()
axes[0].bar(category_sales.index, category_sales.values)
axes[0].set_title('Matplotlib')

# Seaborn
sns.barplot(data=sales, x='category', y='total_amount', estimator=sum, ax=axes[1])
axes[1].set_title('Seaborn')

plt.tight_layout()
plt.show()

In [None]:
# Example: Seaborn built-in datasets

# Seaborn comes with example datasets
tips = sns.load_dataset('tips')
print("Tips dataset:")
print(tips.head())
print(f"\nShape: {tips.shape}")

---
# Section 7: Distribution Plots
---

Seaborn excels at visualizing distributions.

In [None]:
# Example: Histogram with KDE (histplot)

employees = pd.read_csv('assets/datasets/employees.csv')

fig, ax = plt.subplots(figsize=(10, 6))

sns.histplot(data=employees, x='salary', kde=True, ax=ax)

ax.set_title('Salary Distribution', fontsize=14, fontweight='bold')
ax.set_xlabel('Salary ($)')

plt.show()

In [None]:
# Example: Distribution by category

fig, ax = plt.subplots(figsize=(12, 6))

sns.histplot(data=employees, x='salary', hue='department', 
             element='step', kde=True, ax=ax)

ax.set_title('Salary Distribution by Department')
ax.legend(title='Department', bbox_to_anchor=(1.02, 1))

plt.tight_layout()
plt.show()

In [None]:
# Example: Box plot

fig, ax = plt.subplots(figsize=(12, 6))

sns.boxplot(data=employees, x='department', y='salary', ax=ax)

ax.set_title('Salary Distribution by Department', fontsize=14, fontweight='bold')
ax.set_xlabel('Department')
ax.set_ylabel('Salary ($)')
plt.xticks(rotation=45)

plt.tight_layout()
plt.show()

In [None]:
# Example: Violin plot

fig, ax = plt.subplots(figsize=(12, 6))

sns.violinplot(data=employees, x='department', y='salary', ax=ax)

ax.set_title('Salary Distribution by Department (Violin Plot)', fontsize=14, fontweight='bold')
plt.xticks(rotation=45)

plt.tight_layout()
plt.show()

In [None]:
# Example: Box plot with swarm overlay

fig, ax = plt.subplots(figsize=(12, 6))

sns.boxplot(data=employees, x='department', y='salary', ax=ax, color='lightblue')
sns.stripplot(data=employees, x='department', y='salary', ax=ax, 
              color='darkblue', alpha=0.5, size=4)

ax.set_title('Salary by Department (Box + Individual Points)')
plt.xticks(rotation=45)

plt.tight_layout()
plt.show()

## Practice Exercise 7.1

**Task:** Create a figure with 2 subplots comparing salary distributions:
1. Left: Box plot by department
2. Right: Violin plot by status (Active, On Leave, Terminated)

In [None]:
# Your code here


In [None]:
# Solution 7.1

employees = pd.read_csv('assets/datasets/employees.csv')

fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Left: Box by department
sns.boxplot(data=employees, x='department', y='salary', ax=axes[0])
axes[0].set_title('Salary by Department')
axes[0].tick_params(axis='x', rotation=45)

# Right: Violin by status
sns.violinplot(data=employees, x='status', y='salary', ax=axes[1])
axes[1].set_title('Salary by Status')

plt.tight_layout()
plt.show()

---
# Section 8: Categorical Plots
---

In [None]:
# Example: Count plot (like value_counts visualized)

fig, ax = plt.subplots(figsize=(10, 6))

sns.countplot(data=employees, x='department', ax=ax, order=employees['department'].value_counts().index)

ax.set_title('Employee Count by Department', fontsize=14, fontweight='bold')
plt.xticks(rotation=45)

plt.tight_layout()
plt.show()

In [None]:
# Example: Count plot with hue

fig, ax = plt.subplots(figsize=(12, 6))

sns.countplot(data=employees, x='department', hue='status', ax=ax)

ax.set_title('Employee Status by Department')
plt.xticks(rotation=45)
ax.legend(title='Status')

plt.tight_layout()
plt.show()

In [None]:
# Example: Bar plot with confidence intervals

sales = pd.read_csv('assets/datasets/sales_data.csv')

fig, ax = plt.subplots(figsize=(10, 6))

sns.barplot(data=sales, x='category', y='total_amount', ax=ax, errorbar='ci')

ax.set_title('Average Transaction Amount by Category (with 95% CI)')
ax.set_ylabel('Average Amount ($)')

plt.show()

In [None]:
# Example: Point plot (shows mean with confidence interval)

fig, ax = plt.subplots(figsize=(10, 6))

sns.pointplot(data=sales, x='region', y='total_amount', hue='category', ax=ax)

ax.set_title('Average Transaction by Region and Category')
ax.legend(title='Category', bbox_to_anchor=(1.02, 1))

plt.tight_layout()
plt.show()

---
# Section 9: Heatmaps and Correlation Matrices
---

In [None]:
# Example: Basic heatmap with pivot table

sales = pd.read_csv('assets/datasets/sales_data.csv')

pivot = sales.pivot_table(values='total_amount', index='category', 
                          columns='region', aggfunc='sum')

fig, ax = plt.subplots(figsize=(10, 6))

sns.heatmap(pivot, annot=True, fmt=',.0f', cmap='Blues', ax=ax)

ax.set_title('Sales Heatmap: Category vs Region', fontsize=14, fontweight='bold')

plt.tight_layout()
plt.show()

In [None]:
# Example: Correlation matrix

employees = pd.read_csv('assets/datasets/employees.csv')

# Select numeric columns
numeric_cols = employees.select_dtypes(include=[np.number])

# Calculate correlation
corr = numeric_cols.corr()

fig, ax = plt.subplots(figsize=(8, 6))

sns.heatmap(corr, annot=True, cmap='coolwarm', center=0, 
            fmt='.2f', square=True, ax=ax)

ax.set_title('Correlation Matrix', fontsize=14, fontweight='bold')

plt.tight_layout()
plt.show()

In [None]:
# Example: Heatmap with mask (show only lower triangle)

# Create mask for upper triangle
mask = np.triu(np.ones_like(corr, dtype=bool))

fig, ax = plt.subplots(figsize=(8, 6))

sns.heatmap(corr, mask=mask, annot=True, cmap='coolwarm', center=0,
            fmt='.2f', square=True, ax=ax, linewidths=0.5)

ax.set_title('Correlation Matrix (Lower Triangle)', fontsize=14, fontweight='bold')

plt.tight_layout()
plt.show()

---
# Section 10: Pair Plots and Joint Plots
---

In [None]:
# Example: Pair plot

# Use seaborn's tips dataset for a good example
tips = sns.load_dataset('tips')

sns.pairplot(tips, hue='time', height=2.5)
plt.suptitle('Pair Plot: Tips Dataset', y=1.02, fontsize=14)

plt.show()

In [None]:
# Example: Pair plot with specific columns

employees = pd.read_csv('assets/datasets/employees.csv')
employees['hire_date'] = pd.to_datetime(employees['hire_date'])
employees['years_exp'] = (pd.Timestamp.now() - employees['hire_date']).dt.days / 365

# Fill missing numeric values for plotting
employees['performance_rating'] = employees['performance_rating'].fillna(employees['performance_rating'].median())
employees['bonus'] = employees['bonus'].fillna(0)

# Select subset of numeric columns
cols = ['salary', 'bonus', 'years_exp', 'performance_rating']

g = sns.pairplot(employees[cols + ['department']], hue='department', 
                 height=2, corner=True)
g.fig.suptitle('Employee Metrics by Department', y=1.02)

plt.show()

In [None]:
# Example: Joint plot

g = sns.jointplot(data=employees, x='years_exp', y='salary', kind='scatter', height=8)
g.fig.suptitle('Years of Experience vs Salary', y=1.02)

plt.show()

In [None]:
# Example: Joint plot with regression line

g = sns.jointplot(data=employees, x='years_exp', y='salary', kind='reg', height=8)
g.fig.suptitle('Years of Experience vs Salary (with regression)', y=1.02)

plt.show()

In [None]:
# Example: Joint plot with hex bins (good for large datasets)

g = sns.jointplot(data=employees, x='years_exp', y='salary', kind='hex', height=8)
g.fig.suptitle('Years of Experience vs Salary (Hexbin)', y=1.02)

plt.show()

---
# Section 11: Styling and Themes
---

In [None]:
# Example: Seaborn styles

styles = ['darkgrid', 'whitegrid', 'dark', 'white', 'ticks']

fig, axes = plt.subplots(1, 5, figsize=(20, 4))

x = np.linspace(0, 10, 100)

for ax, style in zip(axes, styles):
    with sns.axes_style(style):
        ax.plot(x, np.sin(x))
        ax.set_title(f"Style: {style}")

plt.tight_layout()
plt.show()

In [None]:
# Example: Color palettes

palettes = ['deep', 'muted', 'bright', 'pastel', 'dark', 'colorblind']

fig, axes = plt.subplots(2, 3, figsize=(15, 8))

for ax, palette in zip(axes.flat, palettes):
    colors = sns.color_palette(palette)
    ax.bar(range(len(colors)), [1]*len(colors), color=colors)
    ax.set_title(f"Palette: {palette}")
    ax.set_ylim(0, 1.5)

plt.tight_layout()
plt.show()

In [None]:
# Example: Custom styling

# Set a custom style
sns.set_style('whitegrid')
sns.set_palette('husl')

fig, ax = plt.subplots(figsize=(10, 6))

sales = pd.read_csv('assets/datasets/sales_data.csv')
sns.barplot(data=sales, x='category', y='total_amount', hue='region', ax=ax)

ax.set_title('Sales by Category and Region', fontsize=14, fontweight='bold')
ax.legend(title='Region', bbox_to_anchor=(1.02, 1))

plt.tight_layout()
plt.show()

In [None]:
# Example: Matplotlib style sheets

# Available styles
print("Available matplotlib styles:")
print(plt.style.available)

In [None]:
# Example: Using a style context

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

x = np.linspace(0, 10, 100)

styles_to_show = ['seaborn-v0_8-darkgrid', 'ggplot', 'bmh']

for ax, style in zip(axes, styles_to_show):
    with plt.style.context(style):
        ax.plot(x, np.sin(x), label='sin(x)')
        ax.plot(x, np.cos(x), label='cos(x)')
        ax.set_title(f"Style: {style}")
        ax.legend()

plt.tight_layout()
plt.show()

---
# Section 12: Saving Figures
---

In [None]:
# Example: Saving figures in different formats

sales = pd.read_csv('assets/datasets/sales_data.csv')

fig, ax = plt.subplots(figsize=(10, 6))

category_sales = sales.groupby('category')['total_amount'].sum()
ax.bar(category_sales.index, category_sales.values, color=['steelblue', 'coral', 'seagreen'])
ax.set_title('Sales by Category', fontsize=14, fontweight='bold')
ax.set_ylabel('Total Sales ($)')

# Save as PNG
fig.savefig('assets/datasets/sales_chart.png', dpi=150, bbox_inches='tight')
print("Saved as PNG")

# Save as PDF (vector format - good for publications)
fig.savefig('assets/datasets/sales_chart.pdf', bbox_inches='tight')
print("Saved as PDF")

# Save as SVG (vector format - good for web)
fig.savefig('assets/datasets/sales_chart.svg', bbox_inches='tight')
print("Saved as SVG")

plt.show()

In [None]:
# Example: Saving with transparent background

fig, ax = plt.subplots(figsize=(8, 6))

ax.plot([1, 2, 3, 4, 5], [1, 4, 9, 16, 25], 'b-o', linewidth=2)
ax.set_title('Sample Plot')

fig.savefig('assets/datasets/transparent_plot.png', 
            dpi=150, 
            bbox_inches='tight',
            transparent=True)
print("Saved with transparent background")

plt.show()

In [None]:
# Example: High-resolution figure for publication

fig, ax = plt.subplots(figsize=(8, 6))

employees = pd.read_csv('assets/datasets/employees.csv')
employees['hire_date'] = pd.to_datetime(employees['hire_date'])
employees['years_exp'] = (pd.Timestamp.now() - employees['hire_date']).dt.days / 365

ax.scatter(employees['years_exp'], employees['salary'], alpha=0.5)
ax.set_title('Salary vs Experience')
ax.set_xlabel('Years of Experience')
ax.set_ylabel('Salary ($)')

# Save at 300 DPI (publication quality)
fig.savefig('assets/datasets/publication_quality.png', dpi=300, bbox_inches='tight')
print("Saved at 300 DPI")

plt.show()

## Practice Exercise 12.1

**Task:** Create a comprehensive dashboard with 4 subplots analyzing the sales data:
1. Line plot: Monthly sales trend
2. Bar chart: Top 5 products by total sales
3. Pie chart: Sales distribution by category
4. Heatmap: Sales by region and category

Save the dashboard as a PNG file.

In [None]:
# Your code here


In [None]:
# Solution 12.1

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

sales = pd.read_csv('assets/datasets/sales_data.csv', parse_dates=['date'])

fig, axes = plt.subplots(2, 2, figsize=(14, 10))
fig.suptitle('Sales Analysis Dashboard', fontsize=16, fontweight='bold')

# 1. Monthly trend
monthly = sales.groupby(sales['date'].dt.to_period('M'))['total_amount'].sum()
monthly.index = monthly.index.to_timestamp()
axes[0, 0].plot(monthly.index, monthly.values, 'b-o', linewidth=2)
axes[0, 0].set_title('Monthly Sales Trend')
axes[0, 0].tick_params(axis='x', rotation=45)
axes[0, 0].grid(True, alpha=0.3)
axes[0, 0].set_ylabel('Sales ($)')

# 2. Top 5 products
top_products = sales.groupby('product')['total_amount'].sum().nlargest(5)
axes[0, 1].barh(top_products.index, top_products.values, color='steelblue')
axes[0, 1].set_title('Top 5 Products')
axes[0, 1].set_xlabel('Total Sales ($)')

# 3. Category pie chart
category_sales = sales.groupby('category')['total_amount'].sum()
axes[1, 0].pie(category_sales, labels=category_sales.index, autopct='%1.1f%%',
               colors=['steelblue', 'coral', 'seagreen'])
axes[1, 0].set_title('Sales by Category')

# 4. Heatmap
pivot = sales.pivot_table(values='total_amount', index='category', 
                          columns='region', aggfunc='sum')
sns.heatmap(pivot, annot=True, fmt=',.0f', cmap='Blues', ax=axes[1, 1])
axes[1, 1].set_title('Sales: Category vs Region')

plt.tight_layout()

# Save
fig.savefig('assets/datasets/sales_dashboard.png', dpi=150, bbox_inches='tight')
print("Dashboard saved to assets/datasets/sales_dashboard.png")

plt.show()

---
# Module Summary

## Key Takeaways

1. **Matplotlib** is the foundation - provides complete control over visualizations
2. **Seaborn** simplifies statistical visualizations with better defaults
3. **Choose the right chart type**:
   - Line plots for trends over time
   - Bar charts for comparing categories
   - Scatter plots for relationships between variables
   - Histograms/box plots for distributions
   - Heatmaps for matrix data
4. **Always label** your axes, add titles, and include legends
5. **Subplots** allow comparing multiple views in one figure
6. **Save figures** at appropriate resolution for your use case

## Essential Functions

```python
# Matplotlib
plt.plot(), plt.bar(), plt.scatter(), plt.hist()
plt.subplots(), fig.savefig()

# Seaborn
sns.barplot(), sns.boxplot(), sns.violinplot()
sns.heatmap(), sns.pairplot(), sns.jointplot()
sns.set_style(), sns.set_palette()
```

## Next Module

In the next module, we'll cover **Exploratory Data Analysis (EDA)** - the systematic approach to understanding your data using the visualization and analysis techniques you've learned.

## Additional Practice

For extra practice, try these challenges:

1. **Multi-variable Analysis**: Create a figure that shows the relationship between employee salary, experience, department, and performance rating using appropriate chart types.

2. **Time Series Dashboard**: Create a dashboard showing daily, weekly, and monthly trends from the sales data, including moving averages.

3. **Custom Theme**: Create your own color palette and styling for a professional-looking report with multiple visualizations.