In [None]:
import PIL
import numpy as np
import matplotlib.pyplot as plt
import vtk
from vtk.util.numpy_support import vtk_to_numpy
from tqdm import tqdm
import scipy
from scipy.optimize import minimize
from scipy.ndimage import binary_erosion
from skimage import measure
from vedo import *
import meshio
import seaborn as sns
import pyvista as pv
from data_processing.mesh_to_array import *
import pandas as pd
from matplotlib.colors import LinearSegmentedColormap
import hydra
from hydra import compose, initialize
from omegaconf import OmegaConf, DictConfig
import fnmatch
import cv2
import torch
from train.testing import testing
from data_processing.dataset_3d import load_dataset_3d, generate_points
from data_processing.obj2py import read_get, read_mat, read_egt
from data_processing.helper import sort_filenames
from utils.helper import make_cmap
import copy
cmap = make_cmap()



In [None]:
def get_all_cases(cfg: DictConfig, base_dir=".."):
    if cfg.data.cases == 'all':
        cases = os.listdir(os.path.join(base_dir,cfg.data.processed_data_folder))
        cases = [case.split('.')[0] for case in cases if fnmatch.fnmatch(case, 'case_TCIA*')]
        cases_number = [int(case.split('_')[-2]) for case in cases]
        # cases = [case for case, case_number in zip(cases, cases_number) if case_number < 290]
        # cases 
    else:
        cases = cfg.data.cases
    return cases

def interpolate_arrays(arr, t):
    arr1, arr2, arr3, arr4 = arr
    # Linear interpolation between arrays based on parameter t (0 <= t <= 1)
    return (1 - t) * (1 - t) * arr1 + 2 * (1 - t) * t * arr2 + t * t * arr3 + (1 - t) * (1 - t) * arr4

def remove_empty_space(img, lung_mask=None):           
    mask = img < 0.001
    masked_data = np.ma.masked_where(mask, img)
    rows_to_keep = ~np.all(mask, axis=1)
    cols_to_keep = ~np.all(mask, axis=0)
    masked_data = masked_data[rows_to_keep][:, cols_to_keep]
    if lung_mask is not None:
        mask = lung_mask < 0.001
        lung_mask = np.ma.masked_where(mask, lung_mask)
        lung_mask = lung_mask[rows_to_keep][:, cols_to_keep]
        return masked_data, lung_mask
    return masked_data

# Load model and process test data

In [None]:
def load_model(path):
    cfg = load_cfg(path)
    cfg.inference_path = path
    model = hydra.utils.instantiate(cfg.learning.model, model_3d=cfg.data.model_3d)
    model.load_state_dict(torch.load(os.path.join(path,'model_lung.pt'), map_location=cfg.learning.training.device)['model_state_dict'], strict=False)
    return model, cfg

def load_seg_model(path):
    cfg = load_cfg(path)
    model = hydra.utils.instantiate(cfg.learning.segmentation_model)
    model.load_state_dict(torch.load(os.path.join(path,'seg_model.pt'), map_location=cfg.learning.training.device))
    return model, cfg

def load_cfg(path):
    with initialize(version_base=None, config_path=os.path.join(path, ".hydra"), job_name="test"):
        cfg = compose(config_name="config")
    return cfg

In [None]:
eit_path = 'outputs/2024-07-12/11-22-04'
model, cfg = load_model(eit_path)
# seg_path = 'outputs/2024-04-29/19-25-08'
# seg_model, seg_cfg = load_seg_model(seg_path)

In [None]:
cases = get_all_cases(cfg, base_dir='')    
train_dataset, val_dataset, test_dataset = load_dataset_3d(cases,
            resolution=cfg.data.resolution, 
            base_dir = '',
            raw_data_folder=cfg.data.raw_data_folder, 
            processed_data_folder=cfg.data.processed_data_folder,
            dataset_data_folder=cfg.data.dataset_data_folder,
            name_prefix=cfg.data.name_prefix,
            write_dataset=False, write_npz=False, 
            overwrite_npz=False, n_sample_points=cfg.learning.training.sample_points,
            return_electrodes=cfg.data.return_electrodes, apply_rotation=cfg.data.apply_rotation,
            apply_subsampling=cfg.data.apply_subsampling,
            apply_translation = cfg.data.apply_translation,
            translation_x=cfg.data.translation_x, translation_y=cfg.data.translation_y, translation_z=cfg.data.translation_z,
            point_levels_3d=cfg.data.point_levels_3d, point_range_3d=cfg.data.point_range_3d,
            multi_process=cfg.data.multi_process, num_workers=cfg.data.num_workers, all_signals=cfg.data.all_signals,
            use_body_mask = cfg.learning.model.use_body_mask, signal_norm='all'
            )


## Test Model

In [None]:
# add noise
test_dataset_noise = copy.deepcopy(test_dataset)
test_dataset_noise.case_files = sort_filenames(test_dataset_noise.case_files)
# test_dataset_noise.signals = test_dataset_noise.signals + np.random.normal(0, 0.5, test_dataset_noise.signals.shape)
# test_dataset_noise.signals = torch.from_numpy(np.random.normal(0, 1., test_dataset_noise.signals.shape))

