### Experimental Preparation

In [3]:
import torch
import time
import scanpy as sc
import numpy as np
import pandas as pd
from sklearn.model_selection import KFold
from sklearn import metrics
import stMSG

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

output_dir = './dataset/'
adata_st = sc.read_h5ad(f'{output_dir}/STdata.h5ad')
adata_sc = sc.read_h5ad(f'{output_dir}/scRNAdata.h5ad')

all_common_genes = adata_st.var_names.intersection(adata_sc.var_names)
adata_sc = adata_sc[:,all_common_genes].copy()

sc.pp.normalize_total(adata_st, target_sum=1e4)
sc.pp.log1p(adata_st)
sc.pp.normalize_total(adata_sc, target_sum=1e4)
sc.pp.log1p(adata_sc)


if adata_st.X.shape[0]>3e3:
    k=30
else:
    k=15


if len(adata_sc.obs_names)<len(adata_st.obs_names):
    kmap = 5
elif len(adata_sc.obs_names)>5e4:
    kmap = 100
elif len(adata_sc.obs_names)>4e4:
    kmap = 80
elif len(adata_sc.obs_names)>3e4:
    kmap = 60
elif len(adata_sc.obs_names)>2e4:
    kmap = 40
elif len(adata_sc.obs_names)>1e4:
    kmap = 20
else:
    kmap = 10

a = 0.1
epochs = 100

### Model Training and imputation

In [5]:
kf = KFold(n_splits=5, shuffle=True, random_state=0)
idx = 1
all_pred_res = np.zeros_like(adata_st.X)

for train_index, test_index in kf.split(all_common_genes):
    print('-'*100)
    start_time = time.time()
    adata_st_temp = adata_st.copy()
    train_genes = [all_common_genes[i] for i in train_index]
    adata_st_temp = adata_st_temp[:, train_genes]
    common_genes = adata_st_temp.var_names.intersection(adata_sc.var_names)

    zeros_matrix = np.zeros_like(adata_st_temp.X)
    mse = round(np.mean((adata_st_temp.X - zeros_matrix) ** 2), 1)

    if mse>20:
        lr1 = 1e-3
        lr2 = 5e-3
    elif mse > 5:
        lr1 = 1e-4
        lr2 = 5e-4
    else:
        lr1 = 1e-5
        lr2 = 5e-5
    
    latent_options = np.array([8, 16, 32, 64, 128, 256])
    target = adata_st_temp.X.shape[1] // 2
    candidates = latent_options[latent_options <= target]
    latent_size = candidates[np.argmin(np.abs(candidates - target))]
    
    adata_st_map = stMSG.train_autoencoder(adata_sc, adata_st_temp, common_genes, latent_size, a, lr1, kmap, epochs, device, None)
    G1, G2 = stMSG.Create_cell_network(adata_st_temp, adata_st_map, k)
    Imp = stMSG.train_graphAutoEncoder(adata_st_map, G2, latent_size, lr2, epochs, device)
    all_pred_res[:,test_index] = Imp[:,test_index]
    end_time = time.time()
    elapsed_time = end_time - start_time
    print(f"Fold {idx} completed in {elapsed_time:.2f} seconds ({elapsed_time/60:.2f} minutes)")
    idx += 1

----------------------------------------------------------------------------------------------------
ST Cell autoencoder begins training...
[Epoch 10/100] Loss: 2.8680 (RecST: 3.1752, MMDz: 0.1038)
[Epoch 20/100] Loss: 2.8081 (RecST: 3.1088, MMDz: 0.1018)
[Epoch 30/100] Loss: 2.7519 (RecST: 3.0465, MMDz: 0.0997)
[Epoch 40/100] Loss: 2.6974 (RecST: 2.9863, MMDz: 0.0975)
[Epoch 50/100] Loss: 2.6437 (RecST: 2.9268, MMDz: 0.0953)
[Epoch 60/100] Loss: 2.5901 (RecST: 2.8676, MMDz: 0.0931)
[Epoch 70/100] Loss: 2.5367 (RecST: 2.8084, MMDz: 0.0908)
[Epoch 80/100] Loss: 2.4832 (RecST: 2.7493, MMDz: 0.0884)
[Epoch 90/100] Loss: 2.4298 (RecST: 2.6902, MMDz: 0.0861)
[Epoch 100/100] Loss: 2.3767 (RecST: 2.6315, MMDz: 0.0838)
Preliminary clustering...



 To achieve the future defaults please pass: flavor="igraph" and n_iterations=2.  directed must also be False to work with igraph's implementation.
  sc.tl.leiden(adata_st_pre, resolution=1.0)


