In [3]:
import os
import sys
import numpy as np
import argparse
import time
        
import SimpleITK as sitk
from scipy.io import loadmat
from scipy.optimize import least_squares, curve_fit, leastsq, root
from scipy.special import huber
from numpy.fft import fftn, ifftn
#from tqdm import trange, tqdm
from functools import partial
from multiprocessing import Pool, Lock, cpu_count, current_process
from multiprocessing.shared_memory import SharedMemory #V3.8+ only

import warnings
from sklearn.mixture import GaussianMixture
from scipy.stats import mode


In [None]:
import ants
def read_img_ants(path):
    img = ants.image_read(path)
    return img.numpy(), img.origin, img.spacing, img.direction

def write_img_ants(path, data, o, s, d):
    img = ants.from_numpy(data, spacing = s, origin = o, direction = d)
    img.to_file(path)

# Define a function that computes y_gm and y_wm and return the zero crossing point of y_gm and y_wm

def get_zero_crossing_point(fa,t1,pd,paths, o1, s1, d1):
    y_gm_mean = []
    y_wm_mean = []
    tis=[]
    for ti in range(0,3000,5):
        y= (pd* (1-(1-np.cos(fa))*np.exp(-ti/t1))) # INVERSION RECOVERY SEQUENCE - SIGNAL MODEL
        # Get masked y for gm and wm
        y_gm = y*gm
        y_wm = y*wm
        # Get the mean values for y_gm and y_wm for each TI
        y_gm_mean.append( np.mean(y[y_gm!=0]))
        y_wm_mean.append(np.mean(y[y_wm!=0]))
        tis.append(ti)
        #write_img_ants(paths+'y_'+str(ti)+'.nrrd', y, o1, s1, d1)

    # plot the y_gm and y_wm versus tis
    plt.figure(figsize=(10,10))
    plt.plot(tis, y_gm_mean, label='gm')
    plt.plot(tis, y_wm_mean, label='wm')
    plt.legend()
    # Also show the x-zis passing through origin
    plt.plot([0,3000],[0,0], 'k--')
    # show the grid
    plt.grid()
    # show the fine grid every 10 units
    plt.minorticks_on()
    plt.show()
    
    diff=np.array(y_gm_mean)+np.array(y_wm_mean)
    index = np.where(np.abs(diff)==np.min(np.abs(diff)))[0][0]
    print('Zero crossing point is at',index)
    print(index,y_gm_mean[index],y_wm_mean[index],tis[index])
    
    w=np.zeros(y.shape)
    for i in range( index*5-20,index*5+20, 5):
        y= (  (pd* (1-(1-np.cos(fa))*np.exp(-(i)/t1)))) # *((wm +gm)!=0)
        write_img_ants(paths+'/ti_'+str(i)+'.nrrd', (y), o1, s1, d1)
        z=1.0*(np.abs(y)<30)
        w = w + z
        write_img_ants(paths+'/ti_'+str(i)+'seg.nrrd', (z), o1, s1, d1)
    
    w=1.0*(w!=0)
    ids= 1.0*((wm+gm)!=0) * (w)
    write_img_ants(paths+'/combinedZerosSeg.nrrd', (ids), o1, s1, d1)

    y= (  (pd* (1-(1-np.cos(fa))*np.exp(-(tis[index])/t1)))) # *((wm +gm)!=0)
    write_img_ants(paths+'/ti_'+str(i)+'zerocrossing.nrrd', (y), o1, s1, d1)
    z=1.0*(np.abs(y)<30) * (wm + gm)
    write_img_ants(paths+'/ti_'+str(i)+'zerocrossing_seg.nrrd', (z), o1, s1, d1)

    return ids,z, tis[index]

# Call the function
ids,zerocrossingids,ti = get_zero_crossing_point(fa,t1,pd,'edge', o1, s1, d1)


In [4]:
write_img_ants('edge/Forinput.nrrd', ids, o1, s1, d1)
# fit a gmm model onto it
# define a funtion to save the nrrd file for given cluster
def save_t1ForClsuter(ids,t1, o1, s1, d1):
    zt1=t1[ids!=0]
    print(zt1.shape,t1.shape)
    hist, bin_edges = np.histogram(zt1, bins=[100,700,1300,3500])
    print(hist)
    print(bin_edges)
    
    # Assign labels to each value based on its bin
    bin_labels = 1.0*np.digitize(zt1.reshape(-1,1), bin_edges)
    print('HIIIIIIIIIIIIIIIIIII',(bin_labels))
    la=0.0*t1
    bin_labels=bin_labels.reshape(-1)
    print(np.sum(ids!=0),np.shape(la),np.shape(bin_labels[0]))
    la[ids!=0]=bin_labels
    write_img_ants('edge/T1ThresholdBasedMAskEDGES.nrrd', la, o1, s1, d1)

    # gmm = GaussianMixture(n_components=3).fit(zt1.reshape(-1,1))
    # labels = gmm.predict(zt1.reshape(-1,1))
    # # remove the label with maximum number of elements
    # print(np.unique(labels))
    # print(mode(labels))
    # fulllabels=0.0*t1
    # fulllabels[ids!=0]=labels
    # print(np.unique(fulllabels))

    # # print centroid and standard deviation of the three clusters
    # print(gmm.means_,gmm.covariances_)
    # write_img_ants('edge/T1ThresholdBasedMAskALL.nrrd', fulllabels, o1, s1, d1)
    means=np.zeros((4,1))
    stds=np.zeros((4,1))
    fulllabels=la
    for i in range(0,4):
        print(np.sum(fulllabels==i))
        # plot the histogram of t1 values in each cluster
        if np.any(bin_labels==i):
            print('HEREEEEEEEEEEEEEEEEEEEE',bin_labels)
            plt.figure(figsize=(10,10))
            print(zt1[bin_labels==i])
            plt.hist(zt1[bin_labels==i].reshape(-1,1), bins=100)
            plt.show()
            if np.sum(fulllabels==i)<4000000:
                roi=1.0*(fulllabels==i)
                t1t=t1*roi
                print(np.min(t1t[t1t!=0]),np.max(t1t[t1t!=0]))
                within = (t1 >=np.min(t1t[t1t!=0])) & (t1 <=np.max(t1t[t1t!=0])) 
                thresholdedt1=0.0*t1
                thresholdedt1[within!=0]=1.0
                # write nrrd image
                write_img_ants('edge/T1ThresholdBasedMAsk'+str(i)+'.nrrd', thresholdedt1, o1, s1, d1)
                means[i]=np.mean(zt1[bin_labels==i])
                stds[i]=np.std(zt1[bin_labels==i])
    return means, stds


# call the function
means,stds =save_t1ForClsuter(ids,t1, o1, s1, d1)

print(means,stds)

# choose the mean which lies inside [700, 1300]
chosenmean=means[(means>700) & (means<1300)]
chosenstd=stds[(means>700) & (means<1300)]
t1min= chosenmean-2*chosenstd
t1max= chosenmean+2*chosenstd
print(t1min,t1max)
ids2=0.0*ids
ids2[(t1>t1min) & (t1<t1max)]=1

write_img_ants('edge/t1mapBasedCuttOFF.nrrd', (ids2), o1, s1, d1)

# clear output of cell
from IPython.display import clear_output
clear_output()