In [19]:
import torch
import torch.nn.functional as F
from tqdm import tqdm
import matplotlib.pyplot as plt

import config as CFG
from models import *
from dataset import *
import scanpy as sc
from torch.utils.data import DataLoader

import os
import numpy as np
import pandas as pd

import scanpy as sc
from itertools import chain

In [20]:
current_device = 3
torch.cuda.set_device(current_device)
torch.set_num_threads(28)

In [21]:
os.chdir('/data/lab_ph/matt')
!pwd

/data/lab_ph/matt


In [22]:
#print the current scanpy version
print(sc.__version__)

1.9.6


In [23]:
fold=5
data='her2st' #### Change here to test different dataset 'her2st' 'cscc'

prune='Grid' if data=='her2st' else 'NA'
genes=785 if data=='cscc' else 6073

def pk_load(fold,mode='test',flatten=False,dataset='brst',r=4,ori=True,adj=True,prune='Grid',neighs=8): #r=4 Hist2ST
    assert dataset in ['her2st','cscc', 'brst']
    if dataset=='her2st':
        dataset = CLIP_HER2ST(
            train=(mode=='train'),fold=fold,flatten=flatten,
            ori=ori,neighs=neighs,adj=adj,prune=prune,r=r
        )
    elif dataset=='cscc':
        dataset = CLIP_SKIN(
            train=(mode=='train'),fold=fold,flatten=flatten,
            ori=ori,neighs=neighs,adj=adj,prune=prune,r=r
        )

    elif dataset == 'brst':
        dataset = BRST(train=(mode=='train'),fold=fold,flatten=flatten,
            ori=ori,neighs=neighs,adj=adj,prune=prune,r=r
        )
    return dataset

def build_loaders_inference():
    print("Building loaders")
    trainset = pk_load(fold,'train',dataset=data,flatten=False,adj=True,ori=True,prune=prune)
    train_loader = DataLoader(trainset, batch_size=1, num_workers=0, shuffle=True)
    testset = pk_load(fold,'test',dataset=data,flatten=False,adj=True,ori=True,prune=prune)
    test_loader = DataLoader(testset, batch_size=1, num_workers=0, shuffle=False)
    print("Finished building loaders")
    return trainset, testset, train_loader, test_loader

#2265x256, 2277x256
def find_matches(spot_embeddings, query_embeddings, top_k=1):
    #find the closest matches
    spot_embeddings = torch.tensor(spot_embeddings)
    query_embeddings = torch.tensor(query_embeddings)
    query_embeddings = F.normalize(query_embeddings, p=2, dim=-1)
    spot_embeddings = F.normalize(spot_embeddings, p=2, dim=-1)
    dot_similarity = query_embeddings @ spot_embeddings.T   #2277x2265
    print("dot_similarity.shape = spots * reference_spots = ",dot_similarity.shape)
    _, indices = torch.topk(dot_similarity.squeeze(0), k=top_k)

    return indices.cpu().numpy()

In [24]:
### Loading data

trainset, testset, train_loader, test_loader = build_loaders_inference()
train_loader = chain(train_loader, test_loader)


print("Finished loading data")

Building loaders
Test set names: ['A1', 'B1', 'C1', 'D1', 'E1', 'F1', 'G2', 'H1']
Train set names: ['B2', 'A6', 'D5', 'C2', 'F2', 'C4', 'D6', 'C5', 'B3', 'C3', 'D3', 'B6', 'E3', 'H3', 'E2', 'D4', 'G3', 'D2', 'B4', 'A5', 'A4', 'H2', 'G1', 'F3', 'A2', 'C6', 'A3', 'B5']
Loading imgs...
Loading metadata...
Test set names: ['A1', 'B1', 'C1', 'D1', 'E1', 'F1', 'G2', 'H1']
Train set names: ['B2', 'A6', 'D5', 'C2', 'F2', 'C4', 'D6', 'C5', 'B3', 'C3', 'D3', 'B6', 'E3', 'H3', 'E2', 'D4', 'G3', 'D2', 'B4', 'A5', 'A4', 'H2', 'G1', 'F3', 'A2', 'C6', 'A3', 'B5']
Loading imgs...
Loading metadata...
Finished building loaders
Finished loading data


In [25]:

# model_path ="clip/best.pt"
if data =='her2st':
    model_path ="/data/lab_ph/matt/clip/SGCL2ST_152.pt"
    save_path = "SGCL2ST/clip/embeddings/her2st/"
if data =='cscc':
    model_path ="SGCL2ST/clip/SGCL2ST_SKIN.pt"
    save_path = "SGCL2ST/clip/embeddings/cscc/"
if data=='brst':
    model_path = "/data/lab_ph/matt/SGCL2ST/clip/bleep_brst_best.pt"
    save_path = "SGCL2ST/clip/embeddings/brst/"

model = myModel().cuda()

state_dict = torch.load(model_path)
new_state_dict = {}
for key in state_dict.keys():
    new_key = key.replace('module.', '')  # remove the prefix 'module.'
    new_key = new_key.replace('well', 'spot') # for compatibility with prior naming
    if "image_encoder.gnn" in new_key: # Special to GNN because GNN use torch_geometric.nn
        new_key = new_key.replace("module_1.","module_1.module.")
    new_state_dict[new_key] = state_dict[key]

model.load_state_dict(new_state_dict)
model.eval()

print("Finished loading model")

Finished loading model


In [26]:
if not os.path.exists(save_path):
        os.makedirs(save_path)

