### 0. import packages and select GPU if accessible

In [None]:
import os
import random
import numpy as np
import scanpy as sc
import torch
from torch.utils.data import DataLoader
import argparse
import matplotlib.pyplot as plt
from sklearn.metrics import adjusted_rand_score

from st_loading_utils import load_DLPFC, load_BC, load_mVC, load_mPFC, load_mHypothalamus, load_her2_tumor, load_mMAMP
from model import SpaCLR, TrainerSpaCLR
from utils import get_predicted_results, load_ST_file
import pandas as pd
import warnings
from dataset import Dataset
warnings.filterwarnings("ignore")

In [None]:
def seed_torch(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

parser = argparse.ArgumentParser()

# preprocess
parser.add_argument('--dataset', type=str, default="SpatialLIBD")
parser.add_argument('--path', type=str, default="../spatialLIBD")
parser.add_argument("--gene_preprocess", choices=("pca", "hvg"), default="pca")
parser.add_argument("--n_gene", choices=(300, 1000), default=300)
parser.add_argument('--img_size', type=int, default=112)
parser.add_argument('--num_workers', type=int, default=8)

# model
parser.add_argument('--last_dim', type=int, default=64)
parser.add_argument('--lr', type=float, default=0.0003)
parser.add_argument('--p_drop', type=float, default=0)

parser.add_argument('--w_g2i', type=float, default=1)
parser.add_argument('--w_g2g', type=float, default=0.1)
parser.add_argument('--w_i2i', type=float, default=0.1)
parser.add_argument('--w_recon', type=float, default=0)

# data augmentation
parser.add_argument('--prob_mask', type=float, default=0.5)
parser.add_argument('--pct_mask', type=float, default=0.2)
parser.add_argument('--prob_noise', type=float, default=0.5)
parser.add_argument('--pct_noise', type=float, default=0.8)
parser.add_argument('--sigma_noise', type=float, default=0.5)
parser.add_argument('--prob_swap', type=float, default=0.5)
parser.add_argument('--pct_swap', type=float, default=0.1)

# train
parser.add_argument('--batch_size', type=int, default=96)
parser.add_argument('--epochs', type=int, default=35)
parser.add_argument('--device', type=str, default="cuda:3")
parser.add_argument('--log_name', type=str, default="log_name")
parser.add_argument('--name', type=str, default="None")

iters=20

### 1. DLPFC dataset (12 slides)

change '${dir_}' to  'path/to/your/DLPFC/data'

In [None]:
"""DLPFC"""
# the number of clusters
setting_combinations = [[7, '151674'], [7, '151675'], [7, '151676']] [7, '151507'], [7, '151508'], [7, '151509'], [7, '151510'], [5, '151669'], [5, '151670'], [5, '151671'], [5, '151672'], [7, '151673'], [7, '151674'], [7, '151675'],[7, '151676']]
# setting_combinations = [ 
for setting_combi in setting_combinations:
    args = parser.parse_args()
    # seed
    seed_torch(1)

    path = args.path = '/home/yunfei/spatial_benchmarking/benchmarking_data/DLPFC12'
    name = args.name = setting_combi[1]
    gene_preprocess = args.gene_preprocess
    n_gene = args.n_gene
    last_dim = args.last_dim
    gene_dims=[n_gene, 2*last_dim]
    image_dims=[n_gene]
    lr = args.lr
    p_drop = args.p_drop
    batch_size = args.batch_size
    dataset = args.dataset = 'DLPFC'
    epochs = args.epochs
    img_size = args.img_size
    device = args.device
    log_name = args.log_name
    num_workers = args.num_workers
    prob_mask = args.prob_mask
    pct_mask = args.pct_mask
    prob_noise = args.prob_noise
    pct_noise = args.pct_noise
    sigma_noise = args.sigma_noise
    prob_swap = args.prob_swap
    pct_swap = args.pct_swap
    aris = []
    for iter_ in range(iters):
        # dataset
        trainset = Dataset(dataset, path, name, gene_preprocess=gene_preprocess, n_genes=n_gene,
                        prob_mask=prob_mask, pct_mask=pct_mask, prob_noise=prob_noise, pct_noise=pct_noise, sigma_noise=sigma_noise,
                        prob_swap=prob_swap, pct_swap=pct_swap, img_size=img_size, train=True)
        trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)

        testset = Dataset(dataset, path, name, gene_preprocess=gene_preprocess, n_genes=n_gene,
                        prob_mask=prob_mask, pct_mask=pct_mask, prob_noise=prob_noise, pct_noise=pct_noise, sigma_noise=sigma_noise,
                        prob_swap=prob_swap, pct_swap=pct_swap, img_size=img_size, train=False)
        testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)

        # network
        network = SpaCLR(gene_dims=gene_dims, image_dims=image_dims, p_drop=p_drop, n_pos=trainset.n_pos, backbone='densenet', projection_dims=[last_dim, last_dim])
        optimizer = torch.optim.AdamW(network.parameters(), lr=lr)

        # log
        save_name = f'{name}_{args.w_g2i}_{args.w_g2g}_{args.w_i2i}'
        log_dir = os.path.join('log', log_name, save_name)

        # train
        trainer = TrainerSpaCLR(args, trainset.n_clusters, network, optimizer, log_dir, device=device)
        trainer.fit(trainloader, epochs)
        xg, xi, _ = trainer.valid(testloader)
        z = xg + 0.1*xi

        ARI, pred_label = get_predicted_results(args.dataset, args.name, args.path, z)
        print("Ari value : ", ARI)

        print('Dataset:', name)
        print('ARI:', ARI)
        aris.append(ARI)
    print('Dataset:', name)
    print(aris)
    print(np.mean(aris))
    with open('congi_aris.txt', 'a+') as fp:
        fp.write('DLPFC' + name + ' ')
        fp.write(' '.join([str(i) for i in aris]))
        fp.write('\n')

