# GPSA alignment tutorial

In [None]:
import numpy as np
import anndata
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from gpsa import VariationalGPSA
from gpsa import matern12_kernel, rbf_kernel
from gpsa.plotting import callback_twod
import time
import os
import pandas as pd
from st_loading_functions import load_mHypothalamus, load_DLPFC
import scanpy as sc
from sklearn.neighbors import NearestNeighbors, KNeighborsRegressor
from sklearn.metrics import r2_score
from tqdm import tqdm


iters = 1

device = "cuda:0" if torch.cuda.is_available() else "cpu"

def scale_spatial_coords(X, max_val=10.0):
    X = X - X.min(0)
    X = X / X.max(0)
    return X * max_val

In [None]:
def process_data(adata, n_top_genes=2000):
    adata.var_names_make_unique()
    adata.var["mt"] = adata.var_names.str.startswith("MT-")
    sc.pp.calculate_qc_metrics(adata, qc_vars=["mt"], inplace=True)

    # sc.pp.filter_cells(adata, min_counts=5000)
    sc.pp.filter_cells(adata, max_counts=35000)
    # adata = adata[adata.obs["pct_counts_mt"] < 20]
    sc.pp.filter_genes(adata, min_cells=10)

    sc.pp.normalize_total(adata, inplace=True)
    sc.pp.log1p(adata)
    sc.pp.highly_variable_genes(
        adata, flavor="seurat", n_top_genes=n_top_genes, subset=True
    )
    return adata

def train(model, loss_fn, optimizer):
    model.train()

    # Forward pass
    G_means, G_samples, F_latent_samples, F_samples = model.forward(
        X_spatial={"expression": x}, view_idx=view_idx, Ns=Ns, S=5
    )

    # Compute loss
    loss = loss_fn(data_dict, F_samples)

    # Compute gradients and take optimizer step
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    return loss.item(), G_means

save_dir_time = '/maiziezhou_lab/yunfei/Projects/spatial_benchmarking/spatial-alignment/results'

### DLPFC

In [None]:
N_GENES = 10
N_SAMPLES = None

n_spatial_dims = 2
n_views = 2
m_G = 200
m_X_per_view = 200

N_LATENT_GPS = {"expression": None}

N_EPOCHS =  5000
PRINT_EVERY = 50

In [None]:
"""DLPFC""" 
section_ids_list = [['151507', '151508'], ['151508', '151509'], ['151509', '151510'], ['151669', '151670'], ['151670', '151671'], ['151671', '151672'], ['151673', '151674'], ['151674', '151675'], ['151675', '151676']]
run_times = []
for iter_ in range(iters):
    for section_ids in section_ids_list:
        dataset = section_ids[0] + '_' + section_ids[1]
        start_time = time.time()
        output = '.'
        data_slice1 = load_DLPFC(root_dir="../benchmarking_data/DLPFC12", section_id=section_ids[0])
        data_slice1 = process_data(data_slice1, n_top_genes=200)
        data_slice1.obs['batch'] = 0
        data_slice2 = load_DLPFC(root_dir="../benchmarking_data/DLPFC12", section_id=section_ids[1])
        data_slice2 = process_data(data_slice2, n_top_genes=200)
        data_slice2.obs['batch'] = 1

        data = anndata.concat([data_slice1, data_slice2])

        if N_SAMPLES is not None:
            rand_idx = np.random.choice(
                np.arange(data_slice1.shape[0]), size=N_SAMPLES, replace=False
            )
            data_slice1 = data_slice1[rand_idx]
            rand_idx = np.random.choice(
                np.arange(data_slice2.shape[0]), size=N_SAMPLES, replace=False
            )
            data_slice2 = data_slice2[rand_idx]

        # all_slices = anndata.concat([data_slice1, data_slice2])
        n_samples_list = [data_slice1.shape[0], data_slice2.shape[0]]
        view_idx = [
            np.arange(data_slice1.shape[0]),
            np.arange(data_slice1.shape[0], data_slice1.shape[0] + data_slice2.shape[0]),
        ]

        X1 = data_slice1.obsm["spatial"]
        X2 = data_slice2.obsm["spatial"]
        Y1 = data_slice1.X.todense()
        Y2 = data_slice2.X.todense()

        X1 = scale_spatial_coords(X1)
        X2 = scale_spatial_coords(X2)

        Y1 = (Y1 - Y1.mean(0)) / Y1.std(0)
        Y2 = (Y2 - Y2.mean(0)) / Y2.std(0)

        X = np.concatenate([X1, X2])

        Y = np.concatenate([Y1, Y2])

        n_outputs = Y.shape[1]

        x = torch.from_numpy(X).float().clone().to(device)
        y = torch.from_numpy(Y).float().clone().to(device)

        data_dict = {
            "expression": {
                "spatial_coords": x,
                "outputs": y,
                "n_samples_list": n_samples_list,
            }
        }

        model = VariationalGPSA(
            data_dict,
            n_spatial_dims=n_spatial_dims,
            m_X_per_view=m_X_per_view,
            m_G=m_G,
            data_init=True,
            minmax_init=False,
            grid_init=False,
            n_latent_gps=N_LATENT_GPS,
            mean_function="identity_fixed",
            kernel_func_warp=rbf_kernel,
            kernel_func_data=rbf_kernel,
            # fixed_warp_kernel_variances=np.ones(n_views) * 1.,
            # fixed_warp_kernel_lengthscales=np.ones(n_views) * 10,
            fixed_view_idx=0,
        ).to(device)

        view_idx, Ns, _, _ = model.create_view_idx_dict(data_dict)

        optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

        for t in tqdm(range(N_EPOCHS), desc="Training Progress"):
            loss, G_means = train(model, model.loss_fn, optimizer)
            curr_aligned_coords = G_means["expression"].detach().cpu().numpy()
        print("Done!")

        # G_means, _, _, _ = model.forward({"expression": x}, view_idx=view_idx, Ns=Ns)

        # out = G_means['expression'].detach().cpu().numpy()
        df3 = pd.DataFrame(
            {
                "aligned_x": curr_aligned_coords.T[0],
                "aligned_y": curr_aligned_coords.T[1],
            },
        )
        df3.index = data.obs.index

        results = pd.concat([data.obs, df3], axis=1)
        results.to_csv('./results/' + dataset + '_' + str(0) + '.csv')
        end_time = time.time()
        run_times.append(end_time - start_time)