## 0 - Import libraries

In [1]:
just_sampling = False

In [None]:
import os
import numpy as np
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F

torch.backends.cudnn.benchmark = True

from tqdm import tqdm

from utils import set_random_seed, mk_dir
from importlib import import_module

from torchvision.utils import save_image

In [3]:
def extract_img_vessels_np2(dataset_img_dir):
    for i,img_name in enumerate(tqdm(os.listdir(dataset_img_dir))):
        if '32_img.npy' not in img_name:
            continue
    # load png as a numpy
        img_array = np.load(os.path.join(dataset_img_dir,img_name))
        data_array = np.load(os.path.join(dataset_img_dir,img_name).replace('img','label'))
        if i==0:
            img_list_np = img_array
            data_list_np = data_array
        else:
            img_list_np = np.concatenate((img_list_np, img_array), axis=0)
            data_list_np = np.concatenate((data_list_np, data_array), axis=0)

    return img_list_np, data_list_np

In [4]:
import matplotlib.pyplot as plt
plt.ion() 

def save_img(tensor, name, norm, n_rows=16, scale_each=False):
    save_image(tensor, name, nrow=n_rows, padding=5, normalize=norm, pad_value=1, scale_each=scale_each)
    
def plot_10_patches(img_np, true_vessel_np=False, indexes=None):
    
    random_indexes = np.random.randint(0, img_np.shape[0], size=5) if indexes is None else indexes

    n = len(random_indexes)
    
    fig, ax = plt.subplots(n, 2, figsize=(6, 3*n))
    
    for i, index in enumerate(random_indexes):
        
        if index == img_np.shape[0]:
            index -= 1
        if index == img_np.shape[0]-1:
            index -= 2
            
        #print(f"Shape: {img_np[index,:,:].shape}, Max: {img_np[index,:,:].max()}, Min: {img_np[index,:,:].min()}")
        ax[i, 0].set_title(f'True Vessel Slice {index}')
        ax[i, 0].imshow(img_np[index,:,:], cmap='gray', vmin=0, vmax=1)
        ax[i,0].axis('off')
        
        if true_vessel_np is not False:
            print(f"Shape: {true_vessel_np[index,:,:].shape}, Max: {true_vessel_np[index,:,:].max()}, Min: {true_vessel_np[index,:,:].min()}")
            ax[i, 1].imshow(true_vessel_np[index,:,:], cmap='gray')
            ax[i, 1].set_title(f'Image Slice {index}')
        else:
            ax[i, 1].set_title(f'Image Slice {index+1}')
            ax[i, 1].imshow(img_np[index+1,:,:], cmap='gray', vmin=0, vmax=1)
        
        ax[i,1].axis('off')
            
    
    plt.tight_layout()
    plt.show()

def plot_n_patches_overlap(img_np, true_vessel_np=False, indexes=None, selected_class=None, add_title='', m=5, alpha=0.5, save_dir='clusters_imgs'):
    
    save_dir = os.path.join(save_embeddings_path,save_dir)
    mk_dir(save_dir)
    plt.ioff()
    n = len(indexes) if indexes is not None else 5
    # m is the number of images to plot in each row (2m is the number of columns)
    n_rows = int(np.ceil(n/m))
    n_cols = m
    
    fig, ax = plt.subplots(n_rows, n_cols, figsize=(4*n_cols, 4*n_rows))
    ax = ax.flatten()
    
    for i, index in enumerate(indexes):
        ax[i].set_title(f'Overlay for Slice {index}')
        masked_image = np.ma.masked_where(img_np[index,:,:] == 0, true_vessel_np[index,:,:])
        # Overlay the red image on top of true_vessel_np[index,:,:]
        ax[i].imshow(true_vessel_np[index,:,:], cmap='gray', interpolation='none')
        ax[i].imshow(masked_image, cmap='Reds', alpha=alpha)
        
        ax[i].axis('off')
    
    try:
        plt.tight_layout()
    except:
        pass
    
    if add_title!='':
        print(f"Save Fig to {save_dir}")
        plt.savefig(os.path.join(save_dir,f'{add_title}_class_{selected_class}_patches.png'))
        plt.close(fig)
        
    else:
        print("Plotting..")
        plt.show()
    plt.ion()

    
