# Single Cell Analysis Data Visualization

This notebook creates visualizations from single cell analysis data summarized in a CSV file where columns represent features and rows represent cells/objects.

**Visualizations included:**
1. Feature distributions (histograms, violin plots, box plots)
2. Correlation heatmaps
3. Clustered heatmaps (cells × features)
4. Dimensionality reduction (PCA, UMAP, t-SNE)
5. Spatial scatter plots
6. Cluster summary plots

## 1. Setup and Imports

In [None]:
import os
from pathlib import Path
from typing import List, Optional, Literal

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import ListedColormap
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from tqdm.auto import tqdm

# Optional imports
try:
    import umap
    UMAP_AVAILABLE = True
except ImportError:
    UMAP_AVAILABLE = False
    print("UMAP not available. Install with: pip install umap-learn")

try:
    from sklearn.manifold import TSNE
    TSNE_AVAILABLE = True
except ImportError:
    TSNE_AVAILABLE = False

# Set plotting style
sns.set_style("whitegrid")
%matplotlib inline
plt.rcParams['figure.dpi'] = 100

## 2. Configuration

Set your input file path and parameters below.

In [None]:
# =============================================================================
# USER CONFIGURATION - Modify these parameters
# =============================================================================

# Path to your cell table CSV file
INPUT_CSV = "path/to/your/cell_table.csv"

# Output directory for saving plots (set to None to skip saving)
OUTPUT_DIR = "./cell_visualizations"

# Column names (set to None for auto-detection)
CLUSTER_COL = None          # e.g., "cell_meta_cluster", "cluster", "phenotype"
CELL_ID_COL = None          # e.g., "label", "cell_id"
X_COORD_COL = None          # e.g., "centroid-0", "x"
Y_COORD_COL = None          # e.g., "centroid-1", "y"

# Visualization parameters
DPI = 150                   # Resolution for saved figures
MAX_FEATURES = 20           # Max features to show in distribution plots
SAMPLE_CELLS = 2000         # Max cells to sample for clustermaps

## 3. Load Data

In [None]:
# Load the data
print(f"Loading data from: {INPUT_CSV}")
data = pd.read_csv(INPUT_CSV)

print(f"\nDataset shape: {data.shape[0]} cells × {data.shape[1]} columns")
print(f"\nColumn names:")
for i, col in enumerate(data.columns):
    print(f"  {i+1:3d}. {col}")

In [None]:
# Preview the data
data.head(10)

In [None]:
# Summary statistics
data.describe()

## 4. Auto-detect Feature Columns

Automatically identifies numeric columns that are likely features (not metadata like cell IDs, coordinates, etc.).

In [None]:
def detect_feature_columns(df: pd.DataFrame) -> List[str]:
    """Detect numeric columns that are likely features (not metadata)."""
    numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
    
    # Exclude common metadata columns
    metadata_patterns = [
        'label', 'id', 'fov', 'centroid', 'cell_size', 'area',
        'index', 'row', 'col', 'x', 'y', 'cluster'
    ]
    
    feature_cols = []
    for col in numeric_cols:
        col_lower = col.lower()
        is_metadata = any(pattern in col_lower for pattern in metadata_patterns)
        if not is_metadata:
            feature_cols.append(col)
    
    # If we filtered too aggressively, include all numeric columns
    if len(feature_cols) < 3:
        feature_cols = numeric_cols
    
    return feature_cols

# Detect features
feature_cols = detect_feature_columns(data)
print(f"Detected {len(feature_cols)} feature columns:")
for col in feature_cols:
    print(f"  - {col}")

In [None]:
# Auto-detect spatial columns if not specified
if X_COORD_COL is None or Y_COORD_COL is None:
    spatial_candidates = [
        ('centroid-0', 'centroid-1'),
        ('centroid_x', 'centroid_y'),
        ('x', 'y'),
        ('X', 'Y'),
    ]
    for x_col, y_col in spatial_candidates:
        if x_col in data.columns and y_col in data.columns:
            X_COORD_COL, Y_COORD_COL = x_col, y_col
            break

