# Quality Control and Cell Filtering

## Overview
This notebook performs quality control (QC) and filtering on scRNA-seq data. This is a critical step to remove low-quality cells and genes before downstream analysis.

### Objectives
1. Calculate QC metrics (gene counts, UMI counts, mitochondrial/ribosomal content)
2. Visualize QC distributions
3. Apply adaptive thresholds for filtering
4. Remove low-quality cells and uninformative genes

### QC Metrics
- **n_genes_by_counts**: Number of genes detected per cell
- **total_counts**: Total UMI counts per cell
- **pct_counts_mt**: Percentage of mitochondrial genes (indicates cell stress/death)
- **pct_counts_ribo**: Percentage of ribosomal genes

---

## 1. Setup

In [None]:
import scanpy as sc
import anndata as ad
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import yaml
import warnings

warnings.filterwarnings('ignore')

# Scanpy settings
sc.settings.verbosity = 3
sc.settings.set_figure_params(dpi=100, facecolor='white')

# Project paths
PROJECT_ROOT = Path("../..").resolve()
DATA_RAW = PROJECT_ROOT / 'data' / 'raw' / 'scrna'
DATA_PROCESSED = PROJECT_ROOT / 'data' / 'processed' / 'scrna'
FIGURES = PROJECT_ROOT / 'results' / 'figures'
CONFIG_PATH = PROJECT_ROOT / 'config' / 'analysis_params.yaml'

# Load configuration
with open(CONFIG_PATH, 'r') as f:
    config = yaml.safe_load(f)

# QC parameters
qc_params = config['qc']
print("QC Parameters:")
for key, value in qc_params.items():
    print(f"  {key}: {value}")

## 2. Define QC Functions

In [None]:
def calculate_qc_metrics(adata, species='human'):
    """
    Calculate comprehensive QC metrics for scRNA-seq data.
    
    Parameters
    ----------
    adata : AnnData
        Annotated data matrix
    species : str
        Species for mitochondrial gene identification ('human' or 'mouse')
    
    Returns
    -------
    AnnData
        Data with QC metrics added to .obs and .var
    """
    # Ensure gene names are strings
    adata.var_names = adata.var_names.astype(str)
    
    # Identify mitochondrial genes
    if species == 'human':
        adata.var['mt'] = adata.var_names.str.upper().str.startswith('MT-')
    else:  # mouse
        adata.var['mt'] = adata.var_names.str.startswith('mt-')
    
    # Identify ribosomal genes
    adata.var['ribo'] = adata.var_names.str.upper().str.match('^RP[SL]')
    
    # Identify hemoglobin genes
    adata.var['hb'] = adata.var_names.str.upper().str.match('^HB[^(P)]')
    
    # Calculate QC metrics
    sc.pp.calculate_qc_metrics(
        adata,
        qc_vars=['mt', 'ribo', 'hb'],
        percent_top=None,
        log1p=False,
        inplace=True
    )
    
    # Log transform of counts for visualization
    adata.obs['log1p_total_counts'] = np.log1p(adata.obs['total_counts'])
    adata.obs['log1p_n_genes_by_counts'] = np.log1p(adata.obs['n_genes_by_counts'])
    
    return adata