### 2. BC/MA datasets (2 slides)

In [None]:
"""BC"""
# the number of clusters
setting_combinations = [[20, 'section1']]
for setting_combi in setting_combinations:
    args = parser.parse_args()
    # seed
    seed_torch(1)

    path = args.path = '/home/yunfei/spatial_benchmarking/benchmarking_data/BC'
    name = args.name = setting_combi[1]
    gene_preprocess = args.gene_preprocess
    n_gene = args.n_gene
    last_dim = args.last_dim
    gene_dims=[n_gene, 2*last_dim]
    image_dims=[n_gene]
    lr = args.lr
    p_drop = args.p_drop
    batch_size = args.batch_size
    dataset = args.dataset = 'BC'
    epochs = args.epochs
    img_size = args.img_size
    device = args.device
    log_name = args.log_name
    num_workers = args.num_workers
    prob_mask = args.prob_mask
    pct_mask = args.pct_mask
    prob_noise = args.prob_noise
    pct_noise = args.pct_noise
    sigma_noise = args.sigma_noise
    prob_swap = args.prob_swap
    pct_swap = args.pct_swap
    aris = []
    for iter_ in range(iters):
        # dataset
        trainset = Dataset(dataset, path, name, gene_preprocess=gene_preprocess, n_genes=n_gene,
                        prob_mask=prob_mask, pct_mask=pct_mask, prob_noise=prob_noise, pct_noise=pct_noise, sigma_noise=sigma_noise,
                        prob_swap=prob_swap, pct_swap=pct_swap, img_size=img_size, train=True)
        trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)

        testset = Dataset(dataset, path, name, gene_preprocess=gene_preprocess, n_genes=n_gene,
                        prob_mask=prob_mask, pct_mask=pct_mask, prob_noise=prob_noise, pct_noise=pct_noise, sigma_noise=sigma_noise,
                        prob_swap=prob_swap, pct_swap=pct_swap, img_size=img_size, train=False)
        testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)

        # network
        network = SpaCLR(gene_dims=gene_dims, image_dims=image_dims, p_drop=p_drop, n_pos=trainset.n_pos, backbone='densenet', projection_dims=[last_dim, last_dim])
        optimizer = torch.optim.AdamW(network.parameters(), lr=lr)

        # log
        save_name = f'{name}_{args.w_g2i}_{args.w_g2g}_{args.w_i2i}'
        log_dir = os.path.join('log', log_name, save_name)

        # train
        trainer = TrainerSpaCLR(args, trainset.n_clusters, network, optimizer, log_dir, device=device)
        trainer.fit(trainloader, epochs)
        xg, xi, _ = trainer.valid(testloader)
        z = xg + 0.1*xi

        ARI, pred_label = get_predicted_results(args.dataset, args.name, args.path, z)
        print("Ari value : ", ARI)

        print('Dataset:', name)
        print('ARI:', ARI)
        aris.append(ARI)
    print('Dataset:', name)
    print(aris)
    print(np.mean(aris))
    with open('congi_aris.txt', 'a+') as fp:
        fp.write('BC' + name + ' ')
        fp.write(' '.join([str(i) for i in aris]))
        fp.write('\n')

