# Reproduction of isoST working on mouse brain dataset

This is a reproduction of isoST on MERFISH mouse brain slices.

Bohan Li @ Deng ai Lab @ BUAA 2025.

Software provided as is under MIT License.

## Step 1: Import required libraries and modules

In [9]:
import sys
import os
import os

project_root = os.path.abspath(os.path.join(os.getcwd(), '../../'))
sys.path.append(project_root)


from utils.train_ode import biaxial_train  # custom training function
from utils.inference import fine_inference  # custom inference function
import torch
import numpy as np
import yaml
import time

## Step 2: Set random seed for reproducibility

In [3]:
import random
def seed_all(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


seed_all(0)

## Step 3: Define list of raw slide names and their corresponding identifiers for training.

Experiment: Mouse brain isotropic volume reconstruction.

Inputs:
   - 54 serial tissue slices, ordered by their z-axis positions (ascending or descending is acceptable, 
     but the order must be consistent).

Outputs:
   - Reconstructed isotropic 3D volume, with each voxel having physical dimensions of 
     0.05 mm × 0.05 mm × 0.05 mm.

Features:
   - Min–max normalized top 50 principal components (PCs) computed from the expression profiles 
     of all 1,122 measured genes.

In [4]:
slide_names_ = ['Zhuang-ABCA-2.004', 'Zhuang-ABCA-2.005', 'Zhuang-ABCA-2.006',
   'Zhuang-ABCA-2.007', 'Zhuang-ABCA-2.008', 'Zhuang-ABCA-2.009',
   'Zhuang-ABCA-2.010', 'Zhuang-ABCA-2.011', 'Zhuang-ABCA-2.012',
   'Zhuang-ABCA-2.013', 'Zhuang-ABCA-2.014', 'Zhuang-ABCA-2.015',
   'Zhuang-ABCA-2.016', 'Zhuang-ABCA-2.017', 'Zhuang-ABCA-2.018',
   'Zhuang-ABCA-2.019', 'Zhuang-ABCA-2.020', 'Zhuang-ABCA-2.021',
   'Zhuang-ABCA-2.022', 'Zhuang-ABCA-2.023', 'Zhuang-ABCA-2.025',
   'Zhuang-ABCA-2.026', 'Zhuang-ABCA-2.027', 'Zhuang-ABCA-2.028',
   'Zhuang-ABCA-2.030', 'Zhuang-ABCA-2.031', 'Zhuang-ABCA-2.032',
   'Zhuang-ABCA-2.033', 'Zhuang-ABCA-2.034', 'Zhuang-ABCA-2.035',
   'Zhuang-ABCA-2.036', 'Zhuang-ABCA-2.037', 'Zhuang-ABCA-2.039',
   'Zhuang-ABCA-2.040', 'Zhuang-ABCA-2.041', 'Zhuang-ABCA-2.042',
   'Zhuang-ABCA-2.044', 'Zhuang-ABCA-2.045', 'Zhuang-ABCA-2.046',
   'Zhuang-ABCA-2.047', 'Zhuang-ABCA-2.048', 'Zhuang-ABCA-2.049',
   'Zhuang-ABCA-2.050', 'Zhuang-ABCA-2.051', 'Zhuang-ABCA-2.052',
   'Zhuang-ABCA-2.053', 'Zhuang-ABCA-2.054', 'Zhuang-ABCA-2.055',
   'Zhuang-ABCA-2.056', 'Zhuang-ABCA-2.057', 'Zhuang-ABCA-2.058',
   'Zhuang-ABCA-2.059', 'Zhuang-ABCA-2.060', 'Zhuang-ABCA-2.061']

dim = 50
slide_names = [f'{name}_log_PC' for name in slide_names_]

len(slide_names_)

54

## Step 4: Set project name and path to preprocessed data

In [5]:
proj = f'zhuang/zhuang_ABCA_2/zscore_PC{dim}_minmax'
model = 'isoST'
batch_num = 5  # we only take 20% of the original data as demostration
data_dir = f'/home/lbh/projects_dir/3DProject/{proj}/1_of_{batch_num}_normPC_1'

## Step 5: Load training configuration file

In [6]:
config_file = config_file = os.path.join(project_root, 'config', 'mouse_brain.yml')
with open(config_file, 'r') as file:
    config = yaml.safe_load(file)

dd = config['params']['delta_d']
dd  # the size of step (Δz)

0.01

In [7]:
config

{'params': {'K': 8,
  'alpha': 0.1,
  'beta_end_value': 0.05,
  'beta_n_iterations': 50,
  'beta_start_iteration': 50,
  'beta_start_value': 1,
  'delta_d': 0.01,
  'dual': True,
  'gene_dim': 50,
  'head_num': 1,
  'hidden_dim': 64,
  'lr': 0.001,
  'method': 'euler',
  'optimizer_name': 'NAdam',
  'std_seq': 0.1,
  'std_x': 0.01,
  'std_y': 0.01,
  'std_z': 0.1,
  'stride': 1,
  'warm_up_rate': 1,
  'weight_decay': 1e-08},
 'trainer': 'IsoST'}

## Step 6: Set training hyperparameters

In [8]:
device = 'cuda:1' # device id
checkpoint_every = 20
backup_every = 5
epochs = [100, 100, 100] # epoch setting for 3 stages
mode = 'joint' # optimize both shape and expression

## Step 7: Create experiment and result directories

In [11]:
experiment_dir = f'experiments'
result_dir = f'result'
if not os.path.exists(result_dir):
    os.makedirs(result_dir)

## Step 8: Start training using biaxial_train (~6h)

In [None]:
biaxial_train(experiment_dir=experiment_dir,
              data_dir=data_dir,
              slide_names=slide_names,
              batch_num=1,
              config_file=config_file,
              device=device,
              checkpoint_every=checkpoint_every,
              backup_every=backup_every,
              epoch=epochs,
              mode=mode)

In [10]:
experiment_dir

'experiments/zhuang/zhuang_ABCA_2/zscore_PC50_minmax_isoST'

## Step 9: Run inference on the full data using the trained model

In [19]:
total_data_dir = f'/home/lbh/projects_dir/3DProject/{proj}/1_of_1_normPC_1/'  # we inferece in the total dataset
defined_d = dd
fine_inference(experiment_dir,
               total_data_dir,
               slide_names,
               mode,
               defined_d,
               result_dir,
               batch_num,
               device)

Pretrained Model Loaded!


100%|██████████| 54/54 [00:00<00:00, 443.21it/s]
100%|██████████| 53/53 [02:40<00:00,  3.02s/it]

Done





## Step 10: Postprocess the inference result into 3D volume


In [12]:
import pandas as pd
from utils.postprocess import VolumeProcessor  # custom postprocessing class

gene = pd.read_csv("/home/lbh/projects_dir/3DProject/zhuang/zhuang_ABCA_2/gene.csv",index_col=0)
data_dir = "/home/lbh/projects_dir/3DProject/zhuang/zhuang_ABCA_2"
processor = VolumeProcessor(
    data_dir=data_dir,
    result_dir=result_dir,
    volume_size=(1.0, 0.8, 0.5),  # 10 mm * 8 mm * 5 mm
    gene_list=gene['gene_symbol'].tolist(),
    max_lence=220 # the longest length of the volume
)

## Step 11: Convert results into 3D volume and export

In [12]:
volume, count = processor.result_to_volume(n_features=50, swamp=True) # 20 means 1/20 mm
pc_df = processor.volume_to_df(volume)

np.save(f"{result_dir}/volume.npy", volume)
np.save(f"{result_dir}/density.npy", count)
pc_df.to_csv(f"{result_dir}/pc_volume.csv")
volume.shape

Starting load_result...


Loading inferred results: 100%|██████████| 1074/1074 [00:09<00:00, 118.60it/s]


Finished load_result.
Starting scatter_to_volume...


Processing slices: 100%|██████████| 1074/1074 [04:02<00:00,  4.42it/s]


Finished scatter_to_volume.
Starting volume_to_df...
Finished volume_to_df.


(220, 177, 110, 50)

## Step 12: Mapping to Gene Expression

In [13]:
import joblib

def load_model(model_path):
    pca_model = joblib.load(model_path)
    return pca_model

model_path = "/home/lbh/projects_dir/3DProject/zhuang/zhuang_ABCA_2/zscore_PC50_minmax/zscore_pc_model.pkl"
pc_model = load_model(model_path)
print("PCA Model loaded successfully!")

PCA Model loaded successfully!


In [15]:
import numpy as np
volume = np.load(f'{result_dir}/volume.npy')
processor.pc_to_expression(volume, pc_model, 220)

Starting pc_to_expression...
Starting volume_to_df...
Finished volume_to_df.
Finished pc_to_expression.


'result/log2_expr_220_all_pc.parquet'

In [16]:
import pyarrow as pa
import pyarrow.parquet as pq

# Specify the output path where the data was saved
output_path = f"{result_dir}/log2_expr_220_all_pc.parquet"

# Read the Parquet file and convert it back to a Pandas DataFrame
table = pq.read_table(output_path)
predictions = table.to_pandas()

# Optionally, convert the data back to 'float32' if needed
predictions = predictions.astype('float32')

In [17]:
gene_roi = ['Sv2b', 'Hs3st4', 'Ppp1r1b','C1ql2', 
            'Slc17a6', 'Pvalb',  'Cbln1', 
            'Sox10', 'Frzb', 'Meis2']
predictions_roi = predictions[['x', 'y', 'z'] + gene_roi]
predictions_roi

Unnamed: 0,x,y,z,Sv2b,Hs3st4,Ppp1r1b,C1ql2,Slc17a6,Pvalb,Cbln1,Sox10,Frzb,Meis2
0,0.0,54.0,94.0,-0.472952,0.135905,-0.486873,-0.059412,0.602690,-0.559457,0.022015,-0.555567,0.427110,0.905784
1,0.0,55.0,93.0,-0.380799,0.127583,-0.566213,0.070543,0.461925,-0.046529,0.230294,-0.426596,0.002191,1.250615
2,0.0,55.0,94.0,-0.430811,0.126502,-0.532659,-0.013509,0.513531,-0.403391,0.158040,-0.492902,0.236866,1.093434
3,0.0,56.0,82.0,-0.717183,-1.095525,-0.249123,-0.141648,0.797033,-0.473399,0.899103,-0.209580,-0.279727,1.490349
4,0.0,56.0,85.0,-0.385270,-0.129690,-0.819433,-0.011357,0.119015,-0.411119,0.383012,-0.456342,-0.177939,1.451686
...,...,...,...,...,...,...,...,...,...,...,...,...,...
2027086,219.0,147.0,93.0,-0.798851,-0.293472,-0.824903,-0.071001,0.482426,-0.208755,0.298967,-0.393051,0.286302,0.013817
2027087,219.0,147.0,95.0,-0.947453,-0.863884,-0.958538,0.052395,0.477038,0.247565,0.366447,-0.717290,-0.086774,-0.060893
2027088,219.0,147.0,100.0,-0.430277,-0.330865,-0.512455,0.150796,0.425342,-0.347921,0.464912,0.814735,-0.017342,-0.053766
2027089,219.0,148.0,94.0,-0.843477,-0.761068,-0.987156,0.092908,0.375202,0.222973,0.326311,-0.647367,-0.007662,-0.168717


In [18]:
predictions_roi.to_csv(f'{result_dir}/pred_roi.csv')

## Step 13: 3D Visulization 

In [19]:
# Define color scales
zero_value_color = '#0000FF'
zero_value_color = '#440154'
custom_colorscales = [
    [[0, zero_value_color], [1, '#FF0000']],  # Red
    [[0, zero_value_color], [1, '#00FF00']],  # Green
    [[0, zero_value_color], [1, '#0000FF']],  # Blue

    [[0, zero_value_color], [1, '#FFFFFF']],  # White
    [[0, zero_value_color], [1, '#FFA500']],  # Orange
    [[0, zero_value_color], [1, '#FFFF00']],  # Yellow

    [[0, zero_value_color], [1, '#FF4500']],  # Orange Red
    [[0, zero_value_color], [1, '#FF00FF']],  # Magenta
    [[0, zero_value_color], [1, '#FF69B4']],  # Pink

    [[0, zero_value_color], [1, '#00FF7F']],  # Spring Green
    [[0, zero_value_color], [1, '#FFD700']],  # Gold
    [[0, zero_value_color], [1, '#ADFF2F']],  # Green Yellow
    
    [[0, zero_value_color], [1, '#800080']],  # Purple
    [[0, zero_value_color], [1, '#00FFFF']],  # Cyan
    [[0, zero_value_color], [1, '#DA70D6']],  # Orchid
    [[0, zero_value_color], [1, '#7B68EE']],  # Medium Slate Blue
]
color_dic = {gene_roi[i]:custom_colorscales[i] for i in range(len(gene_roi))}
color_dic

{'Sv2b': [[0, '#440154'], [1, '#FF0000']],
 'Hs3st4': [[0, '#440154'], [1, '#00FF00']],
 'Ppp1r1b': [[0, '#440154'], [1, '#0000FF']],
 'C1ql2': [[0, '#440154'], [1, '#FFFFFF']],
 'Slc17a6': [[0, '#440154'], [1, '#FFA500']],
 'Pvalb': [[0, '#440154'], [1, '#FFFF00']],
 'Cbln1': [[0, '#440154'], [1, '#FF4500']],
 'Sox10': [[0, '#440154'], [1, '#FF00FF']],
 'Frzb': [[0, '#440154'], [1, '#FF69B4']],
 'Meis2': [[0, '#440154'], [1, '#00FF7F']]}

In [20]:
# Sampled 1% for show
predictions_roi_sampled = predictions_roi.sample(frac=0.01)

In [21]:
from utils.ploting import plot_multi_gene_3d

r = 1.2  # camare distance
fig_ii = plot_multi_gene_3d(predictions_roi_sampled, gene_roi, threathold=0, noise_std=0.1,
                            gene_colorscales=custom_colorscales, custom_colorscales_dict=color_dic,
                            opacity=0.9, cmin=-2, cmax=2, size=0.5)
fig_ii.update_layout(
    # Figure size in pixels
    width=800,
    height=800,
    # Background colors
    paper_bgcolor='#080808',  # Outer background (outside the plotting area)
    plot_bgcolor='#080808',   # Plotting area background
    # Mouse interaction mode for 3D plots
    scene_dragmode='orbit',   # Allows free rotation of the scene
    scene=dict(
        # Keep aspect ratio consistent with the data scale
        aspectmode='data',
        # Hide tick labels, grid lines, and axis lines for a clean view
        xaxis=dict(showticklabels=False, showgrid=False, visible=False),
        yaxis=dict(showticklabels=False, showgrid=False, visible=False),
        zaxis=dict(showticklabels=False, showgrid=False, visible=False),
        # Set camera parameters for initial view
        camera=dict(
            eye=dict(  # Camera position relative to the scene center
                x=r * (-3 / 2),  # Move left along x-axis
                y=r * (-2 / 2),  # Move backward along y-axis
                z=r * (-2 / 2)   # Move downward along z-axis
            ),
            up=dict(x=0, y=-1, z=0),       # Define the upward direction for the camera
            center=dict(x=0, y=0, z=0),    # Look-at point (center of the scene)
        )
    ),
)

fig_ii.show()