In [1]:
do_skip=False
do_plot=True
do_stats_plot=True
do_optional_plots=False

## 0 - Import libraries

In [11]:
import os
import numpy as np
import torch
torch.backends.cudnn.benchmark = True

from utils import mk_dir
from tqdm import tqdm
from torchvision.utils import save_image

## IMPORT FUNCTIONS

In [12]:
import matplotlib.pyplot as plt
plt.ion() 

def save_img(tensor, name, norm, n_rows=16, scale_each=False):
    save_image(tensor, name, nrow=n_rows, padding=5, normalize=norm, pad_value=1, scale_each=scale_each)
    
def plot_10_patches(img_np, true_vessel_np=False, indexes=None):
    
    random_indexes = np.random.randint(0, img_np.shape[0], size=5) if indexes is None else indexes

    n = len(random_indexes)
    
    fig, ax = plt.subplots(n, 2, figsize=(6, 3*n))
    
    for i, index in enumerate(random_indexes):
        
        if index == img_np.shape[0]:
            index -= 1
        if index == img_np.shape[0]-1:
            index -= 2
            
        #print(f"Shape: {img_np[index,:,:].shape}, Max: {img_np[index,:,:].max()}, Min: {img_np[index,:,:].min()}")
        ax[i, 0].set_title(f'True Vessel Slice {index}')
        ax[i, 0].imshow(img_np[index,:,:], cmap='gray', vmin=0, vmax=1)
        ax[i,0].axis('off')
        
        if true_vessel_np is not False:
            print(f"Shape: {true_vessel_np[index,:,:].shape}, Max: {true_vessel_np[index,:,:].max()}, Min: {true_vessel_np[index,:,:].min()}")
            ax[i, 1].imshow(true_vessel_np[index,:,:], cmap='gray')
            ax[i, 1].set_title(f'Image Slice {index}')
        else:
            ax[i, 1].set_title(f'Image Slice {index+1}')
            ax[i, 1].imshow(img_np[index+1,:,:], cmap='gray', vmin=0, vmax=1)
        
        ax[i,1].axis('off')
            
    
    plt.tight_layout()
    plt.show()

def plot_n_patches_overlap(img_np, true_vessel_np=False, indexes=None, selected_class=None, add_title='', m=5, alpha=0.5, save_dir='clusters_imgs'):
    
    save_dir = os.path.join(save_data_path,save_dir)
    mk_dir(save_dir)
    plt.ioff()
    n = len(indexes) if indexes is not None else 5
    # m is the number of images to plot in each row (2m is the number of columns)
    n_rows = int(np.ceil(n/m))
    n_cols = m
    
    fig, ax = plt.subplots(n_rows, n_cols, figsize=(4*n_cols, 4*n_rows))
    ax = ax.flatten()
    
    for i, index in enumerate(indexes):
        ax[i].set_title(f'Overlay for Slice {index}')
        masked_image = np.ma.masked_where(img_np[index,:,:] == 0, true_vessel_np[index,:,:])
        # Overlay the red image on top of true_vessel_np[index,:,:]
        ax[i].imshow(true_vessel_np[index,:,:], cmap='gray', interpolation='none')
        ax[i].imshow(masked_image, cmap='Reds', alpha=alpha)
        
        ax[i].axis('off')
    
    try:
        plt.tight_layout()
    except:
        pass
    
    if add_title!='':
        print(f"Save Fig to {save_dir}")
        plt.savefig(os.path.join(save_dir,f'{add_title}_class_{selected_class}_patches.png'))
        plt.close(fig)
        
    else:
        print("Plotting..")
        plt.show()
    plt.ion()

    
