In [None]:
# Import necessary libraries
import os
import sys
import numpy as np
import pandas as pd
import scanpy as sc
import scipy
import matplotlib.pyplot as plt
import seaborn as sns
import requests
import time
import math
import json
from tqdm.notebook import tqdm
import warnings
warnings.filterwarnings('ignore')

# Set plotting defaults
sc.settings.set_figure_params(dpi=100, facecolor='white')
plt.rcParams['figure.figsize'] = (10, 8)

In [None]:
# Function to download h5ad files from cellxgene
def download_h5ad_file(url, output_path):
    """Download h5ad file from URL"""
    if os.path.exists(output_path):
        print(f"File already exists at {output_path}, skipping download")
        return output_path
    
    print(f"Downloading file from {url} to {output_path}")
    
    # Stream download to handle large files
    with requests.get(url, stream=True) as r:
        r.raise_for_status()
        total_size = int(r.headers.get('content-length', 0))
        block_size = 8192  # 8 Kibibytes
        
        with open(output_path, 'wb') as f:
            with tqdm(total=total_size, unit='iB', unit_scale=True) as t:
                for chunk in r.iter_content(chunk_size=block_size):
                    t.update(len(chunk))
                    f.write(chunk)
    
    print(f"Download complete: {output_path}")
    return output_path

In [None]:
# Create directories
os.makedirs('../data/raw', exist_ok=True)
os.makedirs('../data/processed', exist_ok=True)

# Let's download a smaller dataset for exploration (PBMC dataset)
url = "https://datasets.cellxgene.cziscience.com/89619149-162f-4839-8e97-24735924417c.h5ad"
output_path = '../data/raw/pbmc_dataset.h5ad'

# Download will be skipped if file exists
download_h5ad_file(url, output_path)

In [None]:
# Load the dataset
adata = sc.read_h5ad(output_path)

# Show basic statistics
print(f"AnnData object with {adata.n_obs} cells and {adata.n_vars} genes")
print(f"Observation data shape: {adata.obs.shape}")
print(f"Variable data shape: {adata.var.shape}")

In [None]:
# Explore the metadata columns
print("Observation metadata columns:")
for col in adata.obs.columns:
    n_unique = adata.obs[col].nunique()
    print(f"  - {col}: {n_unique} unique values")
    
    # For columns with fewer than 20 unique values, show them
    if n_unique < 20:
        print(f"    Values: {adata.obs[col].value_counts().to_dict()}")
    else:
        print(f"    Top 5 values: {adata.obs[col].value_counts().head(5).to_dict()}")

In [None]:
# Explore cell types
plt.figure(figsize=(14, 6))
cell_counts = adata.obs['cell_type'].value_counts()
sns.barplot(x=cell_counts.index, y=cell_counts.values)
plt.xticks(rotation=90)
plt.title('Cell Type Distribution')
plt.ylabel('Number of cells')
plt.tight_layout()
plt.show()

In [None]:
# Filter dataset to include only normal cells with specific cell types
selected_cell_types = [
    'naive thymus-derived CD4-positive, alpha-beta T cell', 'classical monocyte', 
    'natural killer cell', 'naive B cell', 'CD4-positive helper T cell', 
    'CD8-positive, alpha-beta cytotoxic T cell', 'naive thymus-derived CD8-positive, alpha-beta T cell', 
    'central memory CD8-positive, alpha-beta T cell', 'non-classical monocyte', 
    'regulatory T cell'
]

# Filter cells
adata_filtered = adata[adata.obs['cell_type'].isin(selected_cell_types)].copy()
adata_filtered = adata_filtered[adata_filtered.obs['disease'].isin(['normal'])].copy()

print(f"Filtered dataset has {adata_filtered.n_obs} cells and {adata_filtered.n_vars} genes")

In [None]:
# Basic preprocessing
sc.pp.calculate_qc_metrics(adata_filtered, inplace=True)

# Filter cells with too few genes
sc.pp.filter_cells(adata_filtered, min_genes=200)

# Filter genes expressed in too few cells
sc.pp.filter_genes(adata_filtered, min_cells=3)

# Normalize total counts per cell
sc.pp.normalize_total(adata_filtered, target_sum=1e4)

# Log transform the data
sc.pp.log1p(adata_filtered)

print(f"Preprocessed dataset has {adata_filtered.n_obs} cells and {adata_filtered.n_vars} genes")

In [None]:
# Create standardized cell type names
cell_type_mapping = {
    "naive thymus-derived CD4-positive, alpha-beta T cell": "naive_CD4_T_cell",
    "classical monocyte": "classical_monocyte",
    "natural killer cell": "natural_killer_cell",
    "naive B cell": "naive_B_cell",
    "CD4-positive helper T cell": "CD4_helper_T_cell",
    "CD8-positive, alpha-beta cytotoxic T cell": "CD8_cytotoxic_T_cell",
    "naive thymus-derived CD8-positive, alpha-beta T cell": "naive_CD8_T_cell",
    "central memory CD8-positive, alpha-beta T cell": "central_memory_CD8_T_cell",
    "non-classical monocyte": "non_classical_monocyte",
    "regulatory T cell": "regulatory_T_cell"
}

