In [1]:
import os
import h5py
import numpy as np
import pandas as pd
from PIL import Image
from scipy.spatial import KDTree
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms

##############################
# PART A: SPATIAL FOUNDATION AUTOENCODER
##############################

# 1. Load Training Data (spots/Train) with Slice Information
h5_file_path = "/kaggle/input/el-hackathon-2025/elucidata_ai_challenge_data.h5"

with h5py.File(h5_file_path, "r") as f:
    train_spots = f["spots/Train"]
    # Load each slide and tag with its slice name.
    train_spot_tables = {
        slide: pd.DataFrame(np.array(train_spots[slide])).assign(slice_name=slide)
        for slide in train_spots.keys()
    }
# Concatenate all slides.
train_df = pd.concat(train_spot_tables.values(), ignore_index=True)

# Assume first two columns are coordinates, next 35 are cell abundances.
cell_types = [f"C{i+1}" for i in range(35)]
train_df.columns = ["x", "y"] + cell_types + ["slice_name"]
print("Training data shape:", train_df.shape)

# 2. Compute Descending Ranks for Each Spot (highest abundance gets rank 1)
# This uses the pandas rank function row-wise.
ranks = train_df[cell_types].rank(axis=1, method="dense", ascending=False).values  # shape (N,35)

# 3. Compute Neighbor-Aggregated Ranks (using a KDTree within each slice)
def compute_neighbor_aggregated_ranks(df, rank_array, radius=100):
    agg = np.zeros_like(rank_array)
    for slice_name in df['slice_name'].unique():
        slice_idx = df.index[df['slice_name'] == slice_name].tolist()
        coords = df.loc[slice_idx, ["x", "y"]].values
        tree = KDTree(coords)
        for i, spot in enumerate(coords):
            neighbor_local = tree.query_ball_point(spot, r=radius)
            neighbor_global = [slice_idx[j] for j in neighbor_local]
            agg[slice_idx[i]] = rank_array[neighbor_global].mean(axis=0)
    return agg

neighbor_agg = compute_neighbor_aggregated_ranks(train_df, ranks, radius=100)

# 4. Concatenate Own Ranks with Neighbor Aggregated Ranks -> 70-dim features.
spatial_features = np.concatenate([ranks, neighbor_agg], axis=1)  # shape (N,70)

# 5. Define a Spatial Foundation Autoencoder that learns a latent (16-dim) embedding.
class SpatialFoundationAutoencoder(nn.Module):
    def __init__(self, input_dim=70, embed_dim=16, hidden_dim=128):
        super(SpatialFoundationAutoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, embed_dim)
        )
        self.decoder = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim)
        )
    def forward(self, x):
        emb = self.encoder(x)
        recon = self.decoder(emb)
        return emb, recon

# Dataset for spatial autoencoder training.
class SpatialFoundationDataset(Dataset):
    def __init__(self, features):
        self.data = torch.tensor(features, dtype=torch.float32)
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        return self.data[idx]

def train_spatial_autoencoder(model, dataloader, num_epochs=20, lr=0.001, device='cpu'):
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss()
    model.train()
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        for data in dataloader:
            data = data.to(device)
            emb, recon = model(data)
            loss = criterion(recon, data)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item() * data.size(0)
        print(f"Spatial Autoencoder Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss/len(dataloader.dataset):.4f}")
    return model

spatial_dataset = SpatialFoundationDataset(spatial_features)
spatial_loader = DataLoader(spatial_dataset, batch_size=32, shuffle=True)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
spatial_model = SpatialFoundationAutoencoder(input_dim=70, embed_dim=16, hidden_dim=128)
print("Training Spatial Foundation Autoencoder...")
spatial_model = train_spatial_autoencoder(spatial_model, spatial_loader, num_epochs=200, lr=0.001, device=device)

# Precompute the spatial embeddings for each training spot:
with torch.no_grad():
    spatial_model.eval()
    spatial_tensor = torch.tensor(spatial_features, dtype=torch.float32).to(device)
    spatial_embeddings = spatial_model.encoder(spatial_tensor).cpu().numpy()  # shape (N,16)

##############################
# PART B: MAIN MODEL TO PREDICT SPATIAL EMBEDDINGS
##############################

# For this main model we assume that at test time we don't have cell abundances,
# so we use available inputs (e.g. coordinates and image patches from the H5 file).
# The target for training is the spatial embedding computed above.

# We'll build dataset classes that load images from the H5 file.

