# Foundation model approach

This notebook trains a Foundation model on the Training data, in particular the ranks of the cell abundancies

TBD

## Load Data

We create a big table containing the Cell abundancy across all the spots across all slides.

TODO: add slide_name column to dataframe

In [1]:
import h5py
import numpy as np
import pandas as pd
from scipy.spatial import KDTree
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

h5_file_path = "/kaggle/input/el-hackathon-2025/elucidata_ai_challenge_data.h5"

# Load training data from "spots/Train"
with h5py.File(h5_file_path, "r") as f:
    train_spots = f["spots/Train"]
    # Each slide is loaded into its own DataFrame
    train_spot_tables = {slide: pd.DataFrame(np.array(train_spots[slide])) 
                           for slide in train_spots.keys()}
train_df = pd.concat(train_spot_tables.values(), ignore_index=True)

# Assume the first two columns are coordinates and the rest (columns 2+) are cell abundances.
cell_types = [f"C{i+1}" for i in range(35)]
train_df.columns = ["x", "y"] + cell_types
print("Training data shape:", train_df.shape)


Training data shape: (8349, 37)


In [2]:
train_df

Unnamed: 0,x,y,C1,C2,C3,C4,C5,C6,C7,C8,...,C26,C27,C28,C29,C30,C31,C32,C33,C34,C35
0,1554,1297,0.014401,0.057499,0.022033,0.001704,0.533992,1.511707,0.015313,0.020029,...,0.001010,2.068237,0.121361,0.007344,0.000017,0.036891,0.035934,0.118937,0.001472,0.050057
1,462,1502,0.116196,0.197176,0.110600,0.042614,5.587681,0.006885,0.096346,0.001711,...,0.000692,0.014442,0.000238,0.024071,0.000023,0.217589,0.100662,0.004027,0.004122,0.049491
2,1488,1548,0.133284,0.035880,0.061352,0.003073,1.104479,0.009174,0.009175,0.000114,...,0.000096,0.149792,0.001401,0.000699,0.000009,0.024491,0.018810,0.004171,0.000425,0.015348
3,1725,1182,0.087715,0.235223,0.090382,0.013902,8.760482,0.140912,0.188859,0.010154,...,0.001964,0.142549,0.002036,0.047165,0.000022,0.180372,0.202981,0.003709,0.001845,0.116022
4,581,1113,0.128468,0.066399,0.098982,0.047022,3.425771,0.001009,0.026881,0.000468,...,0.000072,0.005920,0.000048,0.006359,0.000585,0.052661,0.032168,0.000107,0.000107,0.013103
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
8344,1173,842,0.008615,0.052352,0.009905,0.003060,5.230128,0.000176,0.040653,0.000188,...,0.000693,0.006530,0.000047,0.011358,0.000007,0.037057,0.018493,0.000290,0.000296,0.010812
8345,1237,775,0.034781,0.028511,0.031475,0.001812,1.121948,0.000052,0.008572,0.001044,...,0.000500,0.000060,0.000802,0.020970,0.000100,0.020329,0.011358,0.000215,0.000113,0.004115
8346,903,953,0.000515,0.011848,0.001639,0.000039,0.104526,0.000024,0.001327,0.000014,...,0.000145,0.000041,0.000010,0.000049,0.000019,0.007001,0.002478,0.041246,0.000018,0.000621
8347,954,1310,0.009845,0.076963,0.013963,0.001142,5.819259,0.000598,0.073316,0.000391,...,0.000048,0.032145,0.000371,0.019168,0.000122,0.068377,0.031485,0.000532,0.000437,0.018696


## Compute Smoothed Ranks with KDTree

Compute a dataset of smoothed ranks. For every spot of a slide, look at the nearby spots, and compute the average rank for every cell type.

This step assumes that cells of the same type may tend to be close together on the slide. This may make biological sense for most cell types, but may be wrong for others.

In [3]:

def compute_smoothed_ranks(df, radius=100):
    print("Computing smoothed ranks from training data...")
    # Convert the DataFrame to long format for rank computation.
    long_df = df.melt(id_vars=["x", "y"], var_name="cell_type", value_name="abundance")
    # Compute rank per (x,y) location (higher abundance gets a lower rank number).
    long_df["rank"] = long_df.groupby(["x", "y"])["abundance"].rank(method="dense", ascending=False)
    
    # Get unique coordinates and build a KDTree.
    coords = long_df[["x", "y"]].drop_duplicates().values
    tree = KDTree(coords)
    
    # Set the index only once outside the loop for efficiency.
    long_df_indexed = long_df.set_index(["x", "y"])
    
    smoothed = []
    for (x, y) in tqdm(coords, desc="Smoothing ranks"):
        idx = tree.query_ball_point([x, y], r=radius)
        neighbor_coords = [tuple(coords[j]) for j in idx]
        spot_neighbors = long_df_indexed.loc[neighbor_coords]
        avg_ranks = spot_neighbors.groupby("cell_type")["rank"].mean()
        smoothed.append(avg_ranks)
    
    smoothed_df = pd.DataFrame(smoothed).reset_index(drop=True)
    smoothed_df["x"] = coords[:, 0]
    smoothed_df["y"] = coords[:, 1]
    print("Smoothed ranks computed. Shape:", smoothed_df.shape)
    return smoothed_df

