In [4]:
do_skip=True
do_plot=True
do_stats_plot=True
do_optional_plots=False

## 0 - Import libraries

In [5]:
import os
import random
import numpy as np
import torch
from torch.utils.data import DataLoader

torch.backends.cudnn.benchmark = True

from tqdm import tqdm

from utils import set_random_seed, mk_dir
from importlib import import_module

from torchvision.utils import save_image

## IMPORT FUNCTIONS

In [6]:
import matplotlib.pyplot as plt
plt.ion() 

def save_img(tensor, name, norm, n_rows=16, scale_each=False):
    save_image(tensor, name, nrow=n_rows, padding=5, normalize=norm, pad_value=1, scale_each=scale_each)
    
def plot_10_patches(img_np, true_vessel_np=False, indexes=None):
    
    random_indexes = np.random.randint(0, img_np.shape[0], size=5) if indexes is None else indexes

    n = len(random_indexes)
    
    fig, ax = plt.subplots(n, 2, figsize=(6, 3*n))
    
    for i, index in enumerate(random_indexes):
        
        if index == img_np.shape[0]:
            index -= 1
        if index == img_np.shape[0]-1:
            index -= 2
            
        #print(f"Shape: {img_np[index,:,:].shape}, Max: {img_np[index,:,:].max()}, Min: {img_np[index,:,:].min()}")
        ax[i, 0].set_title(f'True Vessel Slice {index}')
        ax[i, 0].imshow(img_np[index,:,:], cmap='gray', vmin=0, vmax=1)
        ax[i,0].axis('off')
        
        if true_vessel_np is not False:
            print(f"Shape: {true_vessel_np[index,:,:].shape}, Max: {true_vessel_np[index,:,:].max()}, Min: {true_vessel_np[index,:,:].min()}")
            ax[i, 1].imshow(true_vessel_np[index,:,:], cmap='gray')
            ax[i, 1].set_title(f'Image Slice {index}')
        else:
            ax[i, 1].set_title(f'Image Slice {index+1}')
            ax[i, 1].imshow(img_np[index+1,:,:], cmap='gray', vmin=0, vmax=1)
        
        ax[i,1].axis('off')
            
    
    plt.tight_layout()
    plt.show()

def plot_n_patches_overlap(img_np, true_vessel_np=False, indexes=None, selected_class=None, add_title='', m=5, alpha=0.5, save_dir='clusters_imgs'):
    
    save_dir = os.path.join(save_embeddings_path,save_dir)
    mk_dir(save_dir)
    plt.ioff()
    n = len(indexes) if indexes is not None else 5
    # m is the number of images to plot in each row (2m is the number of columns)
    n_rows = int(np.ceil(n/m))
    n_cols = m
    
    fig, ax = plt.subplots(n_rows, n_cols, figsize=(4*n_cols, 4*n_rows))
    ax = ax.flatten()
    
    for i, index in enumerate(indexes):
        ax[i].set_title(f'Overlay for Slice {index}')
        masked_image = np.ma.masked_where(img_np[index,:,:] == 0, true_vessel_np[index,:,:])
        # Overlay the red image on top of true_vessel_np[index,:,:]
        ax[i].imshow(true_vessel_np[index,:,:], cmap='gray', interpolation='none')
        ax[i].imshow(masked_image, cmap='Reds', alpha=alpha)
        
        ax[i].axis('off')
    
    try:
        plt.tight_layout()
    except:
        pass
    
    if add_title!='':
        print(f"Save Fig to {save_dir}")
        plt.savefig(os.path.join(save_dir,f'{add_title}_class_{selected_class}_patches.png'))
        plt.close(fig)
        
    else:
        print("Plotting..")
        plt.show()
    plt.ion()

    
def plot_n_patches(img_np, true_vessel_np=False, indexes=None, selected_class=None, add_title='', m=5):
    plt.ioff()
    n = len(indexes) if indexes is not None else 5
    if n > 1000:
        print(f'WARNING: YOU ARE TRYING TO PLOT {n} images')
    # m is the number of images to plot in each row (2m is the number of columns)
    n_rows = int(np.ceil(n/m))
    n_cols = 2*m
    
    fig, ax = plt.subplots(n_rows, n_cols, figsize=(3*n_cols, 3*n_rows))
    ax = ax.flatten()
    
    for i, index in enumerate(indexes):
        
        if index == img_np.shape[0]:
            index -= 1
        if index == img_np.shape[0]-1:
            index -= 2
            
        ax[2*i].set_title(f'Image Slice {index}')
        ax[2*i].imshow(img_np[index,:,:], cmap='gray', vmin=0, vmax=1)
        ax[2*i].axis('off')
        print(f"Shape: {img_np[index,:,:].shape}, Max: {img_np[index,:,:].max()}, Min: {img_np[index,:,:].min()}")
        if true_vessel_np is not False:
            ax[2*i + 1].imshow(true_vessel_np[index,:,:], cmap='gray')
            ax[2*i + 1].set_title(f'True Vessel Slice {index}')
            
        else:
            ax[2*i + 1].set_title(f'Image Slice {index+1}')
            ax[2*i + 1].imshow(img_np[index+1,:,:], cmap='gray', vmin=0, vmax=1)
        
        ax[2*i+1].axis('off')
            
    
    plt.tight_layout()
    if add_title!='':
        print("Save Fig")
        plt.savefig(f'{add_title}_class_{selected_class}_patches.png')
        plt.close(fig)
        
    else:
        print("Plotting..")
        plt.show()
    plt.ion()
    
def reshape_to_square(vector):
    # Calculate the nearest square number greater than or equal to the length of the vector
    n = int(np.ceil(np.sqrt(len(vector))))
    
    # Calculate the number of elements to pad with zeros
    num_zeros = n*n - len(vector)
    
    # Pad the vector with zeros if necessary
    vector_padded = np.pad(vector, (0, num_zeros), mode='constant')
    
    # Reshape the padded vector into a square matrix
    square_matrix = vector_padded.reshape((n, n))
    
    return square_matrix

In [7]:
def extract_img_vessels_np2(dataset_img_dir):
    print(f"Extracting from {dataset_img_dir}")
    for i,img_name in enumerate(tqdm(os.listdir(dataset_img_dir))):
        if '32_img.npy' not in img_name:
            continue
    # load png as a numpy
        img_array = np.load(os.path.join(dataset_img_dir,img_name))
        vess_array = np.load(os.path.join(dataset_img_dir,img_name).replace('img','label'))
        if i==0:
            img_list_np = img_array
            vess_list_np = vess_array
        else:
            img_list_np = np.concatenate((img_list_np, img_array), axis=0)
            vess_list_np = np.concatenate((vess_list_np, vess_array), axis=0)
            
    print(f"Extracted {img_list_np.shape[0]} images and {vess_list_np.shape[0]} labels")
    
    return img_list_np, vess_list_np

def extract_empty_img_vessels(dataset_empty_img_dir):
    print(f"Extracting from {dataset_empty_img_dir}")
    for i,img_name in enumerate(tqdm(os.listdir(dataset_empty_img_dir))):
        if '32_img.npy' not in img_name:
            continue
    # load png as a numpy
        img_array = np.load(os.path.join(dataset_empty_img_dir,img_name))
        if i==0:
            img_list_np = img_array
        else:
            img_list_np = np.concatenate((img_list_np, img_array), axis=0)
    print(f"Extracted {img_list_np.shape[0]} images")
    
    return img_list_np


In [8]:
import matplotlib.pyplot as plt
import numpy as np
from scipy.spatial import KDTree

def label_point(x, y, ids, ax):
    """Annotate points on plot with their IDs."""
    for i, txt in enumerate(ids):
        ax.annotate(txt, (x[i], y[i]))

def interactive_plot(x, y, ids=None, colors=None, action='click', img_list=None, emb_list=None, true_img_list=None ,zoom=False, filtered_ids=[]):
    """Identify the ID of a point by clicking on it."""
    if ids is None:
        ids = [str(i) for i in range(len(x))]
    fig, ax = plt.subplots(2, 3, figsize=(12, 8)) if zoom else plt.subplots(1, 4, figsize=(16,4))
    ax = ax.flatten()
    
    ax[1].axis('off')
    ax[2].axis('off')
    ax[3].axis('off')
    
    ax[0].scatter(x, y, s=1, c=colors, cmap='viridis') if colors is not None else ax[0].scatter(x, y, s=1 if zoom else 0.1)
    # Set limit to the plot
    if zoom:
        ax[4].scatter(x, y, s=1, c=colors) if colors is not None else ax[4].scatter(x, y, s=1)
        ax[4].set_xlim([-100, 100])
        ax[4].set_ylim([-100, 100])
        ax[5].axis('off')
    # Set a threshold based on max values of x and y
    threshold = max(max(x) - min(x), max(y) - min(y)) / 80
    print(f"Threshold: {threshold}")
    #label_point(x, y, ids, ax[0])
    tree = KDTree(np.column_stack((x, y)))
    
    def onclick(event):
        """Event handler for mouse click."""
        if event.inaxes == ax[0]:
            dist, i = tree.query([event.xdata, event.ydata])
            if dist < threshold:
                ax[1].imshow(img_list[int(ids[i])], cmap='gray')
                id_img_title = ids[i] if len(filtered_ids) == 0 else filtered_ids[i]
                ax[1].set_title(f'Image {id_img_title} (color: {colors[i]})') if colors is not None else ax[1].set_title(f'Image {id_img_title}')
                ax[2].imshow(reshape_to_square(emb_list[int(ids[i])]), cmap='gray')
                ax[2].set_title(f'Embedding {id_img_title}')
                ax[3].imshow(true_img_list[int(ids[i])], cmap='gray')
                ax[3].set_title(f'True Image {id_img_title}')
                
                ax[1].axis('off')
                ax[2].axis('off')
                ax[3].axis('off')
        
    fig.tight_layout()
    if action == 'click':
        fig.canvas.mpl_connect('button_press_event', onclick)
    elif action == 'hover':
        fig.canvas.mpl_connect('motion_notify_event', onclick)
    
    plt.show()

In [9]:
def plot_2_clusters(embeddings_tsne, cluster_labels, cluster_labels_2d):
    fig, ax = plt.subplots(1, 2, figsize=(10, 5))
    ax[0].scatter(embeddings_tsne[:, 0], embeddings_tsne[:, 1], c=cluster_labels, s=1)
    ax[0].set_title('Clustering Results (t-SNE embedding)')

    ax[1].scatter(embeddings_tsne[:, 0], embeddings_tsne[:, 1], c=cluster_labels_2d, s=1)
    ax[1].set_title('Clustering Results (t-SNE embedding 2D)')

    plt.show()

In [10]:
def filter_class(embeddings, classes, class_label):
    filtered_embeddings = []
    filtered_classes = []
    filtered_indices = []

    for i, (embedding, class_value) in enumerate(zip(embeddings, classes)):
        if str(class_value) == str(class_label):
            #print(f"IDX {i} -- Class value: {class_value} - Class label: {class_label}")
            filtered_embeddings.append(embedding)
            filtered_classes.append(class_value)
            filtered_indices.append(i)

    filtered_embeddings = np.array(filtered_embeddings)
    filtered_classes = np.array(filtered_classes)
    filtered_indices = np.array(filtered_indices)
    
    return filtered_embeddings, filtered_classes, filtered_indices


In [11]:
def interactive_plot_filtered(x, y, ids=None, colors=None, action='click', selected_class=None, img_list=None, emb_list=None, true_img_list=None, do='plot'):
    x, filtered_colors, indices = filter_class(x,colors,selected_class)
    #assert set(filtered_colors) == {selected_class}
    #print(f"Number of elements in the class {selected_class}: {len(x)}")
    #print(f"Indices of class {selected_class}: {indices}")
    y = y[indices]
    img_list = img_list[indices]
    emb_list = emb_list[indices]
    true_img_list = true_img_list[indices]
    
    if do == 'plot':
        print("Interactive plot...")
        interactive_plot(x, y, action=action, img_list=img_list, emb_list=emb_list, true_img_list=true_img_list, zoom=True, filtered_ids=indices)
    
    return img_list, emb_list, true_img_list, indices

from matplotlib.widgets import Slider
plt.ion() 

def plot_slices_with_cursor(img_np, mask_np, vessel_np, grid_np=None, cursor_position=10, indeces = []):
    if grid_np is None:
        
        fig, ax = plt.subplots(1, 3, figsize=(15, 5))
        
        def update(val):
            slice_index = int(slider.val)
            title = slice_index if len(indeces) == 0 else indeces[slice_index]
            ax[0].imshow(img_np[slice_index], cmap='gray')
            ax[0].set_title(f'Image Slice {title}')
            ax[1].imshow(reshape_to_square(mask_np[slice_index]), cmap='gray')
            ax[1].set_title(f'Mask Slice {title}')
            ax[2].imshow(vessel_np[slice_index], cmap='gray')
            ax[2].set_title(f'Embedding Slice {title}')
            fig.canvas.draw_idle()

        slider_ax = plt.axes([0.1, 0.01, 0.65, 0.03])
        slider = Slider(slider_ax, 'Slice', 0, img_np.shape[0] - 1, valinit=cursor_position, valstep=1)
        slider.on_changed(update)
        plt.subplots_adjust(bottom=0.2)

        # Initialize the plot
        update(cursor_position)

        slider_ax = plt.axes([0.1, 0.01, 0.65, 0.03])
        slider = Slider(slider_ax, 'Slice', 0, img_np.shape[2] - 1, valinit=cursor_position, valstep=1)
        slider.on_changed(update)
        plt.subplots_adjust(bottom=0.2)

        # Initialize the plot
        update(cursor_position)
        
    
    plt.show()
    

In [12]:
import numpy as np

def train_test_split_arrays(*arrays, test_size=0.2, random_state=None):
    """
    Split numpy arrays along the first axis into random train and test subsets.

    Parameters:
    *arrays : array-like
        Arrays to be split. All arrays must have the same size along the first axis.
    test_size : float or int, default=0.2
        If float, should be between 0.0 and 1.0 and represent the proportion of the dataset to include in the test split.
        If int, represents the absolute number of test samples.
    random_state : int or RandomState instance, default=None
        Controls the randomness of the training and testing indices.

    Returns:
    tuple of arrays
        Tuple containing train-test split of input arrays.
    """
    # Check if all arrays have the same size along the first axis
    first_axis_lengths = [arr.shape[0] for arr in arrays]
    if len(set(first_axis_lengths)) != 1:
        raise ValueError("All input arrays must have the same size along the first axis.")

    assert first_axis_lengths[0] > 0, "The size of the first axis should be greater than 0."
    
    # Determine the size of the test set
    if isinstance(test_size, float):
        test_size = int(test_size * first_axis_lengths[0])
    elif isinstance(test_size, int):
        if test_size < 0 or test_size > first_axis_lengths[0]:
            raise ValueError("test_size should be a positive integer less than or equal to the size of the first axis.")
    else:
        raise ValueError("test_size should be either float or int.")
    
    test_size = test_size if test_size > 0 else 1
    
    # Generate random indices for the test set
    rng = np.random.RandomState(random_state)
    indices = np.arange(first_axis_lengths[0])
    rng.shuffle(indices)
    test_indices = indices[:test_size]
    train_indices = indices[test_size:]

    # Split arrays
    train_arrays = tuple(arr[train_indices] for arr in arrays)
    test_arrays = tuple(arr[test_indices] for arr in arrays)

    return train_arrays, test_arrays