adj_dict = {}
exp_dict = {}
center_dict = {}
with torch.no_grad():
    for batch in tqdm(train_loader):
        #ID, patch, center, exp, adj, oris, sfs, centers = batch
        #uncomment to process patch one by one
        ID, image, positions, exp, center, adj, oris, sfs, *_ = batch
        print("Processing image ", ID)
        #B,N,C,H,W = patch.shape
        #patch = patch.reshape(B*N,C,H,W)  # (N,3,112,112)
        if adj.dim() == 3:
            adj = adj.squeeze(0)
        if exp.dim() == 3:
            exp = exp.squeeze(0)
            centers = center.squeeze().numpy()
        adj_dict[ID] = adj
        exp_dict[ID] = exp
        center_dict[ID] = center

        # to process patches one by one
        image_features = []
        n_patches = len(center[0])
        for i in range(n_patches):
            x, y = center[0][i]
            x = x.item()
            y = y.item()
            patch = image[:, (x - 112):(x + 112), (y - 112):(y + 112), :]
            #patch = patch.unsqueeze(0)
            #print(patch.shape)
            patch=patch.permute(0,3,1,2)
            patch_features = model.image_encoder(patch.cuda())  # Process one patch at a time
            image_features.append(patch_features)
            #print(len(image_features), "processed pathces")

        image_features = torch.cat(image_features, dim=0)

        #image_features = model.image_encoder(patch.cuda())
        spot_features = model.spot_encoder(exp.cuda(), adj.cuda())

        image_embeddings = model.image_projection(image_features).cpu().numpy()
        spot_embeddings = (model.spot_projection(spot_features.cuda()))
        #print(spot_embeddings.type(), "spot embedding type")

        spot_encoding = model.spot_autoencoder.encode(spot_embeddings, adj.cuda())
        # spot_reconstruction, extras = model.spot_autoencoder.decode(spot_encoding.cuda())
        spot_reconstruction = model.spot_autoencoder.decode(spot_encoding.cuda())

        spot_embeddings = spot_embeddings.cpu().numpy()
        spot_encoding = spot_encoding.cpu().numpy()
        #spot_reconstruction = spot_reconstruction.cpu().numpy()

        # print(image_embeddings.shape)
        # print(spot_embeddings.shape)
        np.save(save_path + "img_embeddings_" + str(ID[0]) + ".npy", image_embeddings.T)
        np.save(save_path + "spot_embeddings_" + str(ID[0]) + ".npy", spot_embeddings.T)
# with torch.no_grad():
#     for batch in tqdm(train_loader):
#         ID, patch, positions, exp, adj, oris, sfs, center, *_  = batch
#         print("Processing image ", ID)
#         B,N,C,H,W = patch.shape
#         print(patch.shape, "patch shape")
#         patch = patch.reshape(B*N,C,H,W)  # (N,3,112,112)
#         if adj.dim() == 3:
#             adj = adj.squeeze(0)
#         if exp.dim() == 3:
#             exp = exp.squeeze(0)
#             #centers = centers.squeeze().numpy()
#         adj_dict[ID] = adj
#         exp_dict[ID] = exp
#         #center_dict[ID] = centers
#         exp = exp.to(torch.float32)

#         image_features = model.image_encoder(patch.cuda())
#         spot_features = model.spot_encoder(exp.cuda(), adj.cuda())

#         image_embeddings = model.image_projection(image_features).cpu().numpy()
#         spot_embeddings = (model.spot_projection(spot_features.cuda()))

#         spot_encoding = model.spot_autoencoder.encode(spot_embeddings, adj.cuda())
#         spot_reconstruction, extras = model.spot_autoencoder.decode(spot_encoding.cuda())

#         spot_embeddings = spot_embeddings.cpu().numpy()
#         spot_encoding = spot_encoding.cpu().numpy()
#         spot_reconstruction = spot_reconstruction.cpu().numpy()

#         # print(image_embeddings.shape)
#         # print(spot_embeddings.shape)
#         np.save(save_path + "img_embeddings_" + str(ID[0]) + ".npy", image_embeddings.T)
#         np.save(save_path + "spot_embeddings_" + str(ID[0]) + ".npy", spot_embeddings.T)

0it [00:00, ?it/s]

Processing image  ('H3',)


1it [00:04,  4.80s/it]

Processing image  ('A4',)


2it [00:07,  3.83s/it]

Processing image  ('G1',)


3it [00:11,  3.91s/it]

Processing image  ('D4',)


4it [00:14,  3.48s/it]

Processing image  ('C6',)


5it [00:16,  2.87s/it]

Processing image  ('B2',)


6it [00:19,  2.84s/it]

Processing image  ('D2',)


7it [00:22,  2.92s/it]

Processing image  ('A2',)


8it [00:25,  3.01s/it]

Processing image  ('C2',)


9it [00:27,  2.67s/it]

Processing image  ('B4',)


10it [00:30,  2.71s/it]

Processing image  ('C3',)


11it [00:32,  2.41s/it]

Processing image  ('A6',)


12it [00:35,  2.63s/it]

Processing image  ('F3',)


13it [00:41,  3.72s/it]

Processing image  ('C5',)


14it [00:43,  3.11s/it]

Processing image  ('E3',)


15it [00:48,  3.65s/it]

Processing image  ('A3',)


16it [00:51,  3.49s/it]

Processing image  ('B3',)


17it [00:53,  3.24s/it]

Processing image  ('E2',)


18it [00:59,  3.88s/it]

Processing image  ('D5',)


19it [01:01,  3.54s/it]

Processing image  ('C4',)


20it [01:03,  3.02s/it]

Processing image  ('A5',)


21it [01:06,  3.00s/it]

Processing image  ('B5',)


22it [01:09,  3.00s/it]

Processing image  ('H2',)


23it [01:15,  3.94s/it]

Processing image  ('B6',)


24it [01:18,  3.52s/it]

Processing image  ('D3',)


25it [01:21,  3.28s/it]

Processing image  ('D6',)


26it [01:23,  3.14s/it]

Processing image  ('G3',)


27it [01:27,  3.40s/it]

Processing image  ('F2',)


28it [01:33,  4.18s/it]

Processing image  ('A1',)


29it [01:37,  3.98s/it]

Processing image  ('B1',)


30it [01:40,  3.62s/it]

Processing image  ('C1',)


31it [01:42,  3.08s/it]

Processing image  ('D1',)


32it [01:44,  2.98s/it]

Processing image  ('E1',)


33it [01:50,  3.68s/it]

Processing image  ('F1',)


34it [01:56,  4.36s/it]

Processing image  ('G2',)


35it [01:59,  4.23s/it]

Processing image  ('H1',)


36it [02:05,  3.48s/it]


In [27]:

all_files = os.listdir(save_path)

# exp_dict = {}
# for batch in tqdm(train_loader):
#     ID, patch, center, exp, adj, oris, sfs, *_ = batch
#     print(ID)
#     print(exp.shape)
#     exp_dict[ID] = exp  # Assuming ID and exp are tensors, we fetch their first elements

image_embeddings_dict = {}
spot_embeddings_dict = {}
ID_list = []

