In [1]:
import h5py
import numpy as np
import pandas as pd
import cv2
import matplotlib.pyplot as plt
import os
import pickle
from matplotlib.patches import Rectangle

from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.multioutput import MultiOutputRegressor
from sklearn.ensemble import RandomForestRegressor
from sklearn.base import BaseEstimator, TransformerMixin

import torch
import torchvision.models as models
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image

from skimage.registration import phase_cross_correlation

def augment_patch(patch):
    """
    Apply simple data augmentation to an image patch:
    random horizontal flip, vertical flip, and random rotation.
    Converts the patch to uint8 if needed.
    """
    # If patch is not of type uint8, convert it.
    if patch.dtype != np.uint8:
        # If values are in [0, 1], scale them to [0, 255]
        if patch.max() <= 1.0:
            patch = (patch * 255).astype(np.uint8)
        else:
            patch = patch.astype(np.uint8)
            
    # Convert to PIL Image
    pil_patch = Image.fromarray(patch)
    
    # Random horizontal flip
    if np.random.rand() > 0.5:
        pil_patch = pil_patch.transpose(Image.FLIP_LEFT_RIGHT)
    # Random vertical flip
    if np.random.rand() > 0.5:
        pil_patch = pil_patch.transpose(Image.FLIP_TOP_BOTTOM)
    # Random rotation: choose 0, 90, 180, or 270 degrees
    k = np.random.choice([0, 1, 2, 3])
    if k:
        pil_patch = pil_patch.rotate(90 * k)
    
    return np.array(pil_patch)


# -----------------------------
# 1. Function for Patch Extraction
# -----------------------------
def extract_patch(image, center, patch_size):
    """
    Extract a square patch from the image centered at the given coordinate.
    Assumes image shape is (height, width, channels) and center is (x, y).
    """
    x, y = int(center[0]), int(center[1])
    half_size = patch_size // 2
    y_min = max(y - half_size, 0)
    y_max = min(y + half_size, image.shape[0])
    x_min = max(x - half_size, 0)
    x_max = min(x + half_size, image.shape[1])
    patch = image[y_min:y_max, x_min:x_max, :]
    return patch

