In [None]:
# sub group cell type csv
import pandas as pd 


def create_cell_type_dictionary(csv_file_path):
    # Read the CSV file
    df = pd.read_csv(csv_file_path)
    
    # Create dictionary to store results
    cell_type_dict = {}
    
    # Process each row
    for index, row in df.iterrows():
        main_type = row['Main_Types']
        sub_types = row['Sub_Types']
        
        # Split subtypes by '.' to get individual subtypes
        subtype_list = sub_types.split('.')
        
        # Add to dictionary
        cell_type_dict[main_type] = subtype_list
    
    return cell_type_dict

# Create the dictionary
csv_data = '/orange/pinaki.sarder/j.fermin/SpatNet/Data/Counts/Cell_SubTypes_Grouped.csv'

result = create_cell_type_dictionary(csv_data)

# Print the results
print("Cell Type Dictionary:")
print("=" * 50)

print(result)



#

In [None]:
# Main Visium_processor.py


import scanpy as sc
import numpy as np
import pandas as pd
import torch
import cv2
from PIL import Image
import os
from pathlib import Path
import openslide
from openslide import OpenSlide
import argparse
from tqdm import tqdm
import multiprocessing as mp
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
import h5py
import zarr
import time
from functools import partial
import gc

class OptimizedWSIProcessor:
    def __init__(self, wsi_path, patch_size=224, target_mpp=0.5):
        self.wsi_path = wsi_path
        self.patch_size = patch_size
        self.target_mpp = target_mpp
        self.slide = None
        self.scale_factor = 1.0
        self.current_mpp = 0.25
        self.is_tiff = wsi_path.endswith('.tif')
        self.tiff_image = None
        
        self._initialize_slide()
    
    def _initialize_slide(self):
        """Initialize slide and calculate scale factor once"""
        try:
            if self.is_tiff:
                # For TIFF files, load once and keep in memory if reasonable size
                self.tiff_image = cv2.imread(self.wsi_path)
                if self.tiff_image is None:
                    raise ValueError(f"Could not read TIFF file: {self.wsi_path}")
                print(f"Loaded TIFF image: {self.tiff_image.shape}")
            else:
                # For SVS files, use OpenSlide
                self.slide = OpenSlide(self.wsi_path)
                self.scale_factor, self.current_mpp = self._calculate_scale_factor()
                print(f"SVS slide initialized. Scale factor: {self.scale_factor:.3f}, Current MPP: {self.current_mpp:.3f}")
        except Exception as e:
            print(f"Error initializing slide: {e}")
            raise
    
    def _calculate_scale_factor(self):
        """Calculate scale factor for target MPP"""
        try:
            if 'openslide.mpp-x' in self.slide.properties:
                current_mpp = float(self.slide.properties['openslide.mpp-x'])
                print(current_mpp, 'harshit')
            elif 'aperio.MPP' in self.slide.properties:
                print(current_mpp, 'harshit')

                current_mpp = float(self.slide.properties['aperio.MPP'])
            else:
                current_mpp = 0.25
                print(f"Warning: MPP not found. Using default: {current_mpp}")
            
            scale_factor = current_mpp / self.target_mpp
            return scale_factor, current_mpp
        except Exception as e:
            print(f"Error calculating scale factor: {e}")
            return 1.0, 0.25
    
    def extract_batch_patches(self, coordinates_batch):
        """Extract multiple patches in one operation"""
        patches = []
        
        for x, y in coordinates_batch:
            try:
                if self.is_tiff:
                    patch = self._extract_tiff_patch(x, y)
                else:
                    patch = self._extract_svs_patch(x, y)
                patches.append(patch)
            except Exception as e:
                print(f"Error extracting patch at ({x}, {y}): {e}")
                patches.append(np.zeros((self.patch_size, self.patch_size, 3), dtype=np.uint8))
        
        return patches
    
    def _extract_tiff_patch(self, x, y):
        """Extract patch from TIFF image"""
        half_patch = self.patch_size // 2
        x_start = max(0, int(x - half_patch))
        y_start = max(0, int(y - half_patch))
        x_end = min(self.tiff_image.shape[1], int(x + half_patch))
        y_end = min(self.tiff_image.shape[0], int(y + half_patch))
        
        patch = self.tiff_image[y_start:y_end, x_start:x_end]
        
        if patch.shape[:2] != (self.patch_size, self.patch_size):
            patch = cv2.resize(patch, (self.patch_size, self.patch_size))
        
        return patch
    
    def _extract_svs_patch(self, x, y):
        """Extract patch from SVS file"""
        original_patch_size = int(self.patch_size * self.scale_factor)
        half_patch = original_patch_size // 2
        
        patch_pil = self.slide.read_region(
            (int(x - half_patch), int(y - half_patch)), 
            0, 
            (original_patch_size, original_patch_size)
        )
        
        patch_rgb = patch_pil.convert('RGB')
        patch_resized = patch_rgb.resize((self.patch_size, self.patch_size), Image.LANCZOS)
        return np.array(patch_resized)
    
    def __del__(self):
        """Clean up resources"""
        if self.slide:
            self.slide.close()

