In [28]:
import scanpy as sc
import scanpy.external as sce
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import anndata as ad
from scipy import sparse
from scipy.spatial import cKDTree
from scipy.stats import percentileofscore
import warnings
import logging
import sys
import bbknn
import gc

In [29]:
# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)

# Suppress specific warnings
warnings.filterwarnings("ignore", category=FutureWarning)
sc.settings.verbosity = 1


In [30]:
# bulk_path='/private/groups/russelllab/jodie/wolbachia_induced_DE/scanpy_clustering/scanpy_objects/bulk_adata.h5ad'
# ref_path='/private/groups/russelllab/jodie/wolbachia_induced_DE/scanpy_clustering/scanpy_objects/blood_adata.h5ad'
# output_dir='/private/groups/russelllab/jodie/wolbachia_induced_DE/wolbachia_induced_differentiation/scripts/celltype_clustering/claude/MNN_kNN_tree/blood_atlas'
# annotation_key='subclustering'
# k_neighbors=None
# num_permutations=1000
# seed=42

bulk_path='/private/groups/russelllab/jodie/wolbachia_induced_DE/scanpy_clustering/scanpy_objects/bulk_adata.h5ad'
ref_path='/private/groups/russelllab/jodie/wolbachia_induced_DE/scanpy_clustering/scanpy_objects/combined_germline_sg_trachea.h5ad'
output_dir='/private/groups/russelllab/jodie/wolbachia_induced_DE/wolbachia_induced_differentiation/scripts/celltype_clustering/claude/MNN_kNN_tree/embryo_atlas_germline'
annotation_key='subtypes'
k_neighbors=None
num_permutations=1000
seed=42
mem = 1024

In [31]:
# Color map to match final figures
color_dict={
    'JW18DOX':'#87de87', # green
    'JW18wMel':'#00aa44',  # dark green
    'S2DOX':'#ffb380', # orange
    'S2wMel':'#d45500' # dark orange

}

In [32]:
"""Set up output directory and plotting parameters."""
np.random.seed(seed)

# Create output directories
os.makedirs(output_dir, exist_ok=True)
plots_dir = os.path.join(output_dir, 'plots')
os.makedirs(plots_dir, exist_ok=True)

# Set up log file
log_file = os.path.join(output_dir, 'integration_log.txt')
file_handler = logging.FileHandler(log_file)
file_handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
logger.addHandler(file_handler)

# Set scanpy settings
sc.settings.figdir = plots_dir
sc.settings.set_figure_params(dpi=300, frameon=False, figsize=(10, 8), facecolor='white')

# Define custom color palette for cell types
# custom_palette = sns.color_palette("husl", 100)  # Generate a large color palette

In [None]:

def load_and_validate_data(bulk_path, ref_path):
    """Load and validate input AnnData objects."""
    logger.info("Loading data files...")
    
    # Load bulk data
    try:
        bulk_adata = sc.read_h5ad(bulk_path)
        logger.info(f"Bulk dataset loaded: {bulk_adata.shape} (samples × genes)")
    except Exception as e:
        logger.error(f"Error loading bulk data: {e}")
        raise
    
    # Load reference data
    try:
        ref_adata = sc.read_h5ad(ref_path)
        logger.info(f"Reference dataset loaded: {ref_adata.shape} (cells × genes)")
    except Exception as e:
        logger.error(f"Error loading reference data: {e}")
        raise
    
    # Ensure unique gene names
    bulk_adata.var_names_make_unique()
    ref_adata.var_names_make_unique()
    
    # Check for shared genes
    shared_genes = bulk_adata.var_names.intersection(ref_adata.var_names)
    if len(shared_genes) == 0:
        logger.error("No shared genes between bulk and reference datasets!")
        raise ValueError("No shared genes between datasets")
    else:
        logger.info(f"Number of shared genes: {len(shared_genes)}")
    
    return bulk_adata, ref_adata