In [None]:
"""MA"""
# the number of clusters
setting_combinations = [[52, 'MA']]
for setting_combi in setting_combinations:
    args = parser.parse_args()
    # seed
    seed_torch(1)

    path = args.path = '/home/yunfei/spatial_benchmarking/benchmarking_data/mMAMP'
    name = args.name = setting_combi[1]
    gene_preprocess = args.gene_preprocess
    n_gene = args.n_gene
    last_dim = args.last_dim
    gene_dims=[n_gene, 2*last_dim]
    image_dims=[n_gene]
    lr = args.lr
    p_drop = args.p_drop
    batch_size = args.batch_size
    dataset = args.dataset = 'MA'
    epochs = args.epochs
    img_size = args.img_size
    device = args.device
    log_name = args.log_name
    num_workers = args.num_workers
    prob_mask = args.prob_mask
    pct_mask = args.pct_mask
    prob_noise = args.prob_noise
    pct_noise = args.pct_noise
    sigma_noise = args.sigma_noise
    prob_swap = args.prob_swap
    pct_swap = args.pct_swap
    aris = []
    for iter_ in range(iters):
        # dataset
        trainset = Dataset(dataset, path, name, gene_preprocess=gene_preprocess, n_genes=n_gene,
                        prob_mask=prob_mask, pct_mask=pct_mask, prob_noise=prob_noise, pct_noise=pct_noise, sigma_noise=sigma_noise,
                        prob_swap=prob_swap, pct_swap=pct_swap, img_size=img_size, train=True)
        trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)

        testset = Dataset(dataset, path, name, gene_preprocess=gene_preprocess, n_genes=n_gene,
                        prob_mask=prob_mask, pct_mask=pct_mask, prob_noise=prob_noise, pct_noise=pct_noise, sigma_noise=sigma_noise,
                        prob_swap=prob_swap, pct_swap=pct_swap, img_size=img_size, train=False)
        testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)

        # network
        network = SpaCLR(gene_dims=gene_dims, image_dims=image_dims, p_drop=p_drop, n_pos=trainset.n_pos, backbone='densenet', projection_dims=[last_dim, last_dim])
        optimizer = torch.optim.AdamW(network.parameters(), lr=lr)

        # log
        save_name = f'{name}_{args.w_g2i}_{args.w_g2g}_{args.w_i2i}'
        log_dir = os.path.join('log', log_name, save_name)

        # train
        trainer = TrainerSpaCLR(args, trainset.n_clusters, network, optimizer, log_dir, device=device)
        trainer.fit(trainloader, epochs)
        xg, xi, _ = trainer.valid(testloader)
        z = xg + 0.1*xi

        ARI, pred_label = get_predicted_results(args.dataset, args.name, args.path, z)
        print("Ari value : ", ARI)

        print('Dataset:', name)
        print('ARI:', ARI)
        aris.append(ARI)
    print('Dataset:', name)
    print(aris)
    print(np.mean(aris))
    with open('congi_aris.txt', 'a+') as fp:
        fp.write('mAB' + name + ' ')
        fp.write(' '.join([str(i) for i in aris]))
        fp.write('\n')

