In [1]:
import os
import h5py
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms

# --------------------------------------------------------------------
# 1. Load Training Data (spots/Train) with Slice Information
# --------------------------------------------------------------------
h5_file_path = "/kaggle/input/el-hackathon-2025/elucidata_ai_challenge_data.h5"

with h5py.File(h5_file_path, "r") as f:
    train_spots = f["spots/Train"]
    # Load each slide and tag with its slice name.
    train_spot_tables = {
        slide: pd.DataFrame(np.array(train_spots[slide])).assign(slice_name=slide)
        for slide in train_spots.keys()
    }
# Concatenate all slides
train_df = pd.concat(train_spot_tables.values(), ignore_index=True)

# Assume first two columns are coordinates, next 35 are cell abundances.
cell_types = [f"C{i+1}" for i in range(35)]
train_df.columns = ["x", "y"] + cell_types + ["slice_name"]
print("Training data shape:", train_df.shape)

# --------------------------------------------------------------------
# 2. Define the Foundation Autoencoder (learns embeddings from cell abundances)
# --------------------------------------------------------------------
class FoundationAutoencoder(nn.Module):
    def __init__(self, input_dim=35, embed_dim=16, hidden_dim=64):
        super(FoundationAutoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, embed_dim)
        )
        self.decoder = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim)
        )
    def forward(self, x):
        emb = self.encoder(x)
        recon = self.decoder(emb)
        return emb, recon

# Dataset for foundation autoencoder training
class FoundationDataset(Dataset):
    def __init__(self, abundances):
        self.data = torch.tensor(abundances, dtype=torch.float32)
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        return self.data[idx]

def train_foundation_autoencoder(model, dataloader, num_epochs=20, lr=0.001, device='cpu'):
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss()
    model.train()
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        for data in dataloader:
            data = data.to(device)
            emb, recon = model(data)
            loss = criterion(recon, data)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item() * data.size(0)
        print(f"Foundation Autoencoder Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss/len(dataloader.dataset):.4f}")
    return model

# Create dataset and dataloader for the foundation autoencoder
foundation_dataset = FoundationDataset(train_df[cell_types].values)
foundation_loader = DataLoader(foundation_dataset, batch_size=32, shuffle=True)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
foundation_model = FoundationAutoencoder(input_dim=35, embed_dim=16, hidden_dim=64)
print("Training Foundation Autoencoder...")
foundation_model = train_foundation_autoencoder(foundation_model, foundation_loader, num_epochs=20, lr=0.001, device=device)

# --------------------------------------------------------------------
# 3. Precompute Foundation Embeddings for All Training Spots
# --------------------------------------------------------------------
with torch.no_grad():
    foundation_model.eval()
    train_abundances = torch.tensor(train_df[cell_types].values, dtype=torch.float32).to(device)
    train_embeddings, _ = foundation_model(train_abundances)
    train_embeddings = train_embeddings.cpu().numpy()

# --------------------------------------------------------------------
# 4. Define Dataset Classes that Load Images from the H5 File
# --------------------------------------------------------------------
class MainMappingWithImageH5Dataset(Dataset):
    """
    For training: Maps (x, y) coordinates and the corresponding image patch 
    (loaded from the H5 file) to the precomputed foundation embedding.
    Expects a DataFrame with columns: "x", "y", and "slice_name".
    """
    def __init__(self, df, embeddings, h5_file_path, patch_size=64, transform=None, train=True):
        self.df = df.reset_index(drop=True)
        self.embeddings = embeddings  # Precomputed embeddings in the same order as df.
        self.patch_size = patch_size
        self.transform = transform if transform is not None else transforms.ToTensor()
        self.h5_file_path = h5_file_path
        self.train = train
        self.images = {}
        group = "Train" if train else "Test"
        # Load images for each unique slice from the H5 file.
        with h5py.File(self.h5_file_path, "r") as f:
            for slice_name in self.df['slice_name'].unique():
                img_array = np.array(f[f"images/{group}"][slice_name])
                # Normalize and convert to uint8 if necessary.
                if img_array.dtype != np.uint8:
                    img_array = img_array - img_array.min()
                    if img_array.max() > 0:
                        img_array = img_array / img_array.max()
                    img_array = (img_array * 255).astype(np.uint8)
                if img_array.ndim > 3:
                    img_array = np.squeeze(img_array)
                # If the image is grayscale (2D), convert to RGB.
                if img_array.ndim == 2:
                    img_array = np.stack([img_array]*3, axis=-1)
                if img_array.shape[-1] != 3:
                    raise ValueError(f"Unexpected number of channels in image for slice {slice_name}: {img_array.shape}")
                self.images[slice_name] = Image.fromarray(img_array, mode="RGB")
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        coord = np.array([row['x'], row['y']], dtype=np.float32)
        slice_name = row['slice_name']
        image = self.images[slice_name]
        x, y = int(row['x']), int(row['y'])
        half_patch = self.patch_size // 2
        left = max(x - half_patch, 0)
        upper = max(y - half_patch, 0)
        right = left + self.patch_size
        lower = upper + self.patch_size
        patch = image.crop((left, upper, right, lower))
        patch = self.transform(patch)
        target = self.embeddings[idx]
        return torch.tensor(coord, dtype=torch.float32), patch, torch.tensor(target, dtype=torch.float32)

