# Physics-Informed Autoencoders for Virtual Pathology of Prostate

In [1]:
from PIA import PIA, ADC_slice, get_batch, density_scatter
import numpy as np
from torch import optim
import torch
import itertools

### Create the test image

In [None]:
def three_compartment_fit(M, D_ep, D_st, D_lu, T2_ep,  T2_st, T2_lu, V_ep, V_st):
    """
    
    Three-compartment fit for Hybrid estimation
    
    """
    b, TE = M
    S_ep = V_ep*np.exp(-b/1000*D_ep)*np.exp(-TE/T2_ep)
    S_st = V_st*np.exp(-b/1000*D_st)*np.exp(-TE/T2_st)
    S_lu =(1 - V_ep - V_st)*np.exp(-b/1000*D_lu)*np.exp(-TE/T2_lu)
    
    return 1000*(S_ep + S_st + S_lu)

In [None]:
from tqdm import tqdm
from scipy.optimize import curve_fit
def hybrid_fit(signals):
    bvals = [0, 150, 1000, 1500]
    normTE = [0, 13, 93, 143]
    eps = 1e-7;
    numcols, acquisitions = signals.shape
    D = np.zeros((numcols, 3))
    T2 = np.zeros((numcols, 3))
    v = np.zeros((numcols, 3))
    for col in tqdm(range(numcols)):
        voxel = signals[col]
        X, Y = np.meshgrid(normTE, bvals)
        xdata = np.vstack((Y.ravel(), X.ravel()))
        ydata = voxel.ravel()
        try:
            fitdata_, _  = curve_fit(three_compartment_fit, 
                                       xdata,
                                       ydata,
                                       p0 = [0.55, 1.3, 2.8, 50,  70, 750, 0.3, 0.4],
                                       check_finite=True,
                                       bounds=([0.3, 0.7, 2.7, 20,  40, 500, 0, 0],
                                               [0.7,  1.7, 3.0, 70,  100, 1000,1, 1]),
                                      method='dogbox',
                                      maxfev=5000)
        except RuntimeError:
            fitdata_ = [0.55,  1.3, 2.8, 50,  70, 750, 0.3, 0.4]
        coeffs = fitdata_
        D[col, :] = coeffs[0:3]
        T2[col, :] = coeffs[3:6]
        v[col, 0:2] = coeffs[6:]
        v[col, 2]  = 1 - coeffs[6] - coeffs[7]
    return D, T2, v

In [None]:
!pip install scikit-image

In [None]:
from tqdm import tqdm
def ADC_slice(bvalues, slicedata):
    min_adc = 0
    max_adc = 3.0
    eps = 1e-7
    numrows, numcols, numbvalues = slicedata.shape
    adc_map = np.zeros((numrows, numcols))
    for row in range(numrows):
        for col in range(numcols):
            ydata = np.squeeze(slicedata[row,col,:])
            adc = np.polyfit(bvalues.flatten()/1000, np.log(ydata + eps), 1)
            adc = -adc[0]
            adc_map[row, col] =  max(min(adc, max_adc), min_adc)
    return adc_map