# Compute smoothed ranks for training data.
smoothed_df = compute_smoothed_ranks(train_df, radius=100)
smoothed_df

Computing smoothed ranks from training data...


Smoothing ranks: 100%|██████████| 8341/8341 [05:42<00:00, 24.37it/s]


Smoothed ranks computed. Shape: (8341, 37)


cell_type,C1,C10,C11,C12,C13,C14,C15,C16,C17,C18,...,C34,C35,C4,C5,C6,C7,C8,C9,x,y
0,9.534722,23.055556,20.493056,19.166667,27.270833,16.340278,18.965278,12.187500,12.416667,14.972222,...,23.569444,15.138889,21.152778,14.319444,22.041667,17.125000,25.680556,20.506944,1554,1297
1,8.873786,19.300971,26.961165,16.718447,28.485437,18.563107,15.466019,9.592233,20.038835,12.242718,...,18.174757,12.737864,16.155340,12.310680,20.873786,15.970874,27.058252,20.466019,462,1502
2,9.402778,22.375000,24.583333,17.444444,29.333333,19.277778,14.416667,4.152778,16.736111,16.513889,...,26.263889,15.652778,20.013889,3.319444,19.555556,14.402778,28.430556,12.763889,1488,1548
3,14.097015,23.253731,20.686567,20.119403,25.708955,16.007463,17.253731,12.619403,11.694030,16.432836,...,25.440299,14.216418,24.761194,14.865672,24.276119,13.641791,23.649254,20.194030,1725,1182
4,9.696203,28.392405,25.715190,20.221519,24.607595,18.392405,18.620253,9.645570,12.113924,14.436709,...,23.924051,14.734177,17.835443,6.462025,25.145570,13.727848,26.120253,19.512658,581,1113
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
8336,10.653333,25.803333,24.106667,19.333333,27.266667,17.346667,16.923333,8.373333,13.153333,15.110000,...,23.480000,14.396667,21.063333,10.650000,22.553333,16.050000,25.053333,19.210000,1173,842
8337,8.789030,25.257384,24.843882,19.907173,27.489451,18.109705,16.518987,6.729958,14.337553,15.232068,...,24.189873,15.616034,18.426160,10.860759,19.210970,17.493671,26.080169,15.696203,1237,775
8338,12.087248,29.624161,24.687919,18.845638,27.761745,15.909396,17.620805,8.043624,9.832215,14.963087,...,23.882550,14.140940,22.171141,9.281879,21.637584,15.281879,25.298658,18.664430,903,953
8339,8.768116,28.586957,23.637681,19.605072,26.188406,17.648551,19.097826,10.065217,13.590580,13.239130,...,22.028986,13.960145,18.398551,10.072464,23.268116,15.641304,25.184783,21.184783,954,1310


# 3. Create PyTorch Datasets Using the Challenge Data


In [4]:


# ----------------------------------------------------
# ----------------------------------------------------
class FoundationDataset(Dataset):
    """
    Dataset for pretraining the foundation model.
    Each sample consists of (x, y) coordinates and the corresponding smoothed ranks.
    """
    def __init__(self, coords, smoothed_ranks):
        self.coords = torch.tensor(coords, dtype=torch.float32)
        self.smoothed_ranks = torch.tensor(smoothed_ranks, dtype=torch.float32)

    def __len__(self):
        return len(self.coords)

    def __getitem__(self, idx):
        return self.coords[idx], self.smoothed_ranks[idx]

class MainDataset(Dataset):
    """
    Dataset for training the main model.
    Each sample consists of (x, y) coordinates and the raw cell abundances.
    """
    def __init__(self, coords, abundances):
        self.coords = torch.tensor(coords, dtype=torch.float32)
        self.abundances = torch.tensor(abundances, dtype=torch.float32)

    def __len__(self):
        return len(self.coords)

    def __getitem__(self, idx):
        return self.coords[idx], self.abundances[idx]

# For the foundation model, use the unique coordinates from smoothed_df.
foundation_coords = smoothed_df[['x','y']].values
foundation_targets = smoothed_df[cell_types].values

# For the main model, use all training spots (which might include duplicate coordinates).
main_coords = train_df[['x','y']].values
main_targets = train_df[cell_types].values