# test_dataset.signals = torch.zeros_like(test_dataset.signals)

In [None]:
targets_model, preds_model = testing(model, test_dataset_noise, batch_size=1, device=cfg.learning.training.device, model_3d=cfg.data.model_3d, wandb_log=False, point_chunks=40)

In [None]:
def pred_postprocess(preds, targets, resolution=512, nlevel=4, nres=4):
    targets = targets.reshape(-1, nres, nlevel, resolution, resolution)
    preds = preds.reshape(-1, nres, nlevel, resolution, resolution)
    lung_mask = targets[:,0]==0.2
    lung_mask = lung_mask.unsqueeze(1).tile(1, nres, 1, 1, 1).reshape(-1, resolution, resolution).numpy()
    lung_mask_shrunken = torch.tensor([binary_erosion(mask, structure=np.ones((30,30))).astype(np.uint8) for mask in lung_mask])
    lung_mask_shrunken = lung_mask_shrunken.reshape(-1, nres, nlevel, resolution, resolution)
    return preds.numpy(), targets.numpy(), lung_mask_shrunken.numpy()

In [None]:
def plot_lung_contour(img, mask, cond_value, ax, fontsize=20):
    contours = measure.find_contours(m, level=0.5)
    for contour in contours:
        ax.plot(contour[:, 1], contour[:, 0], linewidth=2, color='red')
    # Get the coordinates of the mask
    center_x, center_y = m.shape[1] / 2, m.shape[0]
    # Compute the average pixel value of the rectangle
    average_pixel_value = np.mean(p[m==1])
    # Annotate the rectangle with the average pixel value
    ax.text(center_x, center_y, f'Error: \n{(np.round(cond_value-average_pixel_value, 4)):.4f}', fontsize=fontsize-5,
            color='white', ha='center', va='center', bbox=dict(facecolor='red', alpha=0.8))


In [None]:
preds, targets, lung_masks = pred_postprocess(preds_model, targets_model, nlevel=30)

In [None]:
preds.shape

In [None]:
targets.shape

In [None]:
fontsize = 20
n_cases = preds.shape[0]
n_res = preds.shape[1]
n_level = preds.shape[2]
for case in range(n_cases):
    if case == 4:
        break
    for res in range(n_res):
        fig, axes = plt.subplots(n_level, 2, figsize=(10, int(4*n_level)))
        cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])  # [left, bottom, width, height]
            
        for level in range(n_level):
            if level == 0:
                axes[level,0].set_title('Ground Truth', fontsize=fontsize)
                axes[level,1].set_title('Model Tomogram', fontsize=fontsize)
            # GT
            t = remove_empty_space(targets[case,res,level])
            axes[level,0].imshow(t, vmin=0, vmax=0.7, cmap=cmap)       
            axes[level,0].axis('off')
            # Pred + Mask
            cond_value = np.mean(targets[case,res,level][lung_masks[case,res,level]==1])
            p, m = remove_empty_space(preds[case,res,level], lung_masks[case,res,level])
            axes[level,1].imshow(p, vmin=0, vmax=0.7, cmap=cmap)
            axes[level,1].imshow(m, cmap='Greys', alpha=0.3)
            axes[level,1].axis('off')
            plot_lung_contour(p, m, cond_value, axes[level,1])

        # Add colorbar to the figure
        sm = plt.cm.ScalarMappable(cmap=cmap)
        sm.set_clim(0, 0.7)
        cbar = fig.colorbar(sm, cax=cbar_ax)        
        cbar.set_label('Conductivity (S/m)', fontsize=fontsize)
        fig.suptitle(f'{test_dataset.cases[int(case/4)]}', fontsize=fontsize)
        plt.show()
        plt.close(fig)


### All levels and resistivities

In [None]:
# cmap.set_bad(color='none')  # Set the color for masked values to 'none' for transparency