def process_batch_worker(batch_data):
    """Worker function for batch processing"""
    wsi_path, coordinates_batch, spot_ids_batch, cell_type_vectors_batch, output_dirs, patch_size, target_mpp = batch_data
    
    # Initialize processor for this worker
    processor = OptimizedWSIProcessor(wsi_path, patch_size, target_mpp)
    
    # Extract patches for this batch
    patches = processor.extract_batch_patches(coordinates_batch)
    
    # Save results
    results = []
    for i, (patch, spot_id, cell_type_vector) in enumerate(zip(patches, spot_ids_batch, cell_type_vectors_batch)):
        try:
            # Save patch as PNG
            patch_path = output_dirs['patches'] / f"{spot_id}.png"
            Image.fromarray(patch).save(patch_path)
            
            # Save patch as numpy array
            # array_path = output_dirs['arrays'] / f"{spot_id}.npy"
            # np.save(array_path, patch)
            
            # Save cell type vector
            cell_type_tensor = torch.tensor(cell_type_vector, dtype=torch.float32)
            cell_type_path = output_dirs['cell_types'] / f"{spot_id}.pt"
            torch.save(cell_type_tensor, cell_type_path)
            
            results.append(True)
        except Exception as e:
            print(f"Error saving spot {spot_id}: {e}")
            results.append(False)
    
    return results

def save_batch_hdf5(patches_batch, spot_ids_batch, cell_types_batch, hdf5_path):
    """Save batch data to HDF5 for faster I/O"""
    with h5py.File(hdf5_path, 'a') as f:
        for patch, spot_id, cell_type in zip(patches_batch, spot_ids_batch, cell_types_batch):
            grp = f.create_group(spot_id)
            grp.create_dataset('patch', data=patch, compression='gzip')
            grp.create_dataset('cell_type', data=cell_type, compression='gzip')