if X_COORD_COL and Y_COORD_COL:
    print(f"Spatial columns: {X_COORD_COL}, {Y_COORD_COL}")
else:
    print("No spatial columns detected")

In [None]:
# Create output directory if saving plots
if OUTPUT_DIR:
    Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
    print(f"Output directory: {OUTPUT_DIR}")

---
## 5. Feature Distributions

Visualize the distribution of each feature across all cells.

In [None]:
def plot_feature_distributions(
    df: pd.DataFrame,
    features: List[str],
    plot_type: Literal["histogram", "violin", "box"] = "histogram",
    ncols: int = 4,
    figsize: Optional[tuple] = None,
) -> plt.Figure:
    """Plot distributions of selected features."""
    features = features[:MAX_FEATURES]  # Limit number of features
    nrows = int(np.ceil(len(features) / ncols))
    
    if figsize is None:
        figsize = (ncols * 3, nrows * 2.5)
    
    fig, axes = plt.subplots(nrows, ncols, figsize=figsize)
    axes = np.array(axes).flatten()
    
    for idx, feature in enumerate(tqdm(features, desc="Plotting")):
        ax = axes[idx]
        feature_data = df[feature].dropna()
        
        if plot_type == "histogram":
            ax.hist(feature_data, bins=50, edgecolor='black', alpha=0.7)
        elif plot_type == "violin":
            sns.violinplot(y=feature_data, ax=ax)
        elif plot_type == "box":
            sns.boxplot(y=feature_data, ax=ax)
        
        ax.set_title(feature, fontsize=10)
        ax.tick_params(labelsize=8)
    
    # Hide empty subplots
    for idx in range(len(features), len(axes)):
        axes[idx].set_visible(False)
    
    plt.suptitle("Feature Distributions", fontsize=14, y=1.02)
    plt.tight_layout()
    return fig

In [None]:
# Plot histograms
fig = plot_feature_distributions(data, feature_cols, plot_type="histogram")

if OUTPUT_DIR:
    fig.savefig(f"{OUTPUT_DIR}/feature_distributions.png", dpi=DPI, bbox_inches='tight')
    print(f"Saved: {OUTPUT_DIR}/feature_distributions.png")
plt.show()

In [None]:
# Optional: Plot as violin plots
# fig = plot_feature_distributions(data, feature_cols[:12], plot_type="violin")
# plt.show()

---
## 6. Correlation Heatmap

Visualize correlations between features.

In [None]:
def plot_correlation_heatmap(
    df: pd.DataFrame,
    features: List[str],
    method: Literal["pearson", "spearman", "kendall"] = "pearson",
    figsize: Optional[tuple] = None,
    cmap: str = "RdBu_r",
) -> plt.Figure:
    """Plot correlation heatmap of features."""
    features = features[:30]  # Limit for readability
    corr_matrix = df[features].corr(method=method)
    
    if figsize is None:
        size = max(8, len(features) * 0.4)
        figsize = (size, size)
    
    fig, ax = plt.subplots(figsize=figsize)
    
    sns.heatmap(
        corr_matrix,
        annot=len(features) <= 15,
        fmt=".2f",
        cmap=cmap,
        center=0,
        vmin=-1,
        vmax=1,
        square=True,
        ax=ax,
        cbar_kws={"shrink": 0.8}
    )
    
    ax.set_title(f"Feature Correlation ({method.capitalize()})", fontsize=14)
    plt.xticks(rotation=45, ha='right', fontsize=8)
    plt.yticks(fontsize=8)
    plt.tight_layout()
    return fig

In [None]:
fig = plot_correlation_heatmap(data, feature_cols, method="pearson")