def preprocess_data(bulk_adata, ref_adata, annotation_key):
    """Preprocess AnnData objects and prepare for integration."""
    logger.info("Preprocessing datasets...")
    
    # Make copies to avoid modifying the originals
    bulk = bulk_adata.copy()
    ref = ref_adata.copy()
    
    # Add dataset labels for batch correction
    bulk.obs["dataset"] = "bulk"
    ref.obs["dataset"] = "reference"
    
    # Find shared genes
    shared_genes = bulk.var_names.intersection(ref.var_names)
    logger.info(f"Using {len(shared_genes)} shared genes")
    
    # Subset to shared genes
    bulk_subset = bulk[:, shared_genes].copy()
    ref_subset = ref[:, shared_genes].copy()
    
    # Check if annotation key exists in reference data
    if annotation_key not in ref_subset.obs.columns:
        logger.error(f"Annotation key '{annotation_key}' not found in reference data.")
        available_keys = list(ref_subset.obs.columns)
        logger.error(f"Available keys: {available_keys}")
        raise KeyError(f"Annotation key '{annotation_key}' not found in reference data")
    
    return bulk_subset, ref_subset


def integrate_datasets(bulk_adata, ref_adata):
    """Integrate bulk and single-cell datasets."""
    logger.info("Integrating datasets...")
    
    # Concatenate datasets
    combined = ad.concat([ref_adata, bulk_adata], join="outer", merge="first")
    logger.info(f"Combined dataset shape: {combined.shape}")
    
    # Since data is already normalized and log-transformed, we skip those steps
    logger.info("Data is already normalized and log-transformed")
    
    # Ensure no NaN values that could cause issues
    if sparse.issparse(combined.X):
        combined.X = sparse.csr_matrix(np.nan_to_num(combined.X.toarray(), nan=0, posinf=0, neginf=0))
    else:
        combined.X = np.nan_to_num(combined.X, nan=0, posinf=0, neginf=0)
    
    # # # Apply BBKNN batch correction
    # logger.info("Applying BBKNN batch correction...")
    # # Perform batch correction with BBKNN
    # bbknn.bbknn(combined, batch_key='dataset')

    # Apply batch correction using MNN
    logger.info("Applying MNN batch correction...")
    try:
        corrected = sce.pp.mnn_correct(combined, batch_key="dataset", return_only_var_genes=False)
        # MNN returns a tuple with the corrected AnnData as the first element
        corrected_adata = corrected[0]
        logger.info("MNN batch correction completed")
        print(type(corrected_adata))
        return corrected_adata
    except Exception as e:
        logger.error(f"MNN batch correction failed: {e}")
        logger.warning("Continuing without batch correction")
        return combined
    # return combined
    

def compute_p_value(neighbor_labels, assigned_label, k, num_permutations=1000):
    """
    Compute p-value by shuffling labels and checking how often
    the assigned label appears by chance.
    """
    # Convert to numpy array for efficient operations
    neighbor_labels = np.array(neighbor_labels)
    simulated_counts = []
    
    for _ in range(num_permutations):
        shuffled_labels = np.random.permutation(neighbor_labels)
        simulated_counts.append((shuffled_labels == assigned_label).sum() / k)
    
    observed_prob = (neighbor_labels == assigned_label).sum() / k
    p_value = (100 - percentileofscore(simulated_counts, observed_prob)) / 100
    
    return p_value


def determine_optimal_k(ref_adata, annotation_key):
    """Determine the optimal k value for kNN classification."""
    num_ref_cells = ref_adata.shape[0]
    
    # Get the size of the smallest class
    try:
        min_class_size = ref_adata.obs[annotation_key].value_counts().min()
    except:
        logger.warning("Could not compute min class size, using default")
        min_class_size = 100
    
    # Calculate potential k values:
    # 1. Square root of number of reference cells
    # 2. 10% of smallest class size
    k_sqrt = int(np.sqrt(num_ref_cells))
    k_10pct = int(min_class_size * 0.1)
    
    # Use the smaller of the two values, but ensure k is at least 5
    k = max(5, min(k_sqrt, k_10pct))
    
    logger.info(f"Automatically determined k = {k} (sqrt(n) = {k_sqrt}, 10% of smallest class = {k_10pct})")
    
    return k


# def kNN_classifier(combined_adata, ref_label_key, k, num_permutations=1000):
#     """
#     Classify bulk cells based on their k nearest neighbors in the reference dataset.
#     """
#     logger.info(f"Performing kNN classification with k={k}...")
    