def plot_n_patches(img_np, true_vessel_np=False, indexes=None, selected_class=None, add_title='', m=5):
    plt.ioff()
    n = len(indexes) if indexes is not None else 5
    if n > 1000:
        print(f'WARNING: YOU ARE TRYING TO PLOT {n} images')
    # m is the number of images to plot in each row (2m is the number of columns)
    n_rows = int(np.ceil(n/m))
    n_cols = 2*m
    
    fig, ax = plt.subplots(n_rows, n_cols, figsize=(3*n_cols, 3*n_rows))
    ax = ax.flatten()
    
    for i, index in enumerate(indexes):
        
        if index == img_np.shape[0]:
            index -= 1
        if index == img_np.shape[0]-1:
            index -= 2
            
        ax[2*i].set_title(f'Image Slice {index}')
        ax[2*i].imshow(img_np[index,:,:], cmap='gray', vmin=0, vmax=1)
        ax[2*i].axis('off')
        print(f"Shape: {img_np[index,:,:].shape}, Max: {img_np[index,:,:].max()}, Min: {img_np[index,:,:].min()}")
        if true_vessel_np is not False:
            ax[2*i + 1].imshow(true_vessel_np[index,:,:], cmap='gray')
            ax[2*i + 1].set_title(f'True Vessel Slice {index}')
            
        else:
            ax[2*i + 1].set_title(f'Image Slice {index+1}')
            ax[2*i + 1].imshow(img_np[index+1,:,:], cmap='gray', vmin=0, vmax=1)
        
        ax[2*i+1].axis('off')
            
    
    plt.tight_layout()
    if add_title!='':
        print("Save Fig")
        plt.savefig(f'{add_title}_class_{selected_class}_patches.png')
        plt.close(fig)
        
    else:
        print("Plotting..")
        plt.show()
    plt.ion()
    
def reshape_to_square(vector):
    # Calculate the nearest square number greater than or equal to the length of the vector
    n = int(np.ceil(np.sqrt(len(vector))))
    
    # Calculate the number of elements to pad with zeros
    num_zeros = n*n - len(vector)
    
    # Pad the vector with zeros if necessary
    vector_padded = np.pad(vector, (0, num_zeros), mode='constant')
    
    # Reshape the padded vector into a square matrix
    square_matrix = vector_padded.reshape((n, n))
    
    return square_matrix

In [13]:
def extract_img_vessels_np2(dataset_img_dir):
    print(f"Extracting from {dataset_img_dir}")
    for i,img_name in enumerate(tqdm(os.listdir(dataset_img_dir))):
        if '32_img.npy' not in img_name:
            continue
    # load png as a numpy
        img_array = np.load(os.path.join(dataset_img_dir,img_name))
        vess_array = np.load(os.path.join(dataset_img_dir,img_name).replace('img','label'))
        if i==0:
            img_list_np = img_array
            vess_list_np = vess_array
        else:
            img_list_np = np.concatenate((img_list_np, img_array), axis=0)
            vess_list_np = np.concatenate((vess_list_np, vess_array), axis=0)
            
    print(f"Extracted {img_list_np.shape[0]} images and {vess_list_np.shape[0]} labels")
    
    return img_list_np, vess_list_np

def extract_empty_img_vessels(dataset_empty_img_dir):
    print(f"Extracting from {dataset_empty_img_dir}")
    for i,img_name in enumerate(tqdm(os.listdir(dataset_empty_img_dir))):
        if '32_img.npy' not in img_name:
            continue
    # load png as a numpy
        img_array = np.load(os.path.join(dataset_empty_img_dir,img_name))
        if i==0:
            img_list_np = img_array
        else:
            img_list_np = np.concatenate((img_list_np, img_array), axis=0)
    print(f"Extracted {img_list_np.shape[0]} images")
    
    return img_list_np


In [14]:
import matplotlib.pyplot as plt
import numpy as np
from scipy.spatial import KDTree

def label_point(x, y, ids, ax):
    """Annotate points on plot with their IDs."""
    for i, txt in enumerate(ids):
        ax.annotate(txt, (x[i], y[i]))

