In [3]:
import scanpy as sc
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.sparse import csr_matrix, issparse # Corrected import
import seaborn as sns # For heatmap visualization of crosstab
import os
import harmony
import re
import scvi

In [4]:
metadata_file = '/storage/praha1/home/bucekl/labgenexp/spatial_project/sc/GSE181919/GSE181919_Barcode_metadata.txt'
counts_file = '/storage/praha1/home/bucekl/labgenexp/spatial_project/sc/GSE181919/GSE181919_UMI_counts.txt'

In [None]:
n_top_hvg = 3000       # Number of highly variable genes 
n_pca_comps = 50       # Number of principal components
n_pcs_neighbors = 30   # Number of PCs to use for neighbor finding
batch_key = 'sample.id'
# scanpy settings
sc.settings.verbosity = 3 
sc.logging.print_header()
sc.settings.set_figure_params(dpi=80, facecolor='white')

output_plot_dir_ml = "./integration_plots_ml/" 
os.makedirs(output_plot_dir, exist_ok=True) 

In [None]:
#### data loading

print(f"\n--- Loading Data ---")
# Load UMI counts (genes x cells) and transpose (-> cells x genes)
print(f"Loading UMI counts from: {counts_file}")
try:
    adata = sc.read_text(counts_file, delimiter='\t', first_column_names=True).T
    print(f"Loaded counts. Initial shape: {adata.shape} (Cells x Genes)")
    # Ensure counts are sparse
    if not issparse(adata.X):
        adata.X = csr_matrix(adata.X)
    print(f"Count matrix type: {type(adata.X)}")
except FileNotFoundError:
    print(f"ERROR: Counts file not found at {counts_file}")
    raise
except Exception as e:
    print(f"ERROR: Failed to load counts file: {e}")
    raise

# Fix barcode format (dot to hyphen) in counts data
print("Original first 5 cell names:", adata.obs_names[:5].tolist())
if any('.' in name for name in adata.obs_names[:100]):
    print("Attempting to fix barcode mismatch: Replacing '.' with '-' in adata.obs_names...")
    adata.obs_names = adata.obs_names.str.replace('.', '-', regex=False)
    print("Corrected first 5 cell names:", adata.obs_names[:5].tolist())
else:
    print("Barcode names in counts matrix do not appear to contain '.' - skipping replacement.")

# Load metadata
print(f"Loading metadata from: {metadata_file}")
try:
    metadata = pd.read_csv(metadata_file, sep='\t', index_col=0)
    print(f"Loaded metadata. Shape: {metadata.shape}")
    print("Metadata columns:", metadata.columns.tolist())
    print("First 5 metadata index names:", metadata.index[:5].tolist())
except FileNotFoundError:
    print(f"ERROR: Metadata file not found at {metadata_file}")
    raise
except Exception as e:
    print(f"ERROR: Failed to load metadata: {e}")
    raise

# Merge metadata
print("Merging metadata with AnnData object...")
common_cells = adata.obs_names.intersection(metadata.index)
print(f"Found {len(common_cells)} common cells between counts and metadata.")

if len(common_cells) == 0:
    print("ERROR: No common cell barcodes found after attempting correction!")
    raise ValueError("Cell barcode mismatch persists.")
elif len(common_cells) < adata.n_obs or len(common_cells) < len(metadata):
    print("Warning: Subsetting AnnData and metadata to common cells.")
    adata = adata[common_cells, :].copy()
    metadata = metadata.loc[common_cells] # Keep metadata aligned
    print(f"Filtered AnnData shape: {adata.shape}")
else:
    print("All cells match. Ordering metadata to match AnnData.")
    # Reorder metadata to ensure exact match
    metadata = metadata.loc[adata.obs_names]

adata.obs = metadata
assert all(adata.obs_names == adata.obs.index), "ERROR: Mismatch after merging!"
print("Successfully merged metadata into adata.obs.")
print("adata.obs head:\n", adata.obs.head())

