In [1]:
import os
import numpy as np
import pandas as pd
import scanpy as sc
import torch

# Standard library imports
import os
import sys
import json
import math
import random
import argparse
from collections import defaultdict
from pathlib import Path
import datetime
from typing import Dict, List, Tuple, Set, Union, Optional

# Third-party imports
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, ConcatDataset, TensorDataset
import torchvision
from torchvision import transforms
import wandb
import scanpy as sc
from tqdm import tqdm
import sklearn.model_selection
from PIL import Image
import seaborn as sns
from numba import njit, prange
from scipy.stats import wasserstein_distance
from scipy.spatial import cKDTree
import tangram as tg
import imageio.v3 as iio
import cv2
import scgpt
import timm
from einops import rearrange
from torch import einsum
import torch.nn.utils as U

from schaf_method import *

os.environ["CUDA_VISIBLE_DEVICES"] = '7'  # replace as needed 

# Configure system settings
Image.MAX_IMAGE_PIXELS = 933120000  # Allow loading large images
DEVICE = torch.device("cuda:7")  # replace as needed 
device = DEVICE
NUM_WORKERS = 6 if torch.cuda.is_available() else 2
PIN_MEMORY = torch.cuda.is_available()
os.environ["OPENCV_IO_MAX_IMAGE_PIXELS"] = pow(2,40).__str__()


In [2]:

# Create example single-cell data
n_cells = 1000
n_genes = 2000
expression_matrix = np.random.negative_binomial(5, 0.3, size=(n_cells, n_genes))
adata_sc = sc.AnnData(X=expression_matrix)
adata_sc.obs.index = [f"cell_{i}" for i in range(n_cells)]
adata_sc.var.index = [f"gene_{i}" for i in range(n_genes)]
adata_sc.obs['cluster'] = np.random.choice([0, 1, 2, 3], n_cells)


# Create example spatial transcriptomics data
n_cells = 500
x_coords = np.random.uniform(0, 1000, n_cells)
y_coords = np.random.uniform(0, 1000, n_cells)
spatial_genes = adata_sc.var.index[:1500]  # Use subset of genes
spatial_expression = np.random.negative_binomial(3, 0.4, size=(n_cells, len(spatial_genes)))
adata_st = sc.AnnData(X=spatial_expression)
adata_st.obs.index = [f"spatial_cell_{i}" for i in range(n_cells)]
adata_st.var.index = spatial_genes
adata_st.obsm['spatial'] = np.column_stack([x_coords, y_coords])
adata_st.obs['fold'] = np.random.choice([0, 1, 2, 3], n_cells)
adata_st.obs['x'] = x_coords
adata_st.obs['y'] = y_coords

# Create a dummy H&E image
he_image = np.random.randint(0, 255, size=(1000, 1000, 3), dtype=np.uint8)

# he_image, adata_st, adata_sc
# do tangram 
st_for_tang = adata_st.copy()
sc_for_tang = adata_sc.copy()

sc.pp.filter_cells(sc_for_tang, min_genes=8)
sc.pp.filter_genes(sc_for_tang, min_cells=1)
sc.pp.normalize_total(sc_for_tang)
sc.pp.log1p(sc_for_tang)

# Process spatial data
sc.pp.filter_cells(st_for_tang, min_genes=8)
sc.pp.filter_genes(st_for_tang, min_cells=1)
sc.pp.normalize_total(st_for_tang)
sc.pp.log1p(st_for_tang)

# Find common genes
common_genes = np.intersect1d(sc_for_tang.var.index, st_for_tang.var.index)
sc_for_tang = sc_for_tang[:, common_genes]
st_for_tang = st_for_tang[:, common_genes]
tg.pp_adatas(sc_for_tang, st_for_tang)
ad_map = tg.map_cells_to_space(
    adata_sc=sc_for_tang,
    adata_sp=st_for_tang,
    device=DEVICE
)
projected_adata = tg.project_genes(ad_map, adata_sc)

projected_adata.X = projected_adata.X / 10.
he_image = he_image / 255.
sc.pp.log1p(adata_st)

# he_image, adata_st, projected_adata
# samples correspond to zones in he_image
HOLD_OUT_FOLD = 1 
train_st = adata_st[adata_st.obs['fold']!=HOLD_OUT_FOLD] 
test_st = adata_st[adata_st.obs['fold']==HOLD_OUT_FOLD]  
train_projected = projected_adata[projected_adata.obs['fold']!=HOLD_OUT_FOLD]  
test_projected = projected_adata[projected_adata.obs['fold']==HOLD_OUT_FOLD]
st_mean = train_st.X.mean(axis=0) 
st_std = train_st.X.std(axis=0)  
proj_mean = train_projected.X.mean(axis=0) 
proj_std =  train_projected.X.std(axis=0) 
train_st.X = (train_st.X - st_mean) / st_std
train_projected.X = (train_projected.X - proj_mean) / proj_std

