# PASTE2 alignment tutorial

In [1]:
from src.paste2 import partial_pairwise_align
from src.paste2.model_selection import select_overlap_fraction
import os
import numpy as np
import scanpy as sc

from utils.function_utils import *
import time
import paste as pst

iters = 1
save_dir_gt = './paste2_results'

  from .autonotebook import tqdm as notebook_tqdm


### Mouse Hypothalamus data integration (pair-wise)

In [None]:
section_ids_list = [['-0.04', '-0.09'], ['-0.09', '-0.14'], ['-0.14', '-0.19'], ['-0.19', '-0.24']]
run_times = []
for iter_ in range(iters):
    for section_ids in section_ids_list:
        dataset = section_ids[0] + '_' + section_ids[1]
        print(dataset)
        start_time = time.time()
        output = '.'
        slice1 = load_mHypothalamus(section_id=section_ids[0])
        slice2 = load_mHypothalamus(section_id=section_ids[1])
        
        l = slice1.copy()
        siml = slice2.copy()

        overlap_frac = select_overlap_fraction(l, siml, alpha=0.1, dis="kl")
        print(overlap_frac)
        pi12 = partial_pairwise_align(l, siml, overlap_frac, alpha=0.1, armijo=False, dissimilarity='kl', use_rep=None, G_init=None, a_distribution=None, b_distribution=None, norm=True, return_obj=False, verbose=True)

        # save alignment matrix
        if not os.path.exists(os.path.join(save_dir_gt, dataset)):
            os.makedirs(os.path.join(save_dir_gt, dataset))
        np.save(os.path.join(save_dir_gt, dataset, 'iter'+str(iter_)+'embedding.npy'), pi12)
        end_time = time.time()
        run_times.append(end_time - start_time)

        # np.save(osp.join(output, 'alignment151509151510.npy'), pi12)

        # save labels
        labels = []
        labels.extend(list(slice1.obs['original_clusters']))
        labels.extend(list(slice2.obs['original_clusters']))
        np.save(os.path.join(save_dir_gt, dataset, 'iter'+str(iter_)+'labels.npy'), labels)


### DLPFC data integration (pair-wise)

In [None]:
section_ids_list = [['151507', '151508'], ['151508', '151509'], ['151509', '151510']]
run_times = []
for iter_ in range(iters):
    
    for section_ids in section_ids_list:
        dataset = section_ids[0] + '_' + section_ids[1]
        start_time = time.time()
        slice1 = load_DLPFC(section_id=section_ids[0])
        slice2 = load_DLPFC(section_id=section_ids[1])

        l = slice1.copy()
        siml = slice2.copy()

        # run paste2 pairwise alignment
        # overlap_frac = float(section_ids[1].split("=")[-1].split("%")[0]) / 100
        # pi0 = pst.match_spots_using_spatial_heuristic(slice1.obsm['spatial'],slice2.obsm['spatial'],use_ot=True)
        # overlap_frac = select_overlap_fraction(slice1, slice2, alpha=0.1)
        # overlap_frac = 0.2
        overlap_frac = select_overlap_fraction(l, siml, alpha=0.1, dis="kl")
        print(overlap_frac)
        pi12 = partial_pairwise_align(l, siml, overlap_frac, alpha=0.1, armijo=False, dissimilarity='kl', use_rep=None, G_init=None, a_distribution=None, b_distribution=None, norm=True, return_obj=False, verbose=True)

        # save alignment matrix
        if not os.path.exists(os.path.join(save_dir_gt, dataset)):
            os.makedirs(os.path.join(save_dir_gt, dataset))
        np.save(osp.join(save_dir_gt, dataset, 'iter'+str(iter_)+'embedding.npy'), pi12)

        end_time = time.time()
        run_times.append(end_time - start_time)

        # # stack slices
        # slices, pis = [slice1, slice2], [pi12]
        # new_slices = pst.stack_slices_pairwise(slices, pis)

        # # Center slice integration
        # initial_slice = slice1.copy()    
        # slices = [slice1, slice2]
        # lmbda = len(slices)*[1/len(slices)] # set hyperparameter to be uniform

        # ## Possible to pass in an initial pi (as keyword argument pis_init) 
        # center_slice, pis = pst.center_align(initial_slice, slices, lmbda, backend = ot.backend.TorchBackend(), use_gpu=True) 

        # save labels
        labels = []
        labels.extend(list(slice1.obs['original_clusters']))
        labels.extend(list(slice2.obs['original_clusters']))
        np.save(osp.join(save_dir_gt, dataset, 'iter'+str(iter_)+'labels.npy'), labels)