def interactive_plot(x, y, ids=None, colors=None, action='click', img_list=None, emb_list=None, true_img_list=None ,zoom=False, filtered_ids=[]):
    """Identify the ID of a point by clicking on it."""
    if ids is None:
        ids = [str(i) for i in range(len(x))]
    fig, ax = plt.subplots(2, 3, figsize=(12, 8)) if zoom else plt.subplots(1, 4, figsize=(16,4))
    ax = ax.flatten()
    
    ax[1].axis('off')
    ax[2].axis('off')
    ax[3].axis('off')
    
    ax[0].scatter(x, y, s=1, c=colors, cmap='viridis') if colors is not None else ax[0].scatter(x, y, s=1 if zoom else 0.1)
    # Set limit to the plot
    if zoom:
        ax[4].scatter(x, y, s=1, c=colors) if colors is not None else ax[4].scatter(x, y, s=1)
        ax[4].set_xlim([-100, 100])
        ax[4].set_ylim([-100, 100])
        ax[5].axis('off')
    # Set a threshold based on max values of x and y
    threshold = max(max(x) - min(x), max(y) - min(y)) / 80
    print(f"Threshold: {threshold}")
    #label_point(x, y, ids, ax[0])
    tree = KDTree(np.column_stack((x, y)))
    
    def onclick(event):
        """Event handler for mouse click."""
        if event.inaxes == ax[0]:
            dist, i = tree.query([event.xdata, event.ydata])
            if dist < threshold:
                ax[1].imshow(img_list[int(ids[i])], cmap='gray')
                id_img_title = ids[i] if len(filtered_ids) == 0 else filtered_ids[i]
                ax[1].set_title(f'Image {id_img_title} (color: {colors[i]})') if colors is not None else ax[1].set_title(f'Image {id_img_title}')
                ax[2].imshow(reshape_to_square(emb_list[int(ids[i])]), cmap='gray')
                ax[2].set_title(f'Embedding {id_img_title}')
                ax[3].imshow(true_img_list[int(ids[i])], cmap='gray')
                ax[3].set_title(f'True Image {id_img_title}')
                
                ax[1].axis('off')
                ax[2].axis('off')
                ax[3].axis('off')
        
    fig.tight_layout()
    if action == 'click':
        fig.canvas.mpl_connect('button_press_event', onclick)
    elif action == 'hover':
        fig.canvas.mpl_connect('motion_notify_event', onclick)
    
    plt.show()

In [15]:
def plot_2_clusters(embeddings_tsne, cluster_labels, cluster_labels_2d):
    fig, ax = plt.subplots(1, 2, figsize=(10, 5))
    ax[0].scatter(embeddings_tsne[:, 0], embeddings_tsne[:, 1], c=cluster_labels, s=1)
    ax[0].set_title('Clustering Results (t-SNE embedding)')

    ax[1].scatter(embeddings_tsne[:, 0], embeddings_tsne[:, 1], c=cluster_labels_2d, s=1)
    ax[1].set_title('Clustering Results (t-SNE embedding 2D)')

    plt.show()

In [16]:
def filter_class(embeddings, classes, class_label):
    filtered_embeddings = []
    filtered_classes = []
    filtered_indices = []

    for i, (embedding, class_value) in enumerate(zip(embeddings, classes)):
        if str(class_value) == str(class_label):
            #print(f"IDX {i} -- Class value: {class_value} - Class label: {class_label}")
            filtered_embeddings.append(embedding)
            filtered_classes.append(class_value)
            filtered_indices.append(i)

    filtered_embeddings = np.array(filtered_embeddings)
    filtered_classes = np.array(filtered_classes)
    filtered_indices = np.array(filtered_indices)
    
    return filtered_embeddings, filtered_classes, filtered_indices


In [17]:
def interactive_plot_filtered(x, y, ids=None, colors=None, action='click', selected_class=None, img_list=None, emb_list=None, true_img_list=None, do='plot'):
    x, filtered_colors, indices = filter_class(x,colors,selected_class)
    #assert set(filtered_colors) == {selected_class}
    #print(f"Number of elements in the class {selected_class}: {len(x)}")
    #print(f"Indices of class {selected_class}: {indices}")
    y = y[indices]
    img_list = img_list[indices]
    emb_list = emb_list[indices]
    true_img_list = true_img_list[indices]
    
    if do == 'plot':
        print("Interactive plot...")
        interactive_plot(x, y, action=action, img_list=img_list, emb_list=emb_list, true_img_list=true_img_list, zoom=True, filtered_ids=indices)
    
    return img_list, emb_list, true_img_list, indices