#     # Identify reference and bulk cells
#     ref_indices = combined_adata.obs["dataset"] == "reference"
#     bulk_indices = combined_adata.obs["dataset"] == "bulk"
    
#     # Get indices as arrays
#     ref_idx = np.where(ref_indices)[0]
#     bulk_idx = np.where(bulk_indices)[0]
    
#     logger.info(f"Reference cells: {len(ref_idx)}, Bulk samples: {len(bulk_idx)}")
    
#     # Extract data for classification
#     try:
#         # Handle sparse matrices if needed
#         if sparse.issparse(combined_adata.X):
#             X_ref = combined_adata.X[ref_idx].toarray()
#             X_bulk = combined_adata.X[bulk_idx].toarray()
#         else:
#             X_ref = combined_adata.X[ref_idx]
#             X_bulk = combined_adata.X[bulk_idx]
        
#         # Handle any NaNs or infs
#         X_ref = np.nan_to_num(X_ref, nan=0, posinf=0, neginf=0)
#         X_bulk = np.nan_to_num(X_bulk, nan=0, posinf=0, neginf=0)
        
#         # Build kd-tree for efficient nearest neighbor search
#         tree = cKDTree(X_ref)
#         distances, neighbor_idx = tree.query(X_bulk, k=k)
        
#         # Get reference cell labels
#         ref_labels = combined_adata.obs.loc[ref_indices, ref_label_key].values
        
#         results = []
#         for i, (dists, neighbors) in enumerate(zip(distances, neighbor_idx)):
#             # Get labels of k nearest neighbors
#             neighbor_labels = ref_labels[neighbors]
            
#             # Determine the most frequent label (majority vote)
#             unique_labels, counts = np.unique(neighbor_labels, return_counts=True)
#             assigned_label = unique_labels[np.argmax(counts)]
#             max_count = counts[np.argmax(counts)]
            
#             # Calculate confidence score (percentage of neighbors with the assigned label)
#             confidence = max_count / k
            
#             # Compute p-value with permutation test
#             p_value = compute_p_value(neighbor_labels, assigned_label, k, num_permutations)
            
#             # Store results
#             results.append({
#                 "Bulk_Sample": combined_adata.obs.index[bulk_idx[i]],
#                 "Predicted_Label": assigned_label,
#                 "Confidence": confidence,
#                 "P_value": p_value,
#                 "Nearest_Neighbors": list(combined_adata.obs.index[ref_idx[neighbors]]),
#                 "Neighbor_Labels": list(neighbor_labels),
#                 "Neighbor_Distances": list(dists)
#             })
        
#         # Create results DataFrame
#         results_df = pd.DataFrame(results)
        
#         # Add predicted labels to the combined dataset
#         for i, idx in enumerate(bulk_idx):
#             combined_adata.obs.loc[combined_adata.obs.index[idx], ref_label_key] = results_df.iloc[i]["Predicted_Label"]
        
#         logger.info("kNN classification completed")
#         return results_df, combined_adata
    
#     except Exception as e:
#         logger.error(f"kNN classification failed: {e}")
#         import traceback
#         logger.error(traceback.format_exc())
#         raise



# kNN-based classification
def kNN_classifier_with_stats(combined_adata, ref_label_key, k):
    """
    Classifies unknown cells based on their k nearest neighbors in the reference dataset.
    Reports neighbor indices, distances, cell type assignments, and statistical confidence.

    Returns:
    - DataFrame with predicted labels, neighbor information, and z-scores.
    """

    print("Performing kNN-based classification...")

    ref_indices = combined_adata.obs["dataset"] == "Reference"
    unknown_indices = combined_adata.obs["dataset"] == "Unknown"

    # Extract gene expression space
    X_ref = combined_adata[ref_indices].X.toarray() if hasattr(combined_adata.X, "toarray") else combined_adata[ref_indices].X
    X_unknown = combined_adata[unknown_indices].X.toarray() if hasattr(combined_adata.X, "toarray") else combined_adata[unknown_indices].X

    # Construct k-d tree for nearest neighbor searching
    tree = cKDTree(X_ref)
    distances, indices = tree.query(X_unknown, k=k)

    # Retrieve reference cell labels
    ref_labels = combined_adata.obs.loc[ref_indices, ref_label_key].values

    results = []
    for i, (dists, neighbor_indices) in enumerate(zip(distances, indices)):
        neighbor_labels = ref_labels[neighbor_indices]
        assigned_label = pd.Series(neighbor_labels).mode()[0]

        # Compute probability
        cell_type_counts = pd.Series(neighbor_labels).value_counts()
        expected_prob = cell_type_counts / k
        observed_prob = expected_prob.loc[assigned_label]
        z_score = (observed_prob - expected_prob.mean()) / expected_prob.std()

        # Compute p-value with permutation test
        p_value = compute_p_value(neighbor_labels, assigned_label, k)

        results.append({
            "Unknown_Cell": combined_adata.obs.index[unknown_indices][i],
            "Predicted_Label": assigned_label,
            "Neighbor_Cells": list(combined_adata.obs.index[ref_indices][neighbor_indices]),
            "Neighbor_Types": list(neighbor_labels),
            "Distances": list(dists),
            "Z-score": z_score,
            "P-value": p_value
        })

    return pd.DataFrame(results), combined_adata