class TestMappingWithImageH5Dataset(Dataset):
    """
    For testing: Maps (x, y) coordinates and the corresponding image patch 
    (loaded from the H5 file) to be used for prediction.
    Expects a DataFrame with columns: "x", "y", and "slice_name".
    """
    def __init__(self, df, h5_file_path, patch_size=64, transform=None):
        self.df = df.reset_index(drop=True)
        self.patch_size = patch_size
        self.transform = transform if transform is not None else transforms.ToTensor()
        self.h5_file_path = h5_file_path
        self.images = {}
        group = "Test"
        with h5py.File(self.h5_file_path, "r") as f:
            for slice_name in self.df['slice_name'].unique():
                img_array = np.array(f[f"images/{group}"][slice_name])
                if img_array.dtype != np.uint8:
                    img_array = img_array - img_array.min()
                    if img_array.max() > 0:
                        img_array = img_array / img_array.max()
                    img_array = (img_array * 255).astype(np.uint8)
                if img_array.ndim > 3:
                    img_array = np.squeeze(img_array)
                if img_array.ndim == 2:
                    img_array = np.stack([img_array]*3, axis=-1)
                if img_array.shape[-1] != 3:
                    raise ValueError(f"Unexpected number of channels in image for slice {slice_name}: {img_array.shape}")
                self.images[slice_name] = Image.fromarray(img_array, mode="RGB")
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        coord = np.array([row['x'], row['y']], dtype=np.float32)
        slice_name = row['slice_name']
        image = self.images[slice_name]
        x, y = int(row['x']), int(row['y'])
        half_patch = self.patch_size // 2
        left = max(x - half_patch, 0)
        upper = max(y - half_patch, 0)
        right = left + self.patch_size
        lower = upper + self.patch_size
        patch = image.crop((left, upper, right, lower))
        patch = self.transform(patch)
        return torch.tensor(coord, dtype=torch.float32), patch