def plot_n_patches(img_np, true_vessel_np=False, indexes=None, selected_class=None, add_title='', m=5):
    plt.ioff()
    n = len(indexes) if indexes is not None else 5
    if n > 1000:
        print(f'WARNING: YOU ARE TRYING TO PLOT {n} images')
    # m is the number of images to plot in each row (2m is the number of columns)
    n_rows = int(np.ceil(n/m))
    n_cols = 2*m
    
    fig, ax = plt.subplots(n_rows, n_cols, figsize=(3*n_cols, 3*n_rows))
    ax = ax.flatten()
    
    for i, index in enumerate(indexes):
        
        if index == img_np.shape[0]:
            index -= 1
        if index == img_np.shape[0]-1:
            index -= 2
            
        ax[2*i].set_title(f'Image Slice {index}')
        ax[2*i].imshow(img_np[index,:,:], cmap='gray', vmin=0, vmax=1)
        ax[2*i].axis('off')
        print(f"Shape: {img_np[index,:,:].shape}, Max: {img_np[index,:,:].max()}, Min: {img_np[index,:,:].min()}")
        if true_vessel_np is not False:
            ax[2*i + 1].imshow(true_vessel_np[index,:,:], cmap='gray')
            ax[2*i + 1].set_title(f'True Vessel Slice {index}')
            
        else:
            ax[2*i + 1].set_title(f'Image Slice {index+1}')
            ax[2*i + 1].imshow(img_np[index+1,:,:], cmap='gray', vmin=0, vmax=1)
        
        ax[2*i+1].axis('off')
            
    
    plt.tight_layout()
    if add_title!='':
        print("Save Fig")
        plt.savefig(f'{add_title}_class_{selected_class}_patches.png')
        plt.close(fig)
        
    else:
        print("Plotting..")
        plt.show()
    plt.ion()
    
def reshape_to_square(vector):
    # Calculate the nearest square number greater than or equal to the length of the vector
    n = int(np.ceil(np.sqrt(len(vector))))
    
    # Calculate the number of elements to pad with zeros
    num_zeros = n*n - len(vector)
    
    # Pad the vector with zeros if necessary
    vector_padded = np.pad(vector, (0, num_zeros), mode='constant')
    
    # Reshape the padded vector into a square matrix
    square_matrix = vector_padded.reshape((n, n))
    
    return square_matrix

In [5]:
def extract_img_vessels_np2(dataset_img_dir):
    for i,img_name in enumerate(tqdm(os.listdir(dataset_img_dir))):
        if '32_img.npy' not in img_name:
            continue
        assert '32_img.npy' in img_name
        assert 'label' not in img_name
    # load png as a numpy
        img_array = np.load(os.path.join(dataset_img_dir,img_name))
        data_array = np.load(os.path.join(dataset_img_dir,img_name).replace('img','label'))
        if i==0:
            img_list_np = img_array
            data_list_np = data_array
        else:
            img_list_np = np.concatenate((img_list_np, img_array), axis=0)
            data_list_np = np.concatenate((data_list_np, data_array), axis=0)

    return img_list_np, data_list_np

def extract_empty_img_vessels(dataset_empty_img_dir):
    for i,img_name in enumerate(tqdm(os.listdir(dataset_empty_img_dir))):
        if '32_img.npy' not in img_name:
            continue
    # load png as a numpy
        img_array = np.load(os.path.join(dataset_empty_img_dir,img_name))
        if i==0:
            img_list_np = img_array
        else:
            img_list_np = np.concatenate((img_list_np, img_array), axis=0)

    return img_list_np


In [6]:
import matplotlib.pyplot as plt
import numpy as np
from scipy.spatial import KDTree

def label_point(x, y, ids, ax):
    """Annotate points on plot with their IDs."""
    for i, txt in enumerate(ids):
        ax.annotate(txt, (x[i], y[i]))