def process_visium_data_optimized(h5ad_file, wsi_file, csv_file, output_dir, 
                                patch_size=224, target_mpp=0.5, batch_size=100, 
                                n_workers=None, save_hdf5=True):
    """
    Optimized version with batch processing and parallelization
    """
    start_time = time.time()
    
    # Set default number of workers
    if n_workers is None:
        n_workers = min(mp.cpu_count() - 1, 8)  # Leave one core free, max 8 workers
    
    print(f"Using {n_workers} workers with batch size {batch_size}")
    
    # Create output directories
    output_dir = Path(output_dir)
    output_dirs = {
        'patches': output_dir / "patches",
        'arrays': output_dir / "arrays", 
        'cell_types': output_dir / "cell_sub_types"
    }
    
    for dir_path in output_dirs.values():
        dir_path.mkdir(parents=True, exist_ok=True)
    
    # Read data
    print("Loading data...")
    adata = sc.read_h5ad(h5ad_file)
    spatial_coords = adata.obsm['spatial']
    spot_ids = adata.obs.index.tolist()
    
    # Read CSV with optimized pandas settings
    df_cell_types = pd.read_csv(csv_file, index_col=0)
    print(f"Loaded {len(spot_ids)} spots and {df_cell_types.shape[0]} cell types")
    
    # Prepare cell type vectors for all spots
    cell_type_vectors = []
    for spot_id in spot_ids:
        if spot_id in df_cell_types.columns:
            cell_type_vectors.append(df_cell_types[spot_id].values)
        else:
            cell_type_vectors.append(np.zeros(16))
    
    # Create batches
    n_spots = len(spot_ids)
    n_batches = (n_spots + batch_size - 1) // batch_size
    
    print(f"Processing {n_spots} spots in {n_batches} batches...")
    
    # Prepare batch data
    batch_data_list = []
    for i in range(0, n_spots, batch_size):
        end_idx = min(i + batch_size, n_spots)
        batch_data = (
            wsi_file,
            spatial_coords[i:end_idx],
            spot_ids[i:end_idx],
            cell_type_vectors[i:end_idx],
            output_dirs,
            patch_size,
            target_mpp
        )
        batch_data_list.append(batch_data)
    
    # Process batches in parallel
    successful_extractions = 0
    
    if n_workers > 1:
        # Use ProcessPoolExecutor for CPU-bound tasks
        with ProcessPoolExecutor(max_workers=n_workers) as executor:
            futures = [executor.submit(process_batch_worker, batch_data) for batch_data in batch_data_list]
            
            for future in tqdm(as_completed(futures), total=len(futures), desc="Processing batches"):
                try:
                    results = future.result()
                    successful_extractions += sum(results)
                except Exception as e:
                    print(f"Batch processing error: {e}")
    else:
        # Sequential processing for debugging
        for batch_data in tqdm(batch_data_list, desc="Processing batches"):
            try:
                results = process_batch_worker(batch_data)
                successful_extractions += sum(results)
            except Exception as e:
                print(f"Batch processing error: {e}")
    
    # Optional: Save consolidated HDF5 file
    if save_hdf5:
        print("Saving consolidated HDF5 file...")
        hdf5_path = output_dir / "consolidated_data.h5"
        # This would require a separate implementation for large-scale HDF5 writing
    
    processing_time = time.time() - start_time
    
    # Save summary
    summary = {
        'h5ad_file': h5ad_file,
        'wsi_file': wsi_file,
        'csv_file': csv_file,
        'total_spots': n_spots,
        'successful_extractions': successful_extractions,
        'patch_size': patch_size,
        'target_mpp': target_mpp,
        'batch_size': batch_size,
        'n_workers': n_workers,
        'processing_time_seconds': processing_time,
        'spots_per_second': n_spots / processing_time if processing_time > 0 else 0
    }
    
    summary_path = output_dir / "summary.txt"
    with open(summary_path, 'w') as f:
        for key, value in summary.items():
            f.write(f"{key}: {value}\n")
    
    print(f"Processing complete! Processed {successful_extractions}/{n_spots} spots in {processing_time:.2f}s")
    print(f"Speed: {n_spots/processing_time:.2f} spots/second")
    
    return summary


base = '/blue/pinaki.sarder/j.fermin/Annotations/Data/XY04_IU-21-020F'
filename = base.split('/')[-1]
h5ad_file = '/blue/pinaki.sarder/j.fermin/CellAtlas/data/Kidney/H5AD_files/'+filename+'.h5ad'
wsi_file = base+'/'+filename+'.svs'

cvs_file = '/orange/pinaki.sarder/j.fermin/SpatNet/Data/Counts/FFPE/CellTypes_SpotLevel/All_Counts/'+filename+'counts.csv'
output = '/orange/pinaki.sarder/h.lohaan/Hari_data_pipeline/output_sub_types/'+filename+''
# Validate input files
for file_path in [h5ad_file, wsi_file, cvs_file]:
    if not os.path.exists(file_path):
        print(f"Error: File not found: {file_path}")
        
# Process the data
summary = process_visium_data_optimized(
    h5ad_file, 
    wsi_file, 
    cvs_file, 
    output
)

print("Processing complete!")
print(f"Output saved to: {output}")