import numpy as np

## 1 - Load Data

### Load and save Data

In [None]:
root_dir = '/home/falcetta/0_PhD/sparse_var/vdisnet/'

path_data = "/data/falcetta/brain_data"

preprocessed_data_dir = os.path.join(path_data, "IXIJ/processed/patches_preprocessed_with_empty") 
    
dataset_img_dir = os.path.join(path_data, "IXIJ/processed/patches_preprocessed_with_empty/5_seg_true_patch_extraction") 

dataset_name = 'IXI'
save_embeddings_path = os.path.join(path_data, f"embeddings_VDISNET") 

take = 'all'

mk_dir(save_embeddings_path)


print(f"Extracting images and vessels from {dataset_img_dir}")
print(f"Saving outputs to {os.getcwd()}/{save_embeddings_path}")

In [14]:
if not do_skip:

    img_list_np, ves_list_np = extract_img_vessels_np2(dataset_img_dir)
    
    # Take just 1000 at random
    if take == 'all':
        pass
    else:
        print(f"Taking {take} at random")
        np.random.seed(0)
        random_idx = np.random.choice(img_list_np.shape[0], take, replace=False)
        
        img_list_np = img_list_np[random_idx]
        ves_list_np = ves_list_np[random_idx]

    
    
    print(f"Vessel list length: {len(ves_list_np)} - Img list length: {len(img_list_np)}")


In [15]:
if not do_skip:
    dataset_empty_img_dir = dataset_img_dir.replace('true','empty')

    empty_img_list_np = extract_empty_img_vessels(dataset_empty_img_dir)
    print(f'Empty img list shape: {empty_img_list_np.shape}')


    if take == 'all':
        pass
    else:
        print(f"Taking {take} empty images")
        np.random.seed(0)
        #Take just 1000 at random
        random_indices = np.random.choice(empty_img_list_np.shape[0], int(take), replace=False)
        empty_img_list_np = empty_img_list_np[random_indices]
        
    empty_vess_list_np = np.zeros_like(empty_img_list_np)
    print(f'Empty img list shape: {empty_img_list_np.shape}, Data list shape: {empty_vess_list_np.shape}')

In [16]:
if not do_skip:
    print("Plotting Empty (No Vessels) Patches")
    plot_10_patches(empty_vess_list_np, empty_img_list_np)

In [17]:
if not do_skip:
    print("Plotting Vessel Patches")
    print(f'img_list_np shape: {img_list_np.shape}')
    print(f'data_list_np shape: {ves_list_np.shape}')

    plot_10_patches(ves_list_np, img_list_np)

In [18]:
if not do_skip:
    np.save(os.path.join(save_embeddings_path, f'vess_list_{dataset_name}_{take}.npy'), ves_list_np)
    np.save(os.path.join(save_embeddings_path, f'img_list_{dataset_name}_{take}.npy'), img_list_np)

    np.save(os.path.join(save_embeddings_path, f'empty_vess_list_{dataset_name}_{take}.npy'), empty_vess_list_np)
    np.save(os.path.join(save_embeddings_path, f'empty_img_list_{dataset_name}_{take}.npy'), empty_img_list_np)

    print(f"Data saved in {save_embeddings_path}")

### LOAD SAVED DATA

In [None]:
vess_list_np = np.load(os.path.join(save_embeddings_path, f'vess_list_{dataset_name}_{take}.npy'))
img_list_np = np.load(os.path.join(save_embeddings_path, f'img_list_{dataset_name}_{take}.npy'))

print(f"Data loaded from {save_embeddings_path}")
print(f'data_list_np shape: {vess_list_np.shape}')
print(f'img_list_np shape: {img_list_np.shape}')

In [None]:
empty_vess_list_np = np.load(os.path.join(save_embeddings_path, f'empty_vess_list_{dataset_name}_{take}.npy'))
empty_img_list_np = np.load(os.path.join(save_embeddings_path, f'empty_img_list_{dataset_name}_{take}.npy'))

print(f"Data loaded from {save_embeddings_path}")
print(f'empty_data_list_np shape: {empty_vess_list_np.shape}')
print(f'empty_img_list_np shape: {empty_img_list_np.shape}')

## 2 - Create Dictionart Learning Embeddings

### Load Encoder

In [None]:
OUTPUT_PATH = os.path.join(root_dir, "results")
OUTPUT_PATH = os.path.join(os.path.dirname(OUTPUT_PATH),"",os.path.basename(OUTPUT_PATH)) # Add "CODE_DIMS" if needed
date = '2024-07-05'
dtime = '15-57-09' # SDL
#dtime = '16-11-06' # SDL-NL
datetime = os.path.join(date, dtime)

pretrained_path_enc = f'{OUTPUT_PATH}/{datetime}/checkpoints/ENC_best.pth'

assert os.path.exists(pretrained_path_enc), f"Pretrained encoder not found at {pretrained_path_enc}"


args = {
    "batch_size": 200,
    "code_dim": 256,
    "code_reg": 1,
    "cuda": True,
    "dataset": "VPATCHES",
    "decoder": "linear_dictionary",
    "encoder": "lista_encoder",
    "epochs": 1,
    "FISTA": "--FISTA",
    "hidden_dim": 0,
    "hinge_threshold": 0.5,
    "im_size": 32,
    "lrt_D": 0.001,
    "lrt_E": 0.0003,
    "lrt_Z": 1,
    "n_steps_inf": 200,
    "noise": [1],
    "norm_decoder": 1,
    "num_iter_LISTA": 3,
    "num_workers": 4,
    "outdir": OUTPUT_PATH,
    "patch_size": 0,
    "positive_ISTA": "--positive_ISTA",
    "seed": 31,
    "sparsity_reg": 0.005,
    "stop_early": 0.001,
    "pretrained_path_enc": pretrained_path_enc,
    "use_Zs_enc_as_init": "use_Zs_enc_as_init",
    "variance_reg": 0,
    "weight_decay_D": 0,
    "weight_decay_E": 0,
    "weight_decay_E_bias": 0,
    "weight_decay_D_bias": 0
}

print(f"Loading model from experiment {datetime} from {OUTPUT_PATH}")

In [22]:
class DotAccessibleDict:
    def __init__(self, d):
        self.__dict__['_dict'] = d

    def __getattr__(self, name):
        try:
            return self._dict[name]
        except KeyError:
            raise AttributeError(f"'DotAccessibleDict' object has no attribute '{name}'")

    def __setattr__(self, name, value):
        self._dict[name] = value
        

args = DotAccessibleDict(args)

args.n_channels = 1
args.train_decoder = False
args.train_encoder = False
args.code_dim = 128

In [None]:
print("----- START -----")
# Get arguments
print("--- Step 0: Get arguments")
# Create directory structure
day = args.pretrained_path_enc.split('/')[-4]
exp_time = args.pretrained_path_enc.split('/')[-3]

outdir = os.path.join(save_embeddings_path, "embeddings" ,day, exp_time)
mk_dir(outdir)

print(f"Embeddings will be saved in {outdir}")

### skip

In [24]:
if not do_skip:
    # More logistics
    device = torch.device("cuda" if args.cuda else "cpu")
    print(f"Device: {device}")

    # Random seed
    print("--- Step 2: Set random seed")
    set_random_seed(args.seed, torch, np, random, args.cuda)

In [25]:
from torch.utils.data import Dataset
from torchvision import transforms

class MasksDataset(Dataset):
    
    def __init__(self, masks_array, transform=None):
        self.masks_array = masks_array
        self.transform = transform
        self.compute_metadata()
         
        
    def __getitem__(self,index):
        # Select random index from the dataset
        img0 = self.masks_array[index]
        
        # Convert numpy arrays to PIL Images
        #img0 = Image.fromarray(img0)
        
        # Apply transformations if provided
        if self.transform is not None:
            img0 = self.transform(img0.astype(np.float32))
        
        # Normalize images
        img0 = (img0 - self.mean) / self.std
        
        # Standardize to 0-1
        img0 = (img0 - img0.min()) / (img0.max() - img0.min())
            
        # Flag: 1 if same class, 0 if different class
        
        return img0
    
    def compute_metadata(self):
        flattened_data = self.masks_array.reshape(self.masks_array.shape[0], -1)
        means = np.mean(flattened_data, axis=0)
        stds = np.std(flattened_data, axis=0)
        
        # compute the single mean and std
        mean = np.mean(means)
        std = np.mean(stds)
        
        #convert to pytorch
        self.mean = torch.tensor(mean)
        self.std = torch.tensor(std)
        
        print(f"Mean: {self.mean} - Std: {self.std}")
        return self.mean, self.std
    
    
    def __len__(self):
        return len(self.masks_array)
    
    def __shape__(self):
        return self.masks_array.shape



In [26]:
if not do_skip:
    masks_dataset = MasksDataset(vess_list_np,transform=transforms.ToTensor())
    print(f"Dataset shape: {masks_dataset.__shape__()}")
    vessel_metadata = masks_dataset.compute_metadata()
    
    np.save(os.path.join(save_embeddings_path, f'VESSEL_metadata_{dataset_name}.npy'), vessel_metadata)
    print(f"Metadata saved in {save_embeddings_path}")


In [None]:
vessel_metadata = np.load(os.path.join(save_embeddings_path, f'VESSEL_metadata_{dataset_name}.npy'), allow_pickle=True)
print(f"Metadata loaded from {save_embeddings_path}: {vessel_metadata}")

In [28]:
if not do_skip:
    # Assuming you have a dataset called `masks_dataset`
    test_batch_size = 1
    test_shuffle = False

    data_emb = DataLoader(masks_dataset, batch_size=test_batch_size, shuffle=test_shuffle)
    args.n_channels = 1

#### DECODER (For Dict Viz)