def interactive_plot(x, y, ids=None, colors=None, action='click', img_list=None, emb_list=None, true_img_list=None ,zoom=False, filtered_ids=[]):
    """Identify the ID of a point by clicking on it."""
    if ids is None:
        ids = [str(i) for i in range(len(x))]
    fig, ax = plt.subplots(2, 3, figsize=(12, 8)) if zoom else plt.subplots(1, 4, figsize=(16,4))
    ax = ax.flatten()
    
    ax[1].axis('off')
    ax[2].axis('off')
    ax[3].axis('off')
    
    ax[0].scatter(x, y, s=1, c=colors) if colors is not None else ax[0].scatter(x, y, s=1 if zoom else 0.1)
    # Set limit to the plot
    if zoom:
        ax[4].scatter(x, y, s=1, c=colors) if colors is not None else ax[4].scatter(x, y, s=1)
        ax[4].set_xlim([-100, 100])
        ax[4].set_ylim([-100, 100])
    # Set a threshold based on max values of x and y
    threshold = max(max(x) - min(x), max(y) - min(y)) / 80
    print(f"Threshold: {threshold}")
    #label_point(x, y, ids, ax[0])
    tree = KDTree(np.column_stack((x, y)))
    
    def onclick(event):
        """Event handler for mouse click."""
        if event.inaxes == ax[0]:
            dist, i = tree.query([event.xdata, event.ydata])
            if dist < threshold:
                ax[1].imshow(img_list[int(ids[i])], cmap='gray')
                id_img_title = ids[i] if len(filtered_ids) == 0 else filtered_ids[i]
                ax[1].set_title(f'Image {id_img_title} (color: {colors[i]})') if colors is not None else ax[1].set_title(f'Image {id_img_title}')
                ax[2].imshow(reshape_to_square(emb_list[int(ids[i])]), cmap='gray')
                ax[2].set_title(f'Embedding {id_img_title}')
                ax[3].imshow(true_img_list[int(ids[i])], cmap='gray')
                ax[3].set_title(f'True Image {id_img_title}')
                
                ax[1].axis('off')
                ax[2].axis('off')
                ax[3].axis('off')
        
    fig.tight_layout()
    if action == 'click':
        fig.canvas.mpl_connect('button_press_event', onclick)
    elif action == 'hover':
        fig.canvas.mpl_connect('motion_notify_event', onclick)
    
    plt.show()
    
def plot_2_clusters(embeddings_tsne, cluster_labels, cluster_labels_2d):
    fig, ax = plt.subplots(1, 2, figsize=(10, 5))
    ax[0].scatter(embeddings_tsne[:, 0], embeddings_tsne[:, 1], c=cluster_labels, s=1)
    ax[0].set_title('Clustering Results (t-SNE embedding)')

    ax[1].scatter(embeddings_tsne[:, 0], embeddings_tsne[:, 1], c=cluster_labels_2d, s=1)
    ax[1].set_title('Clustering Results (t-SNE embedding 2D)')

    plt.show()

## 3) SIAMESE NETWORK

In [None]:
path_data = "/data/falcetta/brain_data"
save_embeddings_path = os.path.join(path_data, f"embeddings_VDISNET") 

dataset_name = 'CAS'

save_embeddings_path_SOTA = os.path.join(save_embeddings_path,'SOTA')
print(f"Save embeddings in {save_embeddings_path_SOTA}")
mk_dir(save_embeddings_path_SOTA)

In [8]:
%matplotlib inline

def imshow(img,text=None,should_save=False):
    npimg = img.numpy()
    plt.figure()
    plt.axis("off")
    if text:
        plt.text(75, 8, text, style='italic',fontweight='bold',
            bbox={'facecolor':'white', 'alpha':0.8, 'pad':10})
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()    

def show_plot(iteration,loss, title=''):
    plt.figure()
    plt.plot(iteration,loss)
    plt.title(f"{title} Loss")
    plt.show()

In [9]:
def random_augmentation(img_np):
    # Random horizontal flip
    if np.random.rand() > 0.3:
        img_np = np.flip(img_np, axis=1)
    # Random vertical flip
    if np.random.rand() > 0.3:
        img_np = np.flip(img_np, axis=0)
    # Random rotation
    if np.random.rand() > 0.3:
        img_np = np.rot90(img_np, k=np.random.randint(1, 4))
    return img_np

In [10]:
import numpy as np
from PIL import Image

from torch.utils.data import Dataset
import torch
import random
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
import torch
import random