for case in range(n_cases):
    # set up figure
    fig, axes = plt.subplots(n_res+1, n_level+1, figsize=(int((n_level+1)*4), int(4*n_res)))
    cbar_ax = fig.add_axes([0.95, 0.15, 0.02, 0.7])  # [left, bottom, width, height]
    axes[0,0].text(0.5, 0.5, 'Ground Truth', fontsize=fontsize, ha='center', va='center', rotation=45)
    axes[0,0].text(0.5, 0., 'Level', fontsize=fontsize-5, ha='center', va='center', rotation=0)
    axes[0,0].text(0.9, 0.3, ' Conductivity', fontsize=fontsize-5, ha='center', va='center', rotation=90)
    axes[0,0].axis('off')

    for level in range(n_level):
        # GT across levels (choose cond. 0.05)
        t = remove_empty_space(targets[case,2,level])
        axes[level+1,0].imshow(t, vmin=0, vmax=0.7, cmap=cmap)       
        axes[level+1,0].axis('off')
        for res in range(n_res):
            cond_value = np.round(np.mean(targets[case,res,level][lung_masks[case,res,level]==1]),2)
            # GT across cond. values (choose level 0)
            if level == 0:
                t = remove_empty_space(targets[case,res,0])
                axes[0,res+1].imshow(t, vmin=0, vmax=0.7, cmap=cmap)       
                axes[0,res+1].axis('off')
                axes[level,int(res+1)].set_title(f'{str(cond_value)} S/m', fontsize=fontsize-5)
            p, m = remove_empty_space(preds[case,res,level], lung_masks[case,res,level])
            axes[level+1,res+1].imshow(p, vmin=0, vmax=0.7, cmap=cmap)
            axes[level+1,res+1].imshow(m, cmap='Greys', alpha=0.3)
            axes[level+1,res+1].axis('off')
            plot_lung_contour(p, m, cond_value, axes[level+1,res+1])

            axes[level+1,res+1].axis('off')
            axes[level+1,res+1].set_facecolor('grey')

    # Add colorbar to the figure
    sm = plt.cm.ScalarMappable(cmap=cmap)
    sm.set_clim(0, 0.7)
    cbar = fig.colorbar(sm, cax=cbar_ax)        
    cbar.set_label('Conductivity (S/m)', fontsize=fontsize-5)

    # Draw horizontal and vertical lines separating the lower right 4x4 grid
    # Use the figure's add_artist method to add lines at the required positions
    # Get the positions of the first row and first column of the lower right 4x4 grid
    top_left_of_4x4 = axes[1, 1].get_position() 
    bottom_right_of_4x4 = axes[4, 4].get_position()

    # Adjusted positions for the lines to be outside the plots
    horizontal_line_y = top_left_of_4x4.y1 + 0.01
    vertical_line_x = top_left_of_4x4.x0 - 0.01

    # Draw a horizontal line
    fig.add_artist(plt.Line2D([top_left_of_4x4.x0-0.15, bottom_right_of_4x4.x1], 
                            [horizontal_line_y, horizontal_line_y], 
                            color='black', linewidth=2))

    # Draw a vertical line
    fig.add_artist(plt.Line2D([vertical_line_x, vertical_line_x], 
                            [bottom_right_of_4x4.y0, top_left_of_4x4.y1+0.15], 
                            color='black', linewidth=2))
                            
    # fig.suptitle(f'{test_dataset.cases[int(i/4)]}', fontsize=fontsize)
    # plt.tight_layout()
    plt.show()
    plt.close(fig)
    if case == 10:
        break

## Test Model - Single case

In [None]:
case = 'case_TCIA_401_0*'

### Iterate over the dataloader until certain case is reached

In [None]:
p = []
s = []
e = []
t = []
for i, (points, signals, electrodes, _, targets, _) in enumerate(test_dataset):
    if fnmatch.fnmatch(test_dataset.case_files[i], case):
        p.append(points)
        s.append(signals)
        e.append(electrodes)
        t.append(targets)
        print(test_dataset.case_files[i])
points = torch.stack(p, dim=0)
signals = torch.stack(s, dim=0)
electrodes = torch.stack(e, dim=0)
targets = torch.stack(t, dim=0)

batch = 4

In [None]:
preds_all = []
targets_all = []
noise = torch.randn(signals.shape)*0.5
points = points  

for i in range(points.shape[0]):
    # signals_tmp = signals[i].unsqueeze(0).float() + noise[i].unsqueeze(0).float()
    # signals_tmp = noise[i].unsqueeze(0).float()
    _, pred = testing(model, [signals[i].unsqueeze(0).float(), 
                              electrodes[i].unsqueeze(0).float(), 
                              points[i].unsqueeze(0).float()], 
                              batch_size=1, device=cfg.learning.training.device, wandb_log=False,)
    targets_all.append(targets[i].detach().cpu().numpy().squeeze())
    preds_all.append(pred.detach().cpu().numpy().squeeze())
preds_all = np.concatenate(preds_all, axis=0)
targets_all = np.concatenate(targets_all, axis=0)

In [None]:
fontsize = 20
for i in range(preds_all.shape[0]):
    fig, axes = plt.subplots(preds_all.shape[0], 2, figsize=(10, int(4*cfg.data.point_levels_3d)))
    cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])  # [left, bottom, width, height]
    for level in range(preds_all.shape[1]):
        if level == 0:
            axes[level,0].set_title('Ground Truth', fontsize=fontsize)
            axes[level,1].set_title('Model Tomogram', fontsize=fontsize)
        target = remove_empty_space(targets_all[i,level])
        axes[level,0].imshow(target, vmin=0, vmax=0.7, cmap=cmap)       
        axes[level,0].axis('off')

        lung_mask = targets_all[0]==0.2 
        image = preds_all[i,level]
        lung_mask_shrunken = binary_erosion(lung_mask[level], structure=np.ones((9, 9))).astype(np.uint8)
        image, lung_mask_shrunken = remove_empty_space(image, lung_mask_shrunken)
        
        axes[level,1].imshow(image, vmin=0, vmax=0.7, cmap=cmap)

        contours = measure.find_contours(lung_mask_shrunken, level=0.5)
        for contour in contours:
            axes[level,1].plot(contour[:, 1], contour[:, 0], linewidth=2, color='red')

        # Get the coordinates of the mask
        center_x, center_y = lung_mask_shrunken.shape[1] / 2, lung_mask_shrunken.shape[0]

        # Compute the average pixel value of the rectangle
        average_pixel_value = np.mean(image[lung_mask_shrunken==1])

        # Annotate the rectangle with the average pixel value
        axes[level,1].text(center_x, center_y, f'Average: \n{average_pixel_value:.2f}', fontsize=fontsize-5,
                color='white', ha='center', va='center', bbox=dict(facecolor='red', alpha=0.8))


        # axes[level,1].imshow(lung_masks[i,0,level], cmap='Greys', alpha=0.3)
        axes[level,1].axis('off')
    # Add colorbar to the figure
    sm = plt.cm.ScalarMappable(cmap=cmap)
    sm.set_clim(0, 0.7)
    cbar = fig.colorbar(sm, cax=cbar_ax)        
    cbar.set_label('Conductivity (S/m)', fontsize=fontsize)
    fig.suptitle(f'{test_dataset.cases[i]}', fontsize=fontsize)
    plt.show()
    plt.close(fig)