for file in all_files:
    if file.endswith(".npy"):
        # Extract the ID from the filename (e.g., A2, C3, etc.)
        if data=='her2st':
            if 'rep' not in file:
                ID = file.split("_")[2].split(".")[0]
        elif data=='cscc':
            if 'rep' in file:
                ID = "_".join(file.split("_")[2:-1]) + "_" + file.split("_")[-1].split(".")[0]
        elif data =='brst':
            if 'rep' not in file: 
                ID = "_".join(file.split("_")[2:]).split(".")[0]
        if (ID,) in adj_dict:
            adj_dict[ID] = adj_dict.pop((ID,))
        if (ID,) in exp_dict:
            exp_dict[ID] = exp_dict.pop((ID,))
        if (ID,) in center_dict:
            center_dict[ID] = center_dict.pop((ID,))

        # Determine the type of file based on its prefix and load the data
        if "img_embeddings" in file:
            print("file", file)
            print("ID",ID)
            image_embeddings_dict[ID] = np.load(os.path.join(save_path, file))
            ID_list.append(ID)
        elif "spot_embeddings" in file:
            spot_embeddings_dict[ID] = np.load(os.path.join(save_path, file))

# Now, image_embeddings_dict and spot_embeddings_dict contain the required data
print("IMAGE EMB", image_embeddings_dict.keys())  # Should list all the image embedding IDs
print(spot_embeddings_dict.keys())  # Should list all the spot embedding IDs
print(exp_dict.keys())  # Should list all the spot embedding IDs
print(ID_list)

if data=='her2st':
    fold=[0,6,12,18,24,27,31,33]
    test_ID = ['A1','B1','C1','D1','E1','F1','G2','H1']
elif data=='cscc':
    fold=[0,3,6,9]
    test_ID = ['P2_ST_rep1', 'P5_ST_rep1', 'P9_ST_rep1', 'P10_ST_rep1']
elif data =='brst':
    fold=[0,6,12,18,24,27,31,33]
    test_ID = ['BC23209_C1', 'BC23269_C1', 'BC23272_D2', 'BC23287_C1', 'BC23377_C1', 'BC23450_D2', 'BC23506_C2', 'BC23508_D2']

# test_ID = [ID_list[i] for i in fold]
print("Test set names:", test_ID)
train_ID = list(set(ID_list)-set(test_ID))
print("Train set names:",train_ID)



file img_embeddings_C3.npy
ID C3
file img_embeddings_C1.npy
ID C1
file img_embeddings_D3.npy
ID D3
file img_embeddings_H2.npy
ID H2
file img_embeddings_G3.npy
ID G3
file img_embeddings_B5.npy
ID B5
file img_embeddings_B3.npy
ID B3
file img_embeddings_D1.npy
ID D1
file img_embeddings_D4.npy
ID D4
file img_embeddings_C6.npy
ID C6
file img_embeddings_F3.npy
ID F3
file img_embeddings_C2.npy
ID C2
file img_embeddings_D2.npy
ID D2
file img_embeddings_B6.npy
ID B6
file img_embeddings_B4.npy
ID B4
file img_embeddings_A1.npy
ID A1
file img_embeddings_D5.npy
ID D5
file img_embeddings_A6.npy
ID A6
file img_embeddings_H1.npy
ID H1
file img_embeddings_A4.npy
ID A4
file img_embeddings_A2.npy
ID A2
file img_embeddings_G2.npy
ID G2
file img_embeddings_H3.npy
ID H3
file img_embeddings_F1.npy
ID F1
file img_embeddings_A5.npy
ID A5
file img_embeddings_F2.npy
ID F2
file img_embeddings_D6.npy
ID D6
file img_embeddings_B1.npy
ID B1
file img_embeddings_C4.npy
ID C4
file img_embeddings_E3.npy
ID E3
file img_e

In [28]:
#query
# test_ID.remove('A1')
train_ID.remove("B3")
print(train_ID)
# image_query = [spot_embeddings_dict[ID] for ID in test_ID]
# expression_gt = [exp_dict[ID].numpy().T for ID in test_ID]

# image_train_data = [image_embeddings_dict[ID] for ID in train_ID]
spot_train_data = [spot_embeddings_dict[ID] for ID in train_ID]
expression_train_data = [exp_dict[ID].numpy().T for ID in train_ID]

spot_key = np.concatenate(spot_train_data, axis=1)
expression_key = np.concatenate(expression_train_data, axis=1)

# print(image_query.shape)
# print(expression_gt.shape)
print(spot_key.shape)
print(expression_key.shape)

if spot_key.shape[1] != 256:
    spot_key = spot_key.T
    print("spot_key shape: ", spot_key.shape)
if expression_key.shape[0] != spot_key.shape[0]:
    expression_key = expression_key.T
    print("expression_key shape: ", expression_key.shape)

['B2', 'A6', 'D5', 'C2', 'F2', 'C4', 'D6', 'C3', 'C5', 'D3', 'B6', 'H3', 'E3', 'E2', 'G3', 'D4', 'D2', 'B4', 'A5', 'A4', 'H2', 'G1', 'F3', 'C6', 'A2', 'A3', 'B5']
(256, 9841)
(785, 9841)
spot_key shape:  (9841, 256)
expression_key shape:  (9841, 785)


In [29]:
import torch
import numpy as np
import scanpy as sc
import anndata as ad
from tqdm import tqdm
from scipy.stats import pearsonr,spearmanr
from sklearn.cluster import KMeans
from sklearn.metrics import adjusted_rand_score as ari_score
from sklearn.metrics.cluster import normalized_mutual_info_score as nmi_score

def test(model,test,device='cuda'):
    model=model.to(device)
    model.eval()
    preds=None
    ct=None
    gt=None
    loss=0
    with torch.no_grad():
        for patch, position, exp, adj, *_, center in tqdm(test):
            patch, position, adj = patch.to(device), position.to(device), adj.to(device).squeeze(0)
            pred = model(patch, position, adj)[0]
            preds = pred.squeeze().cpu().numpy()
            ct = center.squeeze().cpu().numpy()
            gt = exp.squeeze().cpu().numpy()
    adata = ad.AnnData(preds)
    adata.obsm['spatial'] = ct
    adata_gt = ad.AnnData(gt)
    adata_gt.obsm['spatial'] = ct
    return adata,adata_gt