Preliminary clustering...
Searching for scattered cells...
Creating cell network...
Cell graphautoencoder(all genes) begins training......
[Epoch 10/100] Loss: 4.7228 (RecST: 3.1544, Recz: 1.5685)
[Epoch 20/100] Loss: 4.1790 (RecST: 2.9981, Recz: 1.1809)
[Epoch 30/100] Loss: 3.7815 (RecST: 2.8234, Recz: 0.9580)
[Epoch 40/100] Loss: 3.4822 (RecST: 2.6284, Recz: 0.8538)
[Epoch 50/100] Loss: 3.2004 (RecST: 2.4033, Recz: 0.7971)
[Epoch 60/100] Loss: 2.9304 (RecST: 2.1852, Recz: 0.7452)
[Epoch 70/100] Loss: 2.6944 (RecST: 1.9804, Recz: 0.7140)
[Epoch 80/100] Loss: 2.4919 (RecST: 1.8118, Recz: 0.6801)
[Epoch 90/100] Loss: 2.3391 (RecST: 1.6893, Recz: 0.6498)
[Epoch 100/100] Loss: 2.2315 (RecST: 1.6021, Recz: 0.6295)
Fold 1 completed in 21.37 seconds (0.36 minutes)
----------------------------------------------------------------------------------------------------
ST Cell autoencoder begins training...
[Epoch 10/100] Loss: 2.8446 (RecST: 3.1506, MMDz: 0.0905)
[Epoch 20/100] Loss: 2.7807 (RecS

### Evaluation metrics

In [6]:
import Evaluation 
from sklearn.cluster import KMeans
from sklearn import metrics
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)

df_imputed_data = pd.DataFrame(all_pred_res.copy(), index=adata_st.obs_names, columns=adata_st.var_names)
df_raw_data = pd.DataFrame(adata_st.X.copy(), index=adata_st.obs_names, columns=adata_st.var_names)
df_imputed_data.to_csv(f'./stMSG_impute.csv')

adata = sc.AnnData(X=df_imputed_data)

sc.pp.pca(adata_st)
sc.pp.neighbors(adata_st, use_rep='X_pca')
sc.tl.leiden(adata_st)

def clustering_evaluate(raw_adata_st, imp_adata):
    imp_adata_st = imp_adata.copy()
    sc.pp.pca(imp_adata_st)
    sc.pp.neighbors(imp_adata_st, use_rep='X_pca')
    sc.tl.leiden(imp_adata_st)
    labels_pred = imp_adata_st.obs['leiden']
    labels_true = raw_adata_st.obs['leiden']
    
    print('ARI: %.3f' % (metrics.adjusted_rand_score(labels_true, labels_pred)))
    print('FMI: %.3f' % (metrics.fowlkes_mallows_score(labels_true, labels_pred)))
    print('AMI: %.3f' % (metrics.adjusted_mutual_info_score(labels_true, labels_pred)))
    print('Homo: %.3f' % (metrics.homogeneity_score(labels_true, labels_pred)))
    print('NMI: %.3f' % (metrics.normalized_mutual_info_score(labels_true, labels_pred)))

clustering_evaluate(adata_st, adata)
print('-----------')
groundTruth_mask = (df_raw_data == 0)
df_imputed_data[groundTruth_mask] = 0    
print('SSIM: {:.3f}'.format(MJSUtils.SSIM(df_raw_data, df_imputed_data)))
print('JS: {:.3f}'.format(MJSUtils.JS(df_raw_data, df_imputed_data)))
print('RMSE: {:.3f}'.format(MJSUtils.RMSE(df_raw_data, df_imputed_data)))

ARI: 0.502
FMI: 0.579
AMI: 0.723
Homo: 0.846
NMI: 0.732
-----------
SPCC: 0.927
SSIM: 0.828
JS: 0.058
RMSE: 0.320