# --------------------------------------------------------------------
# 5. Define the Main Model that Uses Image Patches and Coordinates
# --------------------------------------------------------------------
class MainModelMappingWithImage(nn.Module):
    def __init__(self, coord_input_dim=2, patch_channels=3, patch_size=64, embed_dim=16, hidden_dim=64):
        super(MainModelMappingWithImage, self).__init__()
        # CNN to encode image patches.
        self.image_encoder = nn.Sequential(
            nn.Conv2d(patch_channels, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Flatten()
        )
        # Calculate the image feature dimension.
        img_feat_dim = 32 * (patch_size // 4) * (patch_size // 4)
        # Fully connected layers to map concatenated [coords, image_features] to embedding.
        self.fc = nn.Sequential(
            nn.Linear(coord_input_dim + img_feat_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, embed_dim)
        )
    
    def forward(self, coords, image_patches):
        img_features = self.image_encoder(image_patches)
        combined = torch.cat([coords, img_features], dim=1)
        embedding = self.fc(combined)
        return embedding

def train_main_mapping_with_image(model, dataloader, num_epochs=20, lr=0.001, device='cpu'):
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss()
    model.train()
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        for coords, patches, target_emb in dataloader:
            coords = coords.to(device)
            patches = patches.to(device)
            target_emb = target_emb.to(device)
            pred_emb = model(coords, patches)
            loss = criterion(pred_emb, target_emb)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item() * coords.size(0)
        print(f"Main Mapping With Image Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss/len(dataloader.dataset):.4f}")
    return model

# --------------------------------------------------------------------
# 6. Create Dataset and Train the Main Mapping with Image Model (Training)
# --------------------------------------------------------------------
patch_size = 64

# Create the training dataset using images from the H5 file.
main_image_dataset = MainMappingWithImageH5Dataset(
    train_df, train_embeddings, h5_file_path, patch_size=patch_size, transform=transforms.ToTensor(), train=True
)
main_image_loader = DataLoader(main_image_dataset, batch_size=32, shuffle=True)

main_model_img = MainModelMappingWithImage(coord_input_dim=2, patch_channels=3, patch_size=patch_size, embed_dim=16, hidden_dim=64)
print("Training Main Mapping With Image Model...")
main_model_img = train_main_mapping_with_image(main_model_img, main_image_loader, num_epochs=20, lr=0.001, device=device)

# --------------------------------------------------------------------
# 7. Inference on Test Data and Submission Creation
# --------------------------------------------------------------------
# Load test spots for slide "S_7" from the H5 file.
with h5py.File(h5_file_path, "r") as f:
    test_spots = f["spots/Test"]
    test_array = np.array(test_spots["S_7"])
    test_df = pd.DataFrame(test_array)
# Test file has three columns: x, y, Test_set. Drop the third column.
if test_df.shape[1] == 3:
    test_df.columns = ["x", "y", "Test_set"]
    test_df = test_df[["x", "y"]]
# Add slice_name column so we know which image to load.
test_df["slice_name"] = "S_7"
print("Test data shape:", test_df.shape)

# Create the test dataset (loading images from H5).
test_dataset = TestMappingWithImageH5Dataset(test_df, h5_file_path, patch_size=patch_size, transform=transforms.ToTensor())
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Predict foundation embeddings from test spots using the main mapping model with image.
predicted_embeddings_list = []
main_model_img.eval()
with torch.no_grad():
    for coords, patches in test_loader:
        coords = coords.to(device)
        patches = patches.to(device)
        pred_emb = main_model_img(coords, patches)
        predicted_embeddings_list.append(pred_emb)
    predicted_embeddings = torch.cat(predicted_embeddings_list, dim=0)

# Use the foundation decoder to convert predicted embeddings into cell abundance predictions.
foundation_model.eval()
with torch.no_grad():
    predicted_abundances = foundation_model.decoder(predicted_embeddings)
    predicted_abundances = predicted_abundances.cpu().numpy()

# Create submission DataFrame and save CSV.
submission_df = pd.DataFrame(predicted_abundances, columns=cell_types)
submission_df.insert(0, 'ID', test_df.index)
submission_file = "submission.csv"
submission_df.to_csv(submission_file, index=False)
print(f"Submission file '{submission_file}' created!")


Training data shape: (8349, 38)
Training Foundation Autoencoder...
Foundation Autoencoder Epoch 1/20, Loss: 0.3169
Foundation Autoencoder Epoch 2/20, Loss: 0.0383
Foundation Autoencoder Epoch 3/20, Loss: 0.0247
Foundation Autoencoder Epoch 4/20, Loss: 0.0189
Foundation Autoencoder Epoch 5/20, Loss: 0.0148
Foundation Autoencoder Epoch 6/20, Loss: 0.0124
Foundation Autoencoder Epoch 7/20, Loss: 0.0096
Foundation Autoencoder Epoch 8/20, Loss: 0.0079
Foundation Autoencoder Epoch 9/20, Loss: 0.0068
Foundation Autoencoder Epoch 10/20, Loss: 0.0061
Foundation Autoencoder Epoch 11/20, Loss: 0.0059
Foundation Autoencoder Epoch 12/20, Loss: 0.0057
Foundation Autoencoder Epoch 13/20, Loss: 0.0050
Foundation Autoencoder Epoch 14/20, Loss: 0.0046
Foundation Autoencoder Epoch 15/20, Loss: 0.0043
Foundation Autoencoder Epoch 16/20, Loss: 0.0041
Foundation Autoencoder Epoch 17/20, Loss: 0.0038
Foundation Autoencoder Epoch 18/20, Loss: 0.0036
Foundation Autoencoder Epoch 19/20, Loss: 0.0033
Foundation 