# Tutorial 5: Integrate Embryo data with MaskGraphene

In [1]:
import logging
import numpy as np
from tqdm import tqdm
import torch
import pickle
import sys
import os
import scanpy as sc
import sklearn.metrics.pairwise

# Get the parent directory of the current script
parent_dir = os.path.abspath(os.path.join(os.getcwd(), os.pardir))

# Add the parent directory to the system path
sys.path.insert(0, parent_dir)

from utils import (
    build_args_ST,
    create_optimizer
)
from datasets.st_loading_utils import visualization_umap_spatial, create_dictionary_mnn
from models import build_model_ST

################################################################################
The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a
future torchdata release! Please see https://github.com/pytorch/data/issues/1196
to learn more and leave feedback.
################################################################################

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# python ../maskgraphene_main_dev.py --max_epoch 2000 --max_epoch_triplet 500 --logging False --section_ids " E12.5_E1S1.h5ad,E13.5_E1S1.h5ad" --num_class 23 --load_model False --num_hidden "512,32" \
#                                --exp_fig_dir "./" --h5ad_save_dir "./" --st_data_dir "/maiziezhou_lab/yunfei/Projects/spatial_benchmarking/benchmarking_data/Embryo" --alpha_l 1 --lam 1 --loss_fn "sce" --mask_rate 0.1 --in_drop 0 --attn_drop 0 --remask_rate 0.1 \
#                                --mapping_mat "./" \
#                                --seeds 2023 42 2 3 4 5 6 7 8 9 2024 --num_remasking 1 --dataset Embryo --lr 0.001 --log_name "./mg0916_embryo.log" --hvgs 7500

In [2]:
args = build_args_ST()

args.section_ids=["E11.5_E1S1.h5ad","E12.5_E1S1.h5ad"]
num_hidden = [512,32]
lr = 0.0003
args.max_epoch = 3000
args.max_epoch_triplet = 500
args.dataset = "Embryo"
args.num_hidden = num_hidden
args.num_layers = len(num_hidden)

args.alpha_l = 1
args.lam = 1
args.loss_fn = "sce"
args.mask_rate = 0.4
args.in_drop = 0.1
args.attn_drop = 0.05
args.remask_rate = 0.1
args.seeds = [2024]
args.hvgs = 7500
args.lr = lr
args.activation = "prelu"
args.negative_slope = 0.2
args.num_dec_layers = 1

#### remember to change these paths to your data path/link path
args.st_data_dir="../../spatial_benchmarking/benchmarking_data/Embryo"
args.hl_dir="../hard_links/Embryo"

In [18]:
import dgl
import scipy
import anndata
from datasets.data_proc import load_ST_dataset

dataset_name = args.dataset
section_ids = args.section_ids

graph, (num_features, num_cls), ad_concat = load_ST_dataset(dataset_name=dataset_name, section_ids=section_ids, args_=args)
args.num_features = num_features
args.num_class = num_cls
x = graph.ndata["feat"]

['E11.5_E1S1.h5ad', 'E12.5_E1S1.h5ad']


  concat_annot[label] = label_col
  adata_concat.obs["batch_name"] = adata_concat.obs["slice_name"].astype('category')


num of class
19


In [19]:
model = build_model_ST(args)
print(model)

device = args.device if args.device >= 0 else "cpu"
model.to(device)

optim_type = args.optimizer 
lr = args.lr
weight_decay = args.weight_decay
optimizer = create_optimizer(optim_type, model, lr, weight_decay)

use_scheduler = args.scheduler
max_epoch = args.max_epoch
max_epoch_triplet = args.max_epoch_triplet
if use_scheduler:
    logging.critical("Use scheduler")
    scheduler = lambda epoch :( 1 + np.cos((epoch) * np.pi / max_epoch) ) * 0.5
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler)
else:
    scheduler = None

model.to(device)
graph = graph.to(device)
x = x.to(device)

=== Use sce_loss and alpha_l=1 ===
num_encoder_params: 3858017, num_decoder_params: 262500, num_params_in_total: 4163443
PreModel(
  (encoder): GAT(
    (gat_layers): ModuleList(
      (0): GATConv(
        (fc): Linear(in_features=7500, out_features=512, bias=False)
        (feat_drop): Dropout(p=0.1, inplace=False)
        (attn_drop): Dropout(p=0.05, inplace=False)
        (leaky_relu): LeakyReLU(negative_slope=0.2)
        (activation): PReLU(num_parameters=1)
      )
      (1): GATConv(
        (fc): Linear(in_features=512, out_features=32, bias=False)
        (feat_drop): Dropout(p=0.1, inplace=False)
        (attn_drop): Dropout(p=0.05, inplace=False)
        (leaky_relu): LeakyReLU(negative_slope=0.2)
      )
    )
    (head): Identity()
  )
  (decoder): GAT(
    (gat_layers): ModuleList(
      (0): GATConv(
        (fc): Linear(in_features=32, out_features=7500, bias=False)
        (feat_drop): Dropout(p=0.1, inplace=False)
        (attn_drop): Dropout(p=0.05, inplace=False)
 

In [None]:
from maskgraphene_main import MG, MG_triplet

model, ad_concat_1 = MG(model, graph, x, optimizer, max_epoch, device, ad_concat, scheduler, logger=None, key_="MG")
model, ad_concat_2 = MG_triplet(model, graph, x, optimizer, max_epoch_triplet, device, adata_concat_=ad_concat_1, scheduler=scheduler, logger=None, key_="MG_triplet")

In [None]:
exp_fig_dir = "./temp"
if not os.path.exists(os.path.join(exp_fig_dir, dataset_name+'_'.join(section_ids))):
    os.makedirs(os.path.join(exp_fig_dir, dataset_name+'_'.join(section_ids)))

exp_fig_dir = os.path.join(exp_fig_dir, dataset_name+'_'.join(section_ids))

ari_ = visualization_umap_spatial(ad_temp=ad_concat_2, section_ids=section_ids, exp_fig_dir=exp_fig_dir, dataset_name=dataset_name, num_iter="0", identifier="stage2", num_class=args.num_class, use_key="MG_triplet")
# ari_2.append(ari_[1])
print(section_ids[0], ', ARI = %01.3f' % ari_[0])
print(section_ids[1], ', ARI = %01.3f' % ari_[1])