# Dataset for training main model (loading images from H5)
class MainMappingWithImageH5Dataset(Dataset):
    """
    For training: maps (x, y) coordinates and corresponding image patch (from the H5 file)
    to the precomputed spatial embedding.
    Expects a DataFrame with columns: "x", "y", and "slice_name".
    """
    def __init__(self, df, target_embeddings, h5_file_path, patch_size=64, transform=None, train=True):
        self.df = df.reset_index(drop=True)
        self.targets = target_embeddings  # should be in the same order as df.
        self.patch_size = patch_size
        self.transform = transform if transform is not None else transforms.ToTensor()
        self.h5_file_path = h5_file_path
        self.train = train
        self.images = {}
        group = "Train" if train else "Test"
        with h5py.File(self.h5_file_path, "r") as f:
            for slice_name in self.df['slice_name'].unique():
                img_array = np.array(f[f"images/{group}"][slice_name])
                # Normalize and convert to uint8 if needed.
                if img_array.dtype != np.uint8:
                    img_array = img_array - img_array.min()
                    if img_array.max() > 0:
                        img_array = img_array / img_array.max()
                    img_array = (img_array * 255).astype(np.uint8)
                if img_array.ndim > 3:
                    img_array = np.squeeze(img_array)
                if img_array.ndim == 2:
                    img_array = np.stack([img_array]*3, axis=-1)
                if img_array.shape[-1] != 3:
                    raise ValueError(f"Unexpected number of channels in image for slice {slice_name}: {img_array.shape}")
                self.images[slice_name] = Image.fromarray(img_array, mode="RGB")
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        coord = np.array([row['x'], row['y']], dtype=np.float32)
        slice_name = row['slice_name']
        image = self.images[slice_name]
        x, y = int(row['x']), int(row['y'])
        half_patch = self.patch_size // 2
        left = max(x - half_patch, 0)
        upper = max(y - half_patch, 0)
        right = left + self.patch_size
        lower = upper + self.patch_size
        patch = image.crop((left, upper, right, lower))
        patch = self.transform(patch)
        target = self.targets[idx]
        return torch.tensor(coord, dtype=torch.float32), patch, torch.tensor(target, dtype=torch.float32)

# Dataset for test (similarly loads images from H5)
class TestMappingWithImageH5Dataset(Dataset):
    """
    For testing: maps (x, y) coordinates and corresponding image patch (from the H5 file).
    Expects a DataFrame with columns: "x", "y", and "slice_name".
    """
    def __init__(self, df, h5_file_path, patch_size=64, transform=None):
        self.df = df.reset_index(drop=True)
        self.patch_size = patch_size
        self.transform = transform if transform is not None else transforms.ToTensor()
        self.h5_file_path = h5_file_path
        self.images = {}
        group = "Test"
        with h5py.File(self.h5_file_path, "r") as f:
            for slice_name in self.df['slice_name'].unique():
                img_array = np.array(f[f"images/{group}"][slice_name])
                if img_array.dtype != np.uint8:
                    img_array = img_array - img_array.min()
                    if img_array.max() > 0:
                        img_array = img_array / img_array.max()
                    img_array = (img_array * 255).astype(np.uint8)
                if img_array.ndim > 3:
                    img_array = np.squeeze(img_array)
                if img_array.ndim == 2:
                    img_array = np.stack([img_array]*3, axis=-1)
                if img_array.shape[-1] != 3:
                    raise ValueError(f"Unexpected number of channels in image for slice {slice_name}: {img_array.shape}")
                self.images[slice_name] = Image.fromarray(img_array, mode="RGB")
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        coord = np.array([row['x'], row['y']], dtype=np.float32)
        slice_name = row['slice_name']
        image = self.images[slice_name]
        x, y = int(row['x']), int(row['y'])
        half_patch = self.patch_size // 2
        left = max(x - half_patch, 0)
        upper = max(y - half_patch, 0)
        right = left + self.patch_size
        lower = upper + self.patch_size
        patch = image.crop((left, upper, right, lower))
        patch = self.transform(patch)
        return torch.tensor(coord, dtype=torch.float32), patch