def visualize_integration(combined_adata, annotation_key, plots_dir):
    """Generate UMAP visualizations of the integrated data."""
    logger.info("Generating UMAP visualizations...")
    
    # Compute PCA
    sc.pp.pca(combined_adata, svd_solver='arpack')
    
    # Compute neighborhood graph
    sc.pp.neighbors(combined_adata, n_neighbors=15, n_pcs=30)
    
    # Compute UMAP embedding
    sc.tl.umap(combined_adata)
    
    # Save plots
    sc.pl.umap(combined_adata, color='dataset', title='Dataset (bulk vs reference)',
               save='_dataset.pdf')
    
    sc.pl.umap(combined_adata, color=annotation_key, title=f'Cell Types ({annotation_key})',
               save=f'_{annotation_key}.pdf')
    
    # Run leiden clustering
    sc.tl.leiden(combined_adata, resolution=0.8)
    sc.pl.umap(combined_adata, color='leiden', title='Leiden Clusters',
               save='_leiden.pdf')
    
    # Create a custom plot highlighting bulk samples
    fig, ax = plt.subplots(figsize=(10, 8))
    
    # Plot reference cells in gray
    ref_mask = combined_adata.obs['dataset'] == 'reference'
    ax.scatter(
        combined_adata.obsm['X_umap'][ref_mask, 0],
        combined_adata.obsm['X_umap'][ref_mask, 1],
        c='lightgray', s=5, alpha=0.5, label='Reference'
    )
    
    # Plot bulk samples with distinct colors based on their predicted cell type
    bulk_mask = combined_adata.obs['dataset'] == 'bulk'
    bulk_cell_types = combined_adata.obs.loc[bulk_mask, annotation_key].astype('category')
    
    for ct in bulk_cell_types.cat.categories:
        ct_mask = (combined_adata.obs['dataset'] == 'bulk') & (combined_adata.obs[annotation_key] == ct)
        ax.scatter(
            combined_adata.obsm['X_umap'][ct_mask, 0],
            combined_adata.obsm['X_umap'][ct_mask, 1],
            s=100, alpha=0.9, label=f'Bulk - {ct}'
        )
    
    ax.set_title('UMAP - Bulk Samples Highlighted')
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.tight_layout()
    plt.savefig(os.path.join(plots_dir, 'bulk_samples_highlighted.pdf'), bbox_inches='tight')
    plt.close()
    
    logger.info("UMAP visualizations completed")
    
    return combined_adata



In [34]:


