In [1]:
import matplotlib.pyplot as plt
from PIL import Image
import os
import math
import sys
sys.path.insert(0, '../Utils/')
from Utils import create_dir_if_not_exists

In [2]:
def put_multiple_plot_in_one_figure(in_path,all_img_file_list, selected_group, out_path):

    selected_files = [x for x in all_img_file_list if selected_group in x]
    file_list = [in_path + "/" + x for x in selected_files]
    site_list = [x.split('_')[-1].split('.')[0] for x in file_list]
    id_list = ['OPX_' + x.split('_')[1].split('.')[0] for x in selected_files]
    zipped_lists = list(zip(file_list, site_list, id_list))
    sorted_zipped_lists = sorted(zipped_lists, key=lambda x: x[1]) #sort by site
    
    # Load images
    images = [Image.open(file) for file, site, pt in sorted_zipped_lists]
    sites  = [site for file, site, pt in sorted_zipped_lists]
    ids  = [pt for file, site, pt in sorted_zipped_lists]

    
    # Determine the number of rows needed
    num_images = len(images)
    num_cols = 4
    num_rows = math.ceil(num_images / num_cols)
    
    # Create a figure with subplots
    fig, axes = plt.subplots(num_rows, num_cols, figsize=(15, 5 * num_rows))
    
    # Flatten axes array for easy iteration
    axes = axes.flatten()
    
    # Display each image in a subplot
    for ax, img, site, pt in zip(axes, images, sites, ids):
        ax.imshow(img)
        ax.set_title(site + '(' + pt + ')')
        ax.axis('off')  # Hide axes
    
    # Hide any unused subplots
    for ax in axes[num_images:]:
        ax.axis('off')
        
    # Show the figure
    fig.suptitle(selected_group, fontsize=16)
    #plt.show()
    plt.savefig(os.path.join(out_path, selected_group + '.png'), dpi=500,bbox_inches='tight')
    plt.close()

In [3]:
####################################
######      USERINPUT       ########
####################################
SELECTED_LABEL = ["AR","MMR (MSH2, MSH6, PMS2, MLH1, MSH3, MLH3, EPCAM)2","PTEN","RB1","TP53","TMB_HIGHorINTERMEDITATE","MSI_POS"]
SELECTED_FEATURE = [str(i) for i in range(0,2048)] + ['TUMOR_PIXEL_PERC']
TUMOR_FRAC_THRES = 0.9
TRAIN_SAMPLE_SIZE = "ALLTUMORTILES"
TRAIN_OVERLAP = 100
TEST_OVERLAP = 0
SELECTED_FOLD = 0
CLUSTER_ALG = 'KMEAN'
N_CLUSTERS = 4
CLUSTER_DIST = 'L2'
feature_extraction_method = 'retccl'
save_name = "_NCLUSTER_" + str(N_CLUSTERS) +  "_DISTMETRIC_" + CLUSTER_DIST

##################
###### DIR  ######
##################
proj_dir = '/fh/fast/etzioni_r/Lucas/mh_proj/mutation_pred/'
data_dir = proj_dir + 'intermediate_data/model_ready_data/feature_' + feature_extraction_method + '/MAXSS'+ str(TRAIN_SAMPLE_SIZE)  + '_TrainOL' + str(TRAIN_OVERLAP) +  '_TestOL' + str(TEST_OVERLAP) + '_TFT' + str(TUMOR_FRAC_THRES) + "/split_fold" + str(SELECTED_FOLD) + "/"
image_paths =   os.path.join(data_dir, "spatial_model_input", save_name, "heatmaps")

################################################
#Create output dir
################################################
outdir =   os.path.join(image_paths, "Group_Plot")
create_dir_if_not_exists(outdir)

Directory '/fh/fast/etzioni_r/Lucas/mh_proj/mutation_pred/intermediate_data/model_ready_data/feature_retccl/MAXSSALLTUMORTILES_TrainOL100_TestOL0_TFT0.9/split_fold0/spatial_model_input/_NCLUSTER_4_DISTMETRIC_L2/heatmaps/Group_Plot' created.


In [4]:
#Get all image file 
all_image_files = os.listdir(image_paths)

#Label Postive Group plot
selected_groups = ["AR1","MMR (MSH2, MSH6, PMS2, MLH1, MSH3, MLH3, EPCAM)21","PTEN1","RB11","TP531","TMB_HIGHorINTERMEDITATE1","MSI_POS1"] #Postive groups

for group in selected_groups:
    put_multiple_plot_in_one_figure(image_paths,all_image_files,group, outdir)

ValueError: Number of rows must be a positive integer, not 0

In [None]:
all_site_names = list(set([x.split('_')[-1].split('.png')[0] for x in all_image_files]))
prostate_files = [x for x in all_image_files if 'Prostate' in x]
na_files = [x for x in all_image_files if 'NA' in x]
st_files = [x for x in all_image_files if 'Soft Tissue' in x]
ln_files = [x for x in all_image_files if 'Lymph Node' in x]
lung_files = [x for x in all_image_files if 'Lung' in x]
liver_files = [x for x in all_image_files if 'Liver' in x]
rectum_files = [x for x in all_image_files if 'Rectum' in x]
bone_files = [x for x in all_image_files if 'Bone' in x]
brain_files = [x for x in all_image_files if 'Brain' in x]
print("Prostate",len(prostate_files))
print("NA",len(na_files))
print("Soft Tissue",len(st_files))
print("Lymph Node",len(ln_files))
print("Lung",len(lung_files))
print("Liver",len(liver_files))
print("Rectum",len(rectum_files))
print("Bone",len(bone_files))
print("Brain",len(brain_files))