In [1]:
import sys,argparse
import pdb
import matplotlib.pyplot as plt
import numpy as np
import time
import tensorflow as tf
import pandas as pd
from sklearn.metrics import homogeneity_completeness_v_measure
from utils.plots import *
from utils.metrics import *
import h5py
from disca_dataset.DISCA_visualization import *
import pickle, mrcfile
import scipy.ndimage as SN
from PIL import Image
from collections import Counter
from disca.DISCA_gmmu_cavi_llh_scanning_new import *
from GMMU.gmmu_cavi_stable_new import CAVI_GMMU as GMM
from config import *
from tqdm import *
import cv2
import warnings
warnings.filterwarnings("ignore")

np.random.seed(42)
color=['#6A539D','#E6D7B2','#99CCCC','#FFCCCC','#DB7093','#D8BFD8','#6495ED',\
'#1E90FF','#7FFFAA','#FFFF00','#FFA07A','#FF1493','#B0C4DE','#00CED1','#FFDAB9','#DA70D6']
color=np.array(color)

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "4,5"

In [2]:
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config_yaml', type=str, default='/local/scratch/v_yijian_bai/disca/deepgmmu/config/train.yaml', help='YAML config file')
config_parser = parser.parse_args(args = []) #parser.parse_args() in py file
args = jupyter_parse_args_yaml(config_parser)

In [3]:
label_path = args.saving_path+'/results'
model_path = args.saving_path+'/models'
label_names = ['labels_'+args.algorithm_name]
figures_path = args.saving_path+'/figures/'+label_names[0]
infos = pickle_load(args.data_path+'/info.pickle')
v = read_mrc_numpy_vol(args.data_path+'/emd_4603.map')
algorithms = ['classificationmodel']+args.algorithm_name.split('_')
v = (v - np.mean(v))/np.std(v)
vs = []
s = 32//2

In [4]:
#trained model
model_names = []
for model_name in os.listdir(model_path):
    algo = model_name.split('_')[:len(algorithms)]
    if algo == algorithms :
        model_names.append(os.path.splitext(model_name)[0])
model_names

['classificationmodel_gmmu_cavi_llh_hist_M_20_lr_0.01_reg_1e-06']

In [5]:
#extracted particles
h5f = h5py.File(args.filtered_data_path,'r')                                                        
x_train = h5f['dataset_1'][:] # only 'dataset_1'                              
h5f.close()
print(x_train.shape)

(16265, 24, 24, 24, 1)


In [6]:
infonp = np.array(infos)
print(set(infonp[:,0]),infonp.shape)

{'emd_4604.map', 'emd_4603.map'} (16265, 3)


In [7]:
#visualization using classification NN
for model_name in model_names:
    classmodelpath = os.path.join(model_path,model_name)+'.h5'
    yopopath = os.path.join(model_path,'deltamodel_'+'_'.join(model_name.split('_')[1:]))+'.h5'
    #gpath = os.path.join(model_path,'gmmumodel_'+'_'.join(model_name.split('_')[1:]))+'.h5'
    figure_path = os.path.join(figures_path,'_'.join(model_name.split('_')[1:]))
    if not os.path.isdir(figure_path):
        os.makedirs(figure_path)
    
    yopo = tf.keras.models.load_model(yopopath, custom_objects={'CosineSimilarity': CosineSimilarity})
    classmodel = tf.keras.models.load_model(classmodelpath, custom_objects={'CosineSimilarity': CosineSimilarity,\
                                                              'SNN': SNN,\
                                                              'NSNN': NSNN})
    features = yopo.predict([x_train, x_train, x_train])[0]
    labels_soft = classmodel.predict([features, features, features])[0]
    labels = np.array([np.argmax(labels_soft[q, :]) for q in range(len(labels_soft))])

    for i in tqdm(range(np.max(labels) + 1)):
        #print(model_name, i)
        locs = np.array(infos)[labels == i]
        v_i = np.zeros_like(v)
        for j in locs:
            if j[0] == 'emd_4603.map': #emd_4603_deconv_corrected.mrc / emd_4603.map
                v_i[j[2][0] - s: j[2][0] + s, j[2][1] - s: j[2][1] + s, j[2][2] - s: j[2][2] + s] = \
                v[j[2][0] - s: j[2][0] + s, j[2][1] - s: j[2][1] + s, j[2][2] - s: j[2][2] + s]
        save_png(cub_img(v_i[:,:,::15])['im'], os.path.join(figure_path, 'NN'+str(i) + model_name + '.png'))



100%|██████████| 12/12 [00:43<00:00,  3.61s/it]


In [8]:
#visualization using GMMU
for model_name in model_names:
    classmodelpath = os.path.join(model_path,model_name)+'.h5'
    yopopath = os.path.join(model_path,'deltamodel_'+'_'.join(model_name.split('_')[1:]))+'.h5'
    #gpath = os.path.join(model_path,'gmmumodel_'+'_'.join(model_name.split('_')[1:]))+'.h5'
    figure_path = os.path.join(figures_path,'_'.join(model_name.split('_')[1:]))
    if not os.path.isdir(figure_path):
        os.makedirs(figure_path)
    
    yopo = tf.keras.models.load_model(yopopath, custom_objects={'CosineSimilarity': CosineSimilarity})
    classmodel = tf.keras.models.load_model(classmodelpath, custom_objects={'CosineSimilarity': CosineSimilarity,\
                                                              'SNN': SNN,\
                                                              'NSNN': NSNN})
    features = yopo.predict([x_train, x_train, x_train])[0]
    # you can set replacce args.candidateKs with some K
    labels_temp_proba, labels_temp, K, same_K, features_pca, gmm = \
            statistical_fitting_tf_split_merge(features = np.squeeze(features), \
                                               labels = None, candidateKs = args.candidateKs,\
                                                    K = None, reg_covar = args.reg_covar, it = 0,\
                                                    u_filter_rate=args.u_filter_rate, alpha = args.alpha)
    labels_soft = labels_temp_proba
    labels = labels_temp

    for i in tqdm(range(np.max(labels) + 1)):
        locs = np.array(infos)[labels == i]
        v_i = np.zeros_like(v)
        for j in locs:
            if j[0] == 'emd_4603.map': #emd_4603_deconv_corrected.mrc / emd_4603.map
                v_i[j[2][0] - s: j[2][0] + s, j[2][1] - s: j[2][1] + s, j[2][2] - s: j[2][2] + s] = \
                v[j[2][0] - s: j[2][0] + s, j[2][1] - s: j[2][1] + s, j[2][2] - s: j[2][2] + s]
        save_png(cub_img(v_i[:,:,::15])['im'], os.path.join(figure_path, 'GMMU'+str(i) + model_name + '.png'))



Estimated K: 13


100%|██████████| 14/14 [00:53<00:00,  3.81s/it]