def detect_PIDS_slice(b, S):
    """ Inputs: b - diffusion weight values used in image
                S - Hybrid Multi-dimensional image
        Outputs:
                PIDS_ADC1 : Binary Map with voxels ADC > 3 (could mean motion induced signal loss at high-b)
                PIDS_ADC2 : Binary Map with voxels ADC < 0 (could mean the voxel is below the noise level)
                PIDS_b_decay : Binary Map with voxels disobeying decay rule along b direction
                PIDS_TE_decay : Binary Map with voxels disobeying decay rule along TE direction
    """
    
    eps = 1e-7
    localize = np.eye(4)
    num_rows, num_cols, num_bvalues, num_TEs = S.shape
    PIDS_ADC1 = np.zeros((num_rows, num_cols))
    PIDS_ADC2 = np.zeros((num_rows, num_cols))
    PIDS_b_decay = np.zeros((num_rows, num_cols, num_TEs, 3))
    PIDS_TE_decay = np.zeros((num_rows, num_cols, num_bvalues, 3))
    for row in tqdm(range(num_rows)):
         for col in range(num_cols):
            te0 = np.squeeze(S[row,col, :, 0])
            adc = np.polyfit(b.flatten()/1000, np.log(te0 + eps), 1)
            adc = -adc[0]    
            PIDS_ADC1[row, col] = int(adc > 3)
            PIDS_ADC2[row, col] = int(adc < 0)
            for _b in range(num_bvalues):
                signals_along_te = np.squeeze(S[row,col, _b, :])
                to_compare = signals_along_te.copy().astype(int)
                to_compare[1:] = signals_along_te[:3]
                is_pids = signals_along_te - to_compare
                for local in range(3):
                    is_pids_ = int(is_pids[local + 1]>=0)
                    PIDS_TE_decay[row, col, _b, local] = is_pids_
            for _te in range(num_TEs):
                signals_along_b = np.squeeze(S[row,col, :, _te])
                to_compare = signals_along_b.copy().astype(int)
                to_compare[1:] = signals_along_b[:3]
                is_pids = signals_along_b - to_compare
                for local in range(3):
                    is_pids_ = int(is_pids[local + 1]>=0)
                    PIDS_b_decay[row, col, _te, local] = is_pids_

    return PIDS_ADC1, PIDS_ADC2, PIDS_b_decay, PIDS_TE_decay

class color:
   PURPLE = '\033[95m'
   CYAN = '\033[96m'
   DARKCYAN = '\033[36m'
   BLUE = '\033[94m'
   GREEN = '\033[92m'
   YELLOW = '\033[93m'
   RED = '\033[91m'
   BOLD = '\033[1m'
   UNDERLINE = '\033[4m'
   END = '\033[0m'
from skimage import morphology
PIDS_1, PIDS_2, PIDS_3, PIDS_4 = detect_PIDS_slice(b_values, hybrid_data[40:90, 40:90, 15, :, :])
PIDS = PIDS_1.astype(float) + PIDS_2.astype(float)
for i in range(0,4):
    for j in range(0,2):
        PIDS += PIDS_3[:,:,i, j].astype(float)

PIDS = 1 - (PIDS>0).astype(float)
PIDS = morphology.remove_small_objects(PIDS.astype(bool), min_size=50, connectivity=50)
PIDS = morphology.remove_small_holes(PIDS, area_threshold=50, connectivity=1)
plt.imshow(PIDS)

In [None]:
!pip install ipywidgets

In [None]:
test, D_test2, T2_test2, v_test2, _ = get_batch(2500)

In [None]:
test = np.squeeze(hybrid_data[40:90,40:90, 10, :])
test = np.reshape(test, (2500, 16))

In [None]:
import time
#test = test.detach().cpu().numpy()
start = time.time()
D, T2, v = hybrid_fit(test)
end = time.time()
print(end - start)
bins = (50, 50)
fig, ax = plt.subplots(3,3, figsize=(15,15))
for r in range(3):
    for c in range(3):
        if r==0:
            x_image =  v
            title = ['V_ep', 'V_st', 'V_lu']

        elif r==1:
            x_image =  D
            title = ['D_ep', 'D_st', 'D_lu']
            ylims = [(0.3, 0.7), (0.7, 1.7), (2.7, 3)]

        else:
            x_image = T2
            title = ['T2_ep', 'T2_st', 'T2_lu']
            ylims = [(20, 70), (40, 100), (500, 1000)]

        ax[r,c].imshow(np.reshape(x_image[:,c], bins), cmap='jet')
        ax[r,c].set_title(fr'{title[c]}')
display.display(plt.gcf())
#display.clear_output(wait=True)

In [None]:
v_ep = np.reshape(v[:,0], bins) 
v_lu = np.reshape(v[:,2], bins)
cancer = (v_ep > 0.4)*(v_lu <= 0.2)
from skimage import morphology
from skimage import segmentation
adc_map = ADC_slice(b_values, 
                    np.squeeze(hybrid_data[40:90,40:90, 10, :, 0]))