In [None]:
p = preds_all[:,0]
plt.imshow(np.std(p, axis=0), cmap='coolwarm')
plt.colorbar()

### Add noise

In [None]:
preds_all = []
targets_all = []
noise = torch.randn(signals.shape)*0.5
points = points  

for i in range(points.shape[0]):
    signals_tmp = signals[i].unsqueeze(0).float() + noise[i].unsqueeze(0).float()
    # signals_tmp = noise[i].unsqueeze(0).float()
    _, pred = testing(model, [signals_tmp, 
                              electrodes[i].unsqueeze(0).float(), 
                              points[i].unsqueeze(0).float()], 
                              batch_size=1, device=cfg.learning.training.device, wandb_log=False,
                              downsample_factor_test=downsample_factor)
    targets_all.append(targets[i].detach().cpu().numpy().squeeze().reshape(-1, 4, down_resolution, down_resolution))
    preds_all.append(pred.detach().cpu().numpy().squeeze().reshape(-1, 4, down_resolution, down_resolution))
preds_all = np.concatenate(preds_all, axis=0)
targets_all = np.concatenate(targets_all, axis=0)

fontsize = 20
for i in range(preds_all.shape[0]):
    fig, axes = plt.subplots(preds_all.shape[0], 2, figsize=(10, int(4*cfg.data.point_levels_3d)))
    cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])  # [left, bottom, width, height]
    for level in range(preds_all.shape[1]):
        if level == 0:
            axes[level,0].set_title('Ground Truth', fontsize=fontsize)
            axes[level,1].set_title('Model Tomogram', fontsize=fontsize)
        axes[level,0].imshow(targets_all[i,level], vmin=0, vmax=0.7, cmap=cmap)       
        axes[level,0].axis('off')
        axes[level,1].imshow(preds_all[i,level], vmin=0, vmax=0.7, cmap=cmap)
        # axes[level,1].imshow(lung_masks[i,0,level], cmap='Greys', alpha=0.3)
        axes[level,1].axis('off')
    # Add colorbar to the figure
    sm = plt.cm.ScalarMappable(cmap=cmap)
    sm.set_clim(0, 0.7)
    cbar = fig.colorbar(sm, cax=cbar_ax)        
    cbar.set_label('Conductivity (S/m)', fontsize=fontsize)
    fig.suptitle(f'{test_dataset.cases[i]}', fontsize=fontsize)
    plt.show()
    plt.close(fig)

### Interpolate between signals

In [None]:
signals.shape

In [None]:
n_interpolation_levels = 48
interpolation_range = np.linspace(0, 1, n_interpolation_levels)
preds = []
for interp in interpolation_range:
    signals_interpolated = signals[0] + (signals[3] - signals[0]) * interp
    # signals_interpolated = interpolate_arrays(signals, interp).unsqueeze(0)
    _, pred = testing(model, [signals_interpolated.unsqueeze(0).float(), 
                                electrodes[0].unsqueeze(0).float(), 
                                points[0].unsqueeze(0).float()], 
                                batch_size=1, device=cfg.learning.training.device, wandb_log=False,
                                downsample_factor_test=downsample_factor)
    pred = pred.detach().cpu().numpy().squeeze().reshape(4, down_resolution, down_resolution)
    preds.append(pred)

In [None]:
fig, axes = plt.subplots(4, int(n_interpolation_levels/4), figsize=(int(1*(n_interpolation_levels/4)), 4))
for i, ax in enumerate(axes.flatten()):
    ax.imshow(remove_empty_space(preds[i][0]), cmap=cmap, vmin=0, vmax=0.7)
    if i == 0:
        ax.set_title('0.2 S/m', fontsize=10, loc='left')
    elif i == len(interpolation_range)-1:
        ax.set_title('0.05 S/m', fontsize=10, loc='right')
    ax.axis('off')
plt.tight_layout()

In [None]:
p = np.array(preds)[:,0]
plt.imshow(np.std(p, axis=0), cmap='coolwarm')
plt.colorbar()