def cluster(adata,label):
    idx = label != 'undetermined'
    tmp=adata[idx]
    l=label[idx]
    print("cluster number:",len(set(l)))
    sc.pp.pca(tmp)
    sc.tl.tsne(tmp)
    kmeans = KMeans(n_clusters=len(set(l)), init="k-means++", random_state=0).fit(tmp.obsm['X_pca'])
    p=kmeans.labels_.astype(str)
    lbl=np.full(len(adata),str(len(set(l))))
    lbl[idx]=p
    adata.obs['kmeans']=lbl
    return p,round(ari_score(p,l),3)

def get_R(data1,data2,dim=1,func=pearsonr):
    adata1=data1.X
    adata2=data2.X
    r1,p1=[],[]
    for g in range(data1.shape[dim]):
        if dim==1:
            r,pv=func(adata1[:,g],adata2[:,g])
        elif dim==0:
            r,pv=func(adata1[g,:],adata2[g,:])
        r1.append(r)
        p1.append(pv)
    r1=np.array(r1)
    p1=np.array(p1)
    return r1,p1

def get_top_values(arr, num_top_values=10, lowest=False):
    return sorted([(i, arr[i]) for i in range(len(arr))], key=lambda x: x[1], reverse=not lowest)[:num_top_values]
top_k = 50
results = {}
top_results = {}
selected_folds = [5]



In [30]:
import warnings
warnings.filterwarnings('ignore')

from sklearn.metrics import mean_squared_error
from math import sqrt

