Skip to content

Commit

Permalink
add use_img
Browse files Browse the repository at this point in the history
  • Loading branch information
frickyinn committed May 8, 2022
1 parent e0896ab commit c3875c1
Show file tree
Hide file tree
Showing 4 changed files with 302 additions and 20 deletions.
212 changes: 212 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
#! /usr/bin/env python
# -*- coding:utf-8 -*-
#%%

import torch
import argparse
import random
import numpy as np
import pandas as pd
from src.graph_func import graph_construction
from src.utils_func import mk_dir, adata_preprocess, load_ST_file, res_search_fixed_clus, plot_clustering
from src.training import conST_training

import anndata
from sklearn import metrics
import matplotlib.pyplot as plt
import scanpy as sc
import os
import warnings
warnings.filterwarnings('ignore')

#%%

parser = argparse.ArgumentParser()
parser.add_argument('--k', type=int, default=10, help='parameter k in spatial graph')
parser.add_argument('--knn_distanceType', type=str, default='euclidean',
help='graph distance type: euclidean/cosine/correlation')
parser.add_argument('--epochs', type=int, default=200, help='Number of epochs to train.')
parser.add_argument('--cell_feat_dim', type=int, default=300, help='Dim of PCA')
parser.add_argument('--feat_hidden1', type=int, default=100, help='Dim of DNN hidden 1-layer.')
parser.add_argument('--feat_hidden2', type=int, default=20, help='Dim of DNN hidden 2-layer.')
parser.add_argument('--gcn_hidden1', type=int, default=32, help='Dim of GCN hidden 1-layer.')
parser.add_argument('--gcn_hidden2', type=int, default=8, help='Dim of GCN hidden 2-layer.')
parser.add_argument('--p_drop', type=float, default=0.2, help='Dropout rate.')
parser.add_argument('--use_img', type=bool, default=False, help='Use histology images.')
parser.add_argument('--img_w', type=float, default=0.1, help='Weight of image features.')
parser.add_argument('--use_pretrained', type=bool, default=False, help='Use pretrained weights.')
parser.add_argument('--using_mask', type=bool, default=False, help='Using mask for multi-dataset.')
parser.add_argument('--feat_w', type=float, default=10, help='Weight of DNN loss.')
parser.add_argument('--gcn_w', type=float, default=0.1, help='Weight of GCN loss.')
parser.add_argument('--dec_kl_w', type=float, default=10, help='Weight of DEC loss.')
parser.add_argument('--gcn_lr', type=float, default=0.01, help='Initial GNN learning rate.')
parser.add_argument('--gcn_decay', type=float, default=0.01, help='Initial decay rate.')
parser.add_argument('--dec_cluster_n', type=int, default=10, help='DEC cluster number.')
parser.add_argument('--dec_interval', type=int, default=20, help='DEC interval nnumber.')
parser.add_argument('--dec_tol', type=float, default=0.00, help='DEC tol.')

parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument('--beta', type=float, default=100, help='beta value for l2c')
parser.add_argument('--cont_l2l', type=float, default=0.3, help='Weight of local contrastive learning loss.')
parser.add_argument('--cont_l2c', type=float, default= 0.1, help='Weight of context contrastive learning loss.')
parser.add_argument('--cont_l2g', type=float, default= 0.1, help='Weight of global contrastive learning loss.')

parser.add_argument('--edge_drop_p1', type=float, default=0.1, help='drop rate of adjacent matrix of the first view')
parser.add_argument('--edge_drop_p2', type=float, default=0.1, help='drop rate of adjacent matrix of the second view')
parser.add_argument('--node_drop_p1', type=float, default=0.2, help='drop rate of node features of the first view')
parser.add_argument('--node_drop_p2', type=float, default=0.3, help='drop rate of node features of the second view')

# ______________ Eval clustering Setting ______________
parser.add_argument('--eval_resolution', type=int, default=1, help='Eval cluster number.')
parser.add_argument('--eval_graph_n', type=int, default=20, help='Eval graph kN tol.')

params = parser.parse_args(args=['--k', '20', '--knn_distanceType', 'euclidean', '--epochs', '200'])

# os.environ["CUDA_VISIBLE_DEVICES"] = "1"

np.random.seed(params.seed)
torch.manual_seed(params.seed)
torch.cuda.manual_seed(params.seed)
device = 'cuda:1' if torch.cuda.is_available() else 'cpu'
print('Using device: ' + device)
params.device = device

#%%

"""
path = './data/spatialLIBD/151673'
adata_h5 = load_ST_file(path, count_file='151673_filtered_feature_bc_matrix.h5')
adata_X = adata_preprocess(adata_h5, min_cells=5, pca_n_comps=params.cell_feat_dim)
graph_dict = graph_construction(adata_h5.obsm['spatial'], adata_h5.shape[0], params)
np.save('./input/adatax.npy', adata_X)
np.save('./input/graphdict.npy', graph_dict, allow_pickle = True)
"""