def plot_qc_metrics(adata, sample_name='Sample'):
    """
    Create comprehensive QC visualization plots.
    
    Parameters
    ----------
    adata : AnnData
        Data with QC metrics
    sample_name : str
        Sample name for plot titles
    """
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    fig.suptitle(f'QC Metrics: {sample_name}', fontsize=14, y=1.02)
    
    # Total counts
    axes[0, 0].hist(adata.obs['total_counts'], bins=100, color='steelblue', edgecolor='black')
    axes[0, 0].set_xlabel('Total counts')
    axes[0, 0].set_ylabel('Frequency')
    axes[0, 0].set_title('Total UMI counts per cell')
    axes[0, 0].axvline(qc_params['min_counts'], color='red', linestyle='--', label=f"Min: {qc_params['min_counts']}")
    axes[0, 0].axvline(qc_params['max_counts'], color='red', linestyle='--', label=f"Max: {qc_params['max_counts']}")
    axes[0, 0].legend()
    
    # Genes per cell
    axes[0, 1].hist(adata.obs['n_genes_by_counts'], bins=100, color='steelblue', edgecolor='black')
    axes[0, 1].set_xlabel('Number of genes')
    axes[0, 1].set_ylabel('Frequency')
    axes[0, 1].set_title('Genes detected per cell')
    axes[0, 1].axvline(qc_params['min_genes'], color='red', linestyle='--', label=f"Min: {qc_params['min_genes']}")
    axes[0, 1].axvline(qc_params['max_genes'], color='red', linestyle='--', label=f"Max: {qc_params['max_genes']}")
    axes[0, 1].legend()
    
    # Mitochondrial percentage
    axes[0, 2].hist(adata.obs['pct_counts_mt'], bins=100, color='indianred', edgecolor='black')
    axes[0, 2].set_xlabel('% mitochondrial')
    axes[0, 2].set_ylabel('Frequency')
    axes[0, 2].set_title('Mitochondrial content')
    axes[0, 2].axvline(qc_params['max_mito_pct'], color='red', linestyle='--', label=f"Max: {qc_params['max_mito_pct']}%")
    axes[0, 2].legend()
    
    # Scatter: counts vs genes
    axes[1, 0].scatter(adata.obs['total_counts'], adata.obs['n_genes_by_counts'], 
                       c=adata.obs['pct_counts_mt'], cmap='RdYlBu_r', s=1, alpha=0.5)
    axes[1, 0].set_xlabel('Total counts')
    axes[1, 0].set_ylabel('Number of genes')
    axes[1, 0].set_title('Counts vs Genes (colored by % MT)')
    
    # Scatter: counts vs MT
    axes[1, 1].scatter(adata.obs['total_counts'], adata.obs['pct_counts_mt'], s=1, alpha=0.3)
    axes[1, 1].set_xlabel('Total counts')
    axes[1, 1].set_ylabel('% mitochondrial')
    axes[1, 1].set_title('Counts vs Mitochondrial %')
    axes[1, 1].axhline(qc_params['max_mito_pct'], color='red', linestyle='--')
    
    # Ribosomal percentage
    axes[1, 2].hist(adata.obs['pct_counts_ribo'], bins=100, color='forestgreen', edgecolor='black')
    axes[1, 2].set_xlabel('% ribosomal')
    axes[1, 2].set_ylabel('Frequency')
    axes[1, 2].set_title('Ribosomal content')
    
    plt.tight_layout()
    return fig


def filter_cells_genes(adata, params):
    """
    Apply QC filters to remove low-quality cells and genes.
    
    Parameters
    ----------
    adata : AnnData
        Data with QC metrics
    params : dict
        QC parameters
    
    Returns
    -------
    AnnData
        Filtered data
    """
    n_cells_before = adata.n_obs
    n_genes_before = adata.n_vars
    
    # Filter cells based on QC metrics
    sc.pp.filter_cells(adata, min_genes=params['min_genes'])
    sc.pp.filter_cells(adata, min_counts=params['min_counts'])
    
    # Additional filtering
    adata = adata[
        (adata.obs['n_genes_by_counts'] < params['max_genes']) &
        (adata.obs['total_counts'] < params['max_counts']) &
        (adata.obs['pct_counts_mt'] < params['max_mito_pct'])
    ].copy()
    
    # Filter genes
    sc.pp.filter_genes(adata, min_cells=params['min_cells'])
    
    n_cells_after = adata.n_obs
    n_genes_after = adata.n_vars
    
    print(f"Cells: {n_cells_before} -> {n_cells_after} ({n_cells_before - n_cells_after} removed)")
    print(f"Genes: {n_genes_before} -> {n_genes_after} ({n_genes_before - n_genes_after} removed)")
    
    return adata

print("QC functions defined")

## 3. Load and Process Dataset

Process each dataset from the atlas. Here we show the template for one dataset.

In [None]:
# Example: Load GSE115978 Melanoma dataset
# Modify path based on your downloaded data format

geo_id = "GSE115978"
data_path = DATA_RAW / geo_id

print(f"Processing: {geo_id}")
print(f"Data path: {data_path}")

# Check if data exists
if data_path.exists():
    print(f"Files found: {list(data_path.glob('*'))}")
else:
    print(f"Data not found. Please download first.")