In [29]:
if not do_skip:
    # replace 
    args.pretrained_path_dec = args.pretrained_path_enc.replace('ENC_best','DEC_best')
    print(f'Loading pre-trained enc from {args.pretrained_path_enc}')
    print(f'Loading pre-trained dec from {args.pretrained_path_dec}')

    # Decoder
    decoder = getattr(import_module('models.{}'.format(args.decoder)), 'Decoder')(args).to(device)

    # Load pretrained decoder, turn off gradients if not training it
    print("--- Step 4: Load pretrained decoder, turn off gradients if not training it")
    decoder.load_pretrained(args.pretrained_path_dec, freeze=not(args.train_decoder))

    # If not training the decoder, put it in eval() mode and remove gradient tracking
    decoder.eval()
    decoder.requires_grad_(False)
        


    n_samples = decoder.get_n_atoms()
    cols1, cols2 = decoder.viz_columns2(n_samples, norm_each=True) #norm_each=True to normalize each column (max value to 1, min value to 0) (default: False)
    #print("Saving decoder columns 1")

    save_img(cols1, os.path.join(outdir, "complete_dict.png"),
                norm=False, n_rows=int(2 ** (np.log2(n_samples) // 2)))

    if cols2 != 0:
        save_img(cols2, os.path.join(outdir, "complete_dict_hidden.png"),
                norm=False, n_rows=int(2 ** (np.log2(n_samples) // 2)))
        
        

# TODO: YOU NEED TO CREATE A COPY OF MODELS INSIDE GET_EMBEDDING_DL TO ACCESS THE MODELS !!! (IS THERE A BETTER WAY?)

#### Encoder (For latent space extration)

In [30]:
if not do_skip:
    # Logistics: data
    print(f"Number of channels: {args.n_channels}")
        
    # Encoder
    encoder = getattr(import_module('models.{}'.format(args.encoder)), 'Encoder')(args).to(device)
    assert not(args.train_encoder) and len(args.pretrained_path_enc) > 0
    # Load pretrained encoder, turn off gradients if not training it
    print("--- Step 6: Load pretrained encoder, turn off gradients if not training it")
    encoder.load_pretrained(args.pretrained_path_enc, freeze=not(args.train_encoder))
    # If not training the encoder, put it in eval() mode and remove gradient tracking
    encoder.eval()
    encoder.requires_grad_(False)



In [31]:
def inverse_transform(X):
    #datadir = '/home/falcetta/0_PhD/sparse_var/vdisnet/data/VESSEL' # NON MI PIACE MA PER ORA VA BENE
    #mean, std = np.load(os.path.join(datadir, 'VESSEL_mean_std.npy'))
    
    mean, std = vessel_metadata
    
    return X * std + mean


# TODO: WHERE I CREATE VESSEL_mean_std????

In [32]:
def PSNR(target, pred, tar_sample_mean=None, tar_sample_std=None, pred_sample_mean=None, pred_sample_std=None,
         R=1, dummy=1e-4, reduction='mean', binary_output=False):
    
    assert target.shape == pred.shape, f"Target shape: {target.shape}, pred shape: {pred.shape}"
    
        
    #binarize pred using a threshold of 0.5
    #if binary_output:
        #target_mean_value = (target.max() - target.min()) / 2
        #pred = torch.where(pred > target_mean_value, target.max(), target.min())
        
    with torch.no_grad():
        # Map inputs back to image space
        if tar_sample_mean is not None:
            target = (target * tar_sample_std) + tar_sample_mean
            if pred_sample_mean is not None:
                # Prediction comes from sample different from the target (e.g. in the case of denoising)
                pred = (pred * pred_sample_std) + pred_sample_mean
            else:
                pred = (pred * tar_sample_std) + tar_sample_mean
        target = inverse_transform(target)
        pred = inverse_transform(pred)
        # Compute the PSNR
        dims = (1, 2, 3) if len(target.shape) == 4 else 1
        mean_sq_err = ((target - pred)**2).mean(dims)
        mean_sq_err = mean_sq_err + (mean_sq_err == 0).float() * dummy # if 0, fill with dummy -> PSNR of 40 by default
        output = 10*torch.log10(R**2/mean_sq_err)
        if reduction == 'mean':
            return output.mean()
        elif reduction == 'none':
            return output

In [33]:
if not do_skip:
    # Training loop
    print("--- Step 8: Embedding extraction")
    #create an empty np array to store the embeddings
    if take == 'all':
        take_num = len(data_emb)
    embeddins_tot = np.empty((take_num, args.code_dim))
    psnr_tot = np.empty((take_num))
    for i, mask in enumerate(tqdm(data_emb)):
        y = mask.to(device)
        # Encoder predictions
        Zs_enc = encoder(y) # Encoder input: y
        # Append the embeddings to the array (along the first axis)
        embeddins_tot[i] = Zs_enc.cpu().detach().numpy()

        y_hat = decoder(Zs_enc)
        #y_hat = binarize_prediction(y, y_hat) if args.binary_output else y_hat
        psnr = PSNR(y, y_hat, None, None, binary_output=True)
        psnr_tot[i] = psnr.item()

    print(f"embeddins_tot shape: {embeddins_tot.shape}")
    print(f"psnr_tot shape: {psnr_tot.shape}")
    #save embeddings as a numpy array to
    np.save(os.path.join(outdir, f"embeddings_{dataset_name}_{take}.npy"), embeddins_tot)
    print(f"Embeddings saved in {outdir}/embeddings_{dataset_name}_{take}.npy")

In [34]:
if not do_skip:
    # Print maximum
    print("Maximum:", np.max(psnr_tot))

    # Print minimum
    print("Minimum:", np.min(psnr_tot))

    # Print mean
    print("Mean:", np.mean(psnr_tot))

    # Print standard deviation
    print("Standard Deviation:", np.std(psnr_tot))

    import matplotlib.pyplot as plt

    plt.figure()
    # Plot histogram
    plt.hist(psnr_tot, bins=10, alpha=0.5, color='blue', edgecolor='black')

    # Add mean and standard deviation lines
    plt.axvline(x=np.mean(psnr_tot), color='red', linestyle='dashed', linewidth=1)
    plt.axvline(x=np.mean(psnr_tot) + np.std(psnr_tot), color='green', linestyle='dashed', linewidth=1)
    plt.axvline(x=np.mean(psnr_tot) - np.std(psnr_tot), color='green', linestyle='dashed', linewidth=1)

    # Add labels and title
    plt.xlabel('PSNR')
    plt.ylabel('Frequency')
    plt.title('Histogram of PSNR')

    # Show the plot
    plt.show()


### load

In [None]:
embeddings = np.load(os.path.join(outdir, f"embeddings_{dataset_name}_{take}.npy"))
print(f"Embeddings loaded")
print(f"embeddings shape: {embeddings.shape}")

### T-SNE

### skip

In [36]:
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

if not do_skip:

    # Perform t-SNE clustering
    #Peplexity 500
    #early_exaggeration=40
    tsne = TSNE(n_components=2, random_state=42)
    embeddings_tsne = tsne.fit_transform(embeddings)

    plt.figure()
    plt.scatter(embeddings_tsne[:, 0], embeddings_tsne[:, 1], s=1)
    plt.title('Clustering Results (t-SNE embedding)')

    # save the embeddings
    np.save(os.path.join(outdir,f'embeddings_tsne_{dataset_name}_{take}.npy'), embeddings_tsne)
    print(f"TSNE embeddings saved in {os.path.join(outdir,f'embeddings_tsne_{take}.npy')}")


### load

In [None]:
embeddings_tsne = np.load(os.path.join(outdir,f'embeddings_tsne_{dataset_name}_{take}.npy'))
print(f"TSNE embeddings loaded")
print(f"TSNE embeddings shape: {embeddings_tsne.shape}")

In [None]:
if do_plot:
    %matplotlib widget
    interactive_plot(embeddings_tsne[:, 0], embeddings_tsne[:, 1], colors=None, action='click', img_list=vess_list_np, emb_list=embeddings, true_img_list=img_list_np)

In [None]:
if do_plot:
    plot_n_patches_overlap(vess_list_np, img_list_np, indexes=[1625,24,25,26,23,46,1364, 1862], m=4)


## PLOT STATS (OPTIONAL)

In [None]:
if do_stats_plot:
    # plot histogram of embeddings fraction of zeros
    plt.figure(figsize=(10, 5))
    plt.hist(np.sum(embeddings == 0, axis=1) / embeddings.shape[1], bins=25)
    plt.xlabel('Fraction of zeros')
    plt.ylabel('Count')
    plt.show()

    max_frac_0s = np.max(np.sum(embeddings == 0, axis=1) / embeddings.shape[1])
    min_frac_0s = np.min(np.sum(embeddings == 0, axis=1) / embeddings.shape[1])
    mean_frac_0s = np.mean(np.sum(embeddings == 0, axis=1) / embeddings.shape[1])
    std_frac_0s = np.std(np.sum(embeddings == 0, axis=1) / embeddings.shape[1])

    print(f"Max frac of zeros: {max_frac_0s} (index: {np.argmax(np.sum(embeddings == 0, axis=1) / embeddings.shape[1])}) (Number of non-zero elements: {np.sum(embeddings[np.argmax(np.sum(embeddings == 0, axis=1) / embeddings.shape[1])] != 0)}/{embeddings.shape[1]})")
    print(f"Min frac of zeros: {min_frac_0s} (index: {np.argmin(np.sum(embeddings == 0, axis=1) / embeddings.shape[1])}) (Number of non-zero elements: {np.sum(embeddings[np.argmin(np.sum(embeddings == 0, axis=1) / embeddings.shape[1])] != 0)}/{embeddings.shape[1]})")
    print(f"Mean frac of zeros: {mean_frac_0s}")
    print(f"Standard deviation of frac of zeros: {std_frac_0s}")
    

In [None]:
if do_stats_plot:
    print(f"Shape: {embeddings.shape}")

    #Max over the first dim
    maxs = np.max(embeddings, axis=0)
    mins = np.min(embeddings, axis=0)
    print(f"Min of min {np.min(mins)} - Max of mins {np.max(mins)}")
    print(f"Min of max {np.min(maxs)} - Max of maxs {np.max(maxs)}")

    x = np.arange(0, maxs.size, 1)
    y = maxs
    # Plot the vector
    fig, ax = plt.subplots(1,2, figsize=(10,5))
    ax[0].plot(x, y)
    ax[0].set_xlabel('Dimension')
    ax[0].set_ylabel('Max value')
    ax[0].set_title('Max value per dimension')

    #plot histogram of max values
    ax[1].hist(maxs, bins=50)
    ax[1].set_xlabel('Max value')
    ax[1].set_ylabel('Count')
    ax[1].set_title('Histogram of max values')
    plt.show()


In [None]:
if do_stats_plot:
    # plot histogram of embeddings over the second dimension
    embedding1 = embeddings[:,1]
    print(f"Frac of zeros (embedding 1): {np.sum(embedding1 == 0) / embedding1.size}")

    plt.figure(figsize=(10, 5))
    plt.hist(embedding1, bins=50)
    plt.xlabel('Atom Value')
    plt.ylabel('Count')
    plt.show()

    #same but remove the zeros
    embedding1_wo_0s = embedding1[embedding1 != 0]
    plt.figure(figsize=(10, 5))
    plt.hist(embedding1_wo_0s, bins=10)
    plt.xlabel('Atom Value')
    plt.ylabel('Count')
    plt.show()

    print(f"Min of embedding 1 w/o 0s: {np.min(embedding1_wo_0s)} - Max of embedding 1 w/o 0s: {np.max(embedding1_wo_0s)}")

In [None]:
if do_stats_plot:
    # plot the index of the elements that are not 0
    non_zero_num = []
    for i in range(embeddings.shape[0]):
        non_zero_idx = np.where(embeddings[i] != 0)
        non_zero_num.append(len(non_zero_idx[0]))
        print(f"Non zero idx ({i} (Len: {len(non_zero_idx[0]):3d}): {non_zero_idx[0]}")

    print(f"Max non zero elements: {np.max(non_zero_num)} (index: {np.argmax(non_zero_num)})")
    print(f"Min non zero elements: {np.min(non_zero_num)} (index: {np.argmin(non_zero_num)})")

    #plot mean non zero elements
    mean_non_zero = np.mean(non_zero_num)
    print(f"Mean non zero elements: {mean_non_zero}")

    median_non_zero = np.median(non_zero_num)
    print(f"Median non zero elements: {median_non_zero}")
    
    std_non_zero = np.std(non_zero_num)
    print(f"Std non zero elements: {std_non_zero}")



In [None]:
if do_stats_plot:
    plt.figure(figsize=(10, 5))
    plt.hist(non_zero_num, bins=25)
    plt.xlabel('Number of non zero elements')
    plt.ylabel('Count')
    plt.axvline(mean_non_zero, color='r', linestyle='dashed', linewidth=1)
    plt.axvline(median_non_zero, color='orange', linestyle='dashed', linewidth=1)
    plt.axvline(mean_non_zero + std_non_zero, color='g', linestyle='dashed', linewidth=1)
    plt.axvline(mean_non_zero - std_non_zero, color='g', linestyle='dashed', linewidth=1)
    plt.legend(['Mean', 'Median'])
    plt.show()

In [None]:
if do_stats_plot:
    binary_embeddings = np.where(embeddings != 0, 1, 0)

    #Sum over the column
    def plot_activation_map(embedding_np, mode = 'count'):
        sum_col = np.sum(embedding_np, axis=0)

        #for i, s in enumerate(sum_col):
        #    print(f"Atom {i} is activated {s} times ({s / embedding_np.shape[0] * 100:.2f}% of the times)")
        if mode == 'count':    
            print(f"\nMax activated atom: {np.max(sum_col)} times (index: {np.argmax(sum_col)})")
            print(f"Min activated atom: {np.min(sum_col)} times (index: {np.argmin(sum_col)})")
        elif mode == 'sum':
            print(f"\nMax activated atom: value {np.max(sum_col)} (index: {np.argmax(sum_col)})")
            print(f"Min activated atom: value {np.min(sum_col)} (index: {np.argmin(sum_col)})")
        
        # Plot a map of the activations (the more red the more activated)
        plt.figure(figsize=(10, 5))
        plt.imshow(reshape_to_square(sum_col), cmap='YlGn', interpolation='nearest')
        plt.colorbar()
        plt.title(f'Activation map of the atoms ({mode.upper()})')
        plt.show()
        
        
    plot_activation_map(binary_embeddings) # COUNT
    plot_activation_map(embeddings, 'sum') # COUNT + WEIGTHS



## K-means Clustering

In [46]:
n_clusters = 50

### K-MEANS with TSNE (Skip)

In [47]:
from sklearn.cluster import KMeans 

if not do_skip:
    # Create a KMeans object with the desired number of clusters
    kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init='auto')

    # Fit the KMeans model to the embeddings
    kmeans.fit(embeddings_tsne)

    # Get the cluster labels for each data point
    cluster_labels_2d = kmeans.labels_
    print(f"Classes: {set(cluster_labels_2d)}")
    np.save(os.path.join(outdir,f'cluster_labels_2d_{dataset_name}_{take}.npy'), cluster_labels_2d)


    # Plot histogram of cluster labels
    plt.figure(figsize=(10, 5))
    plt.hist(cluster_labels_2d, bins=n_clusters)
    plt.xlabel('Cluster Label')
    plt.ylabel('Count')
    plt.title('Histogram of Cluster Labels')
    plt.show()


### load

In [None]:
cluster_labels_2d = np.load(os.path.join(outdir,f'cluster_labels_2d_{dataset_name}_{take}.npy'))
print(f"Loading cluster labels from {os.path.join(outdir,f'cluster_labels_2d_{dataset_name}_{take}.npy')}")
print(f"Cluster labels loaded")
print(f"Cluster labels shape: {cluster_labels_2d.shape}")

In [None]:
if do_plot:
    interactive_plot(embeddings_tsne[:, 0], embeddings_tsne[:, 1], colors=cluster_labels_2d, action='click', img_list=vess_list_np, emb_list=embeddings, true_img_list=img_list_np)

### K-MEANS with embedding code (Skip)

In [50]:
from sklearn.cluster import KMeans

if not do_skip:
    # Create a KMeans object with the desired number of clusters
    kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init='auto')

    # Fit the KMeans model to the embeddings
    kmeans.fit(embeddings)

    # Get the cluster labels for each data point
    cluster_labels = kmeans.labels_
    print(f"Classes: {set(cluster_labels)}")
    np.save(os.path.join(outdir,f'cluster_labels_{dataset_name}_{take}.npy'), cluster_labels)

### load

In [None]:
cluster_labels = np.load(os.path.join(outdir,f'cluster_labels_{dataset_name}_{take}.npy'))
print(f"Cluster labels loaded")
print(f"Cluster labels shape: {cluster_labels.shape}")

In [None]:
if do_plot:
    interactive_plot(embeddings_tsne[:, 0], embeddings_tsne[:, 1], colors=cluster_labels, action='click', img_list=vess_list_np, emb_list=embeddings, true_img_list=img_list_np)

In [None]:
if do_plot:
    plot_2_clusters(embeddings_tsne, cluster_labels, cluster_labels_2d)

### Optional to save clusters imgs

In [None]:
import numpy as np
from tqdm import tqdm

if do_optional_plots:
    pick_colors = 'ND' # '2D' or 'ND
    add_title = f"{pick_colors}_{take}"
    # select 10 values at random between 0 and n_cluster
    selected_classes = np.random.choice(n_clusters, 1, replace=False)
    #selected_classes = range(10)
    #selected_classes = range(n_clusters)

    colors = cluster_labels_2d if pick_colors == '2D' else cluster_labels
    do = "filter" # "filter" or "plot"
    n_columns = 10

    print(f"Interactive plot filtered over 1 class ({pick_colors} Embedding)")
    print(f"Saving image of clusters {selected_classes}")
    for selected_class in tqdm(selected_classes):
        #print(f"Selected class: {selected_class}")
        plt.close('all')
        f_img_list, f_emb_list, f_true_img_list, f_indices = interactive_plot_filtered(embeddings_tsne[:, 0], embeddings_tsne[:, 1], colors=colors, action='click', selected_class=selected_class, img_list=vess_list_np, emb_list=embeddings, true_img_list=img_list_np, do=do)
        plot_n_patches_overlap(vess_list_np, img_list_np, indexes=f_indices, selected_class=selected_class, add_title=add_title, m=n_columns)

In [None]:
if do_optional_plots:
    interactive_plot_filtered(embeddings_tsne[:, 0], embeddings_tsne[:, 1], colors=colors, action='click', selected_class=12, img_list=vess_list_np, emb_list=embeddings, true_img_list=img_list_np, do='plot')

In [None]:
if do_optional_plots:    
    interactive_plot(embeddings_tsne[:, 0], embeddings_tsne[:, 1], colors=cluster_labels, action='click', img_list=vess_list_np, emb_list=embeddings, true_img_list=img_list_np)

## SIAMESE NETWORK

In [54]:
use_resnet = False

In [None]:
print(f"Img shape: {img_list_np.shape}")
print(f"Data shape: {vess_list_np.shape}\n")

print(f"Embeddings shape: {embeddings.shape}")
print(f"TSNE embeddings shape: {embeddings_tsne.shape}\n")

print(f"Cluster labels shape: {cluster_labels.shape}")
print(f"Cluster labels 2D shape: {cluster_labels_2d.shape}")

### LOAD EMPTY

In [None]:
print(f"Empty img list shape: {empty_img_list_np.shape}")
print(f"Empty data list shape: {empty_vess_list_np.shape}")

max_cluster_value = np.max(cluster_labels)
max_cluster_value_2d = np.max(cluster_labels_2d)

print(f"Max cluster value: {max_cluster_value}")
print(f"Max cluster value 2D: {max_cluster_value_2d}")

# Create an array that has as shape the shape of empty_img_list_np[0] and fill it with the max value of the cluster labels +1
empty_cluster_labels = np.full((empty_img_list_np.shape[0],), max_cluster_value+1)
empty_cluster_labels_2d = np.full((empty_img_list_np.shape[0],), max_cluster_value_2d+1)

In [None]:
img_list_np_tot = np.concatenate((img_list_np, empty_img_list_np), axis=0)
vess_list_np_tot = np.concatenate((vess_list_np, empty_vess_list_np), axis=0)

print(f"Img shape: {img_list_np_tot.shape}")
print(f"Data shape: {vess_list_np_tot.shape}\n")

cluster_labels_tot = np.concatenate((cluster_labels, empty_cluster_labels), axis=0)
cluster_labels_2d_tot = np.concatenate((cluster_labels_2d, empty_cluster_labels_2d), axis=0)

print(f"Cluster labels shape: {cluster_labels_tot.shape}")
print(f"Cluster labels 2D shape: {cluster_labels_2d_tot.shape}")

print(f"Unique cluster labels {len(set(cluster_labels_tot))} ==> {set(cluster_labels_tot)}")

### Overwrite the variables !!!

In [None]:
img_list_np = img_list_np_tot
vess_list_np = vess_list_np_tot


print(f"Empty img list shape: {img_list_np.shape}")
print(f"Empty data list shape: {vess_list_np.shape}")

cluster_labels = cluster_labels_tot
cluster_labels_2d = cluster_labels_2d_tot

print(f"Max cluster value: {np.max(cluster_labels)}")
print(f"Max cluster value 2D: {np.max(cluster_labels_2d)}")

### CONTINUE SIAMESE

In [59]:
%matplotlib inline

def imshow(img,text=None,should_save=False):
    npimg = img.numpy()
    plt.figure()
    plt.axis("off")
    if text:
        plt.text(75, 8, text, style='italic',fontweight='bold',
            bbox={'facecolor':'white', 'alpha':0.8, 'pad':10})
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()    

def show_plot(iteration,loss, title=''):
    plt.figure()
    plt.plot(iteration,loss)
    plt.title(f"{title} Loss")
    plt.show()

In [60]:
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
import torch
import random
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
import torch
import random

class SiameseNetworkDataset(Dataset):
    
    def __init__(self, images_array, labels_array, transform=None, test_mode=False, mode='double', metadata=None):
        self.images_array = images_array
        self.labels_array = labels_array
        self.transform = transform
        
        if metadata is None:
            self.compute_metadata()
        else:
            self.mean, self.std = metadata
            self.compute_test_metadata()
            
        self.test_mode = test_mode
        self.mode=mode
         
        
    def __getitem__(self,index):
        # Select random index from the dataset
        if self.mode == 'single':
            idx_0 = int(index)
            img0, label0 = self.images_array[idx_0], self.labels_array[idx_0]
            # Convert numpy arrays to PIL Images
            img0 = Image.fromarray(img0)
            # Apply transformations if provided
            if self.transform is not None:
                img0 = self.transform(img0)
            # Normalize images
            img0 = (img0 - self.mean) / self.std    
            # Standardize to 0-1
            img0 = (img0 - img0.min()) / (img0.max() - img0.min())
                
            return img0
        
        else:
            idx_0 = int(index)
            img0, label0 = self.images_array[idx_0], self.labels_array[idx_0]
            
            # Determine whether to select a sample from the same class or different class
            should_get_same_class = random.randint(0, 1) 
            if should_get_same_class:
                while True:
                    idx_1 = random.randint(0, len(self.images_array) - 1)
                    img1, label1 = self.images_array[idx_1], self.labels_array[idx_1]
                    if label0 == label1:
                        break
            else:
                while True:
                    idx_1 = random.randint(0, len(self.images_array) - 1)
                    img1, label1 = self.images_array[idx_1], self.labels_array[idx_1]
                    if label0 != label1:
                        break
    
            if self.test_mode:
                print(f"Testing {idx_0} (class {label0}) against {idx_1} (class {label1})")
                
            # Convert numpy arrays to PIL Images
            img0 = Image.fromarray(img0)
            img1 = Image.fromarray(img1)
    
            # Apply transformations if provided
            if self.transform is not None:
                img0 = self.transform(img0)
                img1 = self.transform(img1)
            
            # Normalize images
            img0 = (img0 - self.mean) / self.std
            img1 = (img1 - self.mean) / self.std
            
            # Standardize to 0-1
            img0 = (img0 - img0.min()) / (img0.max() - img0.min())
            img1 = (img1 - img1.min()) / (img1.max() - img1.min())
                
            # Flag: 1 if same class, 0 if different class
            flag = torch.from_numpy(np.array([int(label1 != label0)], dtype=np.float32))
            
            return img0, img1, flag
        
    def compute_metadata(self):
        flattened_data = self.images_array.reshape(self.images_array.shape[0], -1)
        means = np.mean(flattened_data, axis=0)
        stds = np.std(flattened_data, axis=0)
        
        # compute the single mean and std
        mean = np.mean(means)
        std = np.mean(stds)
        
        #convert to pytorch
        self.mean = torch.tensor(mean)
        self.std = torch.tensor(std)
        
        print(f"Mean: {self.mean} - Std: {self.std}")
    
    def compute_test_metadata(self):
        flattened_data = self.images_array.reshape(self.images_array.shape[0], -1)
        means = np.mean(flattened_data, axis=0)
        stds = np.std(flattened_data, axis=0)
        
        # compute the single mean and std
        mean = np.mean(means)
        std = np.mean(stds)
        
        #convert to pytorch
        self.test_metadata = (torch.tensor(mean), torch.tensor(std))
        
        print(f"Test Mean: {self.test_metadata[0]} - Test Std: {self.test_metadata[1]}")

    def get_metadata(self):
        return self.mean, self.std
    
    def get_test_metadata(self):
        return self.test_metadata
    
    def __len__(self):
        return len(self.images_array)

In [None]:
print(f"Img shape: {img_list_np.shape}")
print(f"Data shape: {vess_list_np.shape}\n")

print(f"Embeddings shape: {embeddings.shape}")
print(f"TSNE embeddings shape: {embeddings_tsne.shape}\n")

print(f"Cluster labels shape: {cluster_labels.shape}")
print(f"Cluster labels 2D shape: {cluster_labels_2d.shape}")

### Skip

In [None]:
if not do_skip:
    train_arrays, val_arrays = train_test_split_arrays(img_list_np, vess_list_np, cluster_labels, test_size=0.15, random_state=42)

    X_train, X_train_mask, y_train = train_arrays
    X_val, X_val_mask, y_val = val_arrays

    # TODO: this test_arrays should be another dataset
    #train_arrays, test_arrays = train_test_split_arrays(X_train, X_train_mask, y_train, test_size=0.15, random_state=42)

    #X_train, X_train_mask, y_train = train_arrays
    #X_test, X_test_mask, y_test = test_arrays



    # Print the shapes of the train and test sets
    print("Train set shape:", X_train.shape)
    print("Val set shape:", X_val.shape)
    #print("Test set shape:", X_test.shape)

    print(y_train.shape,y_val.shape)

    np.save(os.path.join(save_embeddings_path,f'X_train_{dataset_name}_{take}.npy'), X_train)
    np.save(os.path.join(save_embeddings_path,f'X_val_{dataset_name}_{take}.npy'), X_val)

    np.save(os.path.join(save_embeddings_path,f'y_train_{dataset_name}_{take}.npy'), y_train)
    np.save(os.path.join(save_embeddings_path,f'y_val_{dataset_name}_{take}.npy'), y_val)

    np.save(os.path.join(save_embeddings_path,f'X_train_mask_{dataset_name}_{take}.npy'), X_train_mask)
    np.save(os.path.join(save_embeddings_path,f'X_val_mask_{dataset_name}_{take}.npy'), X_val_mask)
else:
    print("Skipping train-test split")

### Load CAS Dataset !!! (After calling patches_dataloader)

In [63]:
# Close all the plots
plt.close('all')
dataset_test_name = 'CAS'

if not do_skip:
    testset_img_dir = '/data/falcetta/brain_data/CASJ/preprocessed/patches_preprocessed_with_empty/5_seg_true_patch_extraction'
    img_list_np_test , data_list_np_test = extract_img_vessels_np2(testset_img_dir)
    X_test, X_test_mask = img_list_np_test , data_list_np_test
    y_test = np.zeros(X_test.shape[0])
    print("Test set shape:", X_test.shape)
    print("Test set mask shape:", X_test_mask.shape)
    print("Test set labels shape:", y_test.shape)

    plot_10_patches(X_test_mask, X_test)

In [None]:
if not do_skip:
    np.save(os.path.join(save_embeddings_path,f'X_test_{dataset_test_name}_{take}.npy'), X_test)
    np.save(os.path.join(save_embeddings_path,f'y_test_{dataset_test_name}_{take}.npy'), y_test)
    np.save(os.path.join(save_embeddings_path,f'X_test_mask_{dataset_test_name}_{take}.npy'), X_test_mask)
    print(f"Test set saved in {save_embeddings_path} ({take})")

In [None]:
if not do_skip:
    empty_img_list_np_test = extract_empty_img_vessels(testset_img_dir.replace('true','empty'))
    empty_data_list_np_test = np.zeros_like(empty_img_list_np_test)
    print(f'Empty img list shape: {empty_img_list_np_test.shape}, Data list shape: {empty_data_list_np_test.shape}')

    X_test_empty, X_test_mask_empty = empty_img_list_np_test , empty_data_list_np_test
    y_test_empty = np.zeros(X_test_empty.shape[0])

    print("Test empty set shape:", X_test_empty.shape)
    print("Test empty set mask shape:", X_test_mask_empty.shape)
    print("Test empty set labels shape:", y_test_empty.shape)

    plot_10_patches(X_test_mask_empty, X_test_empty)

In [None]:
if not do_skip:
    np.save(os.path.join(save_embeddings_path,f'X_test_empty_{dataset_test_name}_{take}.npy'), X_test_empty)
    np.save(os.path.join(save_embeddings_path,f'y_test_empty_{dataset_test_name}_{take}.npy'), y_test_empty)
    np.save(os.path.join(save_embeddings_path,f'X_test_empty_mask_{dataset_test_name}_{take}.npy'), X_test_mask_empty)
    print(f"Empty set saved in {save_embeddings_path} ({take})")

## Load DATA

In [64]:
import os
import numpy as np

In [None]:
X_train = np.load(os.path.join(save_embeddings_path,f'X_train_{dataset_name}_{take}.npy'))
X_train_mask = np.load(os.path.join(save_embeddings_path,f'X_train_mask_{dataset_name}_{take}.npy'))
y_train = np.load(os.path.join(save_embeddings_path,f'y_train_{dataset_name}_{take}.npy'))

X_val = np.load(os.path.join(save_embeddings_path,f'X_val_{dataset_name}_{take}.npy'))
X_val_mask = np.load(os.path.join(save_embeddings_path,f'X_val_mask_{dataset_name}_{take}.npy'))
y_val = np.load(os.path.join(save_embeddings_path,f'y_val_{dataset_name}_{take}.npy'))

X_test = np.load(os.path.join(save_embeddings_path,f'X_test_{dataset_test_name}_{take}.npy'))
X_test_mask = np.load(os.path.join(save_embeddings_path,f'X_test_mask_{dataset_test_name}_{take}.npy'))
y_test = np.load(os.path.join(save_embeddings_path,f'y_test_{dataset_test_name}_{take}.npy'))

X_test_empty = np.load(os.path.join(save_embeddings_path,f'X_test_empty_{dataset_test_name}_{take}.npy'))
X_test_mask_empty = np.load(os.path.join(save_embeddings_path,f'X_test_empty_mask_{dataset_test_name}_{take}.npy'))
y_test_empty = np.load(os.path.join(save_embeddings_path,f'y_test_empty_{dataset_test_name}_{take}.npy'))


print(f"Train Data loaded (IXI)")
print(f"X_train shape: {X_train.shape}")
print(f"X_train_mask shape: {X_train_mask.shape}")
print(f"y_train shape: {y_train.shape}\n")

print(f"Val Data loaded (IXI)")
print(f"X_val shape: {X_val.shape}")
print(f"X_val_mask shape: {X_val_mask.shape}")
print(f"y_val shape: {y_val.shape}\n")

print(f"Test Data loaded (CAS)")
print(f"X_test shape: {X_test.shape}")
print(f"X_test_mask shape: {X_test_mask.shape}")
print(f"y_test shape: {y_test.shape}\n")

print(f"X_test_empty shape: {X_test_empty.shape}")
print(f"X_test_empty_mask shape: {X_test_mask_empty.shape}")
print(f"y_test_empty shape: {y_test_empty.shape}")



In [None]:
X_test_tot = np.concatenate((X_test, X_test_empty), axis=0)
X_test_mask_tot = np.concatenate((X_test_mask, X_test_mask_empty), axis=0)
y_test_tot = np.concatenate((y_test, y_test_empty), axis=0)

print(f"X_test_tot shape: {X_test_tot.shape}")
print(f"X_test_mask_tot shape: {X_test_mask_tot.shape}")
print(f"y_test_tot shape: {y_test_tot.shape}\n")

### Overwrite

In [None]:
X_test = X_test_tot
X_test_mask = X_test_mask_tot
y_test = y_test_tot

print(f"X_test shape: {X_test.shape}")
print(f"X_test_mask shape: {X_test_mask.shape}")
print(f"y_test shape: {y_test.shape}\n")

In [None]:
plot_10_patches(X_test_mask, X_test)

### skip

In [None]:

import torchvision.transforms as transforms

if not do_skip:
    siamese_dataset = SiameseNetworkDataset(images_array=X_train,
                                            labels_array=y_train,
                                            transform=transforms.ToTensor(),)

    metadata = siamese_dataset.get_metadata()
    np.save(os.path.join(save_embeddings_path,f'metadata_{dataset_name}_{take}.npy'), metadata)

    siamese_val_dataset = SiameseNetworkDataset(images_array=X_val,
                                                labels_array=y_val,
                                                transform=transforms.ToTensor(),
                                                metadata=metadata)

### Load metadata

In [None]:
metadata = np.load(os.path.join(save_embeddings_path,f'metadata_{dataset_name}_{take}.npy'))
print(metadata)

### skip

In [None]:
from torch.utils.data import DataLoader
import torchvision

if not do_skip:
    vis_dataloader = DataLoader(siamese_dataset,
                            shuffle=True,
                            num_workers=0,
                            batch_size=1)

    plot_same = True
    plot_diff = True

    print(f"Plotting examples of the same (0) and different (1) clusters")

    plt.close('all')
    plt.figure()
    for example_batch in vis_dataloader:
        if int(example_batch[2]) == 0 and plot_same:
            concatenated = torch.cat((example_batch[0],example_batch[1]),0) # 8,1,100,100
            print(f"Flag: {int(example_batch[2])}")
            imshow(torchvision.utils.make_grid(concatenated))
            plot_same = False
        if int(example_batch[2]) == 1 and plot_diff:
            concatenated = torch.cat((example_batch[0],example_batch[1]),0) # 8,1,100,100
            print(f"Flag: {int(example_batch[2])}")
            imshow(torchvision.utils.make_grid(concatenated))
            plot_diff = False
        if not plot_same and not plot_diff:
            break   

### Load model

In [171]:
### ORIGINAL

import torch.nn as nn

class SiameseNetwork(nn.Module):
    def __init__(self, embedding_size=128, pretrained = False):
        super(SiameseNetwork, self).__init__()
        self.cnn1 = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(1, 4, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(4),
            nn.Dropout2d(p=0.3),  # Add dropout

            nn.ReflectionPad2d(1),
            nn.Conv2d(4, 8, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(8),
            nn.Dropout2d(p=0.3),  # Add dropout

            nn.ReflectionPad2d(1),
            nn.Conv2d(8, 8, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(8),
            nn.Dropout2d(p=0.3)  # Add dropout
        )

        self.fc1 = nn.Sequential(
            nn.Linear(8*32*32, 500),
            nn.ReLU(inplace=True),
            nn.BatchNorm1d(500),  # Batch normalization for fully connected layer

            nn.Linear(500, 500),
            nn.ReLU(inplace=True),
            nn.BatchNorm1d(500),  # Batch normalization for fully connected layer

            nn.Linear(500, embedding_size)
        )

    def forward_once(self, x):
        output = self.cnn1(x)
        output = output.view(output.size()[0], -1)
        output = self.fc1(output)
        return output

    def forward(self, input1, input2, mode='double'):
        output1 = self.forward_once(input1)
        if mode == 'single':
            return output1
        output2 = self.forward_once(input2)
        
        return output1, output2


In [172]:
### WITH RESNET
if use_resnet:
    print("USE RESNET")
    import torch.nn as nn

    import torch.nn as nn
    import torchvision.models as models

    class SiameseNetwork(nn.Module):
        def __init__(self, embedding_size=128, pretrained = False):
            super(SiameseNetwork, self).__init__()
            
            # Load pre-trained ResNet-50
            resnet = models.resnet50(pretrained=pretrained)
            
            # Modify the first convolutional layer to accept 1 input channel
            self.resnet_features = nn.Sequential(
                nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False),
                *list(resnet.children())[1:-2]
            )
            # Remove the last fully connected layer of ResNet-50
            #self.resnet_features = nn.Sequential(*list(resnet.children())[:-1])
            # Define fully connected layers for embedding
            self.fc1 = nn.Sequential(
                nn.Linear(resnet.fc.in_features, 500),
                nn.ReLU(inplace=True),
                nn.BatchNorm1d(500),
                nn.Linear(500, 500),
                nn.ReLU(inplace=True),
                nn.BatchNorm1d(500),
                nn.Linear(500, embedding_size)
            )

        def forward_once(self, x):
            # Extract features using ResNet
            features = self.resnet_features(x)
            features = features.view(features.size(0), -1)
            output = self.fc1(features)
            return output

        def forward(self, input1, input2, mode='double'):
            output1 = self.forward_once(input1)
            if mode == 'single':
                return output1
            output2 = self.forward_once(input2)
            return output1, output2

In [None]:
class ContrastiveLoss(torch.nn.Module):
    """
    Contrastive loss function.
    Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
    """

    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin
        self.loss_accumulator = 0.0
        self.num_samples = 0

    def forward(self, output1, output2, label):
        euclidean_distance = F.pairwise_distance(output1, output2, keepdim=True)
        loss_contrastive = torch.mean((1 - label) * torch.pow(euclidean_distance, 2) + (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
        
        self.loss_accumulator += loss_contrastive.item()
        self.num_samples += 1
        return loss_contrastive

    def get_accumulated_loss(self):
        return self.loss_accumulator

    def get_mean_loss(self):
        self.mean_loss = self.loss_accumulator / self.num_samples
        self.reset()
        return self.mean_loss
        
    def reset(self):
        self.loss_accumulator = 0.0
        self.num_samples = 0

"""
class ContrastiveLoss(torch.nn.Module):

    def __init__(self, margin=.5, **kwargs):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin
        # self.metric = metric
        self.distance = torch.nn.PairwiseDistance(p=2)

    def forward(self, out0, out1, label):
        gt = label.float()
        D = self.distance(out0, out1).float().squeeze()
        loss = gt * 0.5 * torch.pow(D, 2) + (1 - gt) * 0.5 * torch.pow(torch.clamp(self.margin - D, min=0.0), 2)
        return loss
"""

### skip

In [174]:
if not do_skip:
    train_dataloader = DataLoader(siamese_dataset,
                            shuffle=True,
                            num_workers=8,
                            batch_size=128)

    val_dataloader = DataLoader(siamese_val_dataset,
                            shuffle=False,
                            num_workers=8,
                            batch_size=256)

### Training (skip)

In [83]:
if not do_skip:
    net = SiameseNetwork().cuda()
    criterion = ContrastiveLoss()

    best_loss = 9999
    best_loss_epoch = 0
    cumul_epochs = 0

    mean_loss_contrastive_list = []
    best_loss_contrastive_list = []
    validation_loss_list = []
    continue_training = False

In [84]:
from torch import optim
import torch.nn.functional as F

if not do_skip:
    n_epochs = 500
    early_stopping_tolerance = 50
    optimizer = optim.Adam(net.parameters(),lr = 0.0005)

    if continue_training:
        mean_loss_contrastive_list = mean_loss_contrastive_list.tolist()
        best_loss_contrastive_list = best_loss_contrastive_list.tolist()
        validation_loss_list = validation_loss_list.tolist()

    print(f"Starting round of training from epoch {cumul_epochs}")

    for epoch in tqdm(range(0, n_epochs), desc='Epochs'): 
        # Training loop
        net.train()
        for i, data in enumerate(train_dataloader, 0):
            img0, img1, label = data
            img0, img1, label = img0.cuda(), img1.cuda(), label.cuda()
            optimizer.zero_grad()
            output1, output2 = net(img0, img1)
            loss_contrastive = criterion(output1, output2, label)
            loss_contrastive.backward()
            optimizer.step()

        # Calculate mean loss for contrastive loss during training
        mean_loss_contrastive = criterion.get_mean_loss()
        mean_loss_contrastive_list.append(mean_loss_contrastive)
        
        # Validation loop
        net.eval()  # Set the model to evaluation mode
        val_loss = 0.0
        with torch.no_grad():
            for i, val_data in enumerate(val_dataloader, 0):
                val_img0, val_img1, val_label = val_data
                val_img0, val_img1, val_label = val_img0.cuda(), val_img1.cuda(), val_label.cuda()
                val_output1, val_output2 = net(val_img0, val_img1)
                val_loss += criterion(val_output1, val_output2, val_label).item()

        val_loss /= len(val_dataloader)
        validation_loss_list.append(val_loss)
        #print(f"Validation Loss: {val_loss:.2f}")

        # Check if current loss is the best so far
        if val_loss < best_loss:
            print(f"Epoch number {cumul_epochs} --- Best loss {val_loss:.2f}")
            best_loss = val_loss
            best_loss_epoch = cumul_epochs
            torch.save(net.state_dict(), os.path.join(save_embeddings_path, 'best_model.pt'))
        else:
            if cumul_epochs - best_loss_epoch > early_stopping_tolerance:
                print(f"Early stopping at epoch {cumul_epochs}")
                best_loss_contrastive_list.append(best_loss)
                cumul_epochs +=1
                break    
        
        best_loss_contrastive_list.append(best_loss)
        cumul_epochs +=1
        

In [85]:
if not do_skip:
    np.save(os.path.join(save_embeddings_path, f'mean_loss_contrastive_list_{take}.npy'), mean_loss_contrastive_list)
    np.save(os.path.join(save_embeddings_path, f'best_loss_contrastive_list_{take}.npy'), best_loss_contrastive_list)
    np.save(os.path.join(save_embeddings_path, f'validation_loss_list_{take}.npy'), validation_loss_list)
    print(f"Losses saved in {save_embeddings_path} ({take})")


### Plot losses

In [None]:
if do_plot:

    mean_loss_contrastive_list = np.load(os.path.join(save_embeddings_path, f'mean_loss_contrastive_list_{take}.npy'))
    best_loss_contrastive_list = np.load(os.path.join(save_embeddings_path, f'best_loss_contrastive_list_{take}.npy'))
    validation_loss_list = np.load(os.path.join(save_embeddings_path, f'validation_loss_list_{take}.npy'))

    continue_training=True

    %matplotlib inline
    plt.ioff()
    plt.close('all')
    show_plot(range(0,len(mean_loss_contrastive_list)),mean_loss_contrastive_list, title='Training Loss')
    show_plot(range(0,len(best_loss_contrastive_list)),validation_loss_list, title='Val Loss')
    show_plot(range(0,len(best_loss_contrastive_list)),best_loss_contrastive_list, title='Best Val Loss')


### LOAD MODEL and Test (skip)

In [176]:
import torch.nn.functional as F
do_skip = False
if not do_skip:
    # LOAD BEST MODEL
    net = SiameseNetwork().cuda()
    checkpoint_path = os.path.join(save_embeddings_path, 'best_model.pt')

    # Load state dictionary into model
    net.load_state_dict(torch.load(checkpoint_path))
    net.eval()

### Skip

In [None]:
batch_size = 128

siamese_tot_IXI_dataset = SiameseNetworkDataset(images_array=np.concatenate((X_train, X_val), axis=0),
                                                labels_array=np.concatenate((y_train, y_val), axis=0),
                                                transform=transforms.ToTensor(),
                                                test_mode=True,
                                                mode='single')

siamese_tot_IXI_dataloader = DataLoader(siamese_tot_IXI_dataset,
                                shuffle=False,
                                num_workers=8,
                                batch_size=batch_size)


# Pre-allocate memory for the embeddings array
num_samples = len(siamese_tot_IXI_dataset)
embedding_size = 128  # Assuming the size of the embeddings is 128, adjust as necessary
img_embeddings = np.empty((num_samples, embedding_size))
print(f"Number of samples: {num_samples} (batch size: {batch_size}) ==> Passages: {num_samples // batch_size + 1}")

start_idx = 0

for img in tqdm(siamese_tot_IXI_dataloader):
    img0 = img.cuda()
    output = net(img0, None, mode='single').cpu().detach().numpy()
    # Calculate the end index for the current batch
    end_idx = start_idx + output.shape[0]
    # Store the batch of embeddings in the pre-allocated array
    img_embeddings[start_idx:end_idx] = output    
    # Update the start index for the next batch
    start_idx = end_idx
        

np.save(os.path.join(save_embeddings_path, f'img_embeddings_IXI_{take}.npy'), img_embeddings)
print(f"Image Embeddings saved in {save_embeddings_path}/img_embeddings_IXI_{take}.npy")

In [None]:
IXI_img_embeddings = np.load(os.path.join(save_embeddings_path, f'img_embeddings_IXI_{take}.npy'))
print(f"IXI Image Embeddings loaded")

In [None]:
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

if not do_skip:
    
    # Perform t-SNE clustering
    #Peplexity 500
    #early_exaggeration=40
    tsne = TSNE(n_components=2, random_state=42)
    IXI_img_embeddings_tsne = tsne.fit_transform(IXI_img_embeddings)

    plt.figure()
    plt.scatter(IXI_img_embeddings_tsne[:, 0], IXI_img_embeddings_tsne[:, 1], s=1)
    plt.title('Image Clustering Results (t-SNE img_embedding)')

    # save the embeddings
    np.save(os.path.join(save_embeddings_path,f'img_embeddings_tsne_AIXI_{take}.npy'), IXI_img_embeddings_tsne)
    print(f"TSNE img_embeddings saved in {os.path.join(save_embeddings_path,f'img_embeddings_tsne_AIXI_{take}.npy')}")

In [None]:
IXI_img_embeddings_tsne = np.load(os.path.join(save_embeddings_path,f'img_embeddings_tsne_AIXI_{take}.npy'))
print(f"TSNE img_embeddings loaded")
print(f"TSNE img_embeddings shape: {IXI_img_embeddings_tsne.shape}")


In [None]:
%matplotlib widget
if do_plot:
    interactive_plot(IXI_img_embeddings_tsne[:, 0], IXI_img_embeddings_tsne[:, 1], colors=None, action='click', img_list=img_list_np, emb_list=IXI_img_embeddings, true_img_list=img_list_np)

In [None]:
from sklearn.cluster import KMeans

n_clusters_img = 50

if not do_skip:
    # Create a KMeans object with the desired number of clusters
    kmeans = KMeans(n_clusters=n_clusters_img, random_state=42, n_init='auto')

    # Fit the KMeans model to the img_embeddings
    kmeans.fit(IXI_img_embeddings)

    # Get the cluster labels for each data point
    IXI_img_cluster_labels = kmeans.labels_
    print(f"Classes: {set(IXI_img_cluster_labels)}")
    np.save(os.path.join(save_embeddings_path,f'img_cluster_labels_AIXI_{take}.npy'), IXI_img_cluster_labels)

In [None]:
IXI_img_cluster_labels = np.load(os.path.join(save_embeddings_path,f'img_cluster_labels_AIXI_{take}.npy'))
print(f"Cluster labels loaded")
print(f"Cluster labels shape: {IXI_img_cluster_labels.shape}")

In [None]:
if do_plot:
    interactive_plot(IXI_img_embeddings_tsne[:, 0], IXI_img_embeddings_tsne[:, 1], colors=IXI_img_cluster_labels, action='click', img_list=img_list_np, emb_list=IXI_img_embeddings, true_img_list=img_list_np)

### Test CAS

In [None]:
from torch.autograd import Variable 
from torchvision import transforms
do_skip = False
if not do_skip:
    siamese_testset = SiameseNetworkDataset(images_array=X_test,
                                            labels_array=y_test,
                                            transform=transforms.ToTensor(),
                                            test_mode=True,
                                            metadata=metadata)

    CAS_metadata = siamese_testset.get_test_metadata()
    print(f"Saving metadata for CAS dataset: {CAS_metadata}")
    np.save(os.path.join(save_embeddings_path,f'metadata_{dataset_test_name}_{take}.npy'), CAS_metadata)

In [None]:
CAS_metadata = np.load(os.path.join(save_embeddings_path,f'metadata_{dataset_test_name}_{take}.npy'))
print(f"Metadata loaded for CAS dataset: {CAS_metadata}")

In [None]:
if not do_skip:

    # Assuming batch_size is now set to something greater than 1
    batch_size = 32

    siamese_testset = SiameseNetworkDataset(
        images_array=X_test,
        labels_array=y_test,
        transform=transforms.ToTensor(),
        test_mode=True,
        metadata=metadata,
        mode='single'
    )

    test_dataloader = DataLoader(siamese_testset, shuffle=False, batch_size=batch_size)

    # Pre-allocate memory for the embeddings array
    num_samples = len(siamese_testset)
    embedding_size = 128  # Assuming the size of the embeddings is 128, adjust as necessary
    img_embeddings = np.empty((num_samples, embedding_size))
    print(f"Number of samples: {num_samples} (batch size: {batch_size}) ==> Passages: {num_samples // batch_size + 1}")

    start_idx = 0

    for img in tqdm(test_dataloader):
        img0 = img.cuda()
        output = net(img0, None, mode='single').cpu().detach().numpy()
        # Calculate the end index for the current batch
        end_idx = start_idx + output.shape[0]
        # Store the batch of embeddings in the pre-allocated array
        img_embeddings[start_idx:end_idx] = output    
        # Update the start index for the next batch
        start_idx = end_idx
            

    np.save(os.path.join(save_embeddings_path, f'img_embeddings_{dataset_test_name}_{take}.npy'), img_embeddings)
    print(f"Image Embeddings saved")

In [None]:
from tqdm import tqdm
import torch
from tqdm import tqdm
import numpy as np

if not do_skip:

    # Assuming batch_size is now set to something greater than 1
    batch_size = 32

    siamese_testset = SiameseNetworkDataset(
        images_array=X_test,
        labels_array=y_test,
        transform=transforms.ToTensor(),
        test_mode=True,
        metadata=metadata,
        mode='single'
    )

    test_dataloader = DataLoader(siamese_testset, shuffle=False, batch_size=batch_size)

    # Pre-allocate memory for the embeddings array
    num_samples = len(siamese_testset)
    embedding_size = 128  # Assuming the size of the embeddings is 128, adjust as necessary
    img_embeddings = np.empty((num_samples, embedding_size))
    print(f"Number of samples: {num_samples} (batch size: {batch_size}) ==> Passages: {num_samples // batch_size + 1}")

    start_idx = 0

    for img in tqdm(test_dataloader):
        img0 = img.cuda()
        output = net(img0, None, mode='single').cpu().detach().numpy()
        # Calculate the end index for the current batch
        end_idx = start_idx + output.shape[0]
        # Store the batch of embeddings in the pre-allocated array
        img_embeddings[start_idx:end_idx] = output    
        # Update the start index for the next batch
        start_idx = end_idx
            

    np.save(os.path.join(save_embeddings_path, f'img_embeddings_{dataset_test_name}_{take}.npy'), img_embeddings)
    print(f"Image Embeddings saved")

### Load img embeddings

In [None]:
img_embeddings = np.load(os.path.join(save_embeddings_path, f'img_embeddings_{dataset_test_name}_{take}.npy'))

print(f"img_embeddings loaded from {save_embeddings_path}")
print(f'img_embeddings shape: {img_embeddings.shape}')

### PLot IMG Embeddings

### T-SNE (skip)

In [None]:
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

if not do_skip:
    
    # Perform t-SNE clustering
    #Peplexity 500
    #early_exaggeration=40
    tsne = TSNE(n_components=2, random_state=42)
    img_embeddings_tsne = tsne.fit_transform(img_embeddings)

    plt.figure()
    plt.scatter(img_embeddings_tsne[:, 0], img_embeddings_tsne[:, 1], s=1)
    plt.title('Image Clustering Results (t-SNE img_embedding)')

    # save the embeddings
    np.save(os.path.join(save_embeddings_path,f'img_embeddings_tsne_{dataset_test_name}_{take}.npy'), img_embeddings_tsne)
    print(f"TSNE img_embeddings saved in {os.path.join(save_embeddings_path,f'img_embeddings_tsne_{dataset_test_name}_{take}.npy')}")

### Load

In [None]:
img_embeddings_tsne = np.load(os.path.join(save_embeddings_path,f'img_embeddings_tsne_{dataset_test_name}_{take}.npy'))
print(f"TSNE img_embeddings loaded")
print(f"TSNE img_embeddings shape: {img_embeddings_tsne.shape}")

In [None]:
%matplotlib widget
if do_plot:
    interactive_plot(img_embeddings_tsne[:, 0], img_embeddings_tsne[:, 1], colors=None, action='click', img_list=X_test_mask, emb_list=img_embeddings, true_img_list=X_test)

In [None]:
if do_plot:
    plot_n_patches_overlap(X_test_mask, X_test, indexes=[5597,24,25,26,23])

In [None]:
n_clusters_img = 50

### K-MEANS T-SNE (skip)

In [None]:
from sklearn.cluster import KMeans 

if not do_skip:
    # Create a KMeans object with the desired number of clusters
    kmeans = KMeans(n_clusters=n_clusters_img, random_state=42, n_init='auto')

    # Fit the KMeans model to the img_embeddings
    kmeans.fit(img_embeddings_tsne)

    # Get the cluster labels for each data point
    img_cluster_labels_2d = kmeans.labels_
    print(f"Classes: {set(img_cluster_labels_2d)}")
    np.save(os.path.join(save_embeddings_path,f'img_cluster_labels_2d_{dataset_test_name}_{take}.npy'), img_cluster_labels_2d)


    # Plot histogram of cluster labels
    plt.figure(figsize=(10, 5))
    plt.hist(img_cluster_labels_2d)
    plt.xlabel('Cluster Label')
    plt.ylabel('Count')
    plt.title('Histogram of Cluster Labels')
    plt.show()

### load

In [None]:
img_cluster_labels_2d = np.load(os.path.join(save_embeddings_path,f'img_cluster_labels_2d_{dataset_test_name}_{take}.npy'))
print(f"Cluster labels loaded")
print(f"Cluster labels shape: {img_cluster_labels_2d.shape}")

In [None]:
if do_plot:
    interactive_plot(img_embeddings_tsne[:, 0], img_embeddings_tsne[:, 1], colors=img_cluster_labels_2d, action='click', img_list=X_test_mask, emb_list=img_embeddings, true_img_list=X_test)

### K-MEANS with embedding code (skip)

In [None]:
from sklearn.cluster import KMeans

if not do_skip:
    # Create a KMeans object with the desired number of clusters
    kmeans = KMeans(n_clusters=n_clusters_img, random_state=42, n_init='auto')

    # Fit the KMeans model to the img_embeddings
    kmeans.fit(img_embeddings)

    # Get the cluster labels for each data point
    img_cluster_labels = kmeans.labels_
    print(f"Classes: {set(img_cluster_labels)}")
    np.save(os.path.join(save_embeddings_path,f'img_cluster_labels_{dataset_test_name}_{take}.npy'), img_cluster_labels)

### load

In [None]:
img_cluster_labels = np.load(os.path.join(save_embeddings_path,f'img_cluster_labels_{dataset_test_name}_{take}.npy'))
print(f"Cluster labels loaded")
print(f"Cluster labels shape: {img_cluster_labels.shape}")

In [None]:
if do_plot:
    plot_2_clusters(img_embeddings_tsne, img_cluster_labels, img_cluster_labels_2d) 

In [None]:
if do_plot:
    interactive_plot(img_embeddings_tsne[:, 0], img_embeddings_tsne[:, 1], colors=img_cluster_labels, action='click', img_list=X_test_mask, emb_list=img_embeddings, true_img_list=X_test)

In [None]:
if do_plot:
    interactive_plot(img_embeddings_tsne[:, 0], img_embeddings_tsne[:, 1], colors=img_cluster_labels, action='click', img_list=X_test_mask, emb_list=img_embeddings, true_img_list=X_test)

### Optional to save cluster imgs

In [None]:
import numpy as np
from tqdm import tqdm

if do_optional_plots:
    pick_colors = 'IMG_2D' # 'IMG_2D' or 'IMG_ND'
    add_title = f"{pick_colors}_{take}"
    # select 10 values at random between 0 and n_cluster
    #selected_classes = np.random.choice(n_clusters_img, 1, replace=False)
    #selected_classes = range(10)
    selected_classes = range(n_clusters_img)


    colors = img_cluster_labels_2d if pick_colors == 'IMG_2D' else img_cluster_labels

    do = "filter" # "filter" or "plot"

    print(f"Interactive plot filtered over 1 class ({pick_colors} Embedding)")
    for selected_class in tqdm(selected_classes):
        print(f"Selected class: {selected_class}")
        f_img_list, f_emb_list, f_true_img_list, f_indices = interactive_plot_filtered(img_embeddings_tsne[:, 0], img_embeddings_tsne[:, 1], colors=colors, action='click', selected_class=selected_class, img_list=X_test_mask, emb_list=img_embeddings, true_img_list=X_test, do=do)
        #plot_slices_with_cursor(f_img_list, f_emb_list, f_true_img_list, cursor_position=10, indeces=f_indices)
        
        plot_n_patches_overlap(X_test_mask, X_test, indexes=f_indices, selected_class=selected_class, add_title=add_title, m=40, alpha=0.5)
        break

In [None]:
if do_optional_plots:
    interactive_plot(img_embeddings_tsne[:, 0], img_embeddings_tsne[:, 1], colors=img_cluster_labels, action='click', img_list=X_test_mask, emb_list=img_embeddings, true_img_list=X_test)

## Sampling (NO MORE FROM HERE ON)

In [None]:
# def filter_lists(*args, indices):
#     filtered_lists = []
#     for lst in args:
#         filtered_lst = [lst[i] for i in indices]
#         filtered_lists.append(filtered_lst)
#     return filtered_lists

In [None]:
def random_sampling(*embedding_lists, n_size=1000, random_seed=42):
    # Random sampling
    np.random.seed(random_seed)
    embedding_arrays = [np.array(embedding_list) for embedding_list in embedding_lists]
    # Assuming all embedding lists have the same length
    num_samples = embedding_arrays[0].shape[0]
    
    assert len(set(len(embedding_list) for embedding_list in embedding_lists)) == 1, "All embedding lists should have the same length."
    
    if n_size<1: # percentage of num_samples
        print(f"WARNING: n_size parameter is <1 ({n_size})")
        print(f"Taking the {n_size*100}% of the num_samples ({num_samples}) ==> {int(n_size * num_samples)}")
        n_size = int(n_size * num_samples)
        
    # Randomly select indices
    random_idx = np.random.choice(num_samples, n_size, replace=False)
    # Perform random sampling for each embedding list
    random_lists = []
    for embedding_array in embedding_arrays:
        random_list = embedding_array[random_idx].tolist()
        random_lists.append(random_list)
    random_lists = [np.array(random_list) for random_list in random_lists]
    return random_lists, random_idx
    
import numpy as np

def random_class_sampling(*embedding_lists, class_array, n_size=1000, random_seed=42):
    # Random sampling
    np.random.seed(random_seed)
    embedding_arrays = [np.array(embedding_list) for embedding_list in embedding_lists]
    
    # Assuming all embedding lists have the same length
    assert len(set(len(embedding_list) for embedding_list in embedding_lists)) == 1, "All embedding lists should have the same length."
    
    # Group indices by class
    class_indices = {}
    class_count = {}
    for i, class_label in enumerate(class_array):
        if class_label not in class_indices:
            class_indices[class_label] = []
            class_count[class_label] = 0
        class_indices[class_label].append(i)
        class_count[class_label]+=1
    
    # Randomly select indices for each class
    random_idx = []
    if n_size<1:
        # Calculate the number of samples per class (SAME PERCENTAGE FOR EACH CLUSTER)
        print(f"Warning: n_size<1 ({n_size})")
        print(f"Taking {n_size*100}% of samples for each cluster")
        
        if set(class_indices.keys()) == set(class_count.keys()):
            # print("Keys are the same in both dictionaries")
            for key in class_indices.keys():
                indices = class_indices[key]
                count_c = class_count[key]
                #print(f"Key: {key}, Value in dict1: {indices}, Value in dict2: {count_c}")
                samples_per_class = int(count_c * n_size)
                print(f"Taking {samples_per_class} for cluster {key}")
                random_idx.extend(np.random.choice(indices, samples_per_class, replace=False))
        else:
            print("Keys are not the same in both dictionaries")
            assert()
        
    
    else:
        # Calculate the number of samples per class (SAME NUMBER FOR EACH CLUSTER
        samples_per_class = n_size // len(class_indices)
        print(f"Taking {samples_per_class} samples per class")
        for indices in class_indices.values():
           random_idx.extend(np.random.choice(indices, samples_per_class, replace=(len(indices) < samples_per_class)))
            
    
    # Perform random sampling for each embedding list
    random_lists = []
    for embedding_array in embedding_arrays:
        random_list = embedding_array[random_idx].tolist()
        random_lists.append(random_list)
    
    random_lists = [np.array(random_list) for random_list in random_lists]
    return random_lists, random_idx

    
    # # Randomly select indices for each class
    # class_indices = []
    # for embedding_array in embedding_arrays:
    #     indices = np.arange(len(embedding_array))
    #     np.random.shuffle(indices)
    #     class_indices.append(indices[:n_samples_per_class])
    
    # # Concatenate indices from all classes
    # selected_indices = np.concatenate(class_indices)
    # np.random.shuffle(selected_indices)
    
    # # Select samples based on the concatenated indices
    # selected_samples = []
    # for embedding_array in embedding_arrays:
    #     selected_samples.append(embedding_array[selected_indices])
    
    # return selected_samples, selected_indices
    
    

In [None]:
import numpy as np
from sklearn.metrics import pairwise_distances

def farthest_point_sampling(X, num_points, random_seed):
    """
    Perform farthest point sampling to select a subset of points from X.

    Parameters:
    - X: array-like, shape (n_samples, n_features)
        The input data points.
    - num_points: int
        The number of points to select.

    Returns:
    - selected_indices: array-like, shape (num_points,)
        Indices of the selected points.
    """

    # Initialize an empty list to store selected point indices
    selected_indices = []

    # Choose a random point to start with
    np.random.seed(random_seed)
    initial_index = np.random.choice(X.shape[0])
    selected_indices.append(initial_index)

    # Compute pairwise distances from the selected point to all other points
    distances = pairwise_distances(X, [X[initial_index]])

    # Iterate until we select num_points
    while len(selected_indices) < num_points:
        # Find the point farthest from the selected set
        farthest_index = np.argmax(distances)

        # Update the distances array by choosing the minimum between the existing
        # distances and the distances from the newly selected point
        distances = np.minimum(distances, pairwise_distances(X, [X[farthest_index]]))

        # Add the farthest point to the selected set
        selected_indices.append(farthest_index)

    return selected_indices

def sample_within_class(coordinates, classes, n_size, *arrays, random_seed, single_cluster=False):
    """
    Sample m/p samples from each class using farthest point sampling.
    """
    if single_cluster:
        # Consider the whole dataset as a single class
        classes = np.zeros_like(classes)
    
    
    unique_classes = np.unique(classes)
    sampled_indices = []
    if n_size<1:
        # Calculate the number of samples per class (SAME PERCENTAGE FOR EACH CLUSTER)
        print(f"Warning: n_size<1 ({n_size})")
        print(f"Taking {n_size*100}% of samples for each cluster")
        
    for c in unique_classes:
        # Find indices of samples belonging to class c
        class_indices = np.where(classes == c)[0]
        num_samples_class = len(class_indices)
        
        if n_size <1:
            samples_to_select = int(n_size*num_samples_class)
        else:
            samples_to_select = min(n_size // len(unique_classes), num_samples_class)

        # Perform farthest point sampling within the class
        sampled_indices_class = farthest_point_sampling(coordinates[class_indices], samples_to_select, random_seed=random_seed)
        print(f"Class {c}: {len(sampled_indices_class)} samples selected")
        sampled_indices.extend(class_indices[sampled_indices_class])

    # Apply the sampled indices to other arrays
    sampled_arrays = [array[sampled_indices] for array in (coordinates, classes) + arrays]

    return sampled_arrays


In [None]:
# # Generate some random data points
# np.random.seed(0)
# X = np.random.rand(100, 2)

# # Number of points to select
# num_points = 50

# # Perform farthest point sampling
# X_selected, selected_indices = farthest_point_sampling(X, num_points)
# X_randomaaa, random_aaa = random_sampling(X, n_size=num_points)
# X_randomaaa = X_randomaaa[0] 
# print(X_randomaaa.shape)
# #plot all X in grey and the selected indices in red
# import matplotlib.pyplot as plt

# plt.figure(figsize=(10,5))
# plt.subplot(1,2,1)
# plt.scatter(X[:, 0], X[:, 1], color='grey')
# plt.scatter(X_selected[:, 0], X_selected[:, 1], color='red')
# plt.subplot(1,2,2)
# plt.scatter(X[:, 0], X[:, 1], color='grey')
# plt.scatter(X_randomaaa[:, 0], X_randomaaa[:, 1], color='blue')
# plt.scatter


# plt.show()


# # Example usage:
# n = 1000  # Number of samples
# p = 1    # Number of classes
# m = 200   # Total number of samples to select
# coordinates = np.random.rand(n, 2)  # Sample coordinates
# classes = np.random.randint(0, p, size=(n, 1))  # Sample classes
# # Additional arrays
# arrays = [np.random.rand(n, np.random.randint(2, 10)) for _ in range(5)]  # Generating 5 random arrays

# sampled_arrays = sample_within_class(coordinates, classes, m, *arrays)
# coordinate_filtered, classes_filtered, *other_arrays = sampled_arrays

# import matplotlib.pyplot as plt

# # Plot coordinates in grey
# plt.scatter(coordinates[:, 0], coordinates[:, 1], color='grey')

# # Plot filtered coordinates with color of filtered classes
# for c in np.unique(classes_filtered):
#     indices = np.where(classes_filtered == c)[0]
#     plt.scatter(coordinate_filtered[indices, 0], coordinate_filtered[indices, 1], label=f'Class {c}')

# plt.legend()
# plt.show()



In [None]:
print(f"Img list shape: {X_test.shape}")
print(f"Mask list shape: {X_test_mask.shape}\n")

print(f"Img Embeddings shape: {img_embeddings.shape}")
print(f"Img Embeddings TSNE shape {img_embeddings_tsne.shape}")

print(f"Class shape: {img_cluster_labels.shape}")
print(f"Class 2D shape: {img_cluster_labels_2d.shape}")

### START FROM THIS CELL!!!

In [None]:
################################
################################
# Modify this parameter to choose the number of patches to be selected
# It can be either:
# Integer n>=1 (Same number of examples for each clusters)
# or 
# Percentage p=[0,1] (Same percentage of examples for each clusters)

n_size = 75/100
do_plot = True


################################
################################
if n_size<1:
    print(f"In total {int(X_test.shape[0]*n_size)} samples will be selected")
else:
    print(f"In total {n_size} samples will be selected")

In [None]:
#histogram of the classes
import matplotlib.pyplot as plt

if do_plot:
    fig, ax = plt.subplots(1, 2, figsize=(10, 5))
    ax = ax.flatten()

    ax[0].hist(img_cluster_labels, bins='auto')
    ax[0].set_title('(embedding)')
    ax[0].set_xlabel('Class')
    ax[0].set_ylabel('Frequency')

    ax[1].hist(img_cluster_labels_2d, bins='auto')
    ax[1].set_title('(t-SNE embedding 2D)')
    ax[1].set_xlabel('Class')

    plt.show()

### Random Sampling (R)

In [None]:
selected_lists_r, selected_indices_r = random_sampling(X_test, X_test_mask, img_embeddings, img_embeddings_tsne ,img_cluster_labels, img_cluster_labels_2d, n_size=n_size, random_seed=42)
X_test_r, X_test_mask_r, img_embeddings_r,img_embeddings_tsne_r, img_cluster_labels_r, img_cluster_labels_2d_r = selected_lists_r

print(f"Random X Test shape {X_test_r.shape}")
print(f"Random X Test Mask shape {X_test_mask_r.shape}")
print(f"Random Img Embedding shape {img_embeddings_r.shape}")
print(f"Random Img Embedding 2D shape {img_embeddings_tsne_r.shape}")
print(f"Random Img Class shape {img_cluster_labels_r.shape}")
print(f"Random Img Class 2D shape {img_cluster_labels_2d_r.shape}")

In [None]:
#histogram of the classes
import matplotlib.pyplot as plt

if do_plot:

    fig, ax = plt.subplots(1, 2, figsize=(10, 5))
    ax = ax.flatten()

    ax[0].hist(img_cluster_labels_r, bins='auto')
    ax[0].set_title('(embedding)')
    ax[0].set_xlabel('Class')
    ax[0].set_ylabel('Frequency')

    ax[1].hist(img_cluster_labels_2d_r, bins='auto')
    ax[1].set_title('(t-SNE embedding 2D)')
    ax[1].set_xlabel('Class')

    plt.show()

### Diversity sampling (C) (RANDOM WITHIN THE CLASS)

In [None]:
selected_lists_c, selected_indices_c = random_class_sampling(X_test, X_test_mask, img_embeddings, img_embeddings_tsne ,img_cluster_labels, img_cluster_labels_2d,
                                                      class_array=img_cluster_labels, n_size=n_size, random_seed=42)
X_test_c, X_test_mask_c, img_embeddings_c,img_embeddings_tsne_c, img_cluster_labels_c, img_cluster_labels_2d_c = selected_lists_c

print(f"Class X Test shape {X_test_c.shape}")
print(f"Class X Test Mask shape {X_test_mask_c.shape}")
print(f"Class Img Embedding shape {img_embeddings_c.shape}")
print(f"Class Img Embedding 2D shape {img_embeddings_tsne_c.shape}")
print(f"Class Img Class shape {img_cluster_labels_c.shape}")
print(f"Class Img Class 2D shape {img_cluster_labels_2d_c.shape}")

In [None]:
#histogram of the classes
import matplotlib.pyplot as plt

if do_plot:
    fig, ax = plt.subplots(1, 2, figsize=(10, 5))
    ax = ax.flatten()

    ax[0].hist(img_cluster_labels_c, bins=n_clusters_img)
    ax[0].set_title('(embedding)')
    ax[0].set_xlabel('Class')
    ax[0].set_ylabel('Frequency')

    ax[1].hist(img_cluster_labels_2d_c, bins=n_clusters_img)
    ax[1].set_title('(t-SNE embedding 2D)')
    ax[1].set_xlabel('Class')

    plt.show()

### Diversity Sampling (Furthest) (CF) ==> FPS with latent vector classes and t-sne distance computation

In [None]:
sampled_arrays_cf = sample_within_class(img_embeddings_tsne, img_cluster_labels, n_size, X_test_mask, X_test, img_embeddings, img_cluster_labels_2d, random_seed=42)
img_embeddings_tsne_cf, img_cluster_labels_cf, X_test_mask_cf, X_test_cf, img_embeddings_cf, img_cluster_labels_2d_cf = sampled_arrays_cf

print(f"Class X Test shape {X_test_cf.shape}")
print(f"Class X Test Mask shape {X_test_mask_cf.shape}")
print(f"Class Img Embedding shape {img_embeddings_cf.shape}")
print(f"Class Img Embedding 2D shape {img_embeddings_tsne_cf.shape}")
print(f"Class Img Class shape {img_cluster_labels_cf.shape}")
print(f"Class Img Class 2D shape {img_cluster_labels_2d_cf.shape}")

In [None]:
#histogram of the classes
import matplotlib.pyplot as plt


fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax = ax.flatten()

ax[0].hist(img_cluster_labels_cf, bins=n_clusters_img)
ax[0].set_title('(embedding)')
ax[0].set_xlabel('Class')
ax[0].set_ylabel('Frequency')

ax[1].hist(img_cluster_labels_2d_cf, bins=n_clusters_img)
ax[1].set_title('(t-SNE embedding 2D)')
ax[1].set_xlabel('Class')

plt.show()

In [None]:
print("Saving Class-sampled dataset")
#np.save(os.path.join(save_embeddings_path,f'X_test_c_{n_size}_{take}.npy'), X_test_c)
#np.save(os.path.join(save_embeddings_path,f'X_test_mask_c_{n_size}_{take}.npy'), X_test_mask_c)
#np.save(os.path.join(save_embeddings_path,f'img_embeddings_c_{n_size}_{take}.npy'), img_embeddings_c)
#np.save(os.path.join(save_embeddings_path,f'img_embeddings_tsne_c_{n_size}_{take}.npy'), img_embeddings_tsne_c)
#np.save(os.path.join(save_embeddings_path,f'img_cluster_labels_c_{n_size}_{take}.npy'), img_cluster_labels_c)
#np.save(os.path.join(save_embeddings_path,f'img_cluster_labels_2d_c_{n_size}_{take}.npy'), img_cluster_labels_2d_c)

print("Saving Class-sampled (F) dataset")
#np.save(os.path.join(save_embeddings_path,f'X_test_cf_{n_size}_{take}.npy'), X_test_cf)
#np.save(os.path.join(save_embeddings_path,f'X_test_mask_cf_{n_size}_{take}.npy'), X_test_mask_cf)
#np.save(os.path.join(save_embeddings_path,f'img_embeddings_cf_{n_size}_{take}.npy'), img_embeddings_cf)
#np.save(os.path.join(save_embeddings_path,f'img_embeddings_tsne_cf_{n_size}_{take}.npy'), img_embeddings_tsne_cf)
#np.save(os.path.join(save_embeddings_path,f'img_cluster_labels_cf_{n_size}_{take}.npy'), img_cluster_labels_cf)
#np.save(os.path.join(save_embeddings_path,f'img_cluster_labels_2d_cf_{n_size}_{take}.npy'), img_cluster_labels_2d_cf)

print("Saving Random-sampled dataset")
#np.save(os.path.join(save_embeddings_path,f'X_test_r_{n_size}_{take}.npy'), X_test_r)
#np.save(os.path.join(save_embeddings_path,f'X_test_mask_r_{n_size}_{take}.npy'), X_test_mask_r)
#np.save(os.path.join(save_embeddings_path,f'img_embeddings_r_{n_size}_{take}.npy'), img_embeddings_r)
#np.save(os.path.join(save_embeddings_path,f'img_embeddings_tsne_r_{n_size}_{take}.npy'), img_embeddings_tsne_r)
#np.save(os.path.join(save_embeddings_path,f'img_cluster_labels_r_{n_size}_{take}.npy'), img_cluster_labels_r)
#np.save(os.path.join(save_embeddings_path,f'img_cluster_labels_2d_r_{n_size}_{take}.npy'), img_cluster_labels_2d_r)

### (CFX) ==> Stratified FPS with both classes and distance computation from latent vectors

In [None]:
sampled_arrays_cfx = sample_within_class(img_embeddings, img_cluster_labels, n_size, X_test_mask, X_test, img_embeddings_tsne, img_cluster_labels_2d, random_seed=42)
img_embeddings_cfx, img_cluster_labels_cfx, X_test_mask_cfx, X_test_cfx, img_embeddings_tsne_cfx, img_cluster_labels_2d_cfx = sampled_arrays_cfx

print(f"Class X Test shape {X_test_cfx.shape}")
print(f"Class X Test Mask shape {X_test_mask_cfx.shape}")
print(f"Class Img Embedding shape {img_embeddings_cfx.shape}")
print(f"Class Img Embedding 2D shape {img_embeddings_tsne_cfx.shape}")
print(f"Class Img Class shape {img_cluster_labels_cfx.shape}")
print(f"Class Img Class 2D shape {img_cluster_labels_2d_cfx.shape}")

In [None]:
print("Saving Class-sampled (CFX) dataset")
#np.save(os.path.join(save_embeddings_path,f'X_test_cfx_{n_size}_{take}.npy'), X_test_cfx)
#np.save(os.path.join(save_embeddings_path,f'X_test_mask_cfx_{n_size}_{take}.npy'), X_test_mask_cfx)
#np.save(os.path.join(save_embeddings_path,f'img_embeddings_cfx_{n_size}_{take}.npy'), img_embeddings_cfx)
#np.save(os.path.join(save_embeddings_path,f'img_embeddings_tsne_cfx_{n_size}_{take}.npy'), img_embeddings_tsne_cfx)
#np.save(os.path.join(save_embeddings_path,f'img_cluster_labels_cfx_{n_size}_{take}.npy'), img_cluster_labels_cfx)
#np.save(os.path.join(save_embeddings_path,f'img_cluster_labels_2d_cfx_{n_size}_{take}.npy'), img_cluster_labels_2d_cfx)

### FPS WHOLE (CFW) ==> FPS on the whole dataset + distance computation over the 2D tsne points

In [None]:
sampled_arrays_cfw = sample_within_class(img_embeddings_tsne, img_cluster_labels, n_size, X_test_mask, X_test, img_embeddings, img_cluster_labels_2d, img_cluster_labels, random_seed=42, single_cluster=True)
img_embeddings_tsne_cfw, _ , X_test_mask_cfw, X_test_cfw, img_embeddings_cfw, img_cluster_labels_2d_cfw, img_cluster_labels_cfw = sampled_arrays_cfw

print(f"Class X Test shape {X_test_cfw.shape}")
print(f"Class X Test Mask shape {X_test_mask_cfw.shape}")
print(f"Class Img Embedding shape {img_embeddings_cfw.shape}")
print(f"Class Img Embedding 2D shape {img_embeddings_tsne_cfw.shape}")
print(f"Class Img Class shape {img_cluster_labels_cfw.shape}")
print(f"Class Img Class 2D shape {img_cluster_labels_2d_cfw.shape}")

In [None]:
print("Saving Class-sampled (FW) dataset")
#np.save(os.path.join(save_embeddings_path,f'X_test_cfw_{n_size}_{take}.npy'), X_test_cfw)
#np.save(os.path.join(save_embeddings_path,f'X_test_mask_cfw_{n_size}_{take}.npy'), X_test_mask_cfw)
#np.save(os.path.join(save_embeddings_path,f'img_embeddings_cfw_{n_size}_{take}.npy'), img_embeddings_cfw)
#np.save(os.path.join(save_embeddings_path,f'img_embeddings_tsne_cfw_{n_size}_{take}.npy'), img_embeddings_tsne_cfw)
#np.save(os.path.join(save_embeddings_path,f'img_cluster_labels_cfw_{n_size}_{take}.npy'), img_cluster_labels_cfw)
#np.save(os.path.join(save_embeddings_path,f'img_cluster_labels_2d_cfw_{n_size}_{take}.npy'), img_cluster_labels_2d_cfw)

### FPS WHOLE (CFWX) ==> FPS on the whole dataset + distance computation over the latent vectors

In [None]:
sampled_arrays_cfwx = sample_within_class(img_embeddings, img_cluster_labels, n_size, X_test_mask, X_test, img_embeddings_tsne, img_cluster_labels_2d, img_cluster_labels, random_seed=42, single_cluster=True)
img_embeddings_cfwx, _ , X_test_mask_cfwx, X_test_cfwx, img_embeddings_tsne_cfwx, img_cluster_labels_2d_cfwx, img_cluster_labels_cfwx = sampled_arrays_cfwx

print(f"Class X Test shape {X_test_cfwx.shape}")
print(f"Class X Test Mask shape {X_test_mask_cfwx.shape}")
print(f"Class Img Embedding shape {img_embeddings_cfwx.shape}")
print(f"Class Img Embedding 2D shape {img_embeddings_tsne_cfwx.shape}")
print(f"Class Img Class shape {img_cluster_labels_cfwx.shape}")
print(f"Class Img Class 2D shape {img_cluster_labels_2d_cfwx.shape}")

In [None]:
print("Saving Class-sampled (FWX) dataset")
#np.save(os.path.join(save_embeddings_path,f'X_test_cfwx_{n_size}_{take}.npy'), X_test_cfwx)
#np.save(os.path.join(save_embeddings_path,f'X_test_mask_cfwx_{n_size}_{take}.npy'), X_test_mask_cfwx)
#np.save(os.path.join(save_embeddings_path,f'img_embeddings_cfwx_{n_size}_{take}.npy'), img_embeddings_cfwx)
#np.save(os.path.join(save_embeddings_path,f'img_embeddings_tsne_cfwx_{n_size}_{take}.npy'), img_embeddings_tsne_cfwx)
#np.save(os.path.join(save_embeddings_path,f'img_cluster_labels_cfwx_{n_size}_{take}.npy'), img_cluster_labels_cfwx)
#np.save(os.path.join(save_embeddings_path,f'img_cluster_labels_2d_cfwx_{n_size}_{take}.npy'), img_cluster_labels_2d_cfwx)

In [None]:
#histogram of the classes
import matplotlib.pyplot as plt


fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax = ax.flatten()

ax[0].hist(img_cluster_labels_cfw, bins=n_clusters_img)
ax[0].set_title('(embedding)')
ax[0].set_xlabel('Class')
ax[0].set_ylabel('Frequency')

ax[1].hist(img_cluster_labels_2d_cfw, bins=n_clusters_img)
ax[1].set_title('(t-SNE embedding 2D)')
ax[1].set_xlabel('Class')

plt.show()

### Plot all vs random vs class sampling

In [None]:
# STRATIFIED RANDOM CLASS (C)
X_test_c = np.load(os.path.join(save_embeddings_path, f'X_test_c_{n_size}_{take}.npy'))
X_test_mask_c = np.load(os.path.join(save_embeddings_path, f'X_test_mask_c_{n_size}_{take}.npy'))
img_embeddings_c = np.load(os.path.join(save_embeddings_path, f'img_embeddings_c_{n_size}_{take}.npy'))
img_embeddings_tsne_c = np.load(os.path.join(save_embeddings_path, f'img_embeddings_tsne_c_{n_size}_{take}.npy'))
img_cluster_labels_c = np.load(os.path.join(save_embeddings_path, f'img_cluster_labels_c_{n_size}_{take}.npy'))
img_cluster_labels_2d_c = np.load(os.path.join(save_embeddings_path, f'img_cluster_labels_2d_c_{n_size}_{take}.npy'))

# STRATIFIED FPS CLASS (CF)
X_test_cf = np.load(os.path.join(save_embeddings_path, f'X_test_cf_{n_size}_{take}.npy'))
X_test_mask_cf = np.load(os.path.join(save_embeddings_path, f'X_test_mask_cf_{n_size}_{take}.npy'))
img_embeddings_cf = np.load(os.path.join(save_embeddings_path, f'img_embeddings_cf_{n_size}_{take}.npy'))
img_embeddings_tsne_cf = np.load(os.path.join(save_embeddings_path, f'img_embeddings_tsne_cf_{n_size}_{take}.npy'))
img_cluster_labels_cf = np.load(os.path.join(save_embeddings_path, f'img_cluster_labels_cf_{n_size}_{take}.npy'))
img_cluster_labels_2d_cf = np.load(os.path.join(save_embeddings_path, f'img_cluster_labels_2d_cf_{n_size}_{take}.npy'))

# RANDOM TOTAL (R)
X_test_r = np.load(os.path.join(save_embeddings_path, f'X_test_r_{n_size}_{take}.npy'))
X_test_mask_r = np.load(os.path.join(save_embeddings_path, f'X_test_mask_r_{n_size}_{take}.npy'))
img_embeddings_r = np.load(os.path.join(save_embeddings_path, f'img_embeddings_r_{n_size}_{take}.npy'))
img_embeddings_tsne_r = np.load(os.path.join(save_embeddings_path, f'img_embeddings_tsne_r_{n_size}_{take}.npy'))
img_cluster_labels_r = np.load(os.path.join(save_embeddings_path, f'img_cluster_labels_r_{n_size}_{take}.npy'))
img_cluster_labels_2d_r = np.load(os.path.join(save_embeddings_path, f'img_cluster_labels_2d_r_{n_size}_{take}.npy'))

# FPS TOTAL (CFW)
X_test_cfw = np.load(os.path.join(save_embeddings_path, f'X_test_cfw_{n_size}_{take}.npy'))
X_test_mask_cfw = np.load(os.path.join(save_embeddings_path, f'X_test_mask_cfw_{n_size}_{take}.npy'))
img_embeddings_cfw = np.load(os.path.join(save_embeddings_path, f'img_embeddings_cfw_{n_size}_{take}.npy'))
img_embeddings_tsne_cfw = np.load(os.path.join(save_embeddings_path, f'img_embeddings_tsne_cfw_{n_size}_{take}.npy'))
img_cluster_labels_cfw = np.load(os.path.join(save_embeddings_path, f'img_cluster_labels_cfw_{n_size}_{take}.npy'))
img_cluster_labels_2d_cfw = np.load(os.path.join(save_embeddings_path, f'img_cluster_labels_2d_cfw_{n_size}_{take}.npy'))

# --------- # In all the FPS above the classes are computer through the k-means on the latent vector but the FPS distance is computer through the 2D-TSNE points # --------- #

# FPS TOTAL (distance over latent vector) CFWX EXPLANATION

# X_test_cfwx = np.load(os.path.join(save_embeddings_path, f'X_test_cfwx_{n_size}_{take}.npy'))
# X_test_mask_cfwx = np.load(os.path.join(save_embeddings_path, f'X_test_mask_cfwx_{n_size}_{take}.npy'))
# img_embeddings_cfwx = np.load(os.path.join(save_embeddings_path, f'img_embeddings_cfwx_{n_size}_{take}.npy'))
# img_embeddings_tsne_cfwx = np.load(os.path.join(save_embeddings_path, f'img_embeddings_tsne_cfwx_{n_size}_{take}.npy'))
# img_cluster_labels_cfwx = np.load(os.path.join(save_embeddings_path, f'img_cluster_labels_cfwx_{n_size}_{take}.npy'))
# img_cluster_labels_2d_cfwx = np.load(os.path.join(save_embeddings_path, f'img_cluster_labels_2d_cfwx_{n_size}_{take}.npy'))

# (CFX) ==> Stratified FPS with both classes and distance computation from latent vectors
X_test_cfx = np.load(os.path.join(save_embeddings_path, f'X_test_cfx_{n_size}_{take}.npy'))
X_test_mask_cfx = np.load(os.path.join(save_embeddings_path, f'X_test_mask_cfx_{n_size}_{take}.npy'))
img_embeddings_cfx = np.load(os.path.join(save_embeddings_path, f'img_embeddings_cfx_{n_size}_{take}.npy'))
img_embeddings_tsne_cfx = np.load(os.path.join(save_embeddings_path, f'img_embeddings_tsne_cfx_{n_size}_{take}.npy'))
img_cluster_labels_cfx = np.load(os.path.join(save_embeddings_path, f'img_cluster_labels_cfx_{n_size}_{take}.npy'))
img_cluster_labels_2d_cfx = np.load(os.path.join(save_embeddings_path, f'img_cluster_labels_2d_cfx_{n_size}_{take}.npy'))


In [None]:
# R     ==> RANDOM  +   NO CLASS    (No class)          (RANDOM)
# C     ==> RANDOM  +   STRATIFIED  (class from latent) (RANDOM)
# CF    ==> FPS     +   STRATIFIED  (class from latent) (FPS Distance from 2D)
# CFW   ==> FPS     +   NO CLASS    (No class)          (FPS Distance from 2D)
# CFWX  ==> FPS     +   NO CLASS    (No class)          (FPS Distance from latent)
# CFX   ==> FPS     +   STRATIFIED  (class from latent) (FPS Distance from latent)

In [None]:
print("Plot complete test dataset")
print("COMPLETE")
interactive_plot(img_embeddings_tsne[:, 0], img_embeddings_tsne[:, 1], colors=img_cluster_labels, action='click', img_list=X_test_mask, emb_list=img_embeddings, true_img_list=X_test)

In [None]:
print("Plot random sampled test dataset")
print("R")
interactive_plot(img_embeddings_tsne_r[:, 0], img_embeddings_tsne_r[:, 1], colors=img_cluster_labels_r, action='click', img_list=X_test_mask_r, emb_list=img_embeddings_r, true_img_list=X_test_r)

In [None]:
print("Plot class sampled test dataset")
print("C")
interactive_plot(img_embeddings_tsne_c[:, 0], img_embeddings_tsne_c[:, 1], colors=img_cluster_labels_c, action='click', img_list=X_test_mask_c, emb_list=img_embeddings_c, true_img_list=X_test_c)

In [None]:
print("Plot class (furthest) sampled test dataset")
print("CF")
interactive_plot(img_embeddings_tsne_cf[:, 0], img_embeddings_tsne_cf[:, 1], colors=img_cluster_labels_cf, action='click', img_list=X_test_mask_cf, emb_list=img_embeddings_cf, true_img_list=X_test_cf)

In [None]:
print("Plot whole (furthest) sampled test dataset")
print("CFW")
interactive_plot(img_embeddings_tsne_cfw[:, 0], img_embeddings_tsne_cfw[:, 1], colors=img_cluster_labels_cfw, action='click', img_list=X_test_mask_cfw, emb_list=img_embeddings_cfw, true_img_list=X_test_cfw)

In [None]:
print("Plot whole (furthest) sampled test dataset")
print("CFWX")
interactive_plot(img_embeddings_tsne_cfwx[:, 0], img_embeddings_tsne_cfwx[:, 1], colors=img_cluster_labels_cfwx, action='click', img_list=X_test_mask_cfwx, emb_list=img_embeddings_cfwx, true_img_list=X_test_cfwx)

In [None]:
print("Plot whole (furthest) sampled test dataset")
print("CFWX")
interactive_plot(img_embeddings_tsne_cfwx[:, 0], img_embeddings_tsne_cfwx[:, 1], colors=img_cluster_labels_cfwx, action='click', img_list=X_test_mask_cfwx, emb_list=img_embeddings_cfwx, true_img_list=X_test_cfwx)

In [None]:
print("Plot whole (furthest) sampled test dataset")
print("CFX")
interactive_plot(img_embeddings_tsne_cfx[:, 0], img_embeddings_tsne_cfx[:, 1], colors=img_cluster_labels_cfx, action='click', img_list=X_test_mask_cfx, emb_list=img_embeddings_cfx, true_img_list=X_test_cfx)

In [None]:
print(f"Datasaved for {n_size} samples")