#%%

# set seed before every run
def seed_torch(seed):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
seed_torch(params.seed)

#%%

data_name = '151673'
save_root = './output/spatialLIBD/'
data_root = '../spatialLIBD'

params.save_path = mk_dir(f'{save_root}/{data_name}/conST')

path = '../spatialLIBD/151673'
adata_h5 = load_ST_file(path, count_file='151673_filtered_feature_bc_matrix.h5')

adata_X = np.load('./input/adatax.npy')
graph_dict = np.load('./input/graphdict.npy', allow_pickle = True).item()
params.cell_num = adata_h5.shape[0]
df_meta = pd.read_csv(f'{data_root}/{data_name}/metadata.tsv', sep='\t')
labels = pd.Categorical(df_meta['layer_guess']).codes

log_dir = './log/_test/'

n_clusters = 7
if params.use_img:
img_transformed = np.load('./MAE-pytorch/extracted_feature.npy')
img_transformed = (img_transformed - img_transformed.mean()) / img_transformed.std() * adata_X.std() + adata_X.mean()
conST_net = conST_training(adata_X, graph_dict, params, n_clusters, labels, log_dir, img_transformed)
else:
conST_net = conST_training(adata_X, graph_dict, params, n_clusters, labels, log_dir)
if params.use_pretrained:
conST_net.load_model('conST_151673.pth')
else:
conST_net.pretraining()
conST_net.major_training()

'''
use_img:
pretrain 0.410
major 0.432
no img:
pretrain 0.499
major 0.439
'''

conST_embedding = conST_net.get_embedding()

np.save(f'{params.save_path}/conST_result.npy', conST_embedding)
# clustering
adata_conST = anndata.AnnData(conST_embedding)
adata_conST.uns['spatial'] = adata_h5.uns['spatial']
adata_conST.obsm['spatial'] = adata_h5.obsm['spatial']

sc.pp.neighbors(adata_conST, n_neighbors=params.eval_graph_n)

eval_resolution = res_search_fixed_clus(adata_conST, n_clusters)
print(eval_resolution)
cluster_key = "conST_leiden"
sc.tl.leiden(adata_conST, key_added=cluster_key, resolution=eval_resolution)

# plotting
savepath = f'{params.save_path}/conST_leiden_plot.jpg'
plot_clustering(adata_conST, cluster_key, savepath = savepath)

df_meta['conST'] = adata_conST.obs[cluster_key].tolist()
df_meta.to_csv(f'{params.save_path}/metadata.tsv', sep='\t', index=False)
df_meta = df_meta[~pd.isnull(df_meta['layer_guess'])]
ARI = metrics.adjusted_rand_score(df_meta['layer_guess'], df_meta['conST'])
print('===== Project: {} ARI score: {:.3f}'.format(data_name, ARI))

#%%

# index = np.arange(start=0, stop=adata_X.shape[0]).tolist()
# index = [str(x) for x in index]
#
# def refine(sample_id, pred, dis, shape="hexagon"):
# refined_pred=[]
# pred=pd.DataFrame({"pred": pred}, index=sample_id)
# dis_df=pd.DataFrame(dis, index=sample_id, columns=sample_id)
# if shape=="hexagon":
# num_nbs=6
# elif shape=="square":
# num_nbs=4
# else:
# print("Shape not recongized, shape='hexagon' for Visium data, 'square' for ST data.")
# for i in range(len(sample_id)):
# index=sample_id[i]
# dis_tmp=dis_df.loc[index, :].sort_values(ascending=False)
# nbs=dis_tmp[0:num_nbs+1]
# nbs_pred=pred.loc[nbs.index, "pred"]
# self_pred=pred.loc[index, "pred"]
# v_c=nbs_pred.value_counts()
# if (v_c.loc[self_pred]<num_nbs/2) and (np.max(v_c)>num_nbs/2):
# refined_pred.append(v_c.idxmax())
# else:
# refined_pred.append(self_pred)
# return refined_pred
#
# dis = graph_dict['adj_norm'].to_dense().numpy() + np.eye(graph_dict['adj_norm'].shape[0])
# refine = refine(sample_id = index, pred = adata_conST.obs['leiden'].tolist(), dis=dis)
# adata_conST.obs['refine'] = refine
#
# #%%
#
# cluster_key = 'refine'
# savepath = f'{params.save_path}/conST_leiden_plot_refined.jpg'
# plot_clustering(adata_conST, cluster_key, savepath = savepath)
#
# df_meta = pd.read_csv(f'{data_root}/{data_name}/metadata.tsv', sep='\t')
# df_meta['conST_refine'] = adata_conST.obs['refine'].tolist()
# df_meta.to_csv(f'{params.save_path}/metadata.tsv', sep='\t', index=False)
# df_meta = df_meta[~pd.isnull(df_meta['layer_guess'])]
# ARI = metrics.adjusted_rand_score(df_meta['layer_guess'], df_meta['conST_refine'])
# print('===== Project: {} refined ARI score: {:.3f}'.format(data_name, ARI))