In [None]:
# Template for loading data - adapt based on file format
# Uncomment appropriate loader:

# For 10x h5 files:
# adata = sc.read_10x_h5(data_path / 'filtered_feature_bc_matrix.h5')

# For 10x mtx directories:
# adata = sc.read_10x_mtx(data_path)

# For h5ad files:
# adata = sc.read_h5ad(data_path / 'data.h5ad')

# For CSV/TSV matrices:
# adata = sc.read_csv(data_path / 'counts.csv').T

# Placeholder for demonstration
print("Uncomment the appropriate data loader for your file format")

## 4. Calculate QC Metrics

In [None]:
# Calculate QC metrics
# adata = calculate_qc_metrics(adata)

# print(f"\nDataset summary:")
# print(f"  Cells: {adata.n_obs}")
# print(f"  Genes: {adata.n_vars}")
# print(f"\nQC metrics summary:")
# print(adata.obs[['total_counts', 'n_genes_by_counts', 'pct_counts_mt', 'pct_counts_ribo']].describe())

print("Uncomment after loading data")

## 5. Visualize QC Metrics

In [None]:
# Plot QC metrics
# fig = plot_qc_metrics(adata, sample_name=geo_id)
# fig.savefig(FIGURES / f'{geo_id}_qc_metrics.png', dpi=150, bbox_inches='tight')
# plt.show()

print("Uncomment after loading data")

In [None]:
# Violin plots using scanpy
# sc.pl.violin(
#     adata,
#     ['n_genes_by_counts', 'total_counts', 'pct_counts_mt', 'pct_counts_ribo'],
#     jitter=0.4,
#     multi_panel=True
# )

print("Uncomment after loading data")

## 6. Apply Filters

In [None]:
# Apply QC filters
# print(f"\nApplying QC filters to {geo_id}...")
# adata_filtered = filter_cells_genes(adata, qc_params)

print("Uncomment after loading data")

In [None]:
# Plot post-filtering QC
# fig = plot_qc_metrics(adata_filtered, sample_name=f"{geo_id} (filtered)")
# fig.savefig(FIGURES / f'{geo_id}_qc_metrics_filtered.png', dpi=150, bbox_inches='tight')
# plt.show()

print("Uncomment after filtering")

## 7. Save Filtered Data

In [None]:
# Save filtered data
# output_path = DATA_PROCESSED / f'{geo_id}_filtered.h5ad'
# adata_filtered.write(output_path)
# print(f"Saved filtered data to: {output_path}")

print("Uncomment after filtering")

## 8. Batch Processing Template

Use this template to process multiple datasets in batch.

In [None]:
def process_dataset(geo_id, data_path, output_dir, params):
    """
    Process a single dataset through QC pipeline.
    
    Parameters
    ----------
    geo_id : str
        GEO accession ID
    data_path : Path
        Path to raw data
    output_dir : Path
        Output directory for processed data
    params : dict
        QC parameters
    
    Returns
    -------
    dict
        Processing statistics
    """
    print(f"\nProcessing {geo_id}...")
    
    # Load data (adapt based on format)
    # adata = load_data(data_path)
    
    # Calculate QC
    # adata = calculate_qc_metrics(adata)
    
    # Filter
    # adata_filtered = filter_cells_genes(adata, params)
    
    # Save
    # output_path = output_dir / f'{geo_id}_filtered.h5ad'
    # adata_filtered.write(output_path)
    
    # Return stats
    # return {
    #     'geo_id': geo_id,
    #     'cells_before': adata.n_obs,
    #     'cells_after': adata_filtered.n_obs,
    #     'genes_after': adata_filtered.n_vars
    # }
    
    return None

print("Batch processing function defined")

## 9. Summary

### Completed
- Defined QC metric calculation functions
- Created visualization utilities
- Established filtering pipeline

### Key Parameters Used
- Min genes: {min_genes}
- Max genes: {max_genes}
- Min counts: {min_counts}
- Max counts: {max_counts}
- Max mitochondrial %: {max_mito_pct}

### Next Steps
1. Process all downloaded datasets
2. Continue to `02b_normalization_hvg.ipynb` for normalization
3. Run doublet detection in `02c_doublet_detection.ipynb`