# Tutorial 1a: Generate hard-links for consecutive ST slices

### loading packages

In [1]:
import scipy
import os
import pickle
import sys

# Get the parent directory of the current script
parent_dir = os.path.abspath(os.path.join(os.getcwd(), os.pardir))

# Add the parent directory to the system path
sys.path.insert(0, parent_dir)

from utils_local_alignment import (
    build_args_ST,
    create_optimizer
)

from models import build_model_ST

################################################################################
The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a
future torchdata release! Please see https://github.com/pytorch/data/issues/1196
to learn more and leave feedback.
################################################################################

  from .autonotebook import tqdm as notebook_tqdm


### HP setup

In [2]:
args = build_args_ST()
args.max_epoch=2000
args.max_epoch_triplet=500
args.dataset_name="DLPFC"
args.section_ids=["151507", "151508"]
args.num_class=7
args.num_hidden="512,32"
args.alpha_l=1
args.lam=1 
args.loss_fn="sce" 
args.mask_rate=0.5 
args.in_drop=0 
args.attn_drop=0 
args.remask_rate=0.1
args.seeds=[2024] 
args.num_remasking=1 
args.hvgs=5000 
args.dataset="DLPFC" 
args.consecutive_prior=1 
args.lr=0.001
args.scheduler = True
args.st_data_dir="../../spatial_benchmarking/benchmarking_data/DLPFC12"

### data loader

In [None]:
from local_alignment_main import local_alignment_loader
import scanpy as sc
import anndata
import numpy as np
import dgl
import torch
import paste

section_ids = args.section_ids
exp_fig_dir = args.exp_fig_dir
st_data_dir = args.st_data_dir
dataset_name = args.dataset

graph, num_features, ad_concat = local_alignment_loader(section_ids=section_ids, hvgs=args.hvgs, st_data_dir=st_data_dir, dataname=dataset_name)
args.num_features = num_features

name: DLPFC


  utils.warn_names_duplicates("var")
  utils.warn_names_duplicates("var")


------Calculating spatial graph...


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  Spatial_Net['Cell1'] = Spatial_Net['Cell1'].map(id_cell_trans)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  Spatial_Net['Cell2'] = Spatial_Net['Cell2'].map(id_cell_trans)


The graph contains 24762 edges, 4221 cells.
5.8664 neighbors per cell on average.


  utils.warn_names_duplicates("var")
  utils.warn_names_duplicates("var")


------Calculating spatial graph...


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  Spatial_Net['Cell1'] = Spatial_Net['Cell1'].map(id_cell_trans)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  Spatial_Net['Cell2'] = Spatial_Net['Cell2'].map(id_cell_trans)


The graph contains 25692 edges, 4381 cells.
5.8644 neighbors per cell on average.


### model setup

In [4]:
model = build_model_ST(args)
# print(model)

device = args.device if args.device >= 0 else "cpu"

model.to(device)
optimizer = create_optimizer(args.optimizer, model, args.lr, args.weight_decay)

if args.scheduler:
    scheduler = lambda epoch :( 1 + np.cos((epoch) * np.pi / args.max_epoch) ) * 0.5
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler)
else:
    scheduler = None

=== Use sce_loss and alpha_l=1 ===
num_encoder_params: 1170016, num_decoder_params: 1176670, num_params_in_total: 2384362


### masked reconstruction loss training

In [None]:
from tqdm import tqdm
from local_alignment_main import run_local_alignment

max_epoch = args.max_epoch
max_epoch_triplet = args.max_epoch_triplet
num_class = args.num_class
alpha_value = args.alpha_value

"""training"""
batchlist_, ad_concat = run_local_alignment(graph, model, device, ad_concat, section_ids, max_epoch=max_epoch, max_epoch_triplet=max_epoch_triplet, optimizer=optimizer, scheduler=scheduler, logger=None, num_class=num_class, use_mnn=True)

### ot alignment to generate hard-links

In [24]:
slice1 = batchlist_[0]
slice2 = batchlist_[1]

global_PI = np.zeros((len(slice1.obs.index), len(slice2.obs.index)))
slice1_idx_mapping = {}
slice2_idx_mapping = {}
for i in range(len(slice1.obs.index)):
    slice1_idx_mapping[slice1.obs.index[i]] = i
for i in range(len(slice2.obs.index)):
    slice2_idx_mapping[slice2.obs.index[i]] = i

for i in range(num_class):
    print("run for cluster:", i)
    subslice1 = slice1[slice1.obs['mclust']==i+1]
    subslice2 = slice2[slice2.obs['mclust']==i+1]
    if subslice1.shape[0]>0 and subslice2.shape[0]>0:
        if subslice1.shape[0]>1 and subslice2.shape[0]>1: 
            pi00 = paste.match_spots_using_spatial_heuristic(subslice1.obsm['spatial'], subslice2.obsm['spatial'], use_ot=True)
            local_PI = paste.pairwise_align(subslice1, subslice2, alpha=alpha_value, dissimilarity='kl', use_rep=None, norm=True, verbose=True, G_init=pi00, use_gpu = True, backend = ot.backend.TorchBackend())
        else:  # if there is only one spot in a slice, spatial dissimilarity can't be normalized
            local_PI = paste.pairwise_align(subslice1, subslice2, alpha=alpha_value, dissimilarity='kl', use_rep=None, norm=False, verbose=True, G_init=None, use_gpu = True, backend = ot.backend.TorchBackend())
        for ii in range(local_PI.shape[0]):
            for jj in range(local_PI.shape[1]):
                global_PI[slice1_idx_mapping[subslice1.obs.index[ii]]][slice2_idx_mapping[subslice2.obs.index[jj]]] = local_PI[ii][jj]
                # cluster_matrix[slice1_idx_mapping[subslice1.obs.index[ii]]][slice2_idx_mapping[subslice2.obs.index[jj]]] = i
    else:
        pass

gpu is available, using gpu.
gpu is available, using gpu.
gpu is available, using gpu.
gpu is available, using gpu.
gpu is available, using gpu.
gpu is available, using gpu.
gpu is available, using gpu.


### save/load Hard-links

In [None]:
import pandas as pd

file_name = section_ids[0]+'_'+section_ids[1] +'_'+str(alpha_value)
mapping_mat = scipy.sparse.csr_matrix(global_PI)
file = open(os.path.join(exp_fig_dir, file_name+"_HL.pickle"),'wb')
pickle.dump(mapping_mat, file)


new_slices = paste.stack_slices_pairwise(batchlist_, mapping_mat)
for i, L in enumerate(new_slices):
    spatial_data = L.obsm['spatial']

    output_path = os.path.join(exp_fig_dir, f"coordinates_{section_ids[i]}.csv")
    pd.DataFrame(spatial_data).to_csv(output_path, index=False)
    print(f"Saved spatial data for slice {i} to {output_path}")