# Define a main model that predicts the spatial embedding from coordinates and image patch.
class MainModelMappingWithImage(nn.Module):
    def __init__(self, coord_input_dim=2, patch_channels=3, patch_size=64, embed_dim=16, hidden_dim=64):
        super(MainModelMappingWithImage, self).__init__()
        # CNN for image patch encoding.
        self.image_encoder = nn.Sequential(
            nn.Conv2d(patch_channels, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Flatten()
        )
        img_feat_dim = 32 * (patch_size // 4) * (patch_size // 4)
        self.fc = nn.Sequential(
            nn.Linear(coord_input_dim + img_feat_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, embed_dim)
        )
    def forward(self, coords, patches):
        img_features = self.image_encoder(patches)
        combined = torch.cat([coords, img_features], dim=1)
        emb = self.fc(combined)
        return emb

def train_main_mapping_with_image(model, dataloader, num_epochs=20, lr=0.001, device='cpu'):
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss()
    model.train()
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        for coords, patches, target_emb in dataloader:
            coords = coords.to(device)
            patches = patches.to(device)
            target_emb = target_emb.to(device)
            pred_emb = model(coords, patches)
            loss = criterion(pred_emb, target_emb)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item() * coords.size(0)
        print(f"Main Mapping With Image Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss/len(dataloader.dataset):.4f}")
    return model

# Prepare training dataset for the main model.
patch_size = 64
# Use the H5-based dataset for training. (Images from "images/Train")
main_train_dataset = MainMappingWithImageH5Dataset(
    train_df, spatial_embeddings, h5_file_path, patch_size=patch_size, transform=transforms.ToTensor(), train=True
)
main_train_loader = DataLoader(main_train_dataset, batch_size=32, shuffle=True)

main_model_img = MainModelMappingWithImage(coord_input_dim=2, patch_channels=3, patch_size=patch_size, embed_dim=16, hidden_dim=64)
print("Training Main Mapping With Image Model...")
main_model_img = train_main_mapping_with_image(main_model_img, main_train_loader, num_epochs=100, lr=0.001, device=device)

##############################
# PART C: INFERENCE AND SUBMISSION
##############################

# Load test spots for slide "S_7" from the H5 file.
with h5py.File(h5_file_path, "r") as f:
    test_spots = f["spots/Test"]
    test_array = np.array(test_spots["S_7"])
    test_df = pd.DataFrame(test_array)
# Test file has three columns: x, y, Test_set. Drop the third column.
if test_df.shape[1] == 3:
    test_df.columns = ["x", "y", "Test_set"]
    test_df = test_df[["x", "y"]]
# Add slice_name column for image lookup.
test_df["slice_name"] = "S_7"
print("Test data shape:", test_df.shape)

# Create the test dataset using the H5 file (images from "images/Test").
test_dataset = TestMappingWithImageH5Dataset(test_df, h5_file_path, patch_size=patch_size, transform=transforms.ToTensor())
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Use the main model to predict the spatial embedding for test spots.
predicted_embeddings_list = []
main_model_img.eval()
with torch.no_grad():
    for coords, patches in test_loader:
        coords = coords.to(device)
        patches = patches.to(device)
        pred_emb = main_model_img(coords, patches)
        predicted_embeddings_list.append(pred_emb)
    predicted_embeddings = torch.cat(predicted_embeddings_list, dim=0)

# Use the foundation decoder (spatial_model.decoder) to reconstruct the full 70-d vector.
# Here, we assume that the first 35 values correspond to the spot's own rank vector.
spatial_model.eval()
with torch.no_grad():
    pred_reconstruction = spatial_model.decoder(predicted_embeddings)
    pred_reconstruction = pred_reconstruction.cpu().numpy()

# For submission, we extract the first 35 columns (the predicted ranks for the spot).
predicted_ranks = pred_reconstruction[:, :35]
submission_df = pd.DataFrame(predicted_ranks, columns=cell_types)
submission_df.insert(0, 'ID', test_df.index)
submission_file = "submission.csv"
submission_df.to_csv(submission_file, index=False)
print(f"Submission file '{submission_file}' created!")


Training data shape: (8349, 38)
Training Spatial Foundation Autoencoder...
Spatial Autoencoder Epoch 1/200, Loss: 40.4327
Spatial Autoencoder Epoch 2/200, Loss: 13.2733
Spatial Autoencoder Epoch 3/200, Loss: 11.2352
Spatial Autoencoder Epoch 4/200, Loss: 9.4801
Spatial Autoencoder Epoch 5/200, Loss: 8.1366
Spatial Autoencoder Epoch 6/200, Loss: 7.1390
Spatial Autoencoder Epoch 7/200, Loss: 6.5800
Spatial Autoencoder Epoch 8/200, Loss: 6.1916
Spatial Autoencoder Epoch 9/200, Loss: 5.9855
Spatial Autoencoder Epoch 10/200, Loss: 5.8222
Spatial Autoencoder Epoch 11/200, Loss: 5.6645
Spatial Autoencoder Epoch 12/200, Loss: 5.5342
Spatial Autoencoder Epoch 13/200, Loss: 5.4220
Spatial Autoencoder Epoch 14/200, Loss: 5.3279
Spatial Autoencoder Epoch 15/200, Loss: 5.2535
Spatial Autoencoder Epoch 16/200, Loss: 5.1719
Spatial Autoencoder Epoch 17/200, Loss: 5.1097
Spatial Autoencoder Epoch 18/200, Loss: 5.0419
Spatial Autoencoder Epoch 19/200, Loss: 4.9750
Spatial Autoencoder Epoch 20/200, Loss