if OUTPUT_DIR:
    fig.savefig(f"{OUTPUT_DIR}/correlation_heatmap.png", dpi=DPI, bbox_inches='tight')
    print(f"Saved: {OUTPUT_DIR}/correlation_heatmap.png")
plt.show()

---
## 7. Clustered Heatmap (Clustermap)

Hierarchically clustered heatmap showing cells vs features.

In [None]:
def plot_clustermap(
    df: pd.DataFrame,
    features: List[str],
    n_cells: int = 1000,
    standardize: bool = True,
    figsize: tuple = (12, 10),
    cmap: str = "viridis",
) -> sns.matrix.ClusterGrid:
    """Plot clustered heatmap of cells vs features."""
    features = features[:30]
    
    # Sample cells if dataset is large
    if len(df) > n_cells:
        sample_data = df[features].sample(n=n_cells, random_state=42)
        print(f"Sampled {n_cells} cells from {len(df)} total")
    else:
        sample_data = df[features]
    
    # Standardize if requested
    if standardize:
        scaler = StandardScaler()
        plot_data = pd.DataFrame(
            scaler.fit_transform(sample_data),
            columns=features,
            index=sample_data.index
        )
    else:
        plot_data = sample_data
    
    g = sns.clustermap(
        plot_data,
        cmap=cmap,
        figsize=figsize,
        xticklabels=True,
        yticklabels=False,
        dendrogram_ratio=(0.1, 0.15),
        cbar_pos=(0.02, 0.8, 0.03, 0.15),
    )
    
    g.ax_heatmap.set_xlabel("Features", fontsize=12)
    g.ax_heatmap.set_ylabel(f"Cells (n={len(plot_data)})", fontsize=12)
    plt.setp(g.ax_heatmap.get_xticklabels(), rotation=45, ha='right', fontsize=8)
    
    return g

In [None]:
g = plot_clustermap(data, feature_cols, n_cells=SAMPLE_CELLS)

if OUTPUT_DIR:
    g.savefig(f"{OUTPUT_DIR}/clustermap.png", dpi=DPI, bbox_inches='tight')
    print(f"Saved: {OUTPUT_DIR}/clustermap.png")
plt.show()

---
## 8. Dimensionality Reduction (PCA, UMAP, t-SNE)

Reduce high-dimensional feature space to 2D for visualization.