In [None]:
# raw counts saving
print("\n--- Saving Raw Counts ---")
# Important: Do this BEFORE normalization/transformation
adata.layers["counts"] = adata.X.copy()
print("Raw counts saved to adata.layers['counts']")

In [None]:
### HVGs
# *** This is the key change in order ***
print(f"\n--- Finding Top {n_top_hvg} Highly Variable Genes (using RAW counts) ---")
# Explicitly use the 'counts' layer
sc.pp.highly_variable_genes(
    adata,
    layer='counts', # Use the raw counts stored here
    n_top_genes=n_top_hvg,
    flavor='seurat_v3',
    subset=False # Keep all genes, just add the boolean flag to adata.var
)
# This should NOT produce the "non-integers" warning now
sc.pl.highly_variable_genes(adata, show=False)
plt.title('Highly Variable Genes (calculated on raw counts)')
plt.show()
print(f"Identified {adata.var['highly_variable'].sum()} highly variable genes.")

# total counts normalization
print("\n--- Normalizing Total Counts ---")
# Normalize based on the counts in adata.X (which are still raw at this point)
sc.pp.normalize_total(adata, target_sum=1e4) # Modifies adata.X

# log transformation
print("\n--- Log-Transforming Data ---")
sc.pp.log1p(adata) # Modifies adata.X

# saving log 
print("\n--- Saving Log-Normalized Data to .raw ---")
adata.raw = adata.copy()
print("Log-normalized data stored in adata.raw")

In [None]:
### data scaling 
print("\n--- Scaling Data (ALL Genes in current .X) ---")
# CORRECTED AGAIN: No mask_var here. Scale operates on the log-normalized adata.X
sc.pp.scale(adata, max_value=10)
print("Data scaled (max_value=10).")

# PCA on scaled HVGs
print(f"\n--- Running PCA (n_comps={n_pca_comps}, using HVGs) ---")
# Use mask_var here in PCA to select HVGs from the scaled matrix
sc.pp.pca(adata, n_comps=n_pca_comps, mask_var='highly_variable', svd_solver='arpack')
sc.pl.pca_variance_ratio(adata, log=True, n_pcs=n_pca_comps, show=False)
plt.title('PCA Variance Ratio (Before Integration)')
plt.show()
print("Visualizing PCA colored by batch (Before Integration)...")
sc.pl.pca(adata, color=batch_key, title=f'PCA Before Harmony (Colored by {batch_key})')
plt.show()

In [None]:
#### Harmony integration
harmony_embedding_key = 'X_pca_harmony'
print(f"\n--- Running Harmony Integration (using batch key: '{batch_key}') ---")
sc.external.pp.harmony_integrate(
    adata, key=batch_key, basis='X_pca', adjusted_basis=harmony_embedding_key
)
print(f"Harmony integration complete. Corrected embedding stored in adata.obsm['{harmony_embedding_key}']")
print("Visualizing Harmony embedding colored by batch (After Integration)...")
sc.pl.embedding(adata, basis=harmony_embedding_key, color=batch_key, title=f'Harmony Embedding (Colored by {batch_key})')
plt.show()

In [None]:
# compute neighbours using Harmony embeddings 
print(f"\n--- Computing Neighbors (using Harmony embedding, {n_pcs_neighbors} dimensions) ---")
sc.pp.neighbors(adata, n_neighbors=15, n_pcs=n_pcs_neighbors, use_rep=harmony_embedding_key)

# UMAP using integrated neighbours
print("\n--- Computing UMAP (based on integrated neighbors) ---")
sc.tl.umap(adata)
print("Visualizing Integrated UMAP...")
sc.pl.umap(adata, color=[batch_key, 'cell.type'], title=[f'Integrated UMAP (Colored by {batch_key})', 'Integrated UMAP (Colored by cell.type)'])
plt.show()