from matplotlib.widgets import Slider
plt.ion() 

def plot_slices_with_cursor(img_np, mask_np, vessel_np, grid_np=None, cursor_position=10, indeces = []):
    if grid_np is None:
        
        fig, ax = plt.subplots(1, 3, figsize=(15, 5))
        
        def update(val):
            slice_index = int(slider.val)
            title = slice_index if len(indeces) == 0 else indeces[slice_index]
            ax[0].imshow(img_np[slice_index], cmap='gray')
            ax[0].set_title(f'Image Slice {title}')
            ax[1].imshow(reshape_to_square(mask_np[slice_index]), cmap='gray')
            ax[1].set_title(f'Mask Slice {title}')
            ax[2].imshow(vessel_np[slice_index], cmap='gray')
            ax[2].set_title(f'Embedding Slice {title}')
            fig.canvas.draw_idle()

        slider_ax = plt.axes([0.1, 0.01, 0.65, 0.03])
        slider = Slider(slider_ax, 'Slice', 0, img_np.shape[0] - 1, valinit=cursor_position, valstep=1)
        slider.on_changed(update)
        plt.subplots_adjust(bottom=0.2)

        # Initialize the plot
        update(cursor_position)

        slider_ax = plt.axes([0.1, 0.01, 0.65, 0.03])
        slider = Slider(slider_ax, 'Slice', 0, img_np.shape[2] - 1, valinit=cursor_position, valstep=1)
        slider.on_changed(update)
        plt.subplots_adjust(bottom=0.2)

        # Initialize the plot
        update(cursor_position)
        
    
    plt.show()
    

In [18]:
import numpy as np

def train_test_split_arrays(*arrays, test_size=0.2, random_state=None):
    """
    Split numpy arrays along the first axis into random train and test subsets.

    Parameters:
    *arrays : array-like
        Arrays to be split. All arrays must have the same size along the first axis.
    test_size : float or int, default=0.2
        If float, should be between 0.0 and 1.0 and represent the proportion of the dataset to include in the test split.
        If int, represents the absolute number of test samples.
    random_state : int or RandomState instance, default=None
        Controls the randomness of the training and testing indices.

    Returns:
    tuple of arrays
        Tuple containing train-test split of input arrays.
    """
    # Check if all arrays have the same size along the first axis
    first_axis_lengths = [arr.shape[0] for arr in arrays]
    if len(set(first_axis_lengths)) != 1:
        raise ValueError("All input arrays must have the same size along the first axis.")

    assert first_axis_lengths[0] > 0, "The size of the first axis should be greater than 0."
    
    # Determine the size of the test set
    if isinstance(test_size, float):
        test_size = int(test_size * first_axis_lengths[0])
    elif isinstance(test_size, int):
        if test_size < 0 or test_size > first_axis_lengths[0]:
            raise ValueError("test_size should be a positive integer less than or equal to the size of the first axis.")
    else:
        raise ValueError("test_size should be either float or int.")
    
    test_size = test_size if test_size > 0 else 1
    
    # Generate random indices for the test set
    rng = np.random.RandomState(random_state)
    indices = np.arange(first_axis_lengths[0])
    rng.shuffle(indices)
    test_indices = indices[:test_size]
    train_indices = indices[test_size:]

    # Split arrays
    train_arrays = tuple(arr[train_indices] for arr in arrays)
    test_arrays = tuple(arr[test_indices] for arr in arrays)

    return train_arrays, test_arrays
import numpy as np

## 1 - Load Data

### Load and save Data

In [None]:
root_dir = '/home/falcetta/0_PhD/sparse_var/vdisnet/'

dataset_name = 'IXI'
dataset_name = 'IXI_weak'

root_data_dir = '/data/falcetta/brain_data/IXIJ/processed'
dataset_img_dir_general = os.path.join(root_data_dir,'patches_preprocessed_with_empty')
if 'weak' in dataset_name:
    dataset_img_dir = os.path.join(dataset_img_dir_general, '5_dil_kmeans_seg_weak_patch_extraction')