In [None]:
def plot_dimensionality_reduction(
    df: pd.DataFrame,
    features: List[str],
    method: Literal["pca", "umap", "tsne"] = "pca",
    color_by: Optional[str] = None,
    n_components: int = 2,
    figsize: tuple = (10, 8),
    cmap: str = "tab20",
    alpha: float = 0.6,
    point_size: int = 10,
    **kwargs,
) -> plt.Figure:
    """Plot dimensionality reduction visualization."""
    # Prepare data
    X = df[features].dropna()
    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X)
    
    print(f"Running {method.upper()} on {len(X)} cells with {len(features)} features...")
    
    # Apply dimensionality reduction
    if method == "pca":
        reducer = PCA(n_components=n_components, **kwargs)
        embedding = reducer.fit_transform(X_scaled)
        var_explained = reducer.explained_variance_ratio_
        axis_labels = [f"PC{i+1} ({var_explained[i]:.1%})" for i in range(n_components)]
        print(f"Variance explained: {sum(var_explained):.1%}")
    elif method == "umap":
        if not UMAP_AVAILABLE:
            raise ImportError("UMAP not installed. Run: pip install umap-learn")
        reducer = umap.UMAP(n_components=n_components, random_state=42, **kwargs)
        embedding = reducer.fit_transform(X_scaled)
        axis_labels = [f"UMAP{i+1}" for i in range(n_components)]
    elif method == "tsne":
        if not TSNE_AVAILABLE:
            raise ImportError("scikit-learn not installed for t-SNE")
        reducer = TSNE(n_components=n_components, random_state=42, **kwargs)
        embedding = reducer.fit_transform(X_scaled)
        axis_labels = [f"t-SNE{i+1}" for i in range(n_components)]
    else:
        raise ValueError(f"Unknown method: {method}")
    
    # Create plot
    fig, ax = plt.subplots(figsize=figsize)
    
    # Determine coloring
    if color_by and color_by in df.columns:
        color_data = df.loc[X.index, color_by]
        
        if color_data.dtype == 'object' or color_data.nunique() < 20:
            # Categorical coloring
            categories = color_data.unique()
            colors = plt.cm.get_cmap(cmap)(np.linspace(0, 1, len(categories)))
            color_map = dict(zip(categories, colors))
            point_colors = [color_map[c] for c in color_data]
            
            scatter = ax.scatter(
                embedding[:, 0], embedding[:, 1],
                c=point_colors, alpha=alpha, s=point_size
            )
            
            # Add legend
            handles = [plt.scatter([], [], c=[color_map[cat]], label=cat)
                      for cat in categories]
            ax.legend(handles=handles, title=color_by,
                     bbox_to_anchor=(1.05, 1), loc='upper left')
        else:
            # Continuous coloring
            scatter = ax.scatter(
                embedding[:, 0], embedding[:, 1],
                c=color_data, cmap='viridis', alpha=alpha, s=point_size
            )
            plt.colorbar(scatter, ax=ax, label=color_by)
    else:
        scatter = ax.scatter(
            embedding[:, 0], embedding[:, 1],
            alpha=alpha, s=point_size, c='steelblue'
        )
    
    ax.set_xlabel(axis_labels[0], fontsize=12)
    ax.set_ylabel(axis_labels[1], fontsize=12)
    ax.set_title(f"{method.upper()} - {len(X)} cells", fontsize=14)
    
    plt.tight_layout()
    return fig

In [None]:
# PCA
fig = plot_dimensionality_reduction(data, feature_cols, method="pca", color_by=CLUSTER_COL)

if OUTPUT_DIR:
    fig.savefig(f"{OUTPUT_DIR}/pca.png", dpi=DPI, bbox_inches='tight')
    print(f"Saved: {OUTPUT_DIR}/pca.png")
plt.show()

In [None]:
# UMAP (if available)
if UMAP_AVAILABLE:
    fig = plot_dimensionality_reduction(
        data, feature_cols, method="umap", color_by=CLUSTER_COL,
        n_neighbors=15, min_dist=0.1
    )
    
    if OUTPUT_DIR:
        fig.savefig(f"{OUTPUT_DIR}/umap.png", dpi=DPI, bbox_inches='tight')
        print(f"Saved: {OUTPUT_DIR}/umap.png")
    plt.show()
else:
    print("UMAP not available. Install with: pip install umap-learn")

In [None]:
# t-SNE (optional - can be slow for large datasets)
# Uncomment to run

# if len(data) <= 5000:  # t-SNE is slow, limit cells
#     fig = plot_dimensionality_reduction(
#         data, feature_cols, method="tsne", color_by=CLUSTER_COL,
#         perplexity=30
#     )
#     if OUTPUT_DIR:
#         fig.savefig(f"{OUTPUT_DIR}/tsne.png", dpi=DPI, bbox_inches='tight')
#     plt.show()

---
## 9. Spatial Scatter Plot

Visualize cells in their spatial coordinates.