# Apply mapping
adata_filtered.obs['standardized_cell_type'] = adata_filtered.obs['cell_type'].map(cell_type_mapping)

# Check results
adata_filtered.obs['standardized_cell_type'].value_counts()

In [None]:
# UMAP visualization of cells
sc.pp.highly_variable_genes(adata_filtered, n_top_genes=2000)
sc.pp.pca(adata_filtered, n_comps=50, use_highly_variable=True)
sc.pp.neighbors(adata_filtered, n_neighbors=10, n_pcs=40)
sc.tl.umap(adata_filtered)

# Plot UMAP colored by cell type
plt.figure(figsize=(12, 10))
sc.pl.umap(adata_filtered, color='standardized_cell_type', frameon=False, legend_loc='on data')
plt.tight_layout()
plt.show()

In [None]:
# Now let's prepare the data in the Spoticell format (spatial matrices)
# First, create gene position mapping
gene_names = adata_filtered.var_names.tolist()
matrix_size = math.ceil(math.sqrt(len(gene_names)))
print(f"Matrix size will be {matrix_size}x{matrix_size} for {len(gene_names)} genes")

# Create mapping of genes to positions (alphabetical order)
sorted_genes = sorted(gene_names)
gene_to_idx = {gene: i for i, gene in enumerate(sorted_genes)}
gene_positions = np.zeros((len(sorted_genes), 2), dtype=np.int32)

for i, gene in enumerate(sorted_genes):
    row = i // matrix_size
    col = i % matrix_size
    gene_positions[i] = [row, col]

# Create mapping from original gene index to sorted position
gene_idx_map = np.zeros(len(gene_names), dtype=np.int32)
for i, gene in enumerate(gene_names):
    gene_idx_map[i] = gene_to_idx.get(gene, 0)

In [None]:
# Visualize the mapping of a cell's expression to matrix format
cell_idx = 0  # Get the first cell

# Create matrix for this cell
if scipy.sparse.issparse(adata_filtered.X):
    cell_expr = adata_filtered.X[cell_idx].toarray().flatten()
else:
    cell_expr = adata_filtered.X[cell_idx]

# Create matrix
matrix = np.zeros((matrix_size, matrix_size), dtype=np.float32)

# Only add non-zero genes
non_zero_indices = np.nonzero(cell_expr)[0]
for gene_idx in non_zero_indices:
    # Map original gene index to sorted position
    matrix_idx = gene_idx_map[gene_idx]
    row, col = gene_positions[matrix_idx]
    matrix[row, col] = cell_expr[gene_idx]

# Plot the matrix
plt.figure(figsize=(10, 10))
plt.imshow(matrix, cmap='viridis')
plt.colorbar(label='Gene Expression')
plt.title(f'Gene Expression Matrix for Cell {cell_idx}\nCell Type: {adata_filtered.obs["standardized_cell_type"].iloc[cell_idx]}')
plt.tight_layout()
plt.show()

print(f"Cell has {len(non_zero_indices)} non-zero genes out of {len(gene_names)}")
print(f"Matrix shape: {matrix.shape}")
print(f"Sum of expression values: {np.sum(matrix):.4f}")

In [None]:
# Let's visualize matrices for different cell types
unique_cell_types = adata_filtered.obs['standardized_cell_type'].unique()

fig, axes = plt.subplots(2, 5, figsize=(20, 8))
axes = axes.flatten()

for i, cell_type in enumerate(unique_cell_types[:10]):  # Show first 10 cell types
    # Get a cell of this type
    type_cells = adata_filtered[adata_filtered.obs['standardized_cell_type'] == cell_type]
    if len(type_cells) == 0:
        continue
        
    cell_idx = 0  # Get first cell of this type
    
    # Create matrix for this cell
    if scipy.sparse.issparse(type_cells.X):
        cell_expr = type_cells.X[cell_idx].toarray().flatten()
    else:
        cell_expr = type_cells.X[cell_idx]
    
    # Create matrix
    matrix = np.zeros((matrix_size, matrix_size), dtype=np.float32)
    
    # Only add non-zero genes
    non_zero_indices = np.nonzero(cell_expr)[0]
    for gene_idx in non_zero_indices:
        # Map original gene index to sorted position
        matrix_idx = gene_idx_map[gene_idx]
        row, col = gene_positions[matrix_idx]
        matrix[row, col] = cell_expr[gene_idx]
    
    # Plot the matrix
    im = axes[i].imshow(matrix, cmap='viridis')
    axes[i].set_title(cell_type, fontsize=10)
    axes[i].axis('off')

plt.tight_layout()
plt.suptitle('Expression Matrices by Cell Type', fontsize=16, y=1.02)
plt.show()

In [None]:
# Define a simple pytorch model (without training) to verify the architecture
import torch
import torch.nn as nn

class ConvBlock(nn.Module):
    """Convolutional block with batch normalization and residual connection."""
    
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        
        # Residual connection
        self.residual = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride) if in_channels != out_channels or stride != 1 else nn.Identity()
    
    def forward(self, x):
        residual = self.residual(x)
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x + residual)
        return x