fig, ax = plt.subplots(1, figsize=(6,6))
fig.suptitle('predicted cancer map')
ax.imshow(adc_map, cmap='gray')
cancer_map = np.multiply(cancer.astype(float), PIDS.astype(float))
cancer_map = morphology.remove_small_objects(cancer_map.astype(bool), min_size=12, connectivity=1)
cancer_map = cancer_map.astype(float)
cancer_map[cancer_map==0] = np.nan
ax.imshow(cancer_map, cmap='autumn',alpha = 0.4)
ax.axis('off')

In [None]:
import time
test = test.detach().cpu().numpy()
start = time.time()
D, T2, v = hybrid_fit(test)
end = time.time()
print(end - start)

fig, ax = plt.subplots(3,3, figsize=(15,15))
for r in range(3):
    for c in range(3):
        if r==0:
            x_image, y_image = v_test2, v
            title = ['V_ep', 'V st', 'V lu']
            ylims = [(0,1), (0,1), (0,1)]
        elif r==1:
            x_image, y_image = D_test2, D
            title = ['D ep', 'D st', 'D lu']
            ylims = [(0.3, 0.7), (0.7, 1.7), (2.7, 3)]

        else:
            x_image, y_image = T2_test2, T2
            title = ['T2 ep', 'T2 st', 'T2 lu']
            ylims = [(20, 70), (40, 100), (500, 1000)]

        density_scatter(x_image.detach().cpu().numpy()[:,c] , 
                        y_image[:, c], ax = ax[r,c], sort = True, bins = 15)
        err = np.mean(np.abs(x_image.detach().cpu().numpy()[:,c]-y_image[:, c]))
        corr = np.corrcoef(x_image.detach().cpu().numpy()[:,c],y_image[:, c])[0,1]

        ax[r,c].set_title(fr'{title[c]}, MAE = {err:.3f}, $\rho$ = {corr:.3f}')
        ax[r,c].set_xlabel('true')
        ax[r,c].set_ylabel('predicted')
        #ax[r,c].set_ylim(ylims[c])

In [None]:
from IPython import display
from scipy.stats import ttest_rel
%matplotlib inline
model =  PIA(predictor_depth=2)
params = list(model.encoder.parameters()) + list(model.v_predictor.parameters()) + list(model.D_predictor.parameters()) + list(model.T2_predictor.parameters())
optimizer = optim.Adam(params, lr=0.00005)



ctr = 1
total_loss = 0 
test, D_test2, T2_test2, v_test2, _ = get_batch(2500)
test = test.cuda()
for ep in range(50000):
 
    x, D_true, T2_true, v_true, y = get_batch(128)
    x , y = x.cuda(), y.cuda()
    optimizer.zero_grad()
    D, T2, v = model.encode(x)        
    recon = model.decode(D, T2, v).cuda()
    loss = model.loss_function([recon, D.float(), T2.float(), v],[y, D_true, T2_true, v_true], 
                               mathematical_model=False)
    loss.backward()
    optimizer.step()
    total_loss += loss.item()
    if not ep % 1000:

        D, T2, v = model.encode(test)
        bins = (50, 50)
        fig, ax = plt.subplots(3,3, figsize=(15,15))
        for r in range(3):
            for c in range(3):
                if r==0:
                    x_image, y_image = v_test2, v
                    title = ['V_ep', 'V_st', 'V_lu']

                elif r==1:
                    x_image, y_image = D_test2, D
                    title = ['D_ep', 'D_st', 'D_lu']
                    ylims = [(0.3, 0.7), (0.7, 1.7), (2.7, 3)]

                else:
                    x_image, y_image = T2_test2, T2
                    title = ['T2_ep', 'T2_st', 'T2_lu']
                    ylims = [(20, 70), (40, 100), (500, 1000)]
                
                density_scatter(x_image.detach().cpu().numpy()[:,c] , 
                                y_image.detach().cpu().numpy()[:, c], ax = ax[r,c], sort = True, bins = 15)
                err = np.mean(np.abs(x_image.detach().cpu().numpy()[:,c]-y_image.detach().cpu().numpy()[:, c]))
                corr = np.corrcoef(x_image.detach().cpu().numpy()[:,c],y_image.detach().cpu().numpy()[:, c])[0,1]
            
                #print(fr'{title[c]}, MAE = {err:.3f}, $\rho$ = {corr:.3f}')
                ax[r,c].set_title(fr'{title[c]}, MAE = {err:.3f}, $\rho$ = {corr:.3f}')
                ax[r,c].set_xlabel('true')
                ax[r,c].set_ylabel('predicted')
                #ax[r,c].set_ylim(ylims[c])
        display.display(plt.gcf())
    #display.clear_output(wait=True)
    print(f'{total_loss/ctr}',end ="\r")
    ctr += 1