class SiameseNetworkDataset(Dataset):
    
    def __init__(self, images_array, transform=None, test_mode=False, mode='double', metadata=None):
        self.images_array = images_array
        self.transform = transform
        
        if metadata is None:
            self.compute_metadata()
        else:
            self.mean, self.std = metadata 
            
        self.test_mode = test_mode
        self.mode=mode
         
        
    def __getitem__(self,index):
        # Select random index from the dataset
        if self.mode == 'single':
            idx_0 = int(index)
            img0 = self.images_array[idx_0]
            img0 = Image.fromarray(img0)
            # Apply transformations if provided
            if self.transform is not None:
                img0 = self.transform(img0)
            # Normalize images
            img0 = (img0 - self.mean) / self.std    
            # Standardize to 0-1
            img0 = (img0 - img0.min()) / (img0.max() - img0.min())
                
            return img0
        
        else:
            idx_0 = int(index)
            img0 = self.images_array[idx_0]
            
            # Determine whether to select a sample from the same class or different class
            should_get_same_class = random.randint(0, 1) 
            if should_get_same_class:
                img1 = random_augmentation(img0)
            else:
                idx_1 = random.randint(0, len(self.images_array) - 1)
                img1 = self.images_array[idx_1]
                
            if self.test_mode:
                print(f"Testing {idx_0} against {idx_1}")
                
            # Convert numpy arrays to PIL Images
            img0 = Image.fromarray(img0)
            img1 = Image.fromarray(img1)
    
            # Apply transformations if provided
            if self.transform is not None:
                img0 = self.transform(img0)
                img1 = self.transform(img1)
            
            # Normalize images
            img0 = (img0 - self.mean) / self.std
            img1 = (img1 - self.mean) / self.std
            
            # Standardize to 0-1
            img0 = (img0 - img0.min()) / (img0.max() - img0.min())
            img1 = (img1 - img1.min()) / (img1.max() - img1.min())
                
            # Flag: 1 if same class, 0 if different class
            flag = torch.from_numpy((1-should_get_same_class) * np.ones(1))
            
            return img0, img1, flag
        
    def compute_metadata(self):
        flattened_data = self.images_array.reshape(self.images_array.shape[0], -1)
        means = np.mean(flattened_data, axis=0)
        stds = np.std(flattened_data, axis=0)
        
        # compute the single mean and std
        mean = np.mean(means)
        std = np.mean(stds)
        
        #convert to pytorch
        self.mean = torch.tensor(mean)
        self.std = torch.tensor(std)
        
        print(f"Mean: {self.mean} - Std: {self.std}")

    def get_metadata(self):
        return self.mean, self.std
    
    def __len__(self):
        return len(self.images_array)

In [11]:
import numpy as np

def train_test_split_arrays(*arrays, test_size=0.2, random_state=None):
    """
    Split numpy arrays along the first axis into random train and test subsets.

    Parameters:
    *arrays : array-like
        Arrays to be split. All arrays must have the same size along the first axis.
    test_size : float or int, default=0.2
        If float, should be between 0.0 and 1.0 and represent the proportion of the dataset to include in the test split.
        If int, represents the absolute number of test samples.
    random_state : int or RandomState instance, default=None
        Controls the randomness of the training and testing indices.

    Returns:
    tuple of arrays
        Tuple containing train-test split of input arrays.
    """
    # Check if all arrays have the same size along the first axis
    first_axis_lengths = [arr.shape[0] for arr in arrays]
    if len(set(first_axis_lengths)) != 1:
        raise ValueError("All input arrays must have the same size along the first axis.")

    assert first_axis_lengths[0] > 0, "The size of the first axis should be greater than 0."
    
    # Determine the size of the test set
    if isinstance(test_size, float):
        test_size = int(test_size * first_axis_lengths[0])
    elif isinstance(test_size, int):
        if test_size < 0 or test_size > first_axis_lengths[0]:
            raise ValueError("test_size should be a positive integer less than or equal to the size of the first axis.")
    else:
        raise ValueError("test_size should be either float or int.")
    
    test_size = test_size if test_size > 0 else 1
    
    # Generate random indices for the test set
    rng = np.random.RandomState(random_state)
    indices = np.arange(first_axis_lengths[0])
    rng.shuffle(indices)
    test_indices = indices[:test_size]
    train_indices = indices[test_size:]

    # Split arrays
    train_arrays = tuple(arr[train_indices] for arr in arrays)
    test_arrays = tuple(arr[test_indices] for arr in arrays)

    return train_arrays, test_arrays

In [None]:
if not just_sampling:
    X_CAS = np.load(os.path.join(save_embeddings_path,f'X_test_CAS_all.npy'))
    X_CAS_mask = np.load(os.path.join(save_embeddings_path,f'X_test_mask_CAS_all.npy'))

    X_CAS_empty = np.load(os.path.join(save_embeddings_path,f'X_test_empty_CAS_all.npy'))
    X_CAS_mask_empty = np.load(os.path.join(save_embeddings_path,f'X_test_empty_mask_CAS_all.npy'))

    print(f"Data loaded")

    print(f"X_test shape: {X_CAS.shape}")
    print(f"X_test_mask shape: {X_CAS_mask.shape}")

    print(f"X_test_empty shape: {X_CAS_empty.shape}")
    print(f"X_test_empty_mask shape: {X_CAS_mask_empty.shape}")



In [None]:
if not just_sampling:
    X_CAS_tot = np.concatenate((X_CAS, X_CAS_empty), axis=0)
    X_CAS_mask_tot = np.concatenate((X_CAS_mask, X_CAS_mask_empty), axis=0)

    print(f"X_test_tot shape: {X_CAS_tot.shape}")
    print(f"X_test_mask_tot shape: {X_CAS_mask_tot.shape}")