for ID in test_ID:
    print("Begin Processing Image", ID)
    #image_query = spot_embeddings_dict[ID]
    image_query = image_embeddings_dict[ID]
    expression_gt = exp_dict[ID].numpy().T

    method = "weighted_average" # "average" "weighted_average"
    save_path = ""
    if image_query.shape[1] != 256:
        image_query = image_query.T
        print("image query shape: ", image_query.shape)
    if expression_gt.shape[0] != image_query.shape[0]:
        expression_gt = expression_gt.T
        print("expression_gt shape: ", expression_gt.shape)
    if spot_key.shape[1] != 256:
        spot_key = spot_key.T
        print("spot_key shape: ", spot_key.shape)
    if expression_key.shape[0] != spot_key.shape[0]:
        expression_key = expression_key.T
        print("expression_key shape: ", expression_key.shape)

    if method == "simple":
        indices = find_matches(spot_key, image_query, top_k=1)
        matched_spot_embeddings_pred = spot_key[indices[:,0],:]
        print("matched spot embeddings pred shape: ", matched_spot_embeddings_pred.shape)
        matched_spot_expression_pred = expression_key[indices[:,0],:]
        print("matched spot expression pred shape: ", matched_spot_expression_pred.shape)

    if method == "average":
        print("finding matches, using average of top 50 expressions")
        indices = find_matches(spot_key, image_query, top_k=10)
        matched_spot_embeddings_pred = np.zeros((indices.shape[0], spot_key.shape[1]))
        matched_spot_expression_pred = np.zeros((indices.shape[0], expression_key.shape[1]))
        for i in range(indices.shape[0]):
            matched_spot_embeddings_pred[i,:] = np.average(spot_key[indices[i,:],:], axis=0)
            matched_spot_expression_pred[i,:] = np.average(expression_key[indices[i,:],:], axis=0)

        print("matched spot embeddings pred shape: ", matched_spot_embeddings_pred.shape)
        print("matched spot expression pred shape: ", matched_spot_expression_pred.shape)

    if method == "weighted_average":
        print("finding matches, using weighted average of top 100 expressions")
        indices = find_matches(spot_key, image_query, top_k=10)
        # print("indices = ", indices)
        matched_spot_embeddings_pred = np.zeros((indices.shape[0], spot_key.shape[1]))
        matched_spot_expression_pred = np.zeros((indices.shape[0], expression_key.shape[1]))
        for i in range(indices.shape[0]):
            a = np.sum((spot_key[indices[i,0],:] - image_query[i,:])**2) #the smallest MSE
            weights = np.exp(-(np.sum((spot_key[indices[i,:],:] - image_query[i,:])**2, axis=1)-a+1))
            # weights = a/np.sum((spot_key[indices[i,:],:] - image_query[i,:])**2, axis=1)
            # a = np.sqrt(np.sum((spot_key[indices[i,0],:] - image_query[i,:])**2)) #the smallest RMSE
            # weights = np.exp(-(np.sqrt(np.sum((spot_key[indices[i,:],:] - image_query[i,:])**2, axis=1))-a+1))

            # sorted_indices = np.argsort(weights)[::-1]  #
            # top_10_weights = weights[sorted_indices[:10]]
            # least_10_weights = weights[sorted_indices[-10:]]
            # print("Top 10 weights: ", top_10_weights)
            # print("least 10 weights: ", least_10_weights)

            # if i == 0:
            #     print("weights: ", weights)
            matched_spot_embeddings_pred[i,:] = np.average(spot_key[indices[i,:],:], axis=0, weights=weights)
            matched_spot_expression_pred[i,:] = np.average(expression_key[indices[i,:],:], axis=0, weights=weights)

        # print("matched spot embeddings pred shape: ", matched_spot_embeddings_pred.shape)
        # print("matched spot expression pred shape: ", matched_spot_expression_pred.shape)

    true = expression_gt
    pred = matched_spot_expression_pred
    adj = adj_dict[ID]

    model.eval()

    # Create the directory if it doesn't exist
    output_dir = './figures/show'
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    # Additional subdirectories
    subdirs = ['gene', 'clus']
    for subdir in subdirs:
        subdir_path = os.path.join(output_dir, subdir)
        if not os.path.exists(subdir_path):
            os.makedirs(subdir_path)

    with torch.no_grad():
        pred_features = model.spot_encoder(torch.tensor(pred, dtype=torch.float32).cuda(), adj.cuda())
        pred_embeddings = model.spot_projection(torch.tensor(pred_features, dtype=torch.float32).cuda())
        pred_encoding = model.spot_autoencoder.encode(torch.tensor(pred_embeddings, dtype=torch.float32).cuda(), adj.cuda())
        pred_reconstruction, extra = model.spot_autoencoder.decode(torch.tensor(pred_encoding, dtype=torch.float32).cuda())

        pred_features = pred_features.cpu().numpy()
        pred_embeddings = pred_embeddings.cpu().numpy()
        pred_encoding = pred_encoding.cpu().numpy()
        #pred_reconstruction = pred_reconstruction.cpu().numpy()


    print("check for NaN")
    if np.isnan(pred).any():
        print("pred contains NaN values")
    if np.isnan(true).any():
        print("true contains NaN values")
        
    print("pred.shape",pred.shape)
    print("true.shape",true.shape)
    print("np.max(pred)",np.max(pred))
    print("np.max(true)",np.max(true))
    print("np.min(pred)",np.min(pred))
    print("np.min(true)",np.min(true))

    ####### Prediction PCC performance
    #mix = (pred + pred_reconstruction)/2

    def evaluate_gene_expression(pred, true, ID, top_k, fold, top_results, testset):
        # Genewise correlation across cells
        corr_cells = np.zeros(pred.shape[0])
        for i in range(pred.shape[0]):
            corr_cells[i] = np.corrcoef(pred[i, :], true[i, :])[0, 1]
        # Remove NaN
        corr_cells = corr_cells[~np.isnan(corr_cells)]
        print("Cell Mean R: ", np.mean(corr_cells))

        # Calculate RMSE across cells
        mse_cells = mean_squared_error(pred, true)
        rmse_cells = sqrt(mse_cells)
        print("MSE across cells: ", mse_cells)
        print("RMSE across cells: ", rmse_cells)

        # Genewise correlation across genes
        corr_genes = np.zeros(pred.shape[1])
        p_values = np.zeros(pred.shape[1])
        for i in range(pred.shape[1]):
            # corr_genes[i] = np.corrcoef(pred[:, i], true[:, i])[0, 1]
            corr_genes[i], p_values[i] = pearsonr(pred[:, i], true[:, i])
        # Remove NaN
        valid_indices = ~np.isnan(corr_genes)
        corr_genes = corr_genes[valid_indices]
        p_values = p_values[valid_indices]

        if corr_genes.size == 0:
            print("corr_genes is an empty array")
        elif np.isnan(corr_genes).all():
            print("corr_genes is an array of NaNs")
        else:
            print("Max correlation across genes:", np.nanmax(corr_genes))

        print("Genes mean R: ", np.mean(corr_genes))
        print("Gene median R: ", np.median(corr_genes))
        print("number of genes with correlation > 0.3: ", np.sum(corr_genes > 0.3))

        mlog_p_values = -np.log10(p_values)
        # Top-k genes
        # top_k_indices = np.argsort(corr_genes)[-top_k:] ## highest R
        top_k_indices = np.argsort(mlog_p_values)[-top_k:] ## highest -log10 p-values
        top_R_values = corr_genes[top_k_indices]
        top_pred_values = pred[:, top_k_indices]
        top_results[ID] = (top_R_values, top_pred_values)
        print(f'Top {top_k} Genes Mean Pearson Correlation:', np.nanmean(top_R_values))
        print(f'Top {top_k} Genes Median Pearson Correlation:', np.nanmedian(top_R_values))

        # Get top gene correlations
        top_R_values = get_top_values(corr_genes)
        print('Fold', ID, "Top 10 genes with highest -log10 p-values:")
        for gene_id, r_value in top_R_values:
            gene_name = testset.gene_set[gene_id]
            print(f"Gene ID: {gene_id}, Gene Name: {gene_name}, R: {r_value}, p_values: {p_values[gene_id]}")

        return corr_genes

    # Example usage:
    print(f"The Prediction: prediction")
    corr_genes = evaluate_gene_expression(pred, true, ID, top_k, fold, top_results, testset)
    # print(f"\n The Prediction Matrix: pred_reconstruction")
    # evaluate_gene_expression(pred_reconstruction, true, ID, top_k, fold, top_results, testset)
    # print(f"\n The Prediction Matrix: mix")
    # evaluate_gene_expression(mix, true, ID, top_k, fold, top_results, testset)


    # ####### Clustering
    # ### Change the type of pred to AnnData for the next clustering task
    # pred = sc.AnnData(pred)
    # pred.obsm['spatial'] = center_dict[ID]
    # true = ad.AnnData(true)
    # true.obsm['spatial'] = center_dict[ID]
    # if data == "her2st":
    #     pred.var_names = list(np.load('data/her_hvg_cut_1000.npy',allow_pickle=True))
    #     true.var_names = list(np.load('data/her_hvg_cut_1000.npy',allow_pickle=True))
    # elif data == "cscc":
    #     pred.var_names = list(np.load('data/skin_hvg_cut_1000.npy',allow_pickle=True))
    #     true.var_names = list(np.load('data/skin_hvg_cut_1000.npy',allow_pickle=True))
    # elif data =='brst':
    #     pred.var_names = list(np.load("SGCL2ST/data/filtered_hvg_indices.npy", allow_pickle=True))
    #     true.var_names = list(np.load("SGCL2ST/data/filtered_hvg_indices.npy", allow_pickle=True))
    # pred_features = sc.AnnData(pred_features)
    # pred_features.obsm['spatial'] = center_dict[ID]
    # pred_embeddings = sc.AnnData(pred_embeddings)
    # pred_embeddings.obsm['spatial'] = center_dict[ID]
    # # pred_encoding = sc.AnnData(pred_encoding)
    # # pred_encoding.obsm['spatial'] = center_dict[ID]
    # pred_reconstruction = sc.AnnData(pred_reconstruction)
    # pred_reconstruction.obsm['spatial'] = center_dict[ID]
    # mix = sc.AnnData(mix)
    # mix.obsm['spatial'] = center_dict[ID]

    # # Extract top 2 genes based on -log10 p-values
    # top_R_values_2 = get_top_values(corr_genes, num_top_values=3)
    # # Visualize the top 2 genes for this ID
    # for gene_id, r_value in top_R_values_2:
    #     gene_name = testset.gene_set[gene_id]
    #     title = f"ID {ID} Gene: {gene_name} R = {r_value:.3f}"
    #     file_path = f"/gene/{ID}_{gene_name}_{r_value:.3f}.pdf"
    #     sc.pl.spatial(pred, img=testset.get_img(ID), color=gene_name, spot_size=112, title=title, color_map='magma', save=file_path)

        # #if want gt
        # title = f"ID {ID} Gene: {gene_name} gt"
        # file_path = f"/gene/{ID}_{gene_name}_gt.pdf"
        # sc.pl.spatial(true, img=testset.get_img(ID), color=gene_name, spot_size=112, title=title, color_map='magma', save=file_path)


    # if data=='her2st':
    #     ####### Generate cluster figure
    #     label = testset.label[ID]
    #     # print("label = ",label)
    #     clus, ARI = cluster(pred, label)
    #     print('Fold:', fold, 'ARI:', ARI)
    #     title = f"SGCL2ST {ID} ARI = {ARI:.3f}"  # Format title with ARI value
    #     # sc.pl.spatial(pred, img=testset.get_img(ID), color='kmeans', spot_size=112, title=title, save=f"/SGCL2ST_Her2_{ID}_{ARI:.3f}.pdf")


    #     # clus, Top_ARI = cluster(top_pred_values, label)
    #     # print('Fold:', fold, 'Top 100 ARI:', Top_ARI)
    #     clus, feature_ARI = cluster(pred_features, label)
    #     print('Fold:', fold, 'Expression features ARI:', feature_ARI)
    #     title = f"SGCL2ST {ID} ARI = {feature_ARI:.3f}"  # Format title with ARI value
    #     # sc.pl.spatial(pred_features, img=testset.get_img(ID), color='kmeans', spot_size=112, title=title, save=f"/SGCL2ST_Her2_{ID}_features_{feature_ARI:.3f}.pdf")

    #     clus, Emb_ARI = cluster(pred_embeddings, label)
    #     print('Fold:', fold, 'Expression Embeddings ARI:', Emb_ARI)
    #     title = f"SGCL2ST {ID} ARI = {Emb_ARI:.3f}"  # Format title with ARI value
    #     # sc.pl.spatial(pred_embeddings, img=testset.get_img(ID), color='kmeans', spot_size=112, title = title, save=f"/SGCL2ST_Her2_{ID}_Emb_{Emb_ARI:.3f}.pdf")


    #     clus, Re_ARI = cluster(pred_reconstruction, label)
    #     print('Fold:', fold, 'Reconstruction ARI:', Re_ARI)
    #     title = f"SGCL2ST {ID} ARI = {Re_ARI:.3f}"  # Format title with ARI value
    #     # sc.pl.spatial(pred_reconstruction, img=testset.get_img(ID), color='kmeans', spot_size=112, title = title, save=f"/SGCL2ST_Her2_{ID}_Reconstruction_{Re_ARI:.3f}.pdf")

    #     # clus, true_ARI = cluster(true, label) # Observed Gene Expression clustering
    #     # print('Fold:', fold, 'Observed Gene Expression ARI:', true_ARI)
    #     # title = f"Observed Gene Expression {ID} ARI = {true_ARI:.3f}"  # Format title with ARI value
    #     # sc.pl.spatial(true, img=testset.get_img(ID), color='kmeans', spot_size=112, title = title, save=f"/SGCL2ST_Her2_{ID}_true_{true_ARI:.3f}.pdf")


    print("Result of ", ID, " ended! ")
    print("\n\n")


    # if save_path != "":
    #     np.save(save_path + "matched_spot_embeddings_pred.npy", matched_spot_embeddings_pred.T)
    #     np.save(save_path + "matched_spot_expression_pred.npy", matched_spot_expression_pred.T)