In [None]:
### Leiden clustering
print("\n--- Running Leiden Clustering on Integrated Graph ---")
resolutions_to_test = [0.5, 1.0, 1.5,2.5]
for res in resolutions_to_test:
    cluster_key = f'leiden_integrated_res{res}'
    print(f"  Running Leiden resolution = {res} -> key='{cluster_key}'")
    sc.tl.leiden(adata, resolution=res, key_added=cluster_key,flavor="igraph", n_iterations=2)
    n_clusters = len(adata.obs[cluster_key].unique())
    print(f"    Found {n_clusters} clusters.")

    # Optional: Visualize this specific clustering on UMAP
    sc.pl.umap(adata, color=[cluster_key, 'cell.type'],
               title=f'Leiden (res={res}) vs Metadata cell.type',
               legend_loc='on data', legend_fontsize=8)
    plt.show()

chosen_res = 2.5
chosen_leiden_key = f'leiden_integrated_res{chosen_res}'

if chosen_leiden_key in adata.obs.columns:
    # --- Plot 1: Chosen Leiden Resolution ---
    print(f"Visualizing and saving UMAP for {chosen_leiden_key}...")

    # Determine legend location based on number of categories
    n_chosen_leiden_cats = len(adata.obs[chosen_leiden_key].cat.categories)
    # Force 'right margin' if many categories, otherwise let Scanpy decide ('on data' is often default)
    # Let's explicitly choose 'right margin' for consistency here as requested
    chosen_leiden_legend_loc = 'right margin'
    print(f"  Using legend_loc='{chosen_leiden_legend_loc}' for Leiden plot.")

    leiden_plot_filename = f"umap_leiden_integrated_res{chosen_res}.png"

    # Save the plot
    sc.pl.umap(
        adata,
        color=chosen_leiden_key,
        legend_loc=chosen_leiden_legend_loc,
        legend_fontsize=8,
        title=f'Integrated Leiden Clusters (res={chosen_res})',
        show=False,
        save=f"_leiden_integrated_res{chosen_res}.png"
    )

    # Display the plot inline (clearing previous figure first)
    plt.clf()
    sc.pl.umap(
        adata,
        color=chosen_leiden_key,
        legend_loc=chosen_leiden_legend_loc,
        legend_fontsize=8,
        title=f'Integrated Leiden Clusters (res={chosen_res})'
        # Use wspace to add more room on the right if using 'right margin'
        #, wspace=0.6 # Uncomment and adjust value (e.g., 0.5, 0.6) if legend is cut off
    )
    plt.show()


    # --- Plot 2: Original 'cell.type' ---
    print(f"\nVisualizing and saving UMAP for original 'cell.type'...")

    # Explicitly set legend location to 'right margin' for this plot too
    celltype_legend_loc = 'right margin'
    print(f"  Using legend_loc='{celltype_legend_loc}' for cell.type plot.")

    celltype_plot_filename = f"umap_original_cell_type_integrated.png"

    # Save the plot
    sc.pl.umap(
        adata,
        color='cell.type',
        legend_loc=celltype_legend_loc,
        legend_fontsize=8,
        title=f'Original cell.type Annotation (Integrated UMAP)',
        show=False,
        save=f"_original_cell_type_integrated.png"
    )

    # Display the plot inline (clearing previous figure first)
    plt.clf()
    sc.pl.umap(
        adata,
        color='cell.type',
        legend_loc=celltype_legend_loc,
        legend_fontsize=8,
        title=f'Original cell.type Annotation (Integrated UMAP)'
        # Use wspace to add more room on the right
        #, wspace=0.6 # Uncomment and adjust value if legend is cut off
    )
    plt.show()

    print(f"\nPlots saved (check default 'figures' directory or adjust save path/rename).")

else:
    print(f"Chosen resolution key '{chosen_leiden_key}' not found in adata.obs.columns.")
    print(f"Available keys: {leiden_keys_generated}")

print("\n--- Integration and Re-Clustering Complete ---")
print("Review saved plots and proceed with annotation using the chosen Leiden key.")