In [None]:
import mat73
import os
import scipy.io as sio
BASE_ADDRESS = '/home/gundogdu/Downloads'
data_address = os.path.join(BASE_ADDRESS, 'pat099_hybridSortedInput.mat')


print('Loading data')
data = sio.loadmat(data_address)     
print('Data loaded')

b_values = data['b']
TE_values = data['TE']
TE_norm = data['TE_norm']
hybrid_data = data['hybrid_data']

In [None]:
from IPython import display
from scipy.stats import ttest_rel
%matplotlib inline
#model =  PIA(predictor_depth=2)
params = list(model.encoder.parameters()) + list(model.v_predictor.parameters()) + list(model.D_predictor.parameters()) + list(model.T2_predictor.parameters())
optimizer = optim.Adam(params, lr=0.00005)



ctr = 1
total_loss = 0 
for ep in range(101):
    if not ep % 50:
        D, T2, v = model.encode(test)
        bins = (50, 50)
        fig, ax = plt.subplots(3,3, figsize=(15,15))
        for r in range(3):
            for c in range(3):
                if r==0:
                    x_image =  v
                    title = ['V_ep', 'V_st', 'V_lu']

                elif r==1:
                    x_image =  D
                    title = ['D_ep', 'D_st', 'D_lu']
                    ylims = [(0.3, 0.7), (0.7, 1.7), (2.7, 3)]

                else:
                    x_image = T2
                    title = ['T2_ep', 'T2_st', 'T2_lu']
                    ylims = [(20, 70), (40, 100), (500, 1000)]
                
                ax[r,c].imshow(np.reshape(x_image.detach().cpu().numpy()[:,c], bins), cmap='jet')
                ax[r,c].set_title(fr'{title[c]}')
        display.display(plt.gcf())
    #display.clear_output(wait=True)

    optimizer.zero_grad()
    D, T2, v = model.encode(test)        
    recon = model.decode(D, T2, v).cuda()
    loss = model.loss_function([recon, D.float(), T2.float(), v],
                               [test, _, _, _], 
                                mathematical_model=False)
    loss.backward()
    optimizer.step()
    total_loss += loss.item()

    print(f'{total_loss/ctr}',end ="\r")
    ctr += 1


In [None]:
hybrid_data.shape

In [None]:
hybrid_data2 = np.reshape(hybrid_data, (128, 128, hybrid_data.shape[2], 16))
test = np.squeeze(hybrid_data2[40:90,40:90, 9, :])*1000
test = np.reshape(test, (2500, 16))
test = torch.from_numpy(test)
test = test.float().cuda()
D, T2, v = model.encode(test)
bins = (50, 50)
fig, ax = plt.subplots(3,3, figsize=(15,15))
for r in range(3):
    for c in range(3):
        if r==0:
            x_image =  v
            title = ['V_ep', 'V_st', 'V_lu']

        elif r==1:
            x_image =  D
            title = ['D_ep', 'D_st', 'D_lu']
            ylims = [(0.3, 0.7), (0.7, 1.7), (2.7, 3)]

        else:
            x_image = T2
            title = ['T2_ep', 'T2_st', 'T2_lu']
            ylims = [(20, 70), (40, 100), (500, 1000)]

        ax[r,c].imshow(np.reshape(x_image.detach().cpu().numpy()[:,c], bins), cmap='jet')
        ax[r,c].set_title(fr'{title[c]}')
display.display(plt.gcf())

In [None]:
BASE_ADDRESS = '/home/gundogdu/Downloads'
data_address = os.path.join(BASE_ADDRESS, 'pat029_hybridSortedInput.mat')
train = True