Begin Processing Image A1
image query shape:  (346, 256)
expression_gt shape:  (346, 785)
finding matches, using weighted average of top 100 expressions
dot_similarity.shape = spots * reference_spots =  torch.Size([346, 9841])
check for NaN
pred.shape (346, 785)
true.shape (346, 785)
np.max(pred) 171.7166290283203
np.max(true) 188.0
np.min(pred) 0.0
np.min(true) 0.0
The Prediction: prediction
Cell Mean R:  0.2619626484795314
MSE across cells:  20.97224660632903
RMSE across cells:  4.579546550296113
Max correlation across genes: 0.3192695422536538
Genes mean R:  0.03144953471390478
Gene median R:  0.02763726770530971
number of genes with correlation > 0.3:  1
Top 50 Genes Mean Pearson Correlation: 0.17329974823890104
Top 50 Genes Median Pearson Correlation: 0.17401420656783975
Fold A1 Top 10 genes with highest -log10 p-values:
Gene ID: 319, Gene Name: PNMT, R: 0.3192695422536538, p_values: 1.2242040175594486e-09
Gene ID: 134, Gene Name: GNAS, R: 0.2637300508809019, p_values: 6.481617233

In [31]:
print(pred.shape)
print(pred_features.shape)
print(pred_embeddings.shape)
print(pred_reconstruction.shape)

(613, 785)
(613, 785)
(613, 256)
torch.Size([613, 785])


### Visualization of the predicted gene expression

In [32]:
# Visualization of pred
pred_dict = {}
true_dict = {}