foundation_dataset = FoundationDataset(foundation_coords, foundation_targets)
main_dataset = MainDataset(main_coords, main_targets)

# DataLoaders
batch_size = 32
foundation_loader = DataLoader(foundation_dataset, batch_size=batch_size, shuffle=True)
main_loader = DataLoader(main_dataset, batch_size=batch_size, shuffle=True)



## Define the Foundation Model and the main model to predict the challenge

In [5]:

# ------------------------------
# 4. Define the Foundation and Main Models
# ------------------------------
class FoundationModel(nn.Module):
    """
    A network that learns to predict the smoothed ranks (latent spatial features)
    from the (x, y) coordinates.
    """
    def __init__(self, input_dim=2, hidden_dim=64, output_dim=35):
        super(FoundationModel, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x):
        return self.net(x)

class MainModel(nn.Module):
    """
    Main model that predicts cell abundances.
    It augments the original (x, y) coordinates with features from the pretrained foundation model.
    """
    def __init__(self, foundation_model, input_dim=2, hidden_dim=64, output_dim=35, freeze_foundation=True):
        super(MainModel, self).__init__()
        self.foundation_model = foundation_model
        if freeze_foundation:
            for param in self.foundation_model.parameters():
                param.requires_grad = False
        # The combined input is (x, y) plus the foundation model's output.
        combined_input_dim = input_dim + output_dim
        self.net = nn.Sequential(
            nn.Linear(combined_input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x):
        with torch.no_grad():
            foundation_features = self.foundation_model(x)
        x_combined = torch.cat([x, foundation_features], dim=1)
        return self.net(x_combined)



## Define Training Functions

In [6]:

def train_foundation_model(model, dataloader, num_epochs=20, lr=0.001, device='cpu'):
    model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss()
    model.train()
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        for coords, targets in dataloader:
            coords, targets = coords.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(coords)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item() * coords.size(0)
        print(f"Foundation Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss/len(dataloader.dataset):.4f}")
    return model

def train_main_model(model, dataloader, num_epochs=20, lr=0.001, device='cpu'):
    model.to(device)
    # Only optimize parameters that require gradients (i.e. main model's parameters).
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
    criterion = nn.MSELoss()
    model.train()
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        for coords, targets in dataloader:
            coords, targets = coords.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(coords)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item() * coords.size(0)
        print(f"Main Model Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss/len(dataloader.dataset):.4f}")
    return model


## Trigger the Training of the Foundation Model

In [7]:

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Train the foundation model.
foundation_model = FoundationModel(input_dim=2, hidden_dim=64, output_dim=35)
print("Training foundation model...")
foundation_model = train_foundation_model(foundation_model, foundation_loader, num_epochs=20, lr=0.001, device=device)

Training foundation model...
Foundation Epoch 1/20, Loss: 287.9302
Foundation Epoch 2/20, Loss: 19.4312
Foundation Epoch 3/20, Loss: 18.4559
Foundation Epoch 4/20, Loss: 18.1584
Foundation Epoch 5/20, Loss: 17.9333
Foundation Epoch 6/20, Loss: 17.7148
Foundation Epoch 7/20, Loss: 17.6489
Foundation Epoch 8/20, Loss: 17.6779
Foundation Epoch 9/20, Loss: 17.4465
Foundation Epoch 10/20, Loss: 17.3919
Foundation Epoch 11/20, Loss: 17.1368
Foundation Epoch 12/20, Loss: 17.3264
Foundation Epoch 13/20, Loss: 16.9236
Foundation Epoch 14/20, Loss: 16.8397
Foundation Epoch 15/20, Loss: 16.7348
Foundation Epoch 16/20, Loss: 16.4314
Foundation Epoch 17/20, Loss: 16.3606
Foundation Epoch 18/20, Loss: 16.1860
Foundation Epoch 19/20, Loss: 15.8413
Foundation Epoch 20/20, Loss: 16.0077


## Trigger the Training of the Main model

In [8]:


# Train the main model using the pretrained foundation model's features.
main_model = MainModel(foundation_model, input_dim=2, hidden_dim=64, output_dim=35, freeze_foundation=True)
print("Training main model...")
main_model = train_main_model(main_model, main_loader, num_epochs=20, lr=0.001, device=device)



Training main model...
Main Model Epoch 1/20, Loss: 61.5980
Main Model Epoch 2/20, Loss: 1.6298
Main Model Epoch 3/20, Loss: 1.5056
Main Model Epoch 4/20, Loss: 1.4920
Main Model Epoch 5/20, Loss: 1.5036
Main Model Epoch 6/20, Loss: 1.5023
Main Model Epoch 7/20, Loss: 1.4978
Main Model Epoch 8/20, Loss: 1.4833
Main Model Epoch 9/20, Loss: 1.4636
Main Model Epoch 10/20, Loss: 1.4849
Main Model Epoch 11/20, Loss: 1.4886
Main Model Epoch 12/20, Loss: 1.4848
Main Model Epoch 13/20, Loss: 1.4875
Main Model Epoch 14/20, Loss: 1.4932
Main Model Epoch 15/20, Loss: 1.4738
Main Model Epoch 16/20, Loss: 1.4801
Main Model Epoch 17/20, Loss: 1.5053
Main Model Epoch 18/20, Loss: 1.4782
Main Model Epoch 19/20, Loss: 1.4898
Main Model Epoch 20/20, Loss: 1.4880


## Write Submission file

In [9]:
# --------------------------
# 7. Inference on Test Data and Submission Creation
# --------------------------
# Load test data from "spots/Test", using slide "S_7"
with h5py.File(h5_file_path, "r") as f:
    test_spots = f["spots/Test"]
    test_df = pd.DataFrame(np.array(test_spots["S_7"]))
    
# For test data, we expect only coordinates.
test_df.columns = ["x", "y", "Test_set"]
test_df = test_df[["x", "y"]]
    
test_coords = test_df[['x','y']].values
test_coords_tensor = torch.tensor(test_coords, dtype=torch.float32).to(device)

main_model.eval()

MainModel(
  (foundation_model): FoundationModel(
    (net): Sequential(
      (0): Linear(in_features=2, out_features=64, bias=True)
      (1): ReLU()
      (2): Linear(in_features=64, out_features=64, bias=True)
      (3): ReLU()
      (4): Linear(in_features=64, out_features=35, bias=True)
    )
  )
  (net): Sequential(
    (0): Linear(in_features=37, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=35, bias=True)
  )
)

In [10]:
with torch.no_grad():
    predictions = main_model(test_coords_tensor)
predictions = predictions.cpu().numpy()

# Create submission DataFrame and save as CSV.
submission_df = pd.DataFrame(predictions, columns=cell_types)
submission_df.insert(0, 'ID', test_df.index)
submission_file = "submission.csv"
submission_df.to_csv(submission_file, index=False)
print(f"Submission file '{submission_file}' created!")

Submission file 'submission.csv' created!


In [11]:
submission_df

Unnamed: 0,ID,C1,C2,C3,C4,C5,C6,C7,C8,C9,...,C26,C27,C28,C29,C30,C31,C32,C33,C34,C35
0,0,0.443814,-0.106649,0.624589,-0.445415,0.926104,-0.091391,-0.130439,-0.092622,0.612996,...,0.280710,0.423029,-0.122155,0.097438,0.233067,0.015489,0.018033,0.068716,-0.086791,0.170541
1,1,0.697913,0.033329,0.949456,-0.163560,0.950780,-0.109424,-0.130879,-0.348345,0.344447,...,0.304455,0.320713,-0.100551,0.204787,0.039071,0.083105,0.088864,0.010034,-0.235753,-0.024494
2,2,1.486667,-0.010109,1.500400,0.089003,0.156977,0.078268,-0.152856,-0.225797,0.623986,...,0.145301,0.512873,0.023388,0.034958,0.026112,0.032808,-0.016921,0.173338,-0.056170,-0.081946
3,3,1.395471,-0.010434,1.306253,0.016643,0.164487,-0.005278,-0.091837,-0.176774,0.419763,...,0.207736,0.506486,0.023310,0.132480,0.010228,-0.045011,0.047634,0.214746,-0.024696,-0.016433
4,4,0.565398,-0.022647,0.716997,-0.260894,0.556178,-0.123664,-0.119875,-0.152182,0.168082,...,0.287131,0.442836,-0.153977,0.257009,-0.167299,0.090447,0.285522,0.074945,-0.122729,0.033879
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2083,2083,1.115021,0.024317,1.279812,-0.029678,0.399016,0.056700,-0.065979,0.102767,0.912682,...,0.219881,0.419946,0.061821,-0.015085,0.191864,-0.009588,-0.098528,0.053446,0.008142,0.082603
2084,2084,0.285380,-0.005729,0.561254,-0.130385,0.433216,0.000371,-0.066507,0.097263,0.629875,...,0.305071,0.452940,-0.057946,0.102565,0.356884,0.107763,0.058468,0.013911,-0.026446,0.126042
2085,2085,0.103151,-0.020869,0.371702,-0.213651,0.578808,-0.095650,-0.099714,0.052289,0.502729,...,0.340517,0.475567,-0.050067,0.141680,0.385241,0.115566,0.093493,-0.011902,-0.031044,0.166546
2086,2086,1.590991,-0.060172,1.502130,-0.029158,0.279232,-0.005167,-0.164078,-0.329623,0.560773,...,0.195260,0.551318,0.039067,0.067388,0.004007,-0.001105,0.020462,0.203567,-0.042509,-0.082369