#%% md

#
6 changes: 3 additions & 3 deletions src/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ def cluster(data, k, temp, num_iter, init, cluster_temp):
cuda0 = torch.cuda.is_available()

if cuda0:
mu = init.cuda()
data = data.cuda()
cluster_temp = cluster_temp.cuda()
mu = init.to('cuda:1')
data = data.to('cuda:1')
cluster_temp = cluster_temp.to('cuda:1')
else:
mu = init

Expand Down
39 changes: 29 additions & 10 deletions src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,19 @@


class conST(nn.Module):
def __init__(self, input_dim, params, n_clusters, dim, use_img):
def __init__(self, input_dim, params, n_clusters, dim, use_img, input_img_dim=768):
super(conST, self).__init__()
self.alpha = 1.0
self.latent_dim = params.gcn_hidden2 + params.feat_hidden2
self.tau = 0.5
self.n_clusters = n_clusters
self.dim = dim
self.params = params
self.use_img = use_img

if self.use_img:
self.latent_dim = params.gcn_hidden2 + params.feat_hidden2 * 2
else:
self.latent_dim = params.gcn_hidden2 + params.feat_hidden2

# feature autoencoder
self.encoder = nn.Sequential()
Expand All @@ -30,11 +34,11 @@ def __init__(self, input_dim, params, n_clusters, dim, use_img):
# img feature autoencoder
if self.use_img:
self.img_encoder = nn.Sequential()
self.img_encoder.add_module('img_encoder_L1', full_block(input_dim, params.feat_hidden1, params.p_drop))
self.img_encoder.add_module('img_encoder_L1', full_block(input_img_dim, params.feat_hidden1, params.p_drop))
self.img_encoder.add_module('img_encoder_L2', full_block(params.feat_hidden1, params.feat_hidden2, params.p_drop))

self.img_decoder = nn.Sequential()
self.img_decoder.add_module('img_decoder_L0', full_block(self.latent_dim, input_dim, params.p_drop))
self.img_decoder.add_module('img_decoder_L0', full_block(self.latent_dim, input_img_dim, params.p_drop))

# GCN layers
if self.use_img:
Expand All @@ -46,19 +50,34 @@ def __init__(self, input_dim, params, n_clusters, dim, use_img):
self.dc = InnerProductDecoder(params.p_drop, act=lambda x: x)

# DEC cluster layer
self.cluster_layer = Parameter(torch.Tensor(params.dec_cluster_n, params.gcn_hidden2 + params.feat_hidden2))
if use_img:
self.cluster_layer = Parameter(torch.Tensor(params.dec_cluster_n,
params.gcn_hidden2 + params.feat_hidden2 * 2))
else:
self.cluster_layer = Parameter(torch.Tensor(params.dec_cluster_n, params.gcn_hidden2 + params.feat_hidden2))
torch.nn.init.xavier_normal_(self.cluster_layer.data)

# projection
self.fc1 = torch.nn.Linear(params.feat_hidden2, params.feat_hidden2 * 2)
self.fc2 = torch.nn.Linear(params.feat_hidden2 * 2, params.feat_hidden2)
if use_img:
self.fc1 = torch.nn.Linear(params.gcn_hidden2 + params.feat_hidden2 * 2, params.feat_hidden2 * 2)
self.fc2 = torch.nn.Linear(params.feat_hidden2 * 2, params.gcn_hidden2 + params.feat_hidden2 * 2)
else:
self.fc1 = torch.nn.Linear(params.feat_hidden2, params.feat_hidden2 * 2)
self.fc2 = torch.nn.Linear(params.feat_hidden2 * 2, params.feat_hidden2)

self.read = AvgReadout()
self.sigm = nn.Sigmoid()
self.cluster = Clusterator(params.feat_hidden2, K=self.n_clusters)
self.disc_c = Discriminator_cluster(params.feat_hidden2, params.feat_hidden2, n_nb=self.dim,
num_clusters=self.n_clusters)
self.disc = Discriminator(params.feat_hidden2)

if use_img:
self.disc_c = Discriminator_cluster(params.gcn_hidden2 + params.feat_hidden2 * 2,
params.gcn_hidden2 + params.feat_hidden2 * 2, n_nb=self.dim,
num_clusters=self.n_clusters)
self.disc = Discriminator(params.gcn_hidden2 + params.feat_hidden2 * 2)
else:
self.disc_c = Discriminator_cluster(params.feat_hidden2, params.feat_hidden2, n_nb=self.dim,
num_clusters=self.n_clusters)
self.disc = Discriminator(params.feat_hidden2)

def encode(self, x, adj, img=None):
feat_x = self.encoder(x)
Expand Down
Loading

0 comments on commit c3875c1

Please sign in to comment.