for ID in train_ID + test_ID:
# for ID in test_ID:
    print("Begin Processing Image", ID)
    image_query = image_embeddings_dict[ID]
    expression_gt = exp_dict[ID].numpy().T

    method = "weighted_average" # "average" "weighted_average"
    save_path = ""
    if image_query.shape[1] != 256:
        image_query = image_query.T
        print("image query shape: ", image_query.shape)
    if expression_gt.shape[0] != image_query.shape[0]:
        expression_gt = expression_gt.T
        print("expression_gt shape: ", expression_gt.shape)
    if spot_key.shape[1] != 256:
        spot_key = spot_key.T
        print("spot_key shape: ", spot_key.shape)
    if expression_key.shape[0] != spot_key.shape[0]:
        expression_key = expression_key.T
        print("expression_key shape: ", expression_key.shape)

    if method == "simple":
        indices = find_matches(spot_key, image_query, top_k=1)
        matched_spot_embeddings_pred = spot_key[indices[:,0],:]
        print("matched spot embeddings pred shape: ", matched_spot_embeddings_pred.shape)
        matched_spot_expression_pred = expression_key[indices[:,0],:]
        print("matched spot expression pred shape: ", matched_spot_expression_pred.shape)

    if method == "average":
        print("finding matches, using average of top 50 expressions")
        indices = find_matches(spot_key, image_query, top_k=50)
        matched_spot_embeddings_pred = np.zeros((indices.shape[0], spot_key.shape[1]))
        matched_spot_expression_pred = np.zeros((indices.shape[0], expression_key.shape[1]))
        for i in range(indices.shape[0]):
            matched_spot_embeddings_pred[i,:] = np.average(spot_key[indices[i,:],:], axis=0)
            matched_spot_expression_pred[i,:] = np.average(expression_key[indices[i,:],:], axis=0)

        print("matched spot embeddings pred shape: ", matched_spot_embeddings_pred.shape)
        print("matched spot expression pred shape: ", matched_spot_expression_pred.shape)

    if method == "weighted_average":
        print("finding matches, using weighted average of top 50 expressions")
        indices = find_matches(spot_key, image_query, top_k=100)
        # print("indices = ", indices)
        matched_spot_embeddings_pred = np.zeros((indices.shape[0], spot_key.shape[1]))
        matched_spot_expression_pred = np.zeros((indices.shape[0], expression_key.shape[1]))
        for i in range(indices.shape[0]):
            a = np.sum((spot_key[indices[i,0],:] - image_query[i,:])**2) #the smallest MSE
            weights = np.exp(-(np.sum((spot_key[indices[i,:],:] - image_query[i,:])**2, axis=1)-a+1))

            # if i == 0:
            #     print("weights: ", weights)
            matched_spot_embeddings_pred[i,:] = np.average(spot_key[indices[i,:],:], axis=0, weights=weights)
            matched_spot_expression_pred[i,:] = np.average(expression_key[indices[i,:],:], axis=0, weights=weights)

        # print("matched spot embeddings pred shape: ", matched_spot_embeddings_pred.shape)
        # print("matched spot expression pred shape: ", matched_spot_expression_pred.shape)

    true = expression_gt
    pred = matched_spot_expression_pred
    adj = adj_dict[ID]

    model.eval()

    # Create the directory if it doesn't exist
    output_dir = './figures/show'
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    # Additional subdirectories
    subdirs = ['gene', 'clus']
    for subdir in subdirs:
        subdir_path = os.path.join(output_dir, subdir)
        if not os.path.exists(subdir_path):
            os.makedirs(subdir_path)

    with torch.no_grad():
        pred_features = model.spot_encoder(torch.tensor(pred, dtype=torch.float32).cuda(), adj.cuda())
        pred_embeddings = model.spot_projection(torch.tensor(pred_features, dtype=torch.float32).cuda())
        pred_encoding = model.spot_autoencoder.encode(torch.tensor(pred_embeddings, dtype=torch.float32).cuda(), adj.cuda())
        pred_reconstruction, extra = model.spot_autoencoder.decode(torch.tensor(pred_encoding, dtype=torch.float32).cuda())

        pred_features = pred_features.cpu().numpy()
        pred_embeddings = pred_embeddings.cpu().numpy()
        pred_encoding = pred_encoding.cpu().numpy()
        pred_reconstruction = pred_reconstruction.cpu().numpy()

    print("pred.shape",pred.shape)
    print("true.shape",true.shape)
    print("np.max(pred)",np.max(pred))
    print("np.max(true)",np.max(true))
    print("np.min(pred)",np.min(pred))
    print("np.min(true)",np.min(true))

    pred_dict[ID] = pred
    true_dict[ID] = true
    ####### Prediction PCC performance
    # mix = (pred + pred_reconstruction)/2

import pickle
# save pred_dict to file
output_dir = './clip'
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

# # set save path pred_dict
# output_file_path = os.path.join(output_dir, 'pred_dict.pkl')
# # use pickle to write pred_dict to file.
# with open(output_file_path, 'wb') as file:
#     pickle.dump(pred_dict, file)
# print(f"Dictionary saved to {output_file_path}")

# # set save path true_dict
# output_file_path = os.path.join(output_dir, 'true_dict.pkl')
# # use pickle to write pred_dict to file.
# with open(output_file_path, 'wb') as file:
#     pickle.dump(true_dict, file)
# print(f"Dictionary saved to {output_file_path}")


Begin Processing Image B2
image query shape:  (270, 256)
expression_gt shape:  (270, 785)
finding matches, using weighted average of top 50 expressions
dot_similarity.shape = spots * reference_spots =  torch.Size([270, 9841])
pred.shape (270, 785)
true.shape (270, 785)
np.max(pred) 175.4484405517578
np.max(true) 222.0
np.min(pred) 0.0
np.min(true) 0.0
Begin Processing Image A6
image query shape:  (360, 256)
expression_gt shape:  (360, 785)
finding matches, using weighted average of top 50 expressions
dot_similarity.shape = spots * reference_spots =  torch.Size([360, 9841])
pred.shape (360, 785)
true.shape (360, 785)
np.max(pred) 218.63262939453125
np.max(true) 280.0
np.min(pred) 0.0
np.min(true) 0.0
Begin Processing Image D5
image query shape:  (306, 256)
expression_gt shape:  (306, 785)
finding matches, using weighted average of top 50 expressions
dot_similarity.shape = spots * reference_spots =  torch.Size([306, 9841])
pred.shape (306, 785)
true.shape (306, 785)
np.max(pred) 223.5318

In [33]:
# pred_dict['A1']

In [34]:
from math import log10
def get_top_genes(pred_dict, true_dict, num_genes=50):
    # Initialize a dictionary to store gene p-values across all images
    gene_p_values = {}

    # Loop through each image and its predicted expression in pred_dict
    for ID, pred in pred_dict.items():
        print("Processing Image", ID)
        true = true_dict[ID]
        # Calculate p-values for each gene in this image
        for gene_idx in range(pred.shape[1]):
            _, p_value = pearsonr(pred[:, gene_idx], true[:, gene_idx])
            if not np.isnan(p_value):  # Only consider valid p-values
                if gene_idx not in gene_p_values:
                    gene_p_values[gene_idx] = []
                gene_p_values[gene_idx].append(p_value)

    # Calculate the average -log10(p-value) for each gene across all images
    avg_log_p_values = {}
    for gene_idx, p_values in gene_p_values.items():
        avg_p_value = np.mean(p_values)
        avg_log_p_value = -log10(avg_p_value)
        avg_log_p_values[gene_idx] = avg_log_p_value

    # Sort genes by the average -log10(p-value) and get the top 50
    top_genes = sorted(avg_log_p_values, key=avg_log_p_values.get, reverse=True)[:num_genes]

    # Return the top genes and their average -log10(p-value)
    return [(gene_idx, avg_log_p_values[gene_idx]) for gene_idx in top_genes]