train_stage1_dl = DataLoader(
    ImageDataset(
    he_image=he_image,
        adata=train_st,
        tile_radius=10,
        indices=list(range(train_st.shape[0])),
    ),batch_size=64,
    shuffle=1,
    num_workers=1,
    pin_memory=1
)

train_stage2_dl = DataLoader(
    ImageDataset(
        he_image=he_image,
        adata=train_projected,
        tile_radius=10,
        indices=list(range(train_projected.shape[0])),
    ),batch_size=64,
    shuffle=1,
    num_workers=1,
    pin_memory=1
)

test_dl = DataLoader(
    ImageDataset(
        he_image=he_image,
        adata=test_st,
        tile_radius=10,
        indices=list(range(test_st.shape[0])),
   ),batch_size=64,
    shuffle=1,
    num_workers=1,
    pin_memory=1
)

INFO:root:1500 training genes are saved in `uns``training_genes` of both single cell and spatial Anndatas.
INFO:root:1500 overlapped genes are saved in `uns``overlap_genes` of both single cell and spatial Anndatas.
INFO:root:uniform based density prior is calculated and saved in `obs``uniform_density` of the spatial Anndata.
INFO:root:rna count based density prior is calculated and saved in `obs``rna_count_based_density` of the spatial Anndata.
INFO:root:Allocate tensors for mapping.
INFO:root:Begin training with 1500 genes and rna_count_based density_prior in cells mode...
INFO:root:Printing scores every 100 epochs.


Score: 0.916, KL reg: 0.001
Score: 0.918, KL reg: 0.000
Score: 0.920, KL reg: 0.000
Score: 0.921, KL reg: 0.000
Score: 0.921, KL reg: 0.000
Score: 0.921, KL reg: 0.000
Score: 0.921, KL reg: 0.000
Score: 0.921, KL reg: 0.000
Score: 0.921, KL reg: 0.000
Score: 0.921, KL reg: 0.000


INFO:root:Saving results..


In [3]:
# train 

stage1_model = MerNet(
    num_genes=train_st.shape[1],
    model_name='xj_transformer',
    pretrained=True,
    num_hidden_layers=1,
    pretrain_hist=True,
    pretrain_st='NONE'
).to(device)
criterion = nn.MSELoss()
lr = 1e-4
stage1_optimizer = optim.Adam(stage1_model.parameters(), lr=lr)

NUM_EPOCHS = 1
for epoch in range(NUM_EPOCHS):
    # Training phase
    stage1_model.train()
    num_batches = 0
    
    for images, transcripts in tqdm(train_stage1_dl,):
        images = images.to(device)
        transcripts = transcripts.to(device)
        
        stage1_optimizer.zero_grad()
        outputs = stage1_model(images)
        loss = criterion(outputs, transcripts)
        loss.backward()
        stage1_optimizer.step()
        

stage2_model = MerNet(
    num_genes=train_projected.shape[1],
    model_name='xj_transformer',
    pretrained=True,
    num_hidden_layers=3,
    pretrain_hist=True,
    pretrain_st='NONE'
)
stage2_model.part_one = stage1_model.part_one
# load the weights in 

stage2_model = stage2_model.to(device)
criterion = nn.MSELoss()
lr = 1e-4
stage2_optimizer = optim.Adam(stage2_model.parameters(), lr=lr)

NUM_EPOCHS = 1
for epoch in range(NUM_EPOCHS):
    # Training phase
    stage2_model.train()
    num_batches = 0
    
    for images, transcripts in tqdm(train_stage2_dl,):
        images = images.to(device)
        transcripts = transcripts.to(device)
        
        stage2_optimizer.zero_grad()
        outputs = stage2_model(images)
        loss = criterion(outputs, transcripts)
        loss.backward()
        stage2_optimizer.step()
        

 


100%|██████████| 6/6 [00:11<00:00,  1.86s/it]
100%|██████████| 6/6 [00:10<00:00,  1.81s/it]


In [13]:
# infer 
inferred_transcripts = []
with torch.no_grad():
    for images, transcripts in test_dl:
        images = images.to(device)
        transcripts = transcripts.to(device)
        
        outputs = stage2_model(images)
        inferred_transcripts.extend(outputs.cpu().detach().numpy())
inferred_adata = sc.AnnData(X=np.array(inferred_transcripts))
inferred_adata.obs = test_st.obs