## Load from a file directly

In [None]:
def load_data_from_file(path, dataset):
    file = np.load(path)
    target = torch.from_numpy(file['targets'].reshape(-1, 1))
    signal =  torch.from_numpy(file['signals']).reshape(-1, 4, 16, 13)
    signal = (signal - dataset.train_mean) / dataset.train_std
    signal = signal.reshape(4, -1)
    electrode =  torch.from_numpy(file['electrodes'])
    electrode[:,:,:,:,:2] = (electrode[:,:,:,:,:2] - dataset.points_min) / (dataset.points_max - dataset.points_min) * 2 - 1
    electrode[:,:,:,:,2] = (electrode[:,:,:,:,2] - dataset.points_min_z) / (dataset.points_max_z - dataset.points_min_z) * 2 - 1
    points =  torch.from_numpy(file['points']).reshape(-1, 3)
    points[:,:2] = (points[:,:2] - dataset.points_min) / (dataset.points_max - dataset.points_min) * 2 - 1
    points[:,2] = (points[:,2] - dataset.points_min_z) / (dataset.points_max_z - dataset.points_min_z) * 2 - 1
    return points.float(), signal.float(), electrode, target.float()


In [None]:
points, signals, electrodes, targets = load_data_from_file('/home/nibdombe/deep_eit/data/processed/3d/case_TCIA_401_0/case_TCIA_401_0_15.npz', test_dataset)

In [None]:
preds_all = []
targets_all = []
noise = torch.randn(signals.shape)*0.5

downsample_factor = 4
down_resolution = 512//downsample_factor

targets = targets.reshape(-1, 512, 512, 1)

_, pred = testing(model, [signals.unsqueeze(0).float(), 
                              electrodes.unsqueeze(0).float(), 
                              points.unsqueeze(0).float()], 
                              batch_size=1, device=cfg.learning.training.device, wandb_log=False,
                              downsample_factor_test=downsample_factor)
# targets_down = targets.reshape(4, down_resolution, down_resolution)
pred = pred.detach().cpu().numpy().squeeze().reshape(4, down_resolution, down_resolution)


In [None]:
level = 0
fig, ax = plt.subplots(1,2)
ax[0].imshow(targets[level], cmap=cmap, vmin=0)
ax[1].imshow(pred[level], cmap=cmap, vmin=0)
ax[0].axis('off')
ax[1].axis('off')

### Change input coordinates

In [None]:
n_z_levels = 48
points_level_new = points.reshape(-1, 512, 512, 3)[0]
z_pred = []
z_levels = torch.linspace(1, -1, n_z_levels)
points_level_new = points_level_new.unsqueeze(0)

for z in z_levels:
    points_level_new[:,:,:,2] = z
    _, pred = testing(model, [signals.unsqueeze(0).float(), 
                                electrodes.unsqueeze(0).float(), 
                                points_level_new.unsqueeze(0).float()], 
                                batch_size=1, device=cfg.learning.training.device, wandb_log=False,
                                downsample_factor_test=downsample_factor)
    z_pred.append(pred.detach().cpu().numpy().squeeze().reshape(down_resolution, down_resolution))
fig, axes = plt.subplots(4, int(n_z_levels/4), figsize=(int(1*(n_z_levels/4)), 4))
for i, ax in enumerate(axes.flatten()):
    ax.imshow(remove_empty_space(z_pred[i]), cmap=cmap, vmin=0)
    if i == 0:
        ax.set_title('Upper Body', fontsize=10, loc='left')
    elif i == len(z_levels)-1:
        ax.set_title('Lower Body', fontsize=10)
    ax.axis('off')
plt.tight_layout()



## Get coordinate bounds

In [None]:
train_dataset.training = False

In [None]:
points = []
max_x = 0
max_y = 0
max_z = 0
min_x = 0
min_y = 0
min_z = 0

for case in tqdm(test_dataset):
    points = case[0]
    points = points[case[4].reshape(-1) != 0]
    max_x = max(max_x, torch.max(points[:,0]))
    max_y = max(max_y, torch.max(points[:,1]))
    max_z = max(max_z, torch.max(points[:,2]))
    min_x = min(min_x, torch.min(points[:,0]))
    min_y = min(min_y, torch.min(points[:,1]))
    min_z = min(min_z, torch.min(points[:,2]))
for case in tqdm(val_dataset):
    points = case[0]
    points = points[case[4].reshape(-1) != 0]
    max_x = max(max_x, torch.max(points[:,0]))
    max_y = max(max_y, torch.max(points[:,1]))
    max_z = max(max_z, torch.max(points[:,2]))
    min_x = min(min_x, torch.min(points[:,0]))
    min_y = min(min_y, torch.min(points[:,1]))
    min_z = min(min_z, torch.min(points[:,2]))
for case in tqdm(train_dataset):
    points = case[0]
    points = points[case[4].reshape(-1) != 0]
    max_x = max(max_x, torch.max(points[:,0]))
    max_y = max(max_y, torch.max(points[:,1]))
    max_z = max(max_z, torch.max(points[:,2]))
    min_x = min(min_x, torch.min(points[:,0]))
    min_y = min(min_y, torch.min(points[:,1]))
    min_z = min(min_z, torch.min(points[:,2]))