def subsample_celltypes(adata, annotation_key, max_cells_per_type=1000):
    """Subsample cell types to ensure even representation with a max cap."""
    # Get the cell type counts
    celltype_counts = adata.obs[annotation_key].value_counts()

    # Get the minimum cell type count, but cap at max_cells_per_type
    min_count = min(celltype_counts.min(), max_cells_per_type)

    # Create a list to store the subsampled AnnData objects
    subsampled_adatas = []

    # Iterate over each cell type
    for celltype in celltype_counts.index:
        # Get the indices for the current cell type
        celltype_indices = adata.obs[annotation_key] == celltype
        celltype_idx = np.where(celltype_indices)[0]
        
        # Determine sample size (minimum of actual count or min_count)
        sample_size = min(len(celltype_idx), min_count)
        
        # Subsample the current cell type
        if len(celltype_idx) > sample_size:
            chosen_idx = np.random.choice(celltype_idx, sample_size, replace=False)
            subset_idx = np.zeros(adata.shape[0], dtype=bool)
            subset_idx[chosen_idx] = True
            subsampled_adata = adata[subset_idx].copy()
        else:
            subsampled_adata = adata[celltype_indices].copy()
            
        # Append the subsampled AnnData object to the list
        subsampled_adatas.append(subsampled_adata)
        
        # Force garbage collection
        gc.collect()

    # Concatenate the subsampled AnnData objects
    logger.info(f"Concatenating {len(subsampled_adatas)} subsampled datasets")
    subsampled_adata = ad.concat(subsampled_adatas, join='inner', index_unique='-')
    
    # Clean up to free memory
    del subsampled_adatas
    gc.collect()

    return subsampled_adata

In [35]:

def plot_UMAP(combined_adata, annotation):
    # Perform UMAP
    # neighbors_rank(combined_adata)

    # Plot UMAP, save to file
    sc.pl.umap(combined_adata, color=['dataset'], save='combine-dataset.pdf')
    # Visualization
    sc.pl.umap(combined_adata, color=[annotation], save='combined-tissue.pdf')

    #Annotate UMAP with larger markers
    # Plot the UMAP
    sc.pl.umap(combined_adata, color=[annotation], show=False)

    # Get the UMAP coordinates
    umap_coords = combined_adata.obsm['X_umap']

    # Get your sample labels from the data
    labels = combined_adata.obs['Sample']

    # Iterate over each point and add a label if it's not NA
    for idx, label in enumerate(labels):
        if pd.notna(label):  # Check if the label is not NA

            plt.plot(umap_coords[idx, 0], umap_coords[idx, 1], color=color_dict[label[:-8]], marker='o', markersize=3, alpha=0.5)
        # Remove background grid and ticks for a cleaner look
    plt.grid(False)

    # Adjust layout to fit the legend outside
    plt.tight_layout(rect=[0, 0, 0.85, 1])  # Leaves space for legend on the right
    plt.savefig('combined_dataset_samples_and_tissue.pdf', dpi=600, bbox_inches="tight")
    plt.close() 

    # # Save the plot
    # plt.savefig('marker_gene_clustering.pdf', dpi=600)
    # plt.close()

    # # Plot Leiden clustering
    # sc.pl.umap(combined_adata, color=['leiden'], save='leiden_clustering.pdf')

    # # Identify marker genes
    # identify_marker_genes(combined_adata, annotation)

def identify_marker_genes(adata, annotation):
    sc.tl.rank_genes_groups(adata, annotation, method='wilcoxon') #Find marker genes by tissue instead of by leiden clustering (done earlier)
    marker_genes = adata.uns['rank_genes_groups']['names']

    # Get the top N genes for each cluster
    n_top_genes = 5
    top_genes = pd.DataFrame(marker_genes).iloc[:n_top_genes]

    sc.pl.rank_genes_groups_dotplot(
        adata,
        groupby=annotation,  # Use tissue labels for grouping instead of 'leiden'
        n_genes=4,
        values_to_plot="logfoldchanges", cmap='bwr', #changed from 'viridis'  
        vmin=-4,
        vmax=4,
        min_logfoldchange=3,
        colorbar_title='log fold change'
    )
    plt.savefig('marker_genes_by_tissue.pdf')
    plt.close()

In [36]:

# Log start of processing
logger.info("Starting Scanpy bulk to single-cell integration pipeline")
logger.info(f"Bulk data: {bulk_path}")
logger.info(f"Reference data: {ref_path}")
logger.info(f"Output directory: {output_dir}")

# try:
# Load and validate data
bulk_adata, ref_adata = load_and_validate_data(bulk_path, ref_path)

# Preprocess data
bulk_processed, ref_processed = preprocess_data(bulk_adata, ref_adata, annotation_key)

# Sample even cell types across the reference
ref_processed=subsample_celltypes(ref_processed, annotation_key)