In [None]:
if not just_sampling:
    X_CAS = X_CAS_tot
    X_CAS_mask = X_CAS_mask_tot

    print(f"X_test shape: {X_CAS.shape}")
    print(f"X_test_mask shape: {X_CAS_mask.shape}")

In [None]:
if not just_sampling:
    # Perform train-test split
    train_arrays, val_arrays = train_test_split_arrays(X_CAS, X_CAS_mask, test_size=0.15, random_state=42)

    X_train, X_train_mask = train_arrays
    X_val, X_val_mask = val_arrays
    
    print(f"X_train shape: {X_train.shape}")
    print(f"X_train_mask shape: {X_train_mask.shape}")

    print(f"X_val shape: {X_val.shape}")
    print(f"X_val_mask shape: {X_val_mask.shape}")

In [16]:
if not just_sampling:
    
    np.save(os.path.join(save_embeddings_path_SOTA,f'X_train_CAS_SOTA.npy'), X_train)
    np.save(os.path.join(save_embeddings_path_SOTA,f'X_train_mask_CAS_SOTA.npy'), X_train_mask)

    np.save(os.path.join(save_embeddings_path_SOTA,f'X_val_CAS_SOTA.npy'), X_val)
    np.save(os.path.join(save_embeddings_path_SOTA,f'X_val_mask_CAS_SOTA.npy'), X_val_mask)

### Load

In [None]:
X_train = np.load(os.path.join(save_embeddings_path_SOTA,f'X_train_CAS_SOTA.npy'))
X_train_mask = np.load(os.path.join(save_embeddings_path_SOTA,f'X_train_mask_CAS_SOTA.npy'))

X_val = np.load(os.path.join(save_embeddings_path_SOTA,f'X_val_CAS_SOTA.npy'))
X_val_mask = np.load(os.path.join(save_embeddings_path_SOTA,f'X_val_mask_CAS_SOTA.npy'))

X_train.shape, X_train_mask.shape, X_val.shape, X_val_mask.shape

In [None]:
if not just_sampling:
    import torchvision.transforms as transforms

    siamese_dataset = SiameseNetworkDataset(images_array=X_train,
                                            transform=transforms.ToTensor(),)

    metadata = siamese_dataset.get_metadata()
    np.save(os.path.join(save_embeddings_path_SOTA,f'metadata_CAS_SOTA.npy'), metadata)

    siamese_val_dataset = SiameseNetworkDataset(images_array=X_val,
                                                transform=transforms.ToTensor(),
                                                metadata=metadata)

### Load metadata

In [None]:
metadata = np.load(os.path.join(save_embeddings_path_SOTA,f'metadata_CAS_SOTA.npy'))
print(metadata)

### skip

In [None]:
from torch.utils.data import DataLoader
import torchvision

if not just_sampling:
    vis_dataloader = DataLoader(siamese_dataset,
                            shuffle=True,
                            num_workers=0,
                            batch_size=1)

    plot_same = True
    plot_diff = True

    print(f"Plotting examples of the different (0) and same(1) clusters")

    plt.close('all')
    plt.figure()
    for example_batch in vis_dataloader:
        if int(example_batch[2]) == 0 and plot_same:
            concatenated = torch.cat((example_batch[0],example_batch[1]),0) # 8,1,100,100
            imshow(torchvision.utils.make_grid(concatenated))
            print(f"Flag: {int(example_batch[2])}")
            plot_same = False
        if int(example_batch[2]) == 1 and plot_diff:
            concatenated = torch.cat((example_batch[0],example_batch[1]),0) # 8,1,100,100
            imshow(torchvision.utils.make_grid(concatenated))
            print(f"Flag: {int(example_batch[2])}")
            plot_diff = False
        if not plot_same and not plot_diff:
            break   

### Load model

In [21]:
import torch.nn as nn
import torchvision.models as models