class SpoticellModelDemo(nn.Module):
    """Simple demo of Spoticell model architecture."""
    
    def __init__(self, matrix_size, num_classes, cnn_channels=16):
        super(SpoticellModelDemo, self).__init__()
        
        # CNN path
        self.cnn_path = nn.Sequential(
            ConvBlock(1, cnn_channels),
            nn.MaxPool2d(2),
            ConvBlock(cnn_channels, cnn_channels * 2),
            nn.MaxPool2d(2),
            ConvBlock(cnn_channels * 2, cnn_channels * 4),
            nn.AdaptiveAvgPool2d((1, 1))
        )
        
        # Classifier
        self.classifier = nn.Linear(cnn_channels * 4, num_classes)
    
    def forward(self, x):
        # CNN path
        x = self.cnn_path(x)
        x = x.view(x.size(0), -1)
        
        # Classification
        x = self.classifier(x)
        return x

In [None]:
# Test model with a batch of cells
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Number of unique cell types
num_classes = len(adata_filtered.obs['standardized_cell_type'].unique())

# Create model
model = SpoticellModelDemo(matrix_size, num_classes).to(device)

# Create a batch of test data
batch_size = 4
test_matrices = []

for i in range(batch_size):
    # Get a random cell
    cell_idx = np.random.randint(0, adata_filtered.n_obs)
    
    # Create matrix for this cell
    if scipy.sparse.issparse(adata_filtered.X):
        cell_expr = adata_filtered.X[cell_idx].toarray().flatten()
    else:
        cell_expr = adata_filtered.X[cell_idx]
    
    # Create matrix
    matrix = np.zeros((matrix_size, matrix_size), dtype=np.float32)
    
    # Only add non-zero genes
    non_zero_indices = np.nonzero(cell_expr)[0]
    for gene_idx in non_zero_indices:
        # Map original gene index to sorted position
        matrix_idx = gene_idx_map[gene_idx]
        row, col = gene_positions[matrix_idx]
        matrix[row, col] = cell_expr[gene_idx]
    
    # Add channel dimension
    matrix = np.expand_dims(matrix, axis=0)
    test_matrices.append(matrix)

# Convert to tensor
test_batch = torch.tensor(np.array(test_matrices), dtype=torch.float32).to(device)

# Forward pass
with torch.no_grad():
    outputs = model(test_batch)

print(f"Input batch shape: {test_batch.shape}")
print(f"Output shape: {outputs.shape}")
print("Model successfully processes matrices!")

In [None]:
# Save a small preprocessed dataset for testing
os.makedirs('../data/processed/test_data', exist_ok=True)

# Save expression matrix as sparse
if not scipy.sparse.issparse(adata_filtered.X):
    X_sparse = scipy.sparse.csr_matrix(adata_filtered.X)
else:
    X_sparse = adata_filtered.X

scipy.sparse.save_npz('../data/processed/test_data/expression_data.npz', X_sparse)

# Save gene positions and mapping
np.save('../data/processed/test_data/gene_positions.npy', gene_positions)
np.save('../data/processed/test_data/gene_idx_map.npy', gene_idx_map)
np.save('../data/processed/test_data/sorted_genes.npy', np.array(sorted_genes))
np.save('../data/processed/test_data/original_genes.npy', np.array(gene_names))

# Create and save one-hot encoded cell type labels
unique_cell_types = sorted(adata_filtered.obs['standardized_cell_type'].unique())
cell_type_to_idx = {cell_type: idx for idx, cell_type in enumerate(unique_cell_types)}

cell_labels = np.zeros((adata_filtered.n_obs, len(unique_cell_types)), dtype=np.float32)
for i, cell_type in enumerate(adata_filtered.obs['standardized_cell_type']):
    cell_labels[i, cell_type_to_idx[cell_type]] = 1.0

np.save('../data/processed/test_data/cell_labels.npy', cell_labels)

# Save metadata
metadata = {
    'n_cells': int(adata_filtered.n_obs),
    'n_genes': int(adata_filtered.n_vars),
    'matrix_size': int(matrix_size),
    'cell_type_to_idx': {k: int(v) for k, v in cell_type_to_idx.items()},
    'idx_to_cell_type': {str(idx): cell_type for idx, cell_type in enumerate(unique_cell_types)}
}

with open('../data/processed/test_data/metadata.json', 'w') as f:
    json.dump(metadata, f, indent=4)

print("Test dataset saved successfully!")

In [None]:
# Conclude with summary statistics for dataset
print("Dataset Summary")
print("=" * 50)
print(f"Total cells: {adata_filtered.n_obs}")
print(f"Total genes: {adata_filtered.n_vars}")
print(f"Matrix size: {matrix_size}x{matrix_size}")
print(f"Cell types: {len(unique_cell_types)}")
print("\nCell type distribution:")
for cell_type, count in adata_filtered.obs['standardized_cell_type'].value_counts().items():
    print(f"  - {cell_type}: {count} cells ({100 * count / adata_filtered.n_obs:.2f}%)")