In [None]:
import torch
from sklearn import metrics
import multiprocessing as mp
from MGCL import MGCL
import os,csv,re, time
import pickle
import random
import warnings
warnings.filterwarnings('ignore')
import pandas as pd
import numpy as np
from scipy import stats
from scipy.sparse import issparse
import scanpy as sc
import matplotlib.colors as clr
import matplotlib.pyplot as plt
import cv2
import ST as ST
from IPython.display import Image

device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

n_clusters = 7
dataset = '151673'

adata = sc.read("/home/dingsq/dsq/MGCL-project/Data/151673/sample_data.h5ad")
adata.var_names_make_unique()

In [None]:
# define model
model = MGCL.MGCL(adata, device=device, alpha=6)

# train model
adata = model.train()

In [None]:
radius = 50

tool = 'leiden' # mclust, leiden, and louvain

# clustering
from MGCL.utils import clustering

if tool == 'mclust':
   clustering(adata, n_clusters, radius=radius, method=tool, refinement=True) # For DLPFC dataset, we use optional refinement step.
elif tool in ['leiden', 'louvain']:
   clustering(adata, n_clusters, radius=radius, method=tool, start=0.1, end=1.1, increment=0.01, refinement=True)

In [None]:
import numpy as np
import anndata as ad
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score, fowlkes_mallows_score, homogeneity_score, completeness_score, v_measure_score

label_file_path = "/home/dingsq/dsq/Data/151673/label_data.csv"
label_df = pd.read_csv(label_file_path, sep=',')  

cell_codes_in_labels = label_df['cell_code'].tolist()  
label_types = label_df['type'].str.strip().tolist()  

label_dict = {cell_code: label_type for cell_code, label_type in zip(cell_codes_in_labels, label_types)}

#adata.obs['ground_truth'] = None 

for cell_code in label_dict:
    if cell_code in adata.obs.index:
        adata.obs.at[cell_code, 'ground_truth'] = label_dict[cell_code]

adata = adata[~pd.isnull(adata.obs['ground_truth'])]
print(adata.shape)
# calculate metric ARI
ACC = metrics.accuracy_score(adata.obs['domain'], adata.obs['ground_truth'])
NMI = metrics.normalized_mutual_info_score(adata.obs['domain'], adata.obs['ground_truth'])
ARI = metrics.adjusted_rand_score(adata.obs['domain'], adata.obs['ground_truth'])
FMI = fowlkes_mallows_score(adata.obs['domain'], adata.obs['ground_truth'])
homogeneity = homogeneity_score(adata.obs['domain'], adata.obs['ground_truth'])
completeness = completeness_score(adata.obs['domain'], adata.obs['ground_truth'])
v_measure = v_measure_score(adata.obs['domain'], adata.obs['ground_truth'])
print('ARI:', ARI)
print('NMI:', NMI)
print('ARI:', ARI)
print('FMI:', FMI)
print('Homogeneity:', homogeneity)
print('Completeness:', completeness)
print('V_measure:', v_measure)

In [None]:
counts=sc.read("/home/dingsq/dsq/MGCL-ST-project/Data/151673/sample_data.h5ad")
#Read in hitology image
img=cv2.imread("/home/dingsq/dsq/MGCL-ST-project/Data/151673/151673_full_image.tif")

In [None]:
resize_factor=1000/np.min(img.shape[0:2])
resize_width=int(img.shape[1]*resize_factor)
resize_height=int(img.shape[0]*resize_factor)
counts.var.index=[i.upper() for i in counts.var.index]
counts.var_names_make_unique()
sc.pp.log1p(counts) # impute on log scale
if issparse(counts.X):counts.X=counts.X.A


In [None]:
#Three different algorithms to detect contour, select the best one.Here we use cv2.

#-----------------1. Detect contour using cv2-----------------
cnt=ST.cv2_detect_contour(img, apertureSize=5,L2gradient = True)

binary=np.zeros((img.shape[0:2]), dtype=np.uint8)
cv2.drawContours(binary, [cnt], -1, (1), thickness=-1)
#Enlarged filter
cnt_enlarged = ST.scale_contour(cnt, 1.05)
binary_enlarged = np.zeros(img.shape[0:2])
cv2.drawContours(binary_enlarged, [cnt_enlarged], -1, (1), thickness=-1)
img_new = img.copy()
cv2.drawContours(img_new, [cnt], -1, (255), thickness=50)
img_new=cv2.resize(img_new, ((resize_width, resize_height)))


In [None]:
res=50
# Note, if the numer of superpixels is too large and take too long, you can increase the res to 100
enhanced_exp_adata=ST.imputation(img=img, raw=counts, cnt=cnt, genes=counts.var.index.tolist(), shape="None", res=res, s=1, k=2, num_nbs=10)