# Usage:
# Assume testset has an attribute 'gene_set' which is a list of gene names
# true_dict is a dictionary with the same keys as pred_dict and contains the true gene expressions

# Get top 50 genes based on average -log10(p-value) across all images
top_genes_info = get_top_genes(pred_dict, true_dict, num_genes=50)

# Print the names and p-values of the top genes
for gene_idx, log_p_value in top_genes_info:
    gene_name = testset.gene_set[gene_idx]
    print(f"Gene Name: {gene_name}, Average -log10(p-value): {log_p_value}")

# top_genes_info = top_genes_info[:10]
# top_genes_list = [testset.gene_set[gene_idx] for gene_idx, _ in top_genes_info]

Processing Image B2
Processing Image A6
Processing Image D5
Processing Image C2
Processing Image F2
Processing Image C4
Processing Image D6
Processing Image C3
Processing Image C5
Processing Image D3
Processing Image B6
Processing Image H3
Processing Image E3
Processing Image E2
Processing Image G3
Processing Image D4
Processing Image D2
Processing Image B4
Processing Image A5
Processing Image A4
Processing Image H2
Processing Image G1
Processing Image F3
Processing Image C6
Processing Image A2
Processing Image A3
Processing Image B5
Processing Image A1
Processing Image B1
Processing Image C1
Processing Image D1
Processing Image E1
Processing Image F1
Processing Image G2
Processing Image H1
Gene Name: BSG, Average -log10(p-value): 1.5616637478698256
Gene Name: CLDN3, Average -log10(p-value): 1.555038102413818
Gene Name: GNAS, Average -log10(p-value): 1.5125540991637683
Gene Name: IGKC, Average -log10(p-value): 1.4716854625801628
Gene Name: IGHG3, Average -log10(p-value): 1.439351613437

In [35]:
top_n_genes_info = top_genes_info[:12]
top_n_genes_list = [testset.gene_set[gene_idx] for gene_idx, _ in top_n_genes_info]
# top_n_genes_list = ['FN1', 'FASN', 'HLA-DRA', 'CLDN4', 'COL3A1', 'C3', 'GNAS', 'LUM', 'CD74', 'HLA-B', 'MYL12B', 'CCT4']
# top_n_genes_list = ['FN1', 'FASN', 'HLA-DRA', 'CLDN4', 'GNAS', 'MYL12B']
# index_to_name = {index: name for index, name in enumerate(testset.gene_set)}
# top_n_genes_info = [(gene_idx, log_p_value) for gene_idx, log_p_value in top_genes_info if index_to_name[gene_idx] in top_n_genes_list]

In [36]:
print("top_n_genes_list:", top_n_genes_list)

# 计算并保存前6个基因的信息
def calculate_and_save_top_genes(pred_dict, true_dict, testset):
    # # 从 get_top_genes 函数获取基因信息
    # top_genes_info = get_top_genes(pred_dict, true_dict, num_genes=num_genes)
    # top_6_genes_info = top_genes_info[:6]  # 取前6个基因的信息

    # 遍历所有图像ID，计算相关系数和p值
    gene_correlations = {gene_idx: [] for gene_idx, _ in top_n_genes_info}
    for ID in test_ID:
        pred = pred_dict[ID]
        true = true_dict[ID]
        for gene_idx, _ in top_n_genes_info:
            if not np.isnan(pred).any() and not np.isnan(true).any():
                R, p_value = pearsonr(pred[:, gene_idx], true[:, gene_idx])
                gene_correlations[gene_idx].append((R, p_value, ID))

    # # 确保保存图像的目录存在
    # output_dir = './figures/gene'
    # if not os.path.exists(output_dir):
    #     os.makedirs(output_dir)

    # 为每个基因选择最佳的图像ID并保存图像
    for gene_idx, _ in top_n_genes_info:
        print(gene_idx)
        # 找到相关系数最高的记录
        best_record = max(gene_correlations[gene_idx], key=lambda x: x[0])
        R, p_value, ID = best_record
        gene_name = testset.gene_set[gene_idx]
        print(f"Gene: {gene_name}, Best R: {R}, p-value: {p_value}, Image ID: {ID}")

        pred_gene = sc.AnnData(pred_dict[ID])
        pred_gene.obsm['spatial'] = center_dict[ID]
        true_gene = sc.AnnData(true_dict[ID])
        true_gene.obsm['spatial'] = center_dict[ID]
        if data == "her2st":
            pred_gene.var_names = list(np.load('data/her_hvg_cut_1000.npy',allow_pickle=True))
            true_gene.var_names = list(np.load('data/her_hvg_cut_1000.npy',allow_pickle=True))
        elif data == "cscc":
            pred_gene.var_names = list(np.load('data/skin_hvg_cut_1000.npy',allow_pickle=True))
            true_gene.var_names = list(np.load('data/skin_hvg_cut_1000.npy',allow_pickle=True))
        title = f"ID {ID} Gene: {gene_name} R = {R:.3f}"
        file_path = f"/gene/{ID}_{gene_name}_{R:.3f}.pdf"
        # sc.pl.spatial(pred_gene, img=testset.get_img(ID), color=gene_name, spot_size=112, title=title, color_map='magma', save=file_path)

        title = f"ID {ID} Gene: {gene_name} Observed Gene Expression"
        file_path = f"/gene/{ID}_{gene_name}_Observed Gene Expression.pdf"
        # sc.pl.spatial(pred_gene, img=testset.get_img(ID), color=gene_name, spot_size=112, title=title, color_map='magma', save=file_path)

calculate_and_save_top_genes(pred_dict, true_dict, testset)


top_n_genes_list: ['BSG', 'CLDN3', 'GNAS', 'IGKC', 'IGHG3', 'FASN', 'HSPB1', 'PLK2', 'NDUFS1', 'ITGB6', 'MGP', 'IGLC3']
403
Gene: BSG, Best R: 0.4403072695090306, p-value: 2.026585464170368e-15, Image ID: B1


ValueError: Value passed for key 'spatial' is of incorrect shape. Values of obsm must match dimensions (0,) of parent. Value had shape (1,) while it should have had (295,).