print(max_x, max_y, max_z, min_x, min_y, min_z)

## Test Model with Training Data

In [None]:
targets, preds = testing(model, train_dataset, batch_size=cfg.learning.testing.batch_size_test, device=cfg.learning.training.device, wandb_log=False)


In [None]:
test_resolution = cfg.data.resolution//cfg.learning.testing.downsample_factor_test
targets_case = targets.detach().cpu().numpy().squeeze().reshape(-1, 4, cfg.data.point_levels_3d, test_resolution, test_resolution)
preds_case = preds.detach().cpu().numpy().squeeze().reshape(-1, 4, cfg.data.point_levels_3d, test_resolution, test_resolution)

body_masks = targets_case > 0
lung_masks = (targets_case <= 0.2) * (targets_case >= 0.05)
eval_lung_masks = lung_masks * (preds_case<=0.25)

In [None]:
fontsize = 20
for i in range(preds_case.shape[0]):
    fig, axes = plt.subplots(preds_case.shape[2], 2, figsize=(10, int(4*cfg.data.point_levels_3d)))
    cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])  # [left, bottom, width, height]
    for level in range(preds_case.shape[2]):
        if level == 0:
            axes[level,0].set_title('Ground Truth', fontsize=fontsize)
            axes[level,1].set_title('Model Tomogram', fontsize=fontsize)
        axes[level,0].imshow(targets_case[i,0,level], vmin=0, vmax=0.7, cmap=cmap)       
        axes[level,0].axis('off')
        axes[level,1].imshow(preds_case[i,0,level], vmin=0, vmax=0.7, cmap=cmap)
        # axes[level,1].imshow(lung_masks[i,0,level], cmap='Greys', alpha=0.3)
        axes[level,1].axis('off')
    # Add colorbar to the figure
    sm = plt.cm.ScalarMappable(cmap=cmap)
    sm.set_clim(0, 0.7)
    cbar = fig.colorbar(sm, cax=cbar_ax)        
    cbar.set_label('Conductivity (S/m)', fontsize=fontsize)
    fig.suptitle(f'{test_dataset.cases[i]}', fontsize=fontsize)
    plt.show()
    plt.close(fig)
    if i == 10:
        break


## Inference on one case

In [None]:
case = 'case_TCIA_10_0'

In [None]:
from data_processing.dataset import combine_electrode_positions

In [None]:
data = np.load('data/processed/'+case+'.npz')
signals = torch.from_numpy(data['signals'])
electrodes = torch.from_numpy(data['electrodes'])
points = torch.from_numpy(data['points'])
targets = torch.from_numpy(data['targets'])

# normalize
signals = (signals-train_dataset.train_mean) / train_dataset.train_std
points = points[:,:,:2]
points = ((points - train_dataset.points_min) / (train_dataset.points_max - train_dataset.points_min)) * 2 - 1
electrodes[:,:,:2] = ((electrodes[:,:,:2] - train_dataset.points_min) / (train_dataset.points_max - train_dataset.points_min)) * 2 - 1
electrodes = combine_electrode_positions(electrodes)

In [None]:
_, pred = testing(model, [signals.float(), electrodes.float(), points.float()], batch_size=1, device=cfg.learning.training.device, wandb_log=False)
targets = targets.detach().cpu().numpy().squeeze().reshape(-1, 4, 4, 512, 512)
targets = np.moveaxis(targets,1,2)
preds_case = pred.detach().cpu().numpy().squeeze().reshape(-1, 4, 4, 512, 512)
preds_case = np.moveaxis(preds_case,1,2)

# SIRT
threshold = 50
tomogram = np.array([read_egt('data/raw/case_TCIA_10_0/tomograms_rad/level_1_15_radweight_1.egt'),
                        read_egt('data/raw/case_TCIA_10_0/tomograms_rad/level_2_15_radweight_1.egt')])
tomogram = np.where(tomogram>threshold, threshold, tomogram)
tomogram = np.where(tomogram==0, 0, 1/tomogram)


In [None]:
fontsize = 10
for i in range(preds_case.shape[0]):
    fig, axes = plt.subplots(2, 3, figsize=(6, 4))
    cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])  # [left, bottom, width, height]
    for level in range(2):
        axes[level,0].imshow(targets[i,2,level], vmin=0, vmax=0.7, cmap=cmap)
        axes[level,0].axis('off')
        axes[level,1].imshow(preds_case[i,2,level], vmin=0, vmax=0.7, cmap=cmap)
        axes[level,1].axis('off')
        axes[level,2].imshow(tomogram[level], vmin=0, vmax=0.7, cmap=cmap)
        axes[level,2].axis('off')
        if level==0:
            axes[level,0].set_title('Wahre \n Widerstandsverteilung', fontsize=fontsize)
            axes[level,1].set_title('Rekonstruktion \n KI-Modell', fontsize=fontsize)
            axes[level,2].set_title('Rekonstruktion \n physikalisches Modell', fontsize=fontsize)
    # Add colorbar to the figure
    sm = plt.cm.ScalarMappable(cmap=cmap)
    sm.set_clim(0, 0.7)
    cbar = fig.colorbar(sm, cax=cbar_ax)        
    cbar.set_label('Leitfähigkeit (S/m)', fontsize=fontsize)
    # fig.suptitle(f'{test_dataset.cases[i]}', fontsize=fontsize)
    plt.show()
    plt.close(fig)
    if i == 10:
        break