else:
    dataset_img_dir = os.path.join(dataset_img_dir_general, '5_seg_true_patch_extraction')
    
save_data_path = os.path.join(root_data_dir,f'vessels_{dataset_name}')

take = 'all'

mk_dir(save_data_path)


print(f"Extracting images and vessels from {dataset_img_dir}")
print(f"Saving outputs to {os.getcwd()}/{save_data_path}")

In [None]:
if not do_skip:
    
    #img_list_np, data_list_np = extract_img_vessels_np(data_list,dataset_img_dir,dataset_dir)
    img_list_np, ves_list_np = extract_img_vessels_np2(dataset_img_dir)
    
    
    # Take just 1000 at random
    if take == 'all':
        pass
    else:
        print(f"Taking {take} at random")
        np.random.seed(0)
        random_idx = np.random.choice(img_list_np.shape[0], take, replace=False)
        
        img_list_np = img_list_np[random_idx]
        ves_list_np = ves_list_np[random_idx]

    
    
    print(f"Vessel list length: {len(ves_list_np)} - Img list length: {len(img_list_np)}")


In [None]:
if not do_skip:
    dataset_empty_img_dir = dataset_img_dir = os.path.join(dataset_img_dir_general, '5_seg_empty_patch_extraction')

    empty_img_list_np = extract_empty_img_vessels(dataset_empty_img_dir)
    print(f'Empty img list shape: {empty_img_list_np.shape}')
    assert empty_img_list_np.shape[0] == img_list_np.shape[0]//2, "The number of empty images should be half the number of images"

    if take == 'all':
        pass
    else:
        print(f"Taking {take} empty images")
        np.random.seed(0)
        #Take just 1000 at random
        random_indices = np.random.choice(empty_img_list_np.shape[0], int(take), replace=False)
        empty_img_list_np = empty_img_list_np[random_indices]
        
    empty_vess_list_np = np.zeros_like(empty_img_list_np)
    print(f'Empty img list shape: {empty_img_list_np.shape}, Data list shape: {empty_vess_list_np.shape}')

In [None]:
if not do_skip:
    print("Plotting Empty (No Vessels) Patches")
    plot_10_patches(empty_vess_list_np, empty_img_list_np)

In [None]:
if not do_skip:
    print("Plotting Vessel Patches")
    print(f'img_list_np shape: {img_list_np.shape}')
    print(f'data_list_np shape: {ves_list_np.shape}')

    plot_10_patches(ves_list_np, img_list_np)

In [None]:
if not do_skip:
    np.save(os.path.join(save_data_path, f'vess_list_{dataset_name}_{take}.npy'), ves_list_np)
    np.save(os.path.join(save_data_path, f'img_list_{dataset_name}_{take}.npy'), img_list_np)

    np.save(os.path.join(save_data_path, f'empty_vess_list_{dataset_name}_{take}.npy'), empty_vess_list_np)
    np.save(os.path.join(save_data_path, f'empty_img_list_{dataset_name}_{take}.npy'), empty_img_list_np)

    print(f"Data saved in {save_data_path}")

### LOAD SAVED DATA (CHECK SHAPE)

In [None]:
vess_list_np = np.load(os.path.join(save_data_path, f'vess_list_{dataset_name}_{take}.npy'))
img_list_np = np.load(os.path.join(save_data_path, f'img_list_{dataset_name}_{take}.npy'))

print(f"Data loaded from {save_data_path}")
print(f'data_list_np shape: {vess_list_np.shape}')
print(f'img_list_np shape: {img_list_np.shape}')

In [None]:
empty_vess_list_np = np.load(os.path.join(save_data_path, f'empty_vess_list_{dataset_name}_{take}.npy'))
empty_img_list_np = np.load(os.path.join(save_data_path, f'empty_img_list_{dataset_name}_{take}.npy'))

print(f"Data loaded from {save_data_path}")
print(f'empty_data_list_np shape: {empty_vess_list_np.shape}')
print(f'empty_img_list_np shape: {empty_img_list_np.shape}')