2025-03-17 14:27:30 - INFO - Starting Scanpy bulk to single-cell integration pipeline
2025-03-17 14:27:30 - INFO - Bulk data: /private/groups/russelllab/jodie/wolbachia_induced_DE/scanpy_clustering/scanpy_objects/bulk_adata.h5ad
2025-03-17 14:27:30 - INFO - Reference data: /private/groups/russelllab/jodie/wolbachia_induced_DE/scanpy_clustering/scanpy_objects/combined_germline_sg_trachea.h5ad
2025-03-17 14:27:30 - INFO - Output directory: /private/groups/russelllab/jodie/wolbachia_induced_DE/wolbachia_induced_differentiation/scripts/celltype_clustering/claude/MNN_kNN_tree/embryo_atlas_germline
2025-03-17 14:27:30 - INFO - Loading data files...
2025-03-17 14:27:30 - INFO - Bulk dataset loaded: (24, 10957) (samples × genes)
2025-03-17 14:27:30 - INFO - Reference dataset loaded: (3335, 10083) (cells × genes)
2025-03-17 14:27:30 - INFO - Number of shared genes: 8325
2025-03-17 14:27:30 - INFO - Preprocessing datasets...
2025-03-17 14:27:31 - INFO - Using 8325 shared genes
2025-03-17 14:27:3

In [None]:
# Integrate datasets
combined_adata = integrate_datasets(bulk_processed, ref_processed)

combined_adata=combined_adata[0] #necesasry to extract the first element of the tuple of MNN corrected data


plot_UMAP(combined_adata, annotation_key)

2025-03-17 14:27:35 - INFO - Integrating datasets...
2025-03-17 14:27:35 - INFO - Combined dataset shape: (804, 8325)
2025-03-17 14:27:35 - INFO - Data is already normalized and log-transformed
2025-03-17 14:27:35 - INFO - Applying MNN batch correction...
2025-03-17 14:27:35 - INFO - MNN batch correction completed


<class 'anndata._core.anndata.AnnData'>


AttributeError: 'tuple' object has no attribute '_sanitize'

In [None]:

# Determine optimal k if not provided
k = k_neighbors
if k is None:
    k = determine_optimal_k(ref_processed, annotation_key)

# Run kNN classification
results_df, annotated_adata = kNN_classifier(
    combined_adata,
    ref_label_key=annotation_key,
    k=k,
    num_permutations=num_permutations
)

# Generate visualizations
# annotated_adata = visualize_integration(annotated_adata, annotation_key, plots_dir)

# Extract bulk annotations and save to files
bulk_samples = annotated_adata[annotated_adata.obs['dataset'] == 'bulk']
bulk_annotations = bulk_samples.obs[[annotation_key]]

# Save results
results_df.to_csv(os.path.join(output_dir, 'bulk_classification_results.csv'), index=False)
bulk_annotations.to_csv(os.path.join(output_dir, 'bulk_annotations.csv'))
annotated_adata.write_h5ad(os.path.join(output_dir, 'annotated_data.h5ad'))

logger.info("Integration pipeline completed successfully")
logger.info(f"Results saved to {output_dir}")

# Plot UMAP
logger.info("Plotting UMAP...")
# plot_UMAP(combined_adata, annotation_key)


# except Exception as e:
# logger.error(f"Integration pipeline failed: {e}")
# import traceback
# logger.error(traceback.format_exc())
# sys.exit(1)


2025-03-17 11:11:59 - INFO - Automatically determined k = 21 (sqrt(n) = 32, 10% of smallest class = 21)


NameError: name 'kNN_classifier' is not defined

In [None]:
sc.pl.umap(combined_adata, color=[annotation_key], show=False)

# Get the UMAP coordinates
umap_coords = combined_adata.obsm['X_umap']

# Get your sample labels from the data
labels = combined_adata.obs['Sample']

# Iterate over each point and add a label if it's not NA
for idx, label in enumerate(labels):
    if pd.notna(label):  # Check if the label is not NA

        plt.plot(umap_coords[idx, 0], umap_coords[idx, 1], 
                color=color_dict[label[:-8]], 
                marker='o',            # '*' for star marker
                markersize=10,          # Increase size for better visibility
                markeredgecolor='white', # White outline
                markeredgewidth=0.25,   # Width of the outline
                alpha=1)    # Remove background grid and ticks for a cleaner look