class SiameseNetwork(nn.Module):
    def __init__(self, embedding_size=128, pretrained = False):
        super(SiameseNetwork, self).__init__()
        
        # Load pre-trained ResNet-50
        resnet = models.resnet50(pretrained=pretrained)
        
        # Modify the first convolutional layer to accept 1 input channel
        self.resnet_features = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False),
            *list(resnet.children())[1:-2]
        )
        # Remove the last fully connected layer of ResNet-50
        #self.resnet_features = nn.Sequential(*list(resnet.children())[:-1])
        # Define fully connected layers for embedding
        self.fc1 = nn.Sequential(
            nn.Linear(resnet.fc.in_features, 1000),
            nn.ReLU(inplace=True),
            nn.BatchNorm1d(1000),
            nn.Linear(1000, 1000),
            nn.ReLU(inplace=True),
            nn.BatchNorm1d(1000),
            nn.Linear(1000, embedding_size)
        )

    def forward_once(self, x):
        # Extract features using ResNet
        features = self.resnet_features(x)
        features = features.view(features.size(0), -1)
        output = self.fc1(features)
        return output

    def forward(self, input1, input2, mode='double'):
        output1 = self.forward_once(input1)
        if mode == 'single':
            return output1
        output2 = self.forward_once(input2)
        return output1, output2