In [None]:
def plot_spatial_scatter(
    df: pd.DataFrame,
    x_col: str,
    y_col: str,
    color_by: Optional[str] = None,
    figsize: tuple = (10, 10),
    cmap: str = "tab20",
    alpha: float = 0.7,
    point_size: int = 20,
) -> plt.Figure:
    """Plot spatial scatter of cells."""
    fig, ax = plt.subplots(figsize=figsize)
    
    if color_by and color_by in df.columns:
        color_data = df[color_by]
        
        if color_data.dtype == 'object' or color_data.nunique() < 20:
            # Categorical
            categories = color_data.unique()
            colors = plt.cm.get_cmap(cmap)(np.linspace(0, 1, len(categories)))
            color_map = dict(zip(categories, colors))
            point_colors = [color_map[c] for c in color_data]
            
            scatter = ax.scatter(
                df[x_col], df[y_col],
                c=point_colors, alpha=alpha, s=point_size
            )
            
            handles = [plt.scatter([], [], c=[color_map[cat]], label=cat)
                      for cat in categories]
            ax.legend(handles=handles, title=color_by,
                     bbox_to_anchor=(1.05, 1), loc='upper left')
        else:
            # Continuous
            scatter = ax.scatter(
                df[x_col], df[y_col],
                c=color_data, cmap='viridis', alpha=alpha, s=point_size
            )
            plt.colorbar(scatter, ax=ax, label=color_by)
    else:
        scatter = ax.scatter(
            df[x_col], df[y_col],
            alpha=alpha, s=point_size, c='steelblue'
        )
    
    ax.set_xlabel(x_col, fontsize=12)
    ax.set_ylabel(y_col, fontsize=12)
    ax.set_title(f"Spatial Distribution ({len(df)} cells)", fontsize=14)
    ax.set_aspect('equal')
    ax.invert_yaxis()  # Common for image coordinates
    
    plt.tight_layout()
    return fig

In [None]:
if X_COORD_COL and Y_COORD_COL:
    fig = plot_spatial_scatter(data, X_COORD_COL, Y_COORD_COL, color_by=CLUSTER_COL)
    
    if OUTPUT_DIR:
        fig.savefig(f"{OUTPUT_DIR}/spatial_scatter.png", dpi=DPI, bbox_inches='tight')
        print(f"Saved: {OUTPUT_DIR}/spatial_scatter.png")
    plt.show()
else:
    print("No spatial columns available. Set X_COORD_COL and Y_COORD_COL in configuration.")

---
## 10. Feature Comparison Scatter Plot

Compare two features against each other.

In [None]:
def plot_feature_comparison(
    df: pd.DataFrame,
    x_feature: str,
    y_feature: str,
    color_by: Optional[str] = None,
    figsize: tuple = (8, 8),
    alpha: float = 0.5,
    point_size: int = 10,
) -> plt.Figure:
    """Create scatter plot comparing two features."""
    fig, ax = plt.subplots(figsize=figsize)
    
    if color_by and color_by in df.columns:
        scatter = ax.scatter(
            df[x_feature], df[y_feature],
            c=df[color_by], cmap='viridis',
            alpha=alpha, s=point_size
        )
        plt.colorbar(scatter, ax=ax, label=color_by)
    else:
        ax.scatter(
            df[x_feature], df[y_feature],
            alpha=alpha, s=point_size, c='steelblue'
        )
    
    ax.set_xlabel(x_feature, fontsize=12)
    ax.set_ylabel(y_feature, fontsize=12)
    ax.set_title(f"{x_feature} vs {y_feature}", fontsize=14)
    
    plt.tight_layout()
    return fig

In [None]:
# Compare the first two features (modify as needed)
if len(feature_cols) >= 2:
    fig = plot_feature_comparison(data, feature_cols[0], feature_cols[1], color_by=CLUSTER_COL)
    
    if OUTPUT_DIR:
        fig.savefig(f"{OUTPUT_DIR}/feature_comparison.png", dpi=DPI, bbox_inches='tight')
        print(f"Saved: {OUTPUT_DIR}/feature_comparison.png")
    plt.show()

---
## 11. Cluster Summary (if cluster column available)

Comprehensive visualization of cluster assignments.