In [None]:
preds_case = pred.detach().cpu().numpy().squeeze().reshape(-1, 4, 4, 512, 512)
preds_case = np.moveaxis(preds_case,1,2)
fontsize = 20
for i in range(preds_case.shape[0]):
    fig, axes = plt.subplots(2, 1, figsize=(5, 16))
    cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])  # [left, bottom, width, height]
    for resistancy in range(4):
        if resistancy == 0:
            axes[resistancy].set_title('Model Tomogram', fontsize=fontsize)
        axes[resistancy].imshow(preds_case[i,2,resistancy], vmin=0, vmax=0.7, cmap=cmap)
        axes[resistancy].axis('off')
    # Add colorbar to the figure
    sm = plt.cm.ScalarMappable(cmap=cmap)
    sm.set_clim(0, 0.7)
    cbar = fig.colorbar(sm, cax=cbar_ax)        
    cbar.set_label('Conductivity (S/m)', fontsize=fontsize)
    fig.suptitle(f'{test_dataset.cases[i]}', fontsize=fontsize)
    plt.show()
    plt.close(fig)
    if i == 10:
        break


# Inference on training data

In [None]:
train_dataset.targets.shape

In [None]:
train_dataset.apply_subsampling = False
train_dataset.training = False
data = [train_dataset.signals[0].unsqueeze(0).float(), train_dataset.electrodes[0].unsqueeze(0).float(), train_dataset.points[0,:,:2].unsqueeze(0).float()]
_, preds = testing(model, data, batch_size=cfg.learning.testing.batch_size_test, device=cfg.learning.training.device, wandb_log=False)
targets = train_dataset.targets[0]

In [None]:
targets_case = targets.detach().cpu().numpy().squeeze().reshape(-1, 1, 1, 512, 512)
preds_case = preds.detach().cpu().numpy().squeeze().reshape(-1, 1, 1, 512, 512)

body_masks = [cv2.resize(mask.numpy(), (512, 512), interpolation=cv2.INTER_NEAREST) for mask in train_dataset.masks]
body_masks = np.stack(body_masks, 0).reshape(-1, 1, 1, 512, 512)

lung_masks = (train_dataset.targets <= 0.2) * (train_dataset.targets >= 0.05)
lung_masks = lung_masks.reshape(-1, 1, 1, 512, 512)
# eval_lung_masks = lung_masks * (preds_case<=0.25)

In [None]:
fontsize = 20
for i in range(preds_case.shape[0]):
    fig, axes = plt.subplots(4, 2, figsize=(10, 16))
    cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])  # [left, bottom, width, height]
    for resistancy in range(4):
        if resistancy == 0:
            axes[resistancy,0].set_title('Ground Truth', fontsize=fontsize)
            axes[resistancy,1].set_title('Model Tomogram', fontsize=fontsize)
        axes[resistancy,0].imshow(targets_case[i,0,resistancy], vmin=0, vmax=0.7, cmap=cmap)       
        axes[resistancy,0].axis('off')
        axes[resistancy,1].imshow(preds_case[i,0,resistancy], vmin=0, vmax=0.7, cmap=cmap)
        # axes[resistancy,1].imshow(-1*body_masks[i,0,resistancy], cmap='Reds', alpha=0.2)
        axes[resistancy,1].imshow(lung_masks[i,0,resistancy], cmap='Greys', alpha=0.3)
        # axes[resistancy,1].imshow(eval_lung_masks[i,0,resistancy], cmap='Greys', alpha=0.3)
        axes[resistancy,1].axis('off')
    # Add colorbar to the figure
    sm = plt.cm.ScalarMappable(cmap=cmap)
    sm.set_clim(0, 0.7)
    cbar = fig.colorbar(sm, cax=cbar_ax)        
    cbar.set_label('Conductivity (S/m)', fontsize=fontsize)
    fig.suptitle(f'{test_dataset.cases[i]}', fontsize=fontsize)
    plt.show()
    plt.close(fig)
    if i == 10:
        break


In [None]:
targets_case = targets.detach().cpu().numpy().squeeze().reshape(-1, 4, 4, 512, 512)
preds_case_std = np.clip(preds.detach().cpu().numpy().squeeze().reshape(-1, 4, 4, 512, 512), 0, 0.7)
preds_case_std = preds_case_std.std(axis=(2))
fontsize = 10
fig, axes = plt.subplots(preds_case.shape[0], 2, figsize=(10, 80))
for i in range(preds_case.shape[0]):
    cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])  # [left, bottom, width, height]
    axes[i,0].imshow(targets_case[i,0,0], vmin=0, vmax=0.7, cmap=cmap)
    axes[i,0].axis('off')
    axes[i,1].imshow(preds_case_std[i,0], cmap='coolwarm')
    axes[i,1].axis('off')