plt.grid(False)

# Adjust layout to fit the legend outside
plt.tight_layout(rect=[0, 0, 0.85, 1])  # Leaves space for legend on the right
plt.savefig(f'{output_dir}/combined_dataset_samples_and_tissue.pdf', dpi=600, bbox_inches="tight")
plt.show() 

plt.close()

In [None]:
def plot_enhanced_umap(combined_adata, annotation_key, plots_dir):
    """Generate enhanced UMAP visualizations that better show differences between groups."""
    # Create a figure with two subplots side by side
    fig, axes = plt.subplots(1, 2, figsize=(20, 8))
    
    # Plot 1: Reference cells with bulk cells highlighted by sample group
    ax = axes[0]
    
    # Plot reference cells in gray first
    ref_mask = combined_adata.obs['dataset'] == 'reference'
    ax.scatter(
        combined_adata.obsm['X_umap'][ref_mask, 0],
        combined_adata.obsm['X_umap'][ref_mask, 1],
        c='lightgray', s=5, alpha=0.3, label='Reference Cells'
    )
    
    # Get sample groups (assuming Sample column format like "JW18DOX221117-1")
    bulk_mask = combined_adata.obs['dataset'] == 'bulk'
    sample_groups = []
    for sample_name in combined_adata.obs.loc[bulk_mask, 'Sample']:
        # Extract the prefix (e.g., "JW18DOX", "JW18wMel", "S2DOX", "S2wMel")
        if pd.notna(sample_name):
            prefix = sample_name.split('221117')[0]  # Remove date and number suffix
            sample_groups.append(prefix)
        else:
            sample_groups.append('Unknown')
    
    # Add Sample group as a new column
    combined_adata.obs.loc[bulk_mask, 'SampleGroup'] = sample_groups
    
    # Plot bulk samples with colors based on their sample group
    for group, color in color_dict.items():
        group_mask = (combined_adata.obs['dataset'] == 'bulk') & (combined_adata.obs['SampleGroup'] == group)
        ax.scatter(
            combined_adata.obsm['X_umap'][group_mask, 0],
            combined_adata.obsm['X_umap'][group_mask, 1],
            c=color, s=100, alpha=0.9, label=f'{group}',
            edgecolors='white', linewidths=0.5
        )
    
    ax.set_title('UMAP - Samples Colored by Experimental Condition')
    ax.legend(loc='upper right')
    ax.grid(False)
    ax.set_xlabel('UMAP 1')
    ax.set_ylabel('UMAP 2')
    
    # Plot 2: Reference cells colored by cell type, bulk samples as larger points
    ax = axes[1]
    
    # Create a colormap for cell types
    cell_types = combined_adata.obs[annotation_key].cat.categories
    n_cell_types = len(cell_types)
    colors = plt.cm.tab20(np.linspace(0, 1, n_cell_types))
    
    # Plot reference cells colored by cell type
    for i, cell_type in enumerate(cell_types):
        ct_mask = (combined_adata.obs['dataset'] == 'reference') & (combined_adata.obs[annotation_key] == cell_type)
        ax.scatter(
            combined_adata.obsm['X_umap'][ct_mask, 0],
            combined_adata.obsm['X_umap'][ct_mask, 1],
            c=[colors[i]], s=10, alpha=0.6, label=f'{cell_type}'
        )
    
    # Plot bulk samples as larger points
    for i, cell_type in enumerate(cell_types):
        ct_mask = (combined_adata.obs['dataset'] == 'bulk') & (combined_adata.obs[annotation_key] == cell_type)
        ax.scatter(
            combined_adata.obsm['X_umap'][ct_mask, 0],
            combined_adata.obsm['X_umap'][ct_mask, 1],
            c=[colors[i]], s=150, alpha=1.0, edgecolors='black', linewidths=0.5
        )
    
    ax.set_title(f'UMAP - Cell Types ({annotation_key})')
    ax.legend(loc='upper right')
    ax.grid(False)
    ax.set_xlabel('UMAP 1')
    ax.set_ylabel('UMAP 2')
    
    plt.tight_layout()
    plt.savefig(os.path.join(plots_dir, 'enhanced_umap_visualization.pdf'), bbox_inches='tight', dpi=300)
    plt.close()
    
    # Create a visualization to verify KNN classification results
    verify_knn_visualization(combined_adata, annotation_key, plots_dir)
    
    return combined_adata