print('Loading data')
#data = sio.loadmat(data_address)     
data = mat73.loadmat(data_address)    
print('Data loaded')

b_values = data['b']
TE_values = data['TE']
TE_norm = data['TE_norm']
hybrid_data = data['hybrid_data']
_slice = 9
_from = 35
_to = 95
PIDS_1, PIDS_2, PIDS_3, PIDS_4 = detect_PIDS_slice(b_values, hybrid_data[_from:_to, _from:_to, _slice, :, :])
PIDS = PIDS_1.astype(float) + PIDS_2.astype(float)
for i in range(0,4):
    for j in range(0,2):
        PIDS += PIDS_3[:,:,i, j].astype(float)

PIDS = 1 - (PIDS>0).astype(float)
PIDS = morphology.remove_small_objects(PIDS.astype(bool), min_size=50, connectivity=50)
PIDS = morphology.remove_small_holes(PIDS, area_threshold=50, connectivity=1)
hybrid_data2 = np.reshape(hybrid_data, (128, 128, hybrid_data.shape[2], 16))
test = np.squeeze(hybrid_data2[_from:_to, _from:_to, _slice, :])*1000
bins = (test.shape[0], test.shape[1])
test = np.reshape(test, (test.shape[0]*test.shape[1], 16))
test = torch.from_numpy(test)
test = test.float().cuda()
if train:
    optimizer = optim.Adam(params, lr=0.00005)

    ctr = 1
    total_loss = 0 
    for ep in range(100):

        optimizer.zero_grad()
        D, T2, v = model.encode(test)        
        recon = model.decode(D, T2, v).cuda()
        loss = model.loss_function([recon, D.float(), T2.float(), v],
                                   [test, _, _, _], 
                                    mathematical_model=False)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

        print(f'{total_loss/ctr}',end ="\r")
        ctr += 1

v_ep = np.reshape(v.detach().cpu().numpy()[:,0], bins)
v_st = np.reshape(v.detach().cpu().numpy()[:,1], bins)
v_lu = np.reshape(v.detach().cpu().numpy()[:,2], bins)
cancer = (v_ep > 0.4)*(v_lu <= 0.2)
from skimage import morphology
from skimage import segmentation
adc_map = ADC_slice(b_values, 
                    np.squeeze(hybrid_data[_from:_to, _from:_to, _slice, :, 0]))
fig, ax = plt.subplots(1,4, figsize=(20,5))

img1 = np.multiply(v_ep.astype(float), PIDS.astype(float))
img2 = np.multiply(v_st.astype(float), PIDS.astype(float))
img3 = np.multiply(v_lu.astype(float), PIDS.astype(float))

ax[0].imshow(img1, cmap='jet')
ax[0].set_title('epitheluim volume map')

ax[1].imshow(img2, cmap='jet')
ax[1].set_title('stroma volume map')

ax[2].imshow(img3, cmap='jet')
ax[2].set_title('lumen volume map')

ax[3].imshow(adc_map, cmap='gray')
ax[3].set_title('ADC map')
 
cancer_map = np.multiply(cancer.astype(float), PIDS.astype(float))
cancer_map = morphology.remove_small_objects(cancer_map.astype(bool), min_size=4, connectivity=1)
cancer_map = cancer_map.astype(float)
cancer_map[cancer_map==0] = np.nan
ax[3].imshow(cancer_map, cmap='autumn',alpha = 0.4)
for i in range(4):
    ax[i].axis('off')

In [None]:
import time
start = time.time()
D, T2, v = model.encode(test)
end = time.time()
print(end - start)

In [None]:
import re
import os
import mat73
import scipy.io as sio
import numpy as np
from scipy.optimize import curve_fit
import itertools
from tqdm import tqdm
import matplotlib
import matplotlib.pyplot as plt

In [None]:
hybrid_crop2 = np.reshape(hybrid_crop, (55, 45, 16))

In [None]:
hybrid_crop[10,10,:, :]

