# 03_Preprocessing_and_Dataset_Preparation

This notebook prepares the scRNA-seq dataset for model training:
- Downsample to ~20k cells
- Train/validation/test split
- Normalize counts (train-based)
- Log-transform
- Select highly variable genes (HVGs)
- Scale data
- Save preprocessed splits

In [1]:
!pwd

/mnt/d/Study/Python Scripts/scimilarity-finetune


# Imports

In [2]:
import scanpy as sc
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from scipy import sparse as sp
import os, json
import joblib

from app.utils import load_raw_data, split_data
from app.logger_config import get_logger
logger = get_logger()


# Set random seed
SEED = 42
np.random.seed(SEED)

# Display settings
pd.set_option('display.max_columns', None)
pd.set_option('display.max_colwidth', None)

2025-10-19 12:01:57 | logger_config.py:get_logger:41 | INFO | Logger initialized successfully.
2025-10-19 12:01:57 | logger_config.py:get_logger:41 | INFO | Logger initialized successfully.


# Load Raw Dataset

In [3]:
adata = load_raw_data()
logger.info(f"Original dataset shape: {adata.shape}")

2025-10-19 12:02:41 | data_utils.py:load_raw_data:32 | INFO | ✅ Loaded raw data: 65479 cells x 31460 genes.
2025-10-19 12:02:41 | 2035203210.py:<module>:2 | INFO | Original dataset shape: (65479, 31460)


# Downsample to ~20k cells

In [4]:
n_cells_target = 20000
if adata.n_obs > n_cells_target:
    idx = np.random.choice(adata.n_obs, n_cells_target, replace=False)
    adata = adata[idx, :]

logger.info(f"Downsampled dataset shape: {adata.shape}")

2025-10-19 12:02:49 | 700216280.py:<module>:6 | INFO | Downsampled dataset shape: (20000, 31460)


# Train/Validation/Test Split

In [5]:
train_data, val_data, test_data = split_data(adata, test_frac=0.3, random_state=SEED)

logger.info(f"Train: {train_data.n_obs}, Val: {val_data.n_obs}, Test: {test_data.n_obs}")

2025-10-19 12:03:02 | data_utils.py:split_data:166 | INFO | ✅ Data split: Train=14000, Val=3000, Test=3000
2025-10-19 12:03:02 | 2499359727.py:<module>:3 | INFO | Train: 14000, Val: 3000, Test: 3000


# Normalize and Log-Transform

In [6]:
# Compute total counts per cell 
train_data.obs["n_counts"] = np.array(train_data.X.sum(axis=1)).flatten()

# Normalize training data
sc.pp.normalize_total(train_data, target_sum=1e4)
sc.pp.log1p(train_data)

mean_train_total = np.mean(train_data.obs["n_counts"])
print(f"Mean total counts per cell (train): {mean_train_total:.2f}")

# Apply same normalization to val/test 
for split_data, name in zip([val_data, test_data], ['val', 'test']):
    split_data.obs["n_counts"] = np.array(split_data.X.sum(axis=1)).flatten()
    
    # Ensure sparse CSR format
    if not sp.isspmatrix_csr(split_data.X):
        split_data.X = sp.csr_matrix(split_data.X)
    
    # Compute scaling factors (per cell)
    scaling_factors = (1e4 / split_data.obs["n_counts"].to_numpy()) * (mean_train_total / 1e4)
    
    # Apply normalization 
    X_scaled = split_data.X.multiply(scaling_factors[:, None])
    split_data.X = sp.csr_matrix(X_scaled)  # ✅ ensures type compatibility
    
    # Log1p transform
    split_data.X.data = np.log1p(split_data.X.data)
    
    print(f"{name} normalized using train-based scaling.")

Mean total counts per cell (train): 2628.77
val normalized using train-based scaling.
test normalized using train-based scaling.


# Scale using **train mean and std**

In [7]:
# train_means and train_stds computed as before
train_means = np.array(train_data.X.mean(axis=0)).flatten()
X_sq_mean = np.array(train_data.X.power(2).mean(axis=0)).flatten()
train_stds = np.sqrt(X_sq_mean - np.square(train_means))
train_stds[train_stds == 0] = 1  # avoid divide-by-zero

def sparse_scale_std_only(adata_split, std, clip_val=10):
    """
    Sparse-safe standardization: divide each column by train std.
    Does NOT subtract mean (centering) to keep sparsity.
    """
    # ensure CSR format
    if not sp.isspmatrix_csr(adata_split.X):
        adata_split.X = sp.csr_matrix(adata_split.X)

    # create a copy to avoid modifying the view
    X = adata_split.X.copy()

    # multiply each column by 1/std (broadcast across columns)
    # sparse CSR multiply expects shape (n_rows, n_cols) and 1D array of length n_cols
    # but multiply only works for elementwise multiplication along same shape
    # so we convert std to a diagonal sparse matrix for safe multiplication
    D = sp.diags(1 / std)  # shape (n_genes, n_genes)
    X_scaled = X @ D        # sparse matrix multiplication
    # clip
    X_scaled.data = np.clip(X_scaled.data, -clip_val, clip_val)

    adata_split.X = X_scaled

for split_data, name in zip([train_data, val_data, test_data], ["train", "val", "test"]):
    sparse_scale_std_only(split_data, train_stds)
    print(f"{name} scaled using train std.")


train scaled using train std.
val scaled using train std.
test scaled using train std.


# Encode Cell Type Labels

In [8]:
# Initialize encoder
le = LabelEncoder()

# Fit on train labels only
train_labels = train_data.obs['cell_type'].to_numpy()
le.fit(train_labels)

# Transform labels
train_data.obs['cell_type_encoded'] = le.transform(train_labels)
val_data.obs['cell_type_encoded'] = le.transform(val_data.obs['cell_type'].to_numpy())
test_data.obs['cell_type_encoded'] = le.transform(test_data.obs['cell_type'].to_numpy())

# Save processed datasets and artifacts

In [9]:
# directories
os.makedirs("data/processed", exist_ok=True)
os.makedirs("artifacts", exist_ok=True)

# Save processed splits
train_data.write_h5ad("data/processed/train_data.h5ad")
val_data.write_h5ad("data/processed/val_data.h5ad")
test_data.write_h5ad("data/processed/test_data.h5ad")

# Save preprocessing artifacts for inference
np.save("artifacts/gene_order.npy", adata.var_names.to_numpy())
np.save("artifacts/train_means.npy", train_means) 
np.save("artifacts/train_stds.npy", train_stds)
np.save("artifacts/mean_train_total.npy", mean_train_total)

# Save encoder for inference
joblib.dump(le, "artifacts/label_encoder.joblib")

print("✅ Saved train/val/test splits and preprocessing artifacts.")

✅ Saved train/val/test splits and preprocessing artifacts.
