In [None]:
import math
import time
import os
import random
from pathlib import Path
import yaml

import pandas as pd
import numpy as np
import scanpy as sc
import paste as pst
import ot

In [None]:
sc.set_figure_params(dpi_save=200,dpi=150)

In [None]:
# parameter cells
adata1_file = ''
adata2_file = ''
metric_file = ''
matching_file = ''
# emb0_file = '' # PASTE do not have emb 
# emb1_file = ''

In [None]:
def global_seed(seed: int):
    r"""
    Set seed
    
    Parameters
    ----------
    seed 
        int
    """
    seed = seed if seed != -1 else torch.seed()
    if seed > 2**32 - 1:
        seed = seed >> 32

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    print(f"Global seed set to {seed}.")

In [None]:
adata1 = sc.read_h5ad(adata1_file)
adata2 = sc.read_h5ad(adata2_file)

In [None]:
alpha = 0.1

In [None]:
start = time.time()
pi0 = pst.match_spots_using_spatial_heuristic(adata1.obsm['spatial'], adata2.obsm['spatial'], use_ot=True)
pi12 = pst.pairwise_align(adata1, adata2, use_gpu=True, backend=ot.backend.TorchBackend(), 
                          alpha=alpha, G_init=pi0, norm=True, verbose=True)
print('Runtime: ' + str(time.time() - start))
run_time = str(time.time() - start)

In [None]:
result = pd.DataFrame(pi12)
result.shape

In [None]:
matching_index = np.argmax(result.to_numpy(), axis=0)
matching = np.array([np.arange(result.shape[1]), matching_index])

# Save

In [None]:
np.savetxt(matching_file, matching, fmt='%i')

In [None]:
time_dic = {}
time_dic['run_time'] = run_time

out_dir = Path(os.path.dirname(matching_file))
with open(out_dir / 'run_time.yaml', "w") as f:
    yaml.dump(time_dic, f)