Save and load model parameters by stCluster
=====

In certain scenarios, users may find it necessary to retain the embeddings or model parameters acquired during training for subsequent training sessions or to ensure reproducibility in other hardwares. The stCluster framework offers a method that enables users to efficiently and promptly save and load model parameters. The following tutorial will utilize the ZESTA dataset to illustrate this process.  

Frist, we load the dataset, train latent representation by stCluster and save the model parameters and embedding matrix by setting attribute `save_model` and `save_embedding` in function `stCluster.train.train()`.

In [1]:
from st_datasets.dataset import get_data, get_zesta_data
from stCluster.train import train

adata, n_cluster = get_data(get_zesta_data)
train(adata, radius=15, save_model='zesta_model.pkl', save_embedding='zesta_embedding.npy')


	geopandas.options.use_pygeos = True

If you intended to use PyGEOS, set the option to False.
  _check_geopandas_using_shapely()


>>> INFO: dataset name: ZESTA dataset, size: (13166, 26628), cluster: 45.(0.445s)
>>> INFO: Input size torch.Size([13166, 3000]).
>>> INFO: Graph contains 41704 edges, average 3.168 edges per node.
>>> INFO: Build graph success!


  @numba.jit()
  @numba.jit()
  @numba.jit()
  @numba.jit()


>>> INFO: Finish generate precluster embedding!
>>> INFO: Finish pre-cluster, result image is saved at "None", begin to prune graph.
>>> INFO: Finish pruning graph, result image is saved at "None".
>>> INFO: Graph contains 367103 edges, average 27.883 edges per node.
>>> INFO: Build graph success!
>>> INFO: Finish model preparations, begin to train model, input data size: (13166, 3000).


>>> INFO: Training: 100%|██████████| 1000/1000 [01:13<00:00, 13.53it/s]

>>> INFO: Successfully save embedding at zesta_embedding.npy.
>>> INFO: Successfully export model at zesta_model.pkl.
>>> INFO: Finish embedding process, total time: 142.766s.





(AnnData object with n_obs × n_vars = 13166 × 3000
     obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'seurat_clusters', 'spatial_x', 'spatial_y', 'slice', 'bin_annotation', 'colors', 'layer_annotation', 'layer_colors', 'time', 'cluster'
     var: 'highly_variable', 'highly_variable_rank', 'means', 'variances', 'variances_norm'
     uns: 'hvg', 'log1p', 'neighbors', 'louvain'
     obsm: 'spatial', 'embedding'
     layers: 'counts', 'scale.data'
     obsp: 'distances', 'connectivities',
 Graph(num_nodes=13166, num_edges=41704,
       ndata_schemes={}
       edata_schemes={}))

## Load model parameters
We can easily load the model parameters in another device, generate embedding again and do downstream analysis.

In [2]:
from st_datasets.dataset import get_data, get_zesta_data
from stCluster.run import load_and_evaluate

adata, n_cluster = get_data(get_zesta_data)
adata, score = load_and_evaluate(adata, radius=15, n_cluster=n_cluster, cluster_method='mclust', cluster_score_method='ARI', model_paras_path='zesta_model.pkl')

print(score)



>>> INFO: dataset name: ZESTA dataset, size: (13166, 26628), cluster: 45.(0.401s)
>>> INFO: Input size torch.Size([13166, 3000]).
>>> INFO: Graph contains 41704 edges, average 3.168 edges per node.
>>> INFO: Build graph success!
>>> INFO: Finish load model, begin to generate embedding and rebuild gene expression, input data size: (13166, 3000).
>>> INFO: Finish embedding generation process, please use the embedding to do downstream evaluation, total time: 0.334s


R[write to console]:                    __           __ 
   ____ ___  _____/ /_  _______/ /_
  / __ `__ \/ ___/ / / / / ___/ __/
 / / / / / / /__/ / /_/ (__  ) /_  
/_/ /_/ /_/\___/_/\__,_/____/\__/   version 5.4.10
Type 'citation("mclust")' for citing this R package in publications.



fitting ...
{'mclust': 0.35190349034463836}


The model parameters obtained through training with stCluster can be acquired from the containers we have provided at the following locations: `\root\stCluster_paras`. Utilizing the aforementioned resources, you can readily and expeditiously generate latent representation in your device and do downstream analytical tasks.