In [None]:
class ContrastiveLoss(torch.nn.Module):
    """
    Contrastive loss function.
    Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
    """

    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin
        self.loss_accumulator = 0.0
        self.num_samples = 0

    def forward(self, output1, output2, label):
        euclidean_distance = F.pairwise_distance(output1, output2, keepdim=True)
        loss_contrastive = torch.mean((1 - label) * torch.pow(euclidean_distance, 2) + (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
        
        self.loss_accumulator += loss_contrastive.item()
        self.num_samples += 1
        return loss_contrastive

    def get_accumulated_loss(self):
        return self.loss_accumulator

    def get_mean_loss(self):
        self.mean_loss = self.loss_accumulator / self.num_samples
        self.reset()
        return self.mean_loss
        
    def reset(self):
        self.loss_accumulator = 0.0
        self.num_samples = 0

"""
class ContrastiveLoss(torch.nn.Module):

    def __init__(self, margin=.5, **kwargs):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin
        # self.metric = metric
        self.distance = torch.nn.PairwiseDistance(p=2)

    def forward(self, out0, out1, label):
        gt = label.float()
        D = self.distance(out0, out1).float().squeeze()
        loss = gt * 0.5 * torch.pow(D, 2) + (1 - gt) * 0.5 * torch.pow(torch.clamp(self.margin - D, min=0.0), 2)
        return loss
"""

### skip

In [23]:
if not just_sampling:
    train_dataloader = DataLoader(siamese_dataset,
                            shuffle=True,
                            num_workers=8,
                            batch_size=128)

    val_dataloader = DataLoader(siamese_val_dataset,
                            shuffle=False,
                            num_workers=8,
                            batch_size=256)

### Training (skip)

In [24]:
load_model = False
start_epoch=0

In [25]:
if not just_sampling:
    net = SiameseNetwork().cuda()

    if load_model:
        checkpoint_path = os.path.join(save_embeddings_path_SOTA, 'best_model_CA.pt')
        print(f"Loading model from {checkpoint_path}")
        # Load state dictionary into model
        net.load_state_dict(torch.load(checkpoint_path))
        start_epoch= 0


    criterion = ContrastiveLoss()

    best_loss = 9999
    best_loss_epoch = 0
    cumul_epochs = 0

    mean_loss_contrastive_list = []
    best_loss_contrastive_list = []
    validation_loss_list = []
    continue_training = False

In [None]:
from torch import optim

if not just_sampling:
    n_epochs = 500
    early_stopping_tolerance = 50
    optimizer = optim.Adam(net.parameters(),lr = 0.0005)

    if continue_training:
        mean_loss_contrastive_list = mean_loss_contrastive_list.tolist()
        best_loss_contrastive_list = best_loss_contrastive_list.tolist()
        validation_loss_list = validation_loss_list.tolist()

    print(f"Starting round of training from epoch {cumul_epochs}")

    for epoch in tqdm(range(start_epoch, n_epochs), desc='Epochs'): 
        # Training loop
        net.train()
        for data in tqdm(train_dataloader, desc='Batches', leave=False):
            img0, img1, label = data
            img0, img1, label = img0.cuda(), img1.cuda(), label.cuda()
            optimizer.zero_grad()
            output1, output2 = net(img0, img1)
            loss_contrastive = criterion(output1, output2, label)
            loss_contrastive.backward()
            optimizer.step()

        # Calculate mean loss for contrastive loss during training
        mean_loss_contrastive = criterion.get_mean_loss()
        mean_loss_contrastive_list.append(mean_loss_contrastive)
        
        # Validation loop
        net.eval()  # Set the model to evaluation mode
        val_loss = 0.0
        with torch.no_grad():
            for val_data in tqdm(val_dataloader, desc='Validation', leave=False):
                val_img0, val_img1, val_label = val_data
                val_img0, val_img1, val_label = val_img0.cuda(), val_img1.cuda(), val_label.cuda()
                val_output1, val_output2 = net(val_img0, val_img1)
                val_loss += criterion(val_output1, val_output2, val_label).item()

        val_loss /= len(val_dataloader)
        validation_loss_list.append(val_loss)
        #print(f"Validation Loss: {val_loss:.2f}")

        # Check if current loss is the best so far
        if val_loss < best_loss:
            print(f"Epoch number {cumul_epochs} --- Best loss {val_loss:.2f}")
            best_loss = val_loss
            best_loss_epoch = cumul_epochs
            torch.save(net.state_dict(), os.path.join(save_embeddings_path_SOTA, 'best_model_CA.pt'))
        else:
            if cumul_epochs - best_loss_epoch > early_stopping_tolerance:
                print(f"Early stopping at epoch {cumul_epochs}")
                best_loss_contrastive_list.append(best_loss)
                cumul_epochs +=1
                break    
        
        best_loss_contrastive_list.append(best_loss)
        cumul_epochs +=1
        

In [27]:
if not just_sampling:
    np.save(os.path.join(save_embeddings_path_SOTA, f'mean_loss_contrastive_list_CAS_CA.npy'), mean_loss_contrastive_list)
    np.save(os.path.join(save_embeddings_path_SOTA, f'best_loss_contrastive_list_CAS_CA.npy'), best_loss_contrastive_list)
    np.save(os.path.join(save_embeddings_path_SOTA, f'validation_loss_list_CAS_CA.npy'), validation_loss_list)


### Plot losses

In [None]:
if not just_sampling:
    mean_loss_contrastive_list = np.load(os.path.join(save_embeddings_path_SOTA, f'mean_loss_contrastive_list_CAS_CA.npy'))
    best_loss_contrastive_list = np.load(os.path.join(save_embeddings_path_SOTA, f'best_loss_contrastive_list_CAS_CA.npy'))
    validation_loss_list = np.load(os.path.join(save_embeddings_path_SOTA, f'validation_loss_list_CAS_CA.npy'))

    continue_training=True

    %matplotlib inline
    plt.ioff()
    plt.close('all')
    show_plot(range(0,len(mean_loss_contrastive_list)),mean_loss_contrastive_list, title='Training Loss')
    show_plot(range(0,len(best_loss_contrastive_list)),validation_loss_list, title='Val Loss')
    show_plot(range(0,len(best_loss_contrastive_list)),best_loss_contrastive_list, title='Best Val Loss')


### LOAD MODEL and Test (skip)

In [29]:
import torch.nn.functional as F
# LOAD BEST MODEL
if not just_sampling:
    net = SiameseNetwork().cuda()
    checkpoint_path = os.path.join(save_embeddings_path_SOTA, 'best_model_CA.pt')

    # Load state dictionary into model
    net.load_state_dict(torch.load(checkpoint_path))
    net.eval()

### Skip

In [None]:
from torch.autograd import Variable 
from tqdm import tqdm

if not just_sampling:
    batch_size = 64
    siamese_testset = SiameseNetworkDataset(images_array=X_CAS,
                                            transform=transforms.ToTensor(),
                                            test_mode=True,
                                            metadata=metadata,
                                            mode = 'single')

    test_dataloader = DataLoader(siamese_testset,
                            shuffle=False,
                            batch_size=batch_size)

    # Pre-allocate memory for the embeddings array
    num_samples = len(siamese_testset)
    embedding_size = 128  # Assuming the size of the embeddings is 128, adjust as necessary
    img_embeddings = np.empty((num_samples, embedding_size))
    print(f"Number of samples: {num_samples} (batch size: {batch_size}) ==> Passages: {num_samples // batch_size + 1}")

    start_idx = 0

    for img in tqdm(test_dataloader):
        img0 = img.cuda()
        output = net(img0, None, mode='single').cpu().detach().numpy()
        # Calculate the end index for the current batch
        end_idx = start_idx + output.shape[0]
        # Store the batch of embeddings in the pre-allocated array
        img_embeddings[start_idx:end_idx] = output    
        # Update the start index for the next batch
        start_idx = end_idx
            

    np.save(os.path.join(save_embeddings_path_SOTA, f'img_embeddings_CAS_CA.npy'), img_embeddings)
    print(f"Image Embeddings saved")

### Load img embeddings

In [None]:
img_embeddings = np.load(os.path.join(save_embeddings_path_SOTA, f'img_embeddings_CAS_CA.npy'))

print(f"img_embeddings loaded from {save_embeddings_path_SOTA}")
print(f'img_embeddings shape: {img_embeddings.shape}')

### PLot IMG Embeddings

### T-SNE (skip)

In [None]:
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

if not just_sampling:
    # Perform t-SNE clustering
    #Peplexity 500
    #early_exaggeration=40
    tsne = TSNE(n_components=2, random_state=42)
    img_embeddings_tsne = tsne.fit_transform(img_embeddings)

    plt.figure()
    plt.scatter(img_embeddings_tsne[:, 0], img_embeddings_tsne[:, 1], s=1)
    plt.title('Image Clustering Results (t-SNE img_embedding)')

    # save the embeddings
    np.save(os.path.join(save_embeddings_path_SOTA,f'img_embeddings_tsne_CAS_CA.npy'), img_embeddings_tsne)
    print(f"TSNE img_embeddings saved in {os.path.join(save_embeddings_path_SOTA,f'img_embeddings_tsne_CAS_CA.npy')}")

### Load

In [None]:
img_embeddings_tsne = np.load(os.path.join(save_embeddings_path_SOTA,f'img_embeddings_tsne_CAS_CA.npy'))
print(f"TSNE img_embeddings loaded")
print(f"TSNE img_embeddings shape: {img_embeddings_tsne.shape}")

In [None]:
%matplotlib widget
interactive_plot(img_embeddings_tsne[:, 0], img_embeddings_tsne[:, 1], colors=None, action='click', img_list=X_CAS_mask, emb_list=img_embeddings, true_img_list=X_CAS)

In [None]:
plot_n_patches_overlap(X_CAS_mask, X_CAS, indexes=[5597,24,25,26,23])

In [40]:
n_clusters_img = 50

### K-MEANS T-SNE (skip)

In [None]:
from sklearn.cluster import KMeans 

if not just_sampling:
    # Create a KMeans object with the desired number of clusters
    kmeans = KMeans(n_clusters=n_clusters_img, random_state=42, n_init='auto')

    # Fit the KMeans model to the img_embeddings
    kmeans.fit(img_embeddings_tsne)

    # Get the cluster labels for each data point
    img_cluster_labels_2d = kmeans.labels_
    print(f"Classes: {set(img_cluster_labels_2d)}")
    np.save(os.path.join(save_embeddings_path_SOTA,f'img_cluster_labels_2d_CAS_CA.npy'), img_cluster_labels_2d)


    # Plot histogram of cluster labels
    plt.figure(figsize=(10, 5))
    plt.hist(img_cluster_labels_2d)
    plt.xlabel('Cluster Label')
    plt.ylabel('Count')
    plt.title('Histogram of Cluster Labels')
    plt.show()

### load

In [None]:
img_cluster_labels_2d = np.load(os.path.join(save_embeddings_path_SOTA,f'img_cluster_labels_2d_CAS_CA.npy'))
print(f"Cluster labels loaded")
print(f"Cluster labels shape: {img_cluster_labels_2d.shape}")

In [None]:
interactive_plot(img_embeddings_tsne[:, 0], img_embeddings_tsne[:, 1], colors=img_cluster_labels_2d, action='click', img_list=X_CAS_mask, emb_list=img_embeddings, true_img_list=X_CAS)

### K-MEANS with embedding code (skip)

In [None]:
from sklearn.cluster import KMeans
if not just_sampling:
    # Create a KMeans object with the desired number of clusters
    kmeans = KMeans(n_clusters=n_clusters_img, random_state=42, n_init='auto')

    # Fit the KMeans model to the img_embeddings
    kmeans.fit(img_embeddings)

    # Get the cluster labels for each data point
    img_cluster_labels = kmeans.labels_
    print(f"Classes: {set(img_cluster_labels)}")
    np.save(os.path.join(save_embeddings_path_SOTA,f'img_cluster_labels_CAS_CA.npy'), img_cluster_labels)

### load

In [None]:
img_cluster_labels = np.load(os.path.join(save_embeddings_path_SOTA,f'img_cluster_labels_CAS_CA.npy'))
print(f"Cluster labels loaded")
print(f"Cluster labels shape: {img_cluster_labels.shape}")

In [None]:
plot_2_clusters(img_embeddings_tsne, img_cluster_labels, img_cluster_labels_2d) 

In [None]:
interactive_plot(img_embeddings_tsne[:, 0], img_embeddings_tsne[:, 1], colors=img_cluster_labels, action='click', img_list=X_CAS_mask, emb_list=img_embeddings, true_img_list=X_CAS)