# Heatmaps and Correlation Analysis

This notebook demonstrates PlotSmith's heatmap capabilities for visualizing correlation matrices, confusion matrices, and other 2D data structures.


In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from plotsmith import plot_heatmap

np.random.seed(42)


## Correlation Matrix

Visualize correlations between variables:


In [None]:
# Create synthetic correlated data
n_samples = 200
n_features = 8

# Generate correlated features
data = np.random.randn(n_samples, n_features)
# Add some structure
data[:, 1] = data[:, 0] * 0.7 + np.random.randn(n_samples) * 0.3
data[:, 2] = data[:, 0] * 0.5 + data[:, 1] * 0.3 + np.random.randn(n_samples) * 0.2

# Create DataFrame
feature_names = [f"Feature {i+1}" for i in range(n_features)]
df = pd.DataFrame(data, columns=feature_names)

# Compute correlation matrix
corr = df.corr()

fig, ax = plot_heatmap(
    corr,
    annotate=True,
    fmt=".2f",
    cmap="RdYlGn",
    vmin=-1,
    vmax=1,
    title="Feature Correlation Matrix",
    cbar_label="Correlation"
)

plt.tight_layout()
plt.show()


## Custom Heatmap with Labels

Create heatmaps with custom row and column labels:


In [None]:
# Create a confusion matrix-like structure
classes = ["Class A", "Class B", "Class C", "Class D"]
confusion = np.array([
    [85, 5, 3, 2],
    [4, 78, 6, 3],
    [2, 5, 88, 4],
    [3, 4, 2, 91]
])

fig, ax = plot_heatmap(
    confusion,
    x_labels=classes,
    y_labels=classes,
    annotate=True,
    fmt="d",
    cmap="Blues",
    title="Confusion Matrix",
    cbar_label="Count"
)

plt.tight_layout()
plt.show()


## Time-Series Heatmap

Visualize patterns over time and categories:


In [None]:
# Create time-series heatmap data
dates = pd.date_range("2024-01-01", periods=12, freq="M")
regions = ["North", "South", "East", "West"]

# Generate monthly data for each region
data = np.random.rand(12, 4) * 100
data = data + 20 * np.sin(np.arange(12).reshape(-1, 1) * np.pi / 6)

df = pd.DataFrame(data, index=dates.strftime("%b %Y"), columns=regions)

fig, ax = plot_heatmap(
    df,
    annotate=True,
    fmt=".0f",
    cmap="YlOrRd",
    title="Monthly Sales by Region",
    cbar_label="Sales ($K)"
)

plt.tight_layout()
plt.show()
