In [None]:
from warnings import filterwarnings 
filterwarnings('ignore')

from multiprocessing import Process, Pool
import os,time,random 
import cv2
import joblib
import skimage 
import numpy as np
import pandas as pd
from tqdm import tqdm
from PIL import Image
import skimage.io as io
import tifffile as tiff
import scipy.ndimage as nd
import scipy.ndimage as ndi
from scipy.io import loadmat
import matplotlib.pyplot as plt
from skimage.measure import label
import skimage.transform as trans
from skimage import measure,color
from PIL.Image import fromarray as show
from sklearn.model_selection import StratifiedKFold
from skimage.morphology import remove_small_objects

import torch 
import scipy.sparse as sp
from torch_geometric.data import Data
from sklearn.neighbors import kneighbors_graph

data_root = '../data/Brain_Data/'
raw_celltype_dir = os.path.join(data_root,'BRAIN_IMC_CellType')
raw_img_dir = os.path.join(data_root,'BRAIN_IMC_MaskTif')
cell_mask_dir = os.path.join(data_root,'BRAIN_IMC_CELL_MASK')

In [None]:
def display(cv2img):
    pilimg = Image.fromarray(cv2.cvtColor(cv2img,cv2.COLOR_BGR2RGB))
    plt.figure()
    plt.imshow(pilimg)
    plt.show()
