# ALBRT, help me out

This notebook attempts to reproduce this paper: https://github.com/engrodawood/ALBRT

The paper used two datasets for training:

- [NuCL](https://drive.google.com/drive/folders/1ER1fnse5FXotFeQnbrbjff09rI5wAnEn) (mostly cancer slices from TCGA)
- 

In [None]:
# ================================
# 1. Import Required Libraries
# ================================
import os
import h5py
import numpy as np
import pandas as pd
import cv2
import pickle
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle

from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.multioutput import MultiOutputRegressor
from sklearn.ensemble import RandomForestRegressor
from sklearn.base import BaseEstimator, TransformerMixin

from tensorflow.keras.applications import ResNet50
from tensorflow.keras.applications.resnet50 import preprocess_input
from tensorflow.keras.models import Model

from skimage.registration import phase_cross_correlation

# ================================
# 2. Define Helper Functions and Transformers
# ================================
def extract_patch(image, center, patch_size):
    """
    Extract a square patch from the image centered at 'center'.
    """
    x, y = int(center[0]), int(center[1])
    half_size = patch_size // 2
    y_min = max(y - half_size, 0)
    y_max = min(y + half_size, image.shape[0])
    x_min = max(x - half_size, 0)
    x_max = min(x + half_size, image.shape[1])
    patch = image[y_min:y_max, x_min:x_max, :]
    return patch

# Transformer to extract CNN features from image patches.
class PatchFeatureExtractor(BaseEstimator, TransformerMixin):
    def __init__(self, image, patch_size, cnn_model):
        self.image = image
        self.patch_size = patch_size
        self.cnn_model = cnn_model

    def fit(self, X, y=None):
        return self

    def transform(self, X):
        patches = []
        for coord in X:
            patch = extract_patch(self.image, coord, self.patch_size)
            patch_resized = cv2.resize(patch, (224, 224))
            patches.append(patch_resized)
        patches = np.array(patches)
        patches_preprocessed = preprocess_input(patches.astype(np.float32))
        features = self.cnn_model.predict(patches_preprocessed, verbose=0)
        # Reshape features to 2D (one row per patch)
        features = features.reshape(features.shape[0], -1)
        return features

# ================================
# 3. Define the CellTypePipeline Class
# ================================
class CellTypePipeline:
    """
    Pipeline for training ALBRT on competition data.
    This class loads training images and spot data from the H5 file,
    extracts CNN features from patches using a pre-trained CNN,
    and trains a multi-output regression model to predict cell type abundances.
    """
    def __init__(self, h5_file_path, patch_size=64):
        self.h5_file_path = h5_file_path
        self.patch_size = patch_size
        self.train_spot_tables = {}
        self.train_images = {}
        self.cell_type_columns = None
        self.cnn_model = None  # Will be initialized below
        self.feature_extractor_pipeline = None

    def initialize_cnn_model(self):
        """
        Initialize a ResNet50 CNN model.
        Here you can choose to initialize with ImageNet weights or (if available)
        fine-tune on histopathological data.
        """
        base_model = ResNet50(weights='imagenet', include_top=False, pooling='avg')
        self.cnn_model = Model(inputs=base_model.input, outputs=base_model.output)
        print("CNN model initialized.")

    def load_train_data(self):
        """
        Load training spot data from the H5 file.
        Each slide's spot data is stored in a separate DataFrame.
        """
        with h5py.File(self.h5_file_path, "r") as f:
            train_spots = f["spots/Train"]
            for slide_name in train_spots.keys():
                # Assume first two columns are coordinates and the rest are cell type abundances.
                spot_array = np.array(train_spots[slide_name])
                df = pd.DataFrame(spot_array, columns=["x", "y"] + [f"C{i}" for i in range(1, 36)])
                self.train_spot_tables[slide_name] = df
        print("Training spot data loaded.")

    def load_train_images(self):
        """
        Load training images from the H5 file.
        """
        with h5py.File(self.h5_file_path, "r") as f:
            train_imgs = f["images/Train"]
            for slide_name in train_imgs.keys():
                image_array = np.array(train_imgs[slide_name])
                self.train_images[slide_name] = image_array
        print("Training images loaded.")

    def prepare_training_set(self, slide_id, cache_path=None):
        """
        For a given slide, extract CNN features from patches around each spot and
        return the features (X) along with the target cell type abundances (y).
        Optionally, cache the extracted features.
        """
        if cache_path is not None and os.path.exists(cache_path):
            print(f"Loading cached features from {cache_path} for slide {slide_id} ...")
            with open(cache_path, "rb") as f:
                X_features, y = pickle.load(f)
            return X_features, y

        if slide_id not in self.train_spot_tables or slide_id not in self.train_images:
            raise ValueError(f"Slide {slide_id} not found in training data.")

        df = self.train_spot_tables[slide_id]
        # Coordinates are the first two columns.
        feature_cols = ['x', 'y']
        # The remaining columns are the cell type abundances.
        target_cols = [col for col in df.columns if col not in feature_cols]
        self.cell_type_columns = target_cols

        X_coords = df[feature_cols].values.astype(float)
        y = df[target_cols].values.astype(float)

        # (Optional) Registration/alignment could be applied here.
        # For simplicity, we assume the spots are already well-aligned.

        he_image = self.train_images[slide_id]
        patch_extractor = PatchFeatureExtractor(he_image, self.patch_size, self.cnn_model)
        self.feature_extractor_pipeline = Pipeline([
            ('patch_extractor', patch_extractor),
            ('scaler', StandardScaler())
        ])

        print(f"Extracting CNN features for slide {slide_id} ...")
        X_features = self.feature_extractor_pipeline.fit_transform(X_coords)

        if cache_path is not None:
            with open(cache_path, "wb") as f:
                pickle.dump((X_features, y), f)
        return X_features, y

    def prepare_all_training_set(self, cache_dir=None):
        """
        Concatenate training features from all slides into a single training set.
        """
        X_list, y_list = [], []
        for slide_id in sorted(self.train_spot_tables.keys()):
            slide_cache_path = os.path.join(cache_dir, f"train_features_{slide_id}.pkl") if cache_dir else None
            X, y = self.prepare_training_set(slide_id, cache_path=slide_cache_path)
            X_list.append(X)
            y_list.append(y)
        X_all = np.concatenate(X_list, axis=0)
        y_all = np.concatenate(y_list, axis=0)
        print("All training features extracted and concatenated.")
        return X_all, y_all

    def build_regression_pipeline(self):
        """
        Build a regression pipeline with a multi-output RandomForest regressor.
        """
        pipeline = Pipeline([
            ('regressor', MultiOutputRegressor(RandomForestRegressor(n_estimators=100, random_state=42)))
        ])
        return pipeline

    def train(self, X, y):
        """
        Train the regression model on the provided features and targets.
        """
        reg_pipeline = self.build_regression_pipeline()
        reg_pipeline.fit(X, y)
        print("Regression model trained.")
        return reg_pipeline

    def predict(self, reg_model, X_test):
        """
        Predict cell type abundances using the regression model.
        """
        predictions = reg_model.predict(X_test)
        return predictions



In [None]:
!mkdir -p /working/cache_dir

In [None]:
# ================================
# 4. Train ALBRT on Competition Data
# ================================
# Update this path to point to your competition H5 file.
h5_file_path = "/kaggle/input/el-hackathon-2025/elucidata_ai_challenge_data.h5"

# Create the pipeline and load data.
pipeline = CellTypePipeline(h5_file_path, patch_size=64)
pipeline.initialize_cnn_model()
pipeline.load_train_data()
pipeline.load_train_images()

# Prepare the full training set (optionally using a cache directory).
X_train, y_train = pipeline.prepare_all_training_set(cache_dir="/working/cache_dir")

print("Training a regression model")
# Train the regression model.
reg_model = pipeline.train(X_train, y_train)

print("Writing to Pkl")
# Save the trained regression model for later inference.
with open("trained_regression_model.pkl", "wb") as f:
    pickle.dump(reg_model, f)
print("Trained regression model saved.")


In [None]:

# ================================
# 5. Minimal Inference on a Test Slide
# ================================
def load_test_data(h5_file_path, slide_id):
    """
    Load test image and spot coordinates from the H5 file.
    """
    with h5py.File(h5_file_path, "r") as f:
        test_image = np.array(f["images/Test"][slide_id])
        test_spots_array = np.array(f["spots/Test"][slide_id])
        test_spots_df = pd.DataFrame(test_spots_array, columns=["x", "y"])
    print(f"Test data for slide {slide_id} loaded.")
    return test_image, test_spots_df

def extract_features_test(image, spots_df, cnn_model, patch_size=64):
    """
    Extract CNN features from test patches around each spot.
    """
    features = []
    for _, row in spots_df.iterrows():
        coord = (row['x'], row['y'])
        patch = extract_patch(image, coord, patch_size)
        patch_resized = cv2.resize(patch, (224, 224))
        patch_preprocessed = preprocess_input(np.expand_dims(patch_resized.astype(np.float32), axis=0))
        feat = cnn_model.predict(patch_preprocessed, verbose=0)
        features.append(feat.flatten())
    features = np.array(features)
    return features

def minimal_inference(h5_file_path, slide_id, cnn_model, reg_model, patch_size=64):
    """
    Run minimal inference on a test slide:
      1. Load test image and spot coordinates.
      2. Extract CNN features from patches.
      3. Predict cell type abundances.
      4. Save a submission CSV.
    """
    test_image, test_spots_df = load_test_data(h5_file_path, slide_id)
    # (Optional: add registration here if needed.)
    test_features = extract_features_test(test_image, test_spots_df, cnn_model, patch_size=patch_size)
    predictions = reg_model.predict(test_features)
    cell_type_columns = [f"C{i}" for i in range(1, 36)]
    pred_df = pd.DataFrame(predictions, columns=cell_type_columns, index=test_spots_df.index)
    pred_df.insert(0, 'ID', pred_df.index)
    submission_filename = f"submission_{slide_id}.csv"
    pred_df.to_csv(submission_filename, index=False)
    print(f"Submission file '{submission_filename}' created!")
    return predictions

# Run minimal inference on a test slide (update the slide ID as needed)
test_slide_id = "S_1"  # Update as necessary
predictions = minimal_inference(h5_file_path, test_slide_id, pipeline.cnn_model, reg_model, patch_size=64)