class PatchFeatureExtractor(BaseEstimator, TransformerMixin):
    def __init__(self, image, patch_size, cnn_model, augment=False, device='cpu'):
        """
        Parameters:
          image (ndarray): The whole-slide HE image as a numpy array.
          patch_size (int): Size (in pixels) of the square patch to extract.
          cnn_model (nn.Module): Pre-trained PyTorch CNN model for feature extraction.
          augment (bool): Whether to apply data augmentation on the patches.
          device (str): Device to run the model ('cpu' or 'cuda').
        """
        self.image = image
        self.patch_size = patch_size
        self.cnn_model = cnn_model
        self.augment = augment
        self.device = device

        # Define the transformation pipeline.
        # This converts the patch (numpy array) to a PIL image, resizes it to 128x128,
        # then converts to tensor and normalizes with ImageNet means and stds.
        self.transform_pipeline = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((128, 128)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

    def fit(self, X, y=None):
        return self

    def transform(self, X):
        patches = []
        for coord in X:
            # Extract patch from the full image.
            patch = extract_patch(self.image, coord, self.patch_size)
            # Convert to RGB if patch is grayscale or has one channel.
            if patch.ndim == 2 or (patch.ndim == 3 and patch.shape[2] != 3):
                patch = cv2.cvtColor(patch, cv2.COLOR_GRAY2RGB)
            # Optionally apply augmentation (only during training)
            if self.augment:
                patch = augment_patch(patch)
            # Apply the transformation pipeline: PIL conversion, resize, tensor conversion, normalization.
            patch_tensor = self.transform_pipeline(patch)
            patches.append(patch_tensor)
        # Stack patches to create a batch: shape (batch_size, C, H, W)
        batch = torch.stack(patches).to(self.device)
        self.cnn_model.eval()
        with torch.no_grad():
            features = self.cnn_model(batch)
        # Convert features to numpy array
        features_np = features.cpu().numpy()
        return features_np

# -----------------------------
# 3. Pipeline Class for the Elucidata Challenge with Caching and Visualization Options
# -----------------------------
class CellTypePipeline:
    """
    Pipeline for loading data, extracting image patch features using a CNN,
    training a multi-output regression model, and generating a submission file.
    
    Optionally, CNN features can be cached (saved/loaded) using pickle to speed up re-runs.
    Additional visualization methods are provided to verify that the spot coordinates 
    and extracted patches align with the HE slide image.
    """
    
    def __init__(self, h5_file_path, patch_size=110, device='cpu'):
        self.h5_file_path = h5_file_path
        self.patch_size = patch_size
        self.device = device
        self.train_spot_tables = {}
        self.train_images = {}
        self.cell_type_columns = None
        self.cnn_model = None  # To be initialized
        self.feature_extractor_pipeline = None

    def initialize_cnn_model(self):
        """
        Initialize a pre-trained ResNet34 model for feature extraction using torchvision.
        The final fully-connected layer is replaced with an identity mapping so that the model
        outputs a feature vector.
        """
        self.cnn_model = models.resnet34(pretrained=True)
        # Replace the final fully-connected layer with identity
        self.cnn_model.fc = nn.Identity()
        self.cnn_model = self.cnn_model.to(self.device)
        self.cnn_model.eval()
        print("ResNet34 feature extractor initialized on device:", self.device)

    def load_train_data(self):
        """
        Load training spot data from the H5 file and store each slide as a DataFrame.
        """
        with h5py.File(self.h5_file_path, "r") as f:
            train_spots = f["spots/Train"]
            for slide_name in train_spots.keys():
                spot_array = np.array(train_spots[slide_name])
                df = pd.DataFrame(spot_array, columns=["x", "y"] + [f"C{i}" for i in range(1, 36)])
                self.train_spot_tables[slide_name] = df
        print("Training spot data loaded successfully.")
        
    def load_train_images(self):
        """
        Load training HE images from the H5 file.
        """
        with h5py.File(self.h5_file_path, "r") as f:
            train_imgs = f["images/Train"]
            for slide_name in train_imgs.keys():
                image_array = np.array(train_imgs[slide_name])
                self.train_images[slide_name] = image_array
        print("Training images loaded successfully.")

    def load_test_data(self, slide_id):
        """
        Load test spot data for a given slide.
        """
        with h5py.File(self.h5_file_path, "r") as f:
            test_spots = f["spots/Test"]
            if slide_id not in test_spots:
                raise ValueError(f"Slide {slide_id} not found in test spot data.")
            spot_array = np.array(test_spots[slide_id])
            test_df = pd.DataFrame(spot_array, columns=["x", "y"])
        print(f"Test spot data for slide {slide_id} loaded successfully.")
        return test_df

    def load_test_image(self, slide_id):
        """
        Load test HE image for a given slide.
        """
        with h5py.File(self.h5_file_path, "r") as f:
            test_imgs = f["images/Test"]
            if slide_id not in test_imgs:
                raise ValueError(f"Slide {slide_id} not found in test images.")
            image_array = np.array(test_imgs[slide_id])
        print(f"Test image for slide {slide_id} loaded successfully.")
        return image_array

    def prepare_training_set(self, slide_id='S_1', cache_path=None):
        """
        Prepare training features and targets for a given slide.
        Uses the HE image to extract patches and then CNN features.
        """
        if cache_path is not None and os.path.exists(cache_path):
            print(f"Loading cached training features from {cache_path} for slide {slide_id} ...")
            with open(cache_path, "rb") as f:
                X_features, y = pickle.load(f)
            return X_features, y
        
        if slide_id not in self.train_spot_tables:
            raise ValueError(f"Slide {slide_id} not found in training spot data.")
        if slide_id not in self.train_images:
            raise ValueError(f"Slide {slide_id} image not loaded.")
            
        df = self.train_spot_tables[slide_id]
        # First two columns: coordinates; remaining columns: cell type abundances.
        feature_cols = ['x', 'y']
        target_cols = [col for col in df.columns if col not in feature_cols]
        self.cell_type_columns = target_cols
        
        X_coords = df[feature_cols].values.astype(float)
        y = df[target_cols].values.astype(float)
        
        he_image = self.train_images[slide_id]
        patch_extractor = PatchFeatureExtractor(he_image, self.patch_size, self.cnn_model,
                                                 augment=True, device=self.device)
        self.feature_extractor_pipeline = Pipeline([
            ('patch_extractor', patch_extractor),
            ('scaler', StandardScaler())
        ])
        print(f"Extracting CNN features for slide {slide_id} ...")
        X_features = self.feature_extractor_pipeline.fit_transform(X_coords)
        
        if cache_path is not None:
            print(f"Saving training features for slide {slide_id} to {cache_path} ...")
            with open(cache_path, "wb") as f:
                pickle.dump((X_features, y), f)
                
        print(f"Extracted CNN features for slide {slide_id}.")
        return X_features, y

    def prepare_all_training_set(self, cache_dir=None, align_spots=True):
        """
        Prepare training features and targets for all slides in the training set.
        """
        X_list = []
        y_list = []
        for slide_id in sorted(self.train_spot_tables.keys()):
            if align_spots:
                df = self.train_spot_tables[slide_id]
                coords = df[['x', 'y']].values.astype(float)
                self.train_spot_tables[slide_id] = df  # (Here you could adjust coordinates if needed)
            slide_cache_path = os.path.join(cache_dir, f"train_features_{slide_id}.pkl") if cache_dir else None
            X, y = self.prepare_training_set(slide_id=slide_id, cache_path=slide_cache_path)
            X_list.append(X)
            y_list.append(y)
        X_all = np.concatenate(X_list, axis=0)
        y_all = np.concatenate(y_list, axis=0)
        print("All training features extracted and concatenated.")
        return X_all, y_all

    def build_regression_pipeline(self):
        """
        Build and return a regression pipeline that uses the pre-extracted CNN features.
        """
        pipeline = Pipeline([
            ('regressor', MultiOutputRegressor(RandomForestRegressor(n_estimators=100, random_state=42)))
        ])
        return pipeline

    def train(self, X, y):
        """
        Train the regression model on the provided features and targets.
        """
        reg_pipeline = self.build_regression_pipeline()
        reg_pipeline.fit(X, y)
        print("Regression model training complete.")
        return reg_pipeline

    def predict(self, reg_model, X_test):
        """
        Predict cell type abundances on test features.
        """
        predictions = reg_model.predict(X_test)
        return predictions

    def create_submission(self, test_df, predictions, submission_filename="submission.csv"):
        """
        Create a submission CSV file with predicted cell type abundances.
        """
        pred_df = pd.DataFrame(predictions, columns=self.cell_type_columns, index=test_df.index)
        pred_df.insert(0, 'ID', pred_df.index)
        pred_df.to_csv(submission_filename, index=False)
        print(f"Submission file '{submission_filename}' created!")

    # -----------------------------
    # Visualization Methods (unchanged)
    # -----------------------------
    def visualize_spot_overlay(self, slide_id, flip_y=False):
        if slide_id not in self.train_images or slide_id not in self.train_spot_tables:
            raise ValueError(f"Slide {slide_id} data not found.")
        image = self.train_images[slide_id]
        df = self.train_spot_tables[slide_id]
        coords = df[['x', 'y']].values.astype(float)
        if flip_y:
            coords[:, 1] = image.shape[0] - coords[:, 1]
        plt.figure(figsize=(10, 10))
        plt.imshow(image)
        plt.scatter(coords[:, 0], coords[:, 1], marker='o', color='red', s=25)
        plt.title(f"Overlay of Spot Coordinates for Slide {slide_id}")
        plt.show()

    def visualize_extracted_patches(self, slide_id, num_patches=5, flip_y=False):
        if slide_id not in self.train_images or slide_id not in self.train_spot_tables:
            raise ValueError(f"Slide {slide_id} data not found.")
        image = self.train_images[slide_id]
        df = self.train_spot_tables[slide_id]
        coords = df[['x', 'y']].values.astype(float)
        if flip_y:
            coords[:, 1] = image.shape[0] - coords[:, 1]
        fig, axes = plt.subplots(1, num_patches, figsize=(num_patches * 3, 3))
        for i in range(num_patches):
            patch = extract_patch(image, coords[i], self.patch_size)
            axes[i].imshow(patch)
            axes[i].set_title(f"Patch {i}")
            axes[i].axis("off")
        plt.suptitle(f"Extracted Patches for Slide {slide_id}")
        plt.show()
    
    def visualize_cnn_input(self, slide_id, index=0, flip_y=False):
        if slide_id not in self.train_images or slide_id not in self.train_spot_tables:
            raise ValueError(f"Slide {slide_id} data not found.")
            
        image = self.train_images[slide_id]
        df = self.train_spot_tables[slide_id]
        coords = df[['x', 'y']].values.astype(float)
        if flip_y:
            coords[:, 1] = image.shape[0] - coords[:, 1]
        coord = coords[index]
        
        half_size = self.patch_size // 2
        x = int(coord[0])
        y = int(coord[1])
        x_min = max(x - half_size, 0)
        y_min = max(y - half_size, 0)
        patch = extract_patch(image, coord, self.patch_size)
        
        patch_resized = cv2.resize(patch, (128, 128))
        rel_x = x - x_min
        rel_y = y - y_min
        scale_x = 128 / patch.shape[1]
        scale_y = 128 / patch.shape[0]
        spot_resized_x = rel_x * scale_x
        spot_resized_y = rel_y * scale_y
    
        abundances = df.iloc[index][[col for col in df.columns if col not in ['x', 'y']]]
        
        print(f"Patch top-left coordinates: (x_min: {x_min}, y_min: {y_min})")
        
        fig, (ax_full, ax_img, ax_bar) = plt.subplots(1, 3, figsize=(18, 5))
        ax_full.imshow(image)
        rect = Rectangle((x_min, y_min), self.patch_size, self.patch_size, linewidth=2, edgecolor='red', facecolor='none')
        ax_full.add_patch(rect)
        ax_full.set_title("Full Slide with Patch Overlay")
        ax_full.axis("off")
        
        ax_img.imshow(patch_resized)
        ax_img.scatter([spot_resized_x], [spot_resized_y], marker='x', color='red', s=50)
        ax_img.set_title(f"Resized Patch (Index {index})")
        ax_img.text(5, 20, f"({x_min}, {y_min})", color='yellow', fontsize=12, 
                    bbox=dict(facecolor='black', alpha=0.5))
        ax_img.axis("off")
        
        cell_types = abundances.index.tolist()
        ax_bar.bar(cell_types, abundances.values)
        ax_bar.set_title("Cell Type Abundance Distribution")
        ax_bar.set_xticklabels(cell_types, rotation=90)
        ax_bar.set_ylabel("Abundance")
        
        plt.tight_layout()
        plt.show()
            
    def compute_optimal_shift(self, slide_id, flip_y=False, upsample_factor=10, spot_radius=3, display=False):
        if slide_id not in self.train_images or slide_id not in self.train_spot_tables:
            raise ValueError(f"Slide {slide_id} data not found.")
        
        image = self.train_images[slide_id]
        df = self.train_spot_tables[slide_id]
        coords = df[['x', 'y']].values.astype(float)
        if flip_y:
            coords[:, 1] = image.shape[0] - coords[:, 1]
        
        if len(image.shape) == 3 and image.shape[2] == 3:
            image_gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
        else:
            image_gray = image.copy()
        
        if image_gray.dtype != np.uint8:
            image_gray = cv2.normalize(image_gray, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
                
        _, tissue_mask = cv2.threshold(image_gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
        spot_mask = np.zeros_like(tissue_mask, dtype=np.uint8)
        for pt in coords:
            x, y = int(round(pt[0])), int(round(pt[1]))
            cv2.circle(spot_mask, (x, y), radius=spot_radius, color=255, thickness=-1)
        
        shift, error, diffphase = phase_cross_correlation(tissue_mask, spot_mask, upsample_factor=upsample_factor)
        optimal_shift = np.array([shift[1], shift[0]])
        print("Optimal shift (x, y):", optimal_shift)
        print("Registration error:", error)
        
        if display:
            plt.figure(figsize=(15, 5))
            plt.subplot(1, 3, 1)
            plt.title("Tissue Mask")
            plt.imshow(tissue_mask, cmap='gray')
            plt.subplot(1, 3, 2)
            plt.title("Spot Mask")
            plt.imshow(spot_mask, cmap='gray')
            plt.subplot(1, 3, 3)
            plt.title("Overlay of Tissue and Spots")
            plt.imshow(tissue_mask, cmap='gray')
            plt.imshow(spot_mask, cmap='jet', alpha=0.5)
            plt.tight_layout()
            plt.show()
            
            adjusted_coords = coords - optimal_shift
            plt.figure(figsize=(10, 10))
            plt.imshow(image)
            plt.scatter(adjusted_coords[:, 0], adjusted_coords[:, 1], marker='o', color='lime', s=25)
            plt.title("Corrected Spots Overlay on Tissue Image")
            plt.show()
        
        return optimal_shift, error, diffphase

    def manual_shift_alignment(self, slide_id, x_shift, y_shift, flip_y=False, display=True):
        if slide_id not in self.train_images or slide_id not in self.train_spot_tables:
            raise ValueError(f"Slide {slide_id} data not found.")
        
        image = self.train_images[slide_id]
        df = self.train_spot_tables[slide_id]
        original_coords = df[['x', 'y']].values.astype(float)
        if flip_y:
            original_coords[:, 1] = image.shape[0] - original_coords[:, 1]
        
        shift_vector = np.array([x_shift, y_shift])
        shifted_coords = original_coords + shift_vector
        
        if display:
            plt.figure(figsize=(14, 7))
            plt.subplot(1, 2, 1)
            plt.imshow(image)
            plt.scatter(original_coords[:, 0], original_coords[:, 1], 
                        marker='o', color='red', s=25, label='Original Spots')
            plt.title(f"Slide {slide_id} - Original Spots")
            plt.legend()
            plt.axis("off")
            
            plt.subplot(1, 2, 2)
            plt.imshow(image)
            plt.scatter(shifted_coords[:, 0], shifted_coords[:, 1], 
                        marker='o', color='lime', s=25, label='Shifted Spots')
            plt.title(f"Slide {slide_id} - Shifted Spots\n(x_shift: {x_shift}, y_shift: {y_shift})")
            plt.legend()
            plt.axis("off")
            
            plt.tight_layout()
            plt.show()
            
        return original_coords, shifted_coords

# -----------------------------
# 4. Example Usage with Caching and Visualization Options
# -----------------------------

# Path to the provided H5 data file
h5_file_path = "/kaggle/input/el-hackathon-2025/elucidata_ai_challenge_data.h5"

# Optionally specify a directory for caching training features (for slides S_1 to S_6)
train_cache_dir = "train_features_cache"
os.makedirs(train_cache_dir, exist_ok=True)
test_cache_path = "test_features_S_7.pkl"  # For slide S_7 test features

# Initialize the pipeline with desired patch size (110x110) and set the device (e.g., 'cuda' if available)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
pipeline_obj = CellTypePipeline(h5_file_path, patch_size=110, device=device)

# Initialize the ResNet34 feature extractor
pipeline_obj.initialize_cnn_model()

# Load training spots and images
pipeline_obj.load_train_data()
pipeline_obj.load_train_images()


Downloading: "https://download.pytorch.org/models/resnet34-b627a593.pth" to /root/.cache/torch/hub/checkpoints/resnet34-b627a593.pth
100%|██████████| 83.3M/83.3M [00:00<00:00, 175MB/s]


ResNet34 feature extractor initialized on device: cuda
Training spot data loaded successfully.
Training images loaded successfully.


In [2]:

skip_training = False

if not skip_training:
    # Prepare training features and targets from all slides (e.g., S_1 to S_6)
    X_train, y_train = pipeline_obj.prepare_all_training_set(cache_dir=train_cache_dir)

if not skip_training:
    # Train regression model on extracted CNN features
    reg_model = pipeline_obj.train(X_train, y_train)
    
    # Load test data and image for slide S_7 (as per challenge description)
    test_df = pipeline_obj.load_test_data(slide_id='S_7')
    test_image = pipeline_obj.load_test_image(slide_id='S_7')
    
    # Build a feature extractor for test slide (disable augmentation during inference)
    test_patch_extractor = PatchFeatureExtractor(test_image, pipeline_obj.patch_size,
                                                  pipeline_obj.cnn_model, augment=False,
                                                  device=pipeline_obj.device)
    test_feature_pipeline = Pipeline([
        ('patch_extractor', test_patch_extractor),
        ('scaler', StandardScaler())
    ])
    X_test_coords = test_df[['x', 'y']].values.astype(float)
    
    # Check for cached test features
    if os.path.exists(test_cache_path):
        print(f"Loading cached test features from {test_cache_path} ...")
        with open(test_cache_path, "rb") as f:
            X_test_features = pickle.load(f)
    else:
        print("Extracting CNN features for test data ...")
        X_test_features = test_feature_pipeline.fit_transform(X_test_coords)
        print(f"Saving test features to {test_cache_path} ...")
        with open(test_cache_path, "wb") as f:
            pickle.dump(X_test_features, f)

if not skip_training:
    # Predict cell type abundances for test data
    predictions = pipeline_obj.predict(reg_model, X_test_features)
    
    # Create submission file
    pipeline_obj.create_submission(test_df, predictions, submission_filename="submission.csv")


Extracting CNN features for slide S_1 ...
Saving training features for slide S_1 to train_features_cache/train_features_S_1.pkl ...
Extracted CNN features for slide S_1.
Extracting CNN features for slide S_2 ...
Saving training features for slide S_2 to train_features_cache/train_features_S_2.pkl ...
Extracted CNN features for slide S_2.
Extracting CNN features for slide S_3 ...
Saving training features for slide S_3 to train_features_cache/train_features_S_3.pkl ...
Extracted CNN features for slide S_3.
Extracting CNN features for slide S_4 ...
Saving training features for slide S_4 to train_features_cache/train_features_S_4.pkl ...
Extracted CNN features for slide S_4.
Extracting CNN features for slide S_5 ...
Saving training features for slide S_5 to train_features_cache/train_features_S_5.pkl ...
Extracted CNN features for slide S_5.
Extracting CNN features for slide S_6 ...
Saving training features for slide S_6 to train_features_cache/train_features_S_6.pkl ...
Extracted CNN feat