def gen_x_y(nlist,imgshape):
    boundary_list = []
    for pi in range(len(nlist)):
        boundary_list.append([(nlist[pi][0]-1)//imgshape[0],(nlist[pi][0]-1)%imgshape[0]])
    boundary_list = np.array(boundary_list)
    return boundary_list

In [None]:


for mat_name in os.listdir(raw_celltype_dir):
    x = loadmat(os.path.join(raw_celltype_dir,mat_name))
    p = skimage.io.imread(os.path.join(raw_img_dir,mat_name[:-4]+'.tif'))
    # dna_channel = p[12,:,:]
    # display(cv2.merge([dna_channel,dna_channel,dna_channel])) # Comparison with cell_mask 
    all_mask = np.zeros((p.shape[1],p.shape[2])).astype(np.uint32)
    for cell_i in range(x["Boundaries"].shape[1]):
        single_draft = np.zeros((p.shape[1],p.shape[2],3)).astype(np.uint8)
        boundaries_x_y = gen_x_y(x["Boundaries"][0,cell_i],(p.shape[1],p.shape[2]))
        cv2.fillPoly(single_draft,[boundaries_x_y],color = (255,255,255))
        all_mask[single_draft[:,:,0]>0]=cell_i+1 # instance_label_for_cell_mask 

    Image.fromarray(all_mask).save(os.path.join(cell_mask_dir,mat_name[:-4]+'.png'))


In [None]:
# show example 
x = cv2.imread(os.path.join(cell_mask_dir,'BrM_001C6.png'),-1)
x[x>0]=255
x = x.astype(np.uint8)
display(cv2.merge([x,x,x]))

In [None]:
def measure_markers(raw_img,raw_mask,measure_label=False):
    if measure_label:
        labels=measure.label(raw_mask,connectivity=2)
    else:
        labels=raw_mask
        
    props = measure.regionprops(labels)
     
    mean_exp_list=[]
    labels_index=list(np.unique(labels))
    labels_index.pop(0)
    cell_pos = []
    cell_area = []
    cell_perimeter = []
    cell_minor_axis = []
    cell_major_axis = []
    for i in labels_index:
        imask=raw_mask.copy()
        imask[labels!=i]=0
        imask[labels==i]=255
        n_pixel=len(imask[imask==255])
        # imasked = cv2.add(raw_img, np.zeros(np.shape(raw_img), dtype=np.float64), mask=imask.astype(np.uint8))
        # avg=np.mean(imasked,axis=(0,1))
        avg=np.sum(raw_img[:,imask>0],axis=1)/n_pixel
        mean_exp_list.append(avg)
        cell_pos.append(props[i-1].centroid)
        cell_area.append(props[i-1].area)
        cell_perimeter.append(props[i-1].perimeter)
        cell_minor_axis.append(props[i-1].minor_axis_length)
        cell_major_axis.append(props[i-1].major_axis_length)
        
        
    mean_exp_matrix=np.matrix(mean_exp_list)
    
    labels_prop = {'index':labels_index, 'pos':cell_pos, 'area':cell_area, 'perimeter':cell_perimeter, 'minor_axis':cell_minor_axis,
                  'major_axis':cell_major_axis}
    return labels_prop,mean_exp_matrix

In [None]:
target_dir = os.path.join(data_root,'/gnn_data')
df = pd.read_excel(os.path.join(data_root,'/Clinical Data GBM BrM.xlsx').set_index('Image ID')

def gen_edge_index_weight(pos_t):
    A = kneighbors_graph(list(pos_t.numpy()),n_neighbors=5,mode='connectivity',include_self=False)
    edge_index_temp = sp.coo_matrix(A)
    indices = np.vstack((edge_index_temp.row,edge_index_temp.col))
    edge_index_A = torch.LongTensor(indices)

    dist = pos_t[edge_index_A[0]]-pos_t[edge_index_A[1]] 
    dist = dist * dist
    dists2 = dist[:,0]+dist[:,1]
    alpha = -0.05
    ed_weight = torch.exp(alpha*dists2)

    return edge_index_A,ed_weight

labels_dict = {}

for imgname in tqdm(os.listdir(raw_img_dir)):

    if ('Glioma' in imgname) and (not pd.isna(df.loc[imgname[:-4]]['OS event (days post surgery)'])) and ('death' in df.loc[imgname[:-4]]['OS event (days post surgery)']):
        os_time = int(df.loc[imgname[:-4]]['OS event (days post surgery)'][:-1].split('(')[-1])
        if os_time>=365*3: # as Long Term Survive
            y_label = 1
        elif os_time<=365*2: # as Short Term Survive
            y_label = 0
        else:
            continue
    else:
        print(imgname,'skip') # OS NA or Unknown
        continue
    
    raw_mask = cv2.imread(os.path.join(cell_mask_dir,imgname[:-4]+'.png'),-1)
    raw_img = skimage.io.imread(os.path.join(raw_img_dir,imgname[:-4]+'.tif'))
    raw_celltype = loadmat(os.path.join(raw_celltype_dir,imgname[:-4]+'.mat'))

    labels_prop,mean_exp_matrix = measure_markers(raw_img,raw_mask)
    cell_type_list = [list(x)[0] if len(x)==1 else 'Uknown' for x in raw_celltype['cellTypes'][:,0]]

    x_fea = torch.Tensor(mean_exp_matrix)
    pos_list = [[pos[0],pos[1]] for pos in labels_prop['pos']]
    pos_tensor = torch.Tensor(pos_list)
    area_tensor = torch.Tensor(labels_prop['area'])
    perimeter_tensor = torch.Tensor(labels_prop['perimeter'])
    minor_axis_tensor = torch.Tensor(labels_prop['minor_axis'])
    major_axis_tensor = torch.Tensor(labels_prop['major_axis'])

    edge_index_t,edge_weight_t = gen_edge_index_weight(pos_tensor)


    gnn_data = Data(Image_ID=imgname[:-4],Paitent_ID=df.loc[imgname[:-4]]['Patient ID'],x=x_fea, edge_index=edge_index_t, pos=pos_tensor, edge_weight=edge_weight_t.float(), area=area_tensor, cell_type_final=cell_type_list,perimeter=perimeter_tensor,minor_axis=minor_axis_tensor,major_axis=major_axis_tensor,OS=os_time,y=y_label)
    torch.save(gnn_data,os.path.join(target_dir,imgname[:-4]+'.pkl'))
    labels_dict[imgname[:-4]]=y_label

In [None]:
os.makedirs(os.path.join(data_root,'label_and_fold'),exist_ok=True)
label_dicts = {}
for gnnname in os.listdir(target_dir):
    g = torch.load(os.path.join(target_dir,gnnname))
    if g.y!=0 and g.y!=1:
        continue
    label_dicts[gnnname]=g.y

joblib.dump(label_dicts,os.path.join(data_root,'label_and_fold','OS_label_dict.pkl'))


paitents_ = []
paitents_to_Y = {}
for gnnname in os.listdir(target_dir):
    g = torch.load(os.path.join(target_dir,gnnname))
    if g.y!=0 and g.y!=1:
        continue
    paitents_.append(g.Paitent_ID)
    paitents_to_Y[g.Paitent_ID]=g.y
paitents = list(set(paitents_))
Y = [paitents_to_Y[p] for p in paitents]

cnt = 0
fold_num = 5
random_seed = 0 
cvkFold = StratifiedKFold(n_splits=fold_num, shuffle = True, random_state=random_seed)
final_img_cv_dict = {}
for train_idx,test_idx in cvkFold.split(paitents,Y):
    final_img_cv_dict['fold'+str(cnt)+'_train']=[]
    final_img_cv_dict['fold'+str(cnt)+'_val']=[]
    final_img_cv_dict['fold'+str(cnt)+'_test']=[]
    train_paitents = [paitents[i] for i in train_idx]
    train_Y = [Y[i] for i in train_idx]
    test_paitents = [paitents[i] for i in test_idx]
    test_Y = [Y[i] for i in test_idx]
    
    val_in_train_idx = random.sample(list(range(len(train_paitents))),len(train_paitents)//fold_num)
    val_paitents = [train_paitents[i] for i in val_in_train_idx]
    val_Y = [train_Y[i] for i in val_in_train_idx]

    new_train_paitents = list(set(train_paitents)-set(val_paitents))
    new_train_Y = [paitents_to_Y[p] for p in new_train_paitents]
    
    # print(len(new_train_paitents),len(new_train_Y))
    # print(len(val_paitents),len(val_Y))
    # print(len(test_paitents),len(test_Y))
    # print(set(new_train_paitents) & set(test_paitents))

    for gnnname in os.listdir(target_dir):
        g = torch.load(os.path.join(target_dir,gnnname))    
        if g.Paitent_ID in new_train_paitents:
            final_img_cv_dict['fold'+str(cnt)+'_train'].append(gnnname)
        elif g.Paitent_ID in val_paitents:
            final_img_cv_dict['fold'+str(cnt)+'_val'].append(gnnname)
        elif g.Paitent_ID in test_paitents:
            final_img_cv_dict['fold'+str(cnt)+'_test'].append(gnnname)

    cnt += 1 

joblib.dump(final_img_cv_dict,os.path.join(data_root,'label_and_fold','{}_fold_for_OS.pkl'.format(fold_num))
