# Reproduction of isoST-i working on reduced mouse brain dataset and imageing data

This is a reproduction of isoST on MERFISH mouse brain slices and imaging data

Bohan Li @ Deng ai Lab @ BUAA 2025.

Software provided as is under MIT License.

## Step 1: Import required libraries and modules

In [1]:
import sys
import os
import os

project_root = os.path.abspath(os.path.join(os.getcwd(), '../../'))
sys.path.append(project_root)


from utils.train_ode import biaxial_train  # custom training function
from utils.inference import fine_inference  # custom inference function
import torch
import numpy as np
import yaml
import time

## Step 2: Set random seed for reproducibility

In [2]:
import random
def seed_all(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


seed_all(0)

## Step 3: Define list of raw slide names and their corresponding identifiers for training.

Experiment: Mouse brain isotropic volume reconstruction.

Inputs:
   - 8 serial tissue slices, ordered by their z-axis positions (ascending or descending is acceptable, 
     but the order must be consistent).

Outputs:
   - Reconstructed isotropic 3D volume, with each voxel having physical dimensions of 
     0.05 mm × 0.05 mm × 0.05 mm.

Features:
   - Min–max normalized top 50 principal components (PCs) computed from the expression profiles 
     of all 1,122 measured genes.

In [3]:
slide_names_all = ['Zhuang-ABCA-3.001', 'Zhuang-ABCA-3.002', 'Zhuang-ABCA-3.003', 'Zhuang-ABCA-3.004',
                'Zhuang-ABCA-3.005', 'Zhuang-ABCA-3.006', 'Zhuang-ABCA-3.007', 'Zhuang-ABCA-3.008',
                'Zhuang-ABCA-3.009', 'Zhuang-ABCA-3.010', 'Zhuang-ABCA-3.011', 'Zhuang-ABCA-3.012',
                'Zhuang-ABCA-3.013', 'Zhuang-ABCA-3.015', 'Zhuang-ABCA-3.016', 'Zhuang-ABCA-3.017',
                'Zhuang-ABCA-3.019', 'Zhuang-ABCA-3.020', 'Zhuang-ABCA-3.021', 'Zhuang-ABCA-3.022',
                'Zhuang-ABCA-3.023', 'Zhuang-ABCA-3.024']

dim = 50
kk = 3   # Skip step

slide_names_ = slide_names_all[::kk]
slide_names = [f'{name}_log_PC' for name in slide_names_]

len(slide_names)

8

## Step 4: Set project name and path to preprocessed data

In [4]:
proj = 'zhuang/zhuang_ABCA_3/zscore_PC50_minmax'
model = 'isoST_i'
batch_num = 16  # we only take 20% of the original data as demostration
data_dir = f'/home/lbh/projects_dir/3DProject/{proj}/1_of_{batch_num}_normPC_1'

## Step 5: Load training configuration file

In [5]:
config_file = config_file = os.path.join(project_root, 'config', 'img_reg.yml')
with open(config_file, 'r') as file:
    config = yaml.safe_load(file)

dd = config['params']['delta_d']
dd  # the size of step (Δz)

0.01

In [6]:
config

{'params': {'K': 8,
  '_lambda_1': 1000,
  '_lambda_2': 0.0001,
  'alpha': 0.1,
  'beta_end_value': 0.001,
  'beta_n_iterations': 8000,
  'beta_start_iteration': 8000,
  'beta_start_value': 1,
  'delta_d': 0.01,
  'gene_dim': 50,
  'head_num': 1,
  'hidden_dim': 64,
  'image_data_dir': 'data/CCFv3_feature',
  'lr': 0.001,
  'method': 'euler',
  'optimizer_name': 'NAdam',
  'scale_z': 1,
  'slice_data_dir': 'data/zhuang_ABCA_3',
  'slice_width': 0.2,
  'spacing': [0.01, 0.01, 0.01],
  'std_seq': 0.1,
  'std_x': 0.01,
  'std_y': 0.01,
  'std_z': 0.057,
  'stride': 1,
  'template_sample_rate': 0.125,
  'warm_up_rate': 0.01,
  'weight_decay': 1e-08},
 'trainer': 'IsoSTImageReg'}

## Step 6: Set training hyperparameters

In [7]:
device = 'cuda:3' # device id
checkpoint_every = 20
backup_every = 5
epochs = [50, 50, 50] # epoch setting for 3 stages
mode = 'joint' # optimize both shape and expression

## Step 7: Create experiment and result directories

In [8]:
experiment_dir = f'experiments'
result_dir = f'result'
if not os.path.exists(result_dir):
    os.makedirs(result_dir)

## Step 8: Start training using biaxial_train (~8h)

In [None]:
biaxial_train(experiment_dir=experiment_dir,
              data_dir=data_dir,
              slide_names=slide_names,
              batch_num=16,
              config_file=config_file,
              device=device,
              checkpoint_every=checkpoint_every,
              backup_every=backup_every,
              epoch=epochs,
              mode=mode)

In [9]:
experiment_dir

'experiments'

## Step 9: Run inference on the full data using the trained model

In [9]:
total_data_dir = f'/home/lbh/projects_dir/3DProject/{proj}/1_of_1_normPC_1/'  # we inferece in the total dataset
defined_d = dd
fine_inference(experiment_dir,
               total_data_dir,
               slide_names,
               mode,
               defined_d,
               result_dir,
               batch_num,
               device)

Pretrained Model Loaded!


100%|██████████| 8/8 [00:00<00:00, 151.01it/s]
  adj = torch.sparse_csr_tensor(
100%|██████████| 7/7 [04:28<00:00, 38.34s/it]

Done





## Step 10: Postprocess the inference result into 3D volume


In [10]:
import pandas as pd
from utils.postprocess import VolumeProcessor  # custom postprocessing class

gene = pd.read_csv("/home/lbh/projects_dir/3DProject/zhuang/zhuang_ABCA_3/gene.csv",index_col=0)
data_dir = "/home/lbh/projects_dir/3DProject/zhuang/zhuang_ABCA_3"
processor = VolumeProcessor(
    data_dir=data_dir,
    result_dir=result_dir,
    volume_size=(1.2, 0.6, 0.4),  # 10 mm * 8 mm * 5 mm
    gene_list=gene['gene_symbol'].tolist(),
    max_lence=264 # the longest length of the volume
)

## Step 11: Convert results into 3D volume and export

In [11]:
volume, count = processor.result_to_volume(n_features=50, swamp=False) # 20 means 1/20 mm
pc_df = processor.volume_to_df(volume)

np.save(f"{result_dir}/volume.npy", volume)
np.save(f"{result_dir}/density.npy", count)
pc_df.to_csv(f"{result_dir}/pc_volume.csv")
volume.shape

Starting load_result...


Loading inferred results: 100%|██████████| 442/442 [00:15<00:00, 29.46it/s]


Finished load_result.
Starting scatter_to_volume...


Processing slices: 100%|██████████| 442/442 [06:39<00:00,  1.11it/s]


Finished scatter_to_volume.
Starting volume_to_df...
Finished volume_to_df.


(317, 159, 106, 50)

## Step 12: Mapping to Gene Expression

In [12]:
import joblib

def load_model(model_path):
    pca_model = joblib.load(model_path)
    return pca_model

model_path = "/home/lbh/projects_dir/3DProject/zhuang/zhuang_ABCA_3/zscore_PC50_minmax/zscore_pc_model.pkl"
pc_model = load_model(model_path)
print("PCA Model loaded successfully!")

PCA Model loaded successfully!


In [13]:
import numpy as np
volume = np.load(f'{result_dir}/volume.npy')
processor.pc_to_expression(volume, pc_model, 264)

Starting pc_to_expression...
Starting volume_to_df...
Finished volume_to_df.
Finished pc_to_expression.


'result/log2_expr_264_all_pc.parquet'

In [15]:
import pyarrow as pa
import pyarrow.parquet as pq

# Specify the output path where the data was saved
output_path = f"{result_dir}/log2_expr_264_all_pc.parquet"

# Read the Parquet file and convert it back to a Pandas DataFrame
table = pq.read_table(output_path)
predictions = table.to_pandas()

# Optionally, convert the data back to 'float32' if needed
predictions = predictions.astype('float32')

In [16]:
gene_roi = ['Sv2b', 'Hs3st4', 'Ppp1r1b','C1ql2', 
            'Slc17a6', 'Pvalb',  'Cbln1', 
            'Sox10', 'Frzb', 'Meis2']
predictions_roi = predictions[['x', 'y', 'z'] + gene_roi]
predictions_roi

Unnamed: 0,x,y,z,Sv2b,Hs3st4,Ppp1r1b,C1ql2,Slc17a6,Pvalb,Cbln1,Sox10,Frzb,Meis2
0,0.0,58.0,84.0,-0.513949,-0.577605,-0.335458,0.011921,-0.096269,-0.329574,-0.525311,-0.095616,-0.540668,0.436061
1,0.0,59.0,81.0,-0.521150,-0.501448,-0.308009,-0.049255,0.324807,0.572786,-0.492219,-0.142330,-0.378186,0.125802
2,0.0,59.0,82.0,-0.495607,-0.547312,-0.297352,-0.065462,0.347662,0.565812,-0.457016,-0.144736,-0.335664,0.164168
3,0.0,59.0,84.0,-0.522623,-0.526652,-0.287416,0.013507,-0.086913,-0.273510,-0.541314,-0.100054,-0.521169,0.469179
4,0.0,59.0,86.0,-0.504508,-0.459823,-0.221050,0.008126,0.065121,-0.246580,-0.615679,0.080057,-0.482968,0.417330
...,...,...,...,...,...,...,...,...,...,...,...,...,...
2072749,271.0,132.0,70.0,-0.370281,-0.130955,-0.344442,0.744877,2.114329,-0.370300,0.355134,-0.030697,-0.583163,0.361622
2072750,271.0,133.0,70.0,-0.579651,-0.499265,0.051818,-0.057371,0.080423,0.120076,-0.060299,-0.410379,-1.032582,-0.896028
2072751,271.0,133.0,71.0,-0.601567,-0.540597,0.021636,-0.025346,0.128069,0.164613,-0.045016,-0.450622,-1.067313,-0.916806
2072752,272.0,122.0,69.0,-0.852956,-0.944329,-0.194449,0.331559,0.476395,0.126190,-0.119669,-0.139450,-1.311114,0.116500


In [17]:
predictions_roi.to_csv(f'{result_dir}/pred_roi.csv')

## Step 13: 3D Visulization 

In [18]:
# Define color scales
zero_value_color = '#0000FF'
zero_value_color = '#440154'
custom_colorscales = [
    [[0, zero_value_color], [1, '#FF0000']],  # Red
    [[0, zero_value_color], [1, '#00FF00']],  # Green
    [[0, zero_value_color], [1, '#0000FF']],  # Blue

    [[0, zero_value_color], [1, '#FFFFFF']],  # White
    [[0, zero_value_color], [1, '#FFA500']],  # Orange
    [[0, zero_value_color], [1, '#FFFF00']],  # Yellow

    [[0, zero_value_color], [1, '#FF4500']],  # Orange Red
    [[0, zero_value_color], [1, '#FF00FF']],  # Magenta
    [[0, zero_value_color], [1, '#FF69B4']],  # Pink

    [[0, zero_value_color], [1, '#00FF7F']],  # Spring Green
    [[0, zero_value_color], [1, '#FFD700']],  # Gold
    [[0, zero_value_color], [1, '#ADFF2F']],  # Green Yellow
    
    [[0, zero_value_color], [1, '#800080']],  # Purple
    [[0, zero_value_color], [1, '#00FFFF']],  # Cyan
    [[0, zero_value_color], [1, '#DA70D6']],  # Orchid
    [[0, zero_value_color], [1, '#7B68EE']],  # Medium Slate Blue
]
color_dic = {gene_roi[i]:custom_colorscales[i] for i in range(len(gene_roi))}
color_dic

{'Sv2b': [[0, '#440154'], [1, '#FF0000']],
 'Hs3st4': [[0, '#440154'], [1, '#00FF00']],
 'Ppp1r1b': [[0, '#440154'], [1, '#0000FF']],
 'C1ql2': [[0, '#440154'], [1, '#FFFFFF']],
 'Slc17a6': [[0, '#440154'], [1, '#FFA500']],
 'Pvalb': [[0, '#440154'], [1, '#FFFF00']],
 'Cbln1': [[0, '#440154'], [1, '#FF4500']],
 'Sox10': [[0, '#440154'], [1, '#FF00FF']],
 'Frzb': [[0, '#440154'], [1, '#FF69B4']],
 'Meis2': [[0, '#440154'], [1, '#00FF7F']]}

In [19]:
# Sampled 1% for show
predictions_roi_sampled = predictions_roi.sample(frac=0.01)

In [20]:
from utils.ploting import plot_multi_gene_3d

r = 1.2  # camare distance
fig_ii = plot_multi_gene_3d(predictions_roi_sampled, gene_roi, threathold=0, noise_std=0.1,
                            gene_colorscales=custom_colorscales, custom_colorscales_dict=color_dic,
                            opacity=0.9, cmin=-2, cmax=2, size=0.5)
fig_ii.update_layout(
    # Figure size in pixels
    width=800,
    height=800,
    # Background colors
    paper_bgcolor='#080808',  # Outer background (outside the plotting area)
    plot_bgcolor='#080808',   # Plotting area background
    # Mouse interaction mode for 3D plots
    scene_dragmode='orbit',   # Allows free rotation of the scene
    scene=dict(
        # Keep aspect ratio consistent with the data scale
        aspectmode='data',
        # Hide tick labels, grid lines, and axis lines for a clean view
        xaxis=dict(showticklabels=False, showgrid=False, visible=False),
        yaxis=dict(showticklabels=False, showgrid=False, visible=False),
        zaxis=dict(showticklabels=False, showgrid=False, visible=False),
        # Set camera parameters for initial view
        camera=dict(
            eye=dict(  # Camera position relative to the scene center
                x=r * (-3 / 2),  # Move left along x-axis
                y=r * (-2 / 2),  # Move backward along y-axis
                z=r * (-2 / 2)   # Move downward along z-axis
            ),
            up=dict(x=0, y=-1, z=0),       # Define the upward direction for the camera
            center=dict(x=0, y=0, z=0),    # Look-at point (center of the scene)
        )
    ),
)

fig_ii.show()