### 3. Her2Tumor dataset (8 slides)

In [None]:
"""Her2st"""
# the number of clusters [6, 'A1'], [5, 'B1'], [4, 'C1'], 
setting_combinations = [[4, 'D1'], [4, 'E1'], [4, 'F1'], [7, 'G2'], [7, 'H1']]
# setting_combinations = [[7, '151674'], [7, '151675'], [7, '151676']]
for setting_combi in setting_combinations:
    args = parser.parse_args()
    # seed
    seed_torch(1)

    path = args.path = '/home/yunfei/spatial_benchmarking/benchmarking_data/Her2_tumor'
    name = args.name = setting_combi[1]
    gene_preprocess = args.gene_preprocess
    n_gene = args.n_gene
    last_dim = args.last_dim
    gene_dims=[n_gene, 2*last_dim]
    image_dims=[n_gene]
    lr = args.lr
    p_drop = args.p_drop
    batch_size = args.batch_size
    dataset = args.dataset = 'Her2st'
    epochs = args.epochs
    img_size = args.img_size
    device = args.device
    log_name = args.log_name
    num_workers = args.num_workers
    prob_mask = args.prob_mask
    pct_mask = args.pct_mask
    prob_noise = args.prob_noise
    pct_noise = args.pct_noise
    sigma_noise = args.sigma_noise
    prob_swap = args.prob_swap
    pct_swap = args.pct_swap
    aris = []
    for iter_ in range(iters):
        # dataset
        trainset = Dataset(dataset, path, name, gene_preprocess=gene_preprocess, n_genes=n_gene,
                        prob_mask=prob_mask, pct_mask=pct_mask, prob_noise=prob_noise, pct_noise=pct_noise, sigma_noise=sigma_noise,
                        prob_swap=prob_swap, pct_swap=pct_swap, img_size=img_size, train=True)
        trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)

        testset = Dataset(dataset, path, name, gene_preprocess=gene_preprocess, n_genes=n_gene,
                        prob_mask=prob_mask, pct_mask=pct_mask, prob_noise=prob_noise, pct_noise=pct_noise, sigma_noise=sigma_noise,
                        prob_swap=prob_swap, pct_swap=pct_swap, img_size=img_size, train=False)
        testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)

        # network
        network = SpaCLR(gene_dims=gene_dims, image_dims=image_dims, p_drop=p_drop, n_pos=trainset.n_pos, backbone='densenet', projection_dims=[last_dim, last_dim])
        optimizer = torch.optim.AdamW(network.parameters(), lr=lr)

        # log
        save_name = f'{name}_{args.w_g2i}_{args.w_g2g}_{args.w_i2i}'
        log_dir = os.path.join('log', log_name, save_name)

        # train
        trainer = TrainerSpaCLR(args, trainset.n_clusters, network, optimizer, log_dir, device=device)
        trainer.fit(trainloader, epochs)
        xg, xi, _ = trainer.valid(testloader)
        z = xg + 0.1*xi

        ARI, pred_label = get_predicted_results(args.dataset, args.name, args.path, z)
        print("Ari value : ", ARI)

        print('Dataset:', name)
        print('ARI:', ARI)
        aris.append(ARI)
    print('Dataset:', name)
    print(aris)
    print(np.mean(aris))
    with open('congi_aris.txt', 'a+') as fp:
        fp.write('Her2tumor' + name + ' ')
        fp.write(' '.join([str(i) for i in aris]))
        fp.write('\n')