# Add colorbar to the figure
sm = plt.cm.ScalarMappable(cmap='coolwarm')
cbar = fig.colorbar(sm, cax=cbar_ax)        
cbar.set_label('Standard Deviation of Specific Conductivity (S/m)', fontsize=fontsize)
# fig.suptitle(f'{test_dataset.cases[i]}', fontsize=fontsize)
plt.show()
plt.close(fig)


In [None]:
# images
t = test_dataset.targets.reshape(-1, 4, 4, 512, 512)
t = t.moveaxis(2, 1)
t = t.reshape(-1, 16, 512, 512)
# level
l = test_dataset.levels.reshape(-1, 4, 4).moveaxis(2,1).reshape(-1, 16)
# electrodes
e = test_dataset.electrodes.reshape(-1, 4, 4, 16, 3).moveaxis(2, 1).reshape(-1, 16, 16, 3)
e = (e[:, :, :, :2] + 1) * 256

for i in range(preds_case.shape[0]):
    fig, axes = plt.subplots(4, 4, figsize=(10, 11))
    for resistancy in range(16):
        if resistancy == 0:
            axes.flatten()[resistancy].set_ylabel('Lung 5 Ohm')
        if resistancy == 3:
            axes.flatten()[resistancy].set_ylabel('Lung 10 Ohm')
        if resistancy == 7:
            axes.flatten()[resistancy].set_ylabel('Lung 15 Ohm')
        if resistancy == 11:
            axes.flatten()[resistancy].set_ylabel('Lung 20 Ohm')

        axes.flatten()[resistancy].set_title(f'Level {str(l[i,resistancy].numpy()*-1+3)}')
        axes.flatten()[resistancy].imshow(t[i,resistancy], cmap=cmap)
        # axes.flatten()[resistancy].scatter(e[i,resistancy,:,0], e[i,resistancy,:,1], marker='x', c='r')
        axes.flatten()[resistancy].axis('off')
    # Add colorbar to the figure
    # fig.suptitle(f'{test_dataset.cases[i]}', fontsize=12)
    plt.show()
    plt.close(fig)
    if i == 20:
        break

In [None]:
train_dataset.signals.max()

# Use 'real' data

In [None]:
real_signals = read_get('data/raw/case_real/Tag1/Tag_01_Msg_19_SF_1_U_top_I_top_372-459_min_mean.get')[:208]*1000

In [None]:
plt.plot(real_signals)

## get electrode position from ME0*

In [None]:
electrodes = read_mat('/home/nibdombe/deep_eit/data/raw/case_0/electrodes/electrodes.mat').reshape(-1, 16, 3)

In [None]:
real_electrodes = electrodes[0]

In [None]:
plt.scatter(real_electrodes[:,0], real_electrodes[:,1], marker='x', c='r')

## normalize

In [None]:
mean_signals = test_dataset.train_mean
std_signals = test_dataset.train_std
real_signals = (torch.from_numpy(real_signals) - mean_signals) / std_signals

min_xy = test_dataset.points_min
max_xy = test_dataset.points_max
real_electrodes = (real_electrodes - min_xy) / (max_xy - min_xy) * 2 - 1
real_electrodes = torch.from_numpy(real_electrodes.reshape(-1, 16, 3))
real_electrodes[:,:,2] = 1 

real_signals = real_signals.reshape(1, 16, 13).float()
real_electrodes = real_electrodes.reshape(1, 16, 3).float()
points = generate_points(resolution=512)
points = points.reshape(1, -1, 2).float()

In [None]:
pred = testing(model, data=[real_signals, real_electrodes, points], batch_size=1, device='cuda:0', wandb_log=False)[1].reshape(512, 512)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])  # [left, bottom, width, height]
ax.imshow(pred, vmin=0.0, vmax=0.7, cmap=cmap)
ax.scatter((real_electrodes[0,:,0]+1)*256, (real_electrodes[0,:,1]+1)*256, marker='x', c='r')
ax.axis('off')
# Add colorbar to the figure
sm = plt.cm.ScalarMappable(cmap=cmap)
sm.set_clim(0, 0.7)
cbar = fig.colorbar(sm, cax=cbar_ax)        
cbar.set_label('Specific Conductivity (S/m)', fontsize=20)
plt.show()

In [None]:
tomogram = read_egt('/home/nibdombe/deep_eit/data/raw/case_11/tomograms_kf/level_1_20.egt')

In [None]:
plt.imshow(1/tomogram, cmap=cmap, vmin=0., vmax=2)

In [None]:
gt = read_mat('/home/nibdombe/deep_eit/data/raw/case_11/targets/level_1_20.mat', targets=True)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
plt.imshow(gt.T, cmap=cmap, vmin=0., vmax=0.7)
ax.axis('off')

In [None]:
import numpy as np
n = np.arange(30)
np.linspace(n[1],n[-2],4)

In [None]:
np.linspace(n[1],n[-2],4)