def verify_knn_visualization(combined_adata, annotation_key, plots_dir):
    """Create a visualization to verify KNN classification results for a few samples."""
    bulk_mask = combined_adata.obs['dataset'] == 'bulk'
    ref_mask = combined_adata.obs['dataset'] == 'reference'
    
    # Get sample groups
    sample_groups = combined_adata.obs.loc[bulk_mask, 'SampleGroup'].unique()
    
    # Select one sample from each group to visualize
    selected_samples = []
    for group in sample_groups:
        group_samples = combined_adata.obs.loc[bulk_mask & (combined_adata.obs['SampleGroup'] == group)].index
        if len(group_samples) > 0:
            selected_samples.append(group_samples[0])
    
    # Create a figure with subplots for each selected sample
    fig, axes = plt.subplots(len(selected_samples), 1, figsize=(12, 6*len(selected_samples)))
    if len(selected_samples) == 1:
        axes = [axes]  # Make axes iterable if only one subplot
    
    # For each selected sample
    for i, sample_id in enumerate(selected_samples):
        ax = axes[i]
        
        # Get sample info
        sample_idx = combined_adata.obs.index.get_loc(sample_id)
        sample_group = combined_adata.obs.loc[sample_id, 'SampleGroup']
        predicted_label = combined_adata.obs.loc[sample_id, annotation_key]
        
        # Plot all reference cells as background
        ax.scatter(
            combined_adata.obsm['X_umap'][ref_mask, 0],
            combined_adata.obsm['X_umap'][ref_mask, 1],
            c='lightgray', s=5, alpha=0.2
        )
        
        # Plot cells of the predicted cell type
        pred_mask = ref_mask & (combined_adata.obs[annotation_key] == predicted_label)
        ax.scatter(
            combined_adata.obsm['X_umap'][pred_mask, 0],
            combined_adata.obsm['X_umap'][pred_mask, 1],
            c='blue', s=20, alpha=0.5, label=f'Reference {predicted_label} cells'
        )
        
        # Plot the sample itself
        ax.scatter(
            combined_adata.obsm['X_umap'][sample_idx, 0],
            combined_adata.obsm['X_umap'][sample_idx, 1],
            c=color_dict[sample_group], s=200, alpha=1.0, 
            edgecolors='black', linewidths=1.0,
            label=f'{sample_id} ({sample_group})'
        )
        
        # Get the K nearest neighbors for this sample from the results file
        # Note: This would need the results DataFrame to be passed as an argument
        # For visualization purposes, we'll just use the 10 nearest reference cells based on UMAP distance
        
        # Calculate distances in UMAP space
        sample_umap = combined_adata.obsm['X_umap'][sample_idx]
        ref_umaps = combined_adata.obsm['X_umap'][ref_mask]
        
        # Calculate Euclidean distances
        distances = np.sqrt(np.sum((ref_umaps - sample_umap)**2, axis=1))
        
        # Get indices of 10 nearest neighbors
        nearest_indices = np.argsort(distances)[:10]
        nearest_indices = np.where(ref_mask)[0][nearest_indices]
        
        # Plot nearest neighbors
        ax.scatter(
            combined_adata.obsm['X_umap'][nearest_indices, 0],
            combined_adata.obsm['X_umap'][nearest_indices, 1],
            c='red', s=80, alpha=0.7, 
            edgecolors='black', linewidths=0.5,
            marker='*', label='Nearest neighbors (UMAP space)'
        )
        
        ax.set_title(f'Sample: {sample_id} - Group: {sample_group} - Predicted: {predicted_label}')
        ax.legend(loc='upper right')
        ax.grid(False)
        ax.set_xlabel('UMAP 1')
        ax.set_ylabel('UMAP 2')
    
    plt.tight_layout()
    plt.savefig(os.path.join(plots_dir, 'knn_verification.pdf'), bbox_inches='tight', dpi=300)
    plt.close()





In [None]:
plot_enhanced_umap(combined_adata, annotation_key, plots_dir)

In [None]:
verify_knn_visualization(combined_adata, annotation_key, plots_dir)