In [None]:
def plot_cluster_summary(
    df: pd.DataFrame,
    cluster_col: str,
    features: List[str],
    figsize: tuple = (14, 10),
) -> plt.Figure:
    """Create summary visualization for cluster analysis."""
    features = features[:10]
    
    fig, axes = plt.subplots(2, 2, figsize=figsize)
    
    # 1. Cluster sizes (bar plot)
    ax1 = axes[0, 0]
    cluster_counts = df[cluster_col].value_counts().sort_index()
    cluster_counts.plot(kind='bar', ax=ax1, color='steelblue', edgecolor='black')
    ax1.set_xlabel("Cluster", fontsize=10)
    ax1.set_ylabel("Cell Count", fontsize=10)
    ax1.set_title("Cluster Sizes", fontsize=12)
    ax1.tick_params(axis='x', rotation=45)
    
    # 2. Cluster proportions (pie chart)
    ax2 = axes[0, 1]
    cluster_counts.plot(
        kind='pie', ax=ax2, autopct='%1.1f%%',
        startangle=90, labels=None
    )
    ax2.set_ylabel("")
    ax2.set_title("Cluster Proportions", fontsize=12)
    ax2.legend(cluster_counts.index, loc='center left', bbox_to_anchor=(1, 0.5))
    
    # 3. Mean feature expression by cluster (heatmap)
    ax3 = axes[1, 0]
    cluster_means = df.groupby(cluster_col)[features].mean()
    
    # Z-score normalize
    cluster_means_z = (cluster_means - cluster_means.mean()) / cluster_means.std()
    
    sns.heatmap(
        cluster_means_z.T, ax=ax3, cmap='RdBu_r', center=0,
        xticklabels=True, yticklabels=True,
        cbar_kws={'label': 'Z-score'}
    )
    ax3.set_xlabel("Cluster", fontsize=10)
    ax3.set_ylabel("Feature", fontsize=10)
    ax3.set_title("Mean Feature Expression (Z-scored)", fontsize=12)
    plt.setp(ax3.get_xticklabels(), rotation=45, ha='right', fontsize=8)
    plt.setp(ax3.get_yticklabels(), fontsize=8)
    
    # 4. Feature boxplots by cluster (for top feature by variance)
    ax4 = axes[1, 1]
    top_feature = df[features].var().idxmax()
    sns.boxplot(data=df, x=cluster_col, y=top_feature, ax=ax4)
    ax4.set_xlabel("Cluster", fontsize=10)
    ax4.set_ylabel(top_feature, fontsize=10)
    ax4.set_title(f"Distribution of {top_feature} by Cluster", fontsize=12)
    ax4.tick_params(axis='x', rotation=45)
    
    plt.suptitle(f"Cluster Summary (n={len(df)} cells)", fontsize=14, y=1.02)
    plt.tight_layout()
    return fig

In [None]:
if CLUSTER_COL and CLUSTER_COL in data.columns:
    fig = plot_cluster_summary(data, CLUSTER_COL, feature_cols)
    
    if OUTPUT_DIR:
        fig.savefig(f"{OUTPUT_DIR}/cluster_summary.png", dpi=DPI, bbox_inches='tight')
        print(f"Saved: {OUTPUT_DIR}/cluster_summary.png")
    plt.show()
else:
    print("No cluster column specified. Set CLUSTER_COL in configuration to generate this plot.")

---
## 12. Summary

List all generated visualizations.

In [None]:
if OUTPUT_DIR:
    output_path = Path(OUTPUT_DIR)
    saved_files = list(output_path.glob("*.png"))
    
    print(f"\n{'='*50}")
    print(f"VISUALIZATION SUMMARY")
    print(f"{'='*50}")
    print(f"\nDataset: {INPUT_CSV}")
    print(f"Cells: {len(data)}")
    print(f"Features: {len(feature_cols)}")
    print(f"\nSaved files ({len(saved_files)}):")
    for f in sorted(saved_files):
        print(f"  - {f.name}")
    print(f"\nOutput directory: {OUTPUT_DIR}")
else:
    print("Plots displayed but not saved. Set OUTPUT_DIR to save plots.")