In [None]:
def adc_voxel(b, S):
    min_adc = 0
    max_adc = 3.0
    all_bs = [[[b[i], np.log(S[i] + eps)]] for i in range(len(b))]
    combs = [x for x in itertools.product(*all_bs)]
    def lms_adc(inpt):
        sum_xi_yi = sum([x[0]*x[1] for x in inpt])
        sum_yj = sum([x[1] for x in inpt])
        sum_xi_sum_yj = sum([x[0]*sum_yj for x in inpt])
        sum_x2 = sum([x[0]**2 for x in inpt])
        sum_x = sum([x[0] for x in inpt])
        adc = -(len(inpt)*sum_xi_yi - sum_xi_sum_yj)/(len(inpt)*sum_x2 - sum_x**2 )*1000
        return adc
    return max(min(lms_adc(combs[0]), max_adc), min_adc)

In [None]:
from skimage import morphology
from skimage import segmentation


In [None]:
from torch import optim
D_mean = torch.from_numpy(np.asarray([0.55, 1.3, 2.8]))
T2_mean = torch.from_numpy(np.asarray([50, 70, 750]))
v_ep = np.zeros_like(hybrid_crop2[:,:,1])
v_st = np.zeros_like(hybrid_crop2[:,:,1])
v_lu = np.zeros_like(hybrid_crop2[:,:,1])
#model =  HybridVAE(number_of_signals=16, hidden_dims= [128, 128])
optimizer = optim.Adam(model.encoder.parameters(), lr=0.0003)
for ep in tqdm(range(1)):
    total_loss = 0
    for row in range(hybrid_crop2.shape[0]):
        for col in range(hybrid_crop2.shape[1]):
            if hybrid_crop2[row, col, 0] > 5000:
                voxel = hybrid_crop2[row, col, :]/hybrid_crop2[row, col, 0]
                optimizer.zero_grad()
                x = torch.from_numpy(voxel).float()
                D_var, T2_var, v = model.encode(x)        
                D, T2 = model.sample(D_mean, T2_mean, D_var, T2_var)
                recon = model.decode(D, T2, v)
                loss = model.loss_function(recon, x)
                loss['loss'].backward()
                optimizer.step()
                total_loss += loss['loss'].item()
                v_ep[row, col] = v.detach().numpy()[0]
                v_st[row, col] = v.detach().numpy()[1]
                v_lu[row, col] = v.detach().numpy()[2]
fig, ax = plt.subplots(1, 3, figsize=(15, 5))
im = ax[0].imshow(v_ep,  cmap = 'jet')
im = ax[1].imshow(v_st,  cmap = 'jet')
im = ax[2].imshow(v_lu,  cmap = 'jet')
fig.colorbar(im, ax=ax[0],shrink=0.75)
fig.colorbar(im, ax=ax[1],shrink=0.75)
fig.colorbar(im, ax=ax[2],shrink=0.75)
ax[0].set_title('Epithelium Volume')
ax[1].set_title('Stroma Volume')
ax[2].set_title('Lumen Volume')
ax[0].axis('off')
ax[1].axis('off')
ax[2].axis('off')
imageMap=  (v_ep > 0.4)*(v_lu <= 0.2)
b0_th = 5000
data_th = 3#1000
signal = np.squeeze(hybrid_crop[:,:, 3, 0])
b0 = np.squeeze(hybrid_crop[:,:, 0, 0])
mask = b0 > b0_th
mask2 = signal > data_th
data_segmented = mask2*signal
b0_segmented = mask*b0
#mask = mask*mask2
b_values = [0, 150, 1000, 1500]
adc_map = ADC_slice(b_values, hybrid_crop[:,:, :, 0])
fig, ax = plt.subplots(1, figsize=(6,6))
fig.suptitle('predicted cancer map')
ax.imshow(adc_map, cmap='gray')
cancer_map = np.multiply(imageMap.astype(float), mask.astype(float))
cancer_map = np.multiply(cancer_map, mask2.astype(float))
cancer_map = morphology.remove_small_objects(cancer_map.astype(bool), min_size=1, connectivity=1)
cancer_map = cancer_map.astype(float)
cancer_map[cancer_map==0] = np.nan
ax.imshow(cancer_map, cmap='autumn', alpha = 0.4)
ax.axis('off')