## 0 - Import libraries

In [1]:
just_sampling = False

In [2]:
import os
import random
import numpy as np
import torch
from torch.utils.data import DataLoader

torch.backends.cudnn.benchmark = True

from tqdm import tqdm
import imageio.v2 as imageio

from utils import set_random_seed, mk_dir
from importlib import import_module

from torchvision.utils import save_image

In [3]:
import matplotlib.pyplot as plt
plt.ion() 

def save_img(tensor, name, norm, n_rows=16, scale_each=False):
    save_image(tensor, name, nrow=n_rows, padding=5, normalize=norm, pad_value=1, scale_each=scale_each)
    
def plot_10_patches(img_np, true_vessel_np=False, indexes=None):
    
    random_indexes = np.random.randint(0, img_np.shape[0], size=5) if indexes is None else indexes

    n = len(random_indexes)
    
    fig, ax = plt.subplots(n, 2, figsize=(6, 3*n))
    
    for i, index in enumerate(random_indexes):
        
        if index == img_np.shape[0]:
            index -= 1
        if index == img_np.shape[0]-1:
            index -= 2
            
        #print(f"Shape: {img_np[index,:,:].shape}, Max: {img_np[index,:,:].max()}, Min: {img_np[index,:,:].min()}")
        ax[i, 0].set_title(f'True Vessel Slice {index}')
        ax[i, 0].imshow(img_np[index,:,:], cmap='gray', vmin=0, vmax=1)
        ax[i,0].axis('off')
        
        if true_vessel_np is not False:
            print(f"Shape: {true_vessel_np[index,:,:].shape}, Max: {true_vessel_np[index,:,:].max()}, Min: {true_vessel_np[index,:,:].min()}")
            ax[i, 1].imshow(true_vessel_np[index,:,:], cmap='gray')
            ax[i, 1].set_title(f'Image Slice {index}')
        else:
            ax[i, 1].set_title(f'Image Slice {index+1}')
            ax[i, 1].imshow(img_np[index+1,:,:], cmap='gray', vmin=0, vmax=1)
        
        ax[i,1].axis('off')
            
    
    plt.tight_layout()
    plt.show()

def plot_n_patches_overlap(img_np, true_vessel_np=False, indexes=None, selected_class=None, add_title='', m=5, alpha=0.5, save_dir='clusters_imgs'):
    
    save_dir = os.path.join(save_embeddings_path,save_dir)
    mk_dir(save_dir)
    plt.ioff()
    n = len(indexes) if indexes is not None else 5
    # m is the number of images to plot in each row (2m is the number of columns)
    n_rows = int(np.ceil(n/m))
    n_cols = m
    
    fig, ax = plt.subplots(n_rows, n_cols, figsize=(4*n_cols, 4*n_rows))
    ax = ax.flatten()
    
    for i, index in enumerate(indexes):
        ax[i].set_title(f'Overlay for Slice {index}')
        masked_image = np.ma.masked_where(img_np[index,:,:] == 0, true_vessel_np[index,:,:])
        # Overlay the red image on top of true_vessel_np[index,:,:]
        ax[i].imshow(true_vessel_np[index,:,:], cmap='gray', interpolation='none')
        ax[i].imshow(masked_image, cmap='Reds', alpha=alpha)
        
        ax[i].axis('off')
    
    try:
        plt.tight_layout()
    except:
        pass
    
    if add_title!='':
        print(f"Save Fig to {save_dir}")
        plt.savefig(os.path.join(save_dir,f'{add_title}_class_{selected_class}_patches.png'))
        plt.close(fig)
        
    else:
        print("Plotting..")
        plt.show()
    plt.ion()

    
def plot_n_patches(img_np, true_vessel_np=False, indexes=None, selected_class=None, add_title='', m=5):
    plt.ioff()
    n = len(indexes) if indexes is not None else 5
    if n > 1000:
        print(f'WARNING: YOU ARE TRYING TO PLOT {n} images')
    # m is the number of images to plot in each row (2m is the number of columns)
    n_rows = int(np.ceil(n/m))
    n_cols = 2*m
    
    fig, ax = plt.subplots(n_rows, n_cols, figsize=(3*n_cols, 3*n_rows))
    ax = ax.flatten()
    
    for i, index in enumerate(indexes):
        
        if index == img_np.shape[0]:
            index -= 1
        if index == img_np.shape[0]-1:
            index -= 2
            
        ax[2*i].set_title(f'Image Slice {index}')
        ax[2*i].imshow(img_np[index,:,:], cmap='gray', vmin=0, vmax=1)
        ax[2*i].axis('off')
        print(f"Shape: {img_np[index,:,:].shape}, Max: {img_np[index,:,:].max()}, Min: {img_np[index,:,:].min()}")
        if true_vessel_np is not False:
            ax[2*i + 1].imshow(true_vessel_np[index,:,:], cmap='gray')
            ax[2*i + 1].set_title(f'True Vessel Slice {index}')
            
        else:
            ax[2*i + 1].set_title(f'Image Slice {index+1}')
            ax[2*i + 1].imshow(img_np[index+1,:,:], cmap='gray', vmin=0, vmax=1)
        
        ax[2*i+1].axis('off')
            
    
    plt.tight_layout()
    if add_title!='':
        print("Save Fig")
        plt.savefig(f'{add_title}_class_{selected_class}_patches.png')
        plt.close(fig)
        
    else:
        print("Plotting..")
        plt.show()
    plt.ion()
    
def reshape_to_square(vector):
    # Calculate the nearest square number greater than or equal to the length of the vector
    n = int(np.ceil(np.sqrt(len(vector))))
    
    # Calculate the number of elements to pad with zeros
    num_zeros = n*n - len(vector)
    
    # Pad the vector with zeros if necessary
    vector_padded = np.pad(vector, (0, num_zeros), mode='constant')
    
    # Reshape the padded vector into a square matrix
    square_matrix = vector_padded.reshape((n, n))
    
    return square_matrix

In [4]:
import matplotlib.pyplot as plt
import numpy as np
from scipy.spatial import KDTree

def label_point(x, y, ids, ax):
    """Annotate points on plot with their IDs."""
    for i, txt in enumerate(ids):
        ax.annotate(txt, (x[i], y[i]))

def interactive_plot(x, y, ids=None, colors=None, action='click', img_list=None, emb_list=None, true_img_list=None ,zoom=False, filtered_ids=[]):
    """Identify the ID of a point by clicking on it."""
    if ids is None:
        ids = [str(i) for i in range(len(x))]
    fig, ax = plt.subplots(2, 3, figsize=(12, 8)) if zoom else plt.subplots(1, 4, figsize=(16,4))
    ax = ax.flatten()
    
    ax[1].axis('off')
    ax[2].axis('off')
    ax[3].axis('off')
    
    ax[0].scatter(x, y, s=1, c=colors) if colors is not None else ax[0].scatter(x, y, s=1 if zoom else 0.1)
    # Set limit to the plot
    if zoom:
        ax[4].scatter(x, y, s=1, c=colors) if colors is not None else ax[4].scatter(x, y, s=1)
        ax[4].set_xlim([-100, 100])
        ax[4].set_ylim([-100, 100])
    # Set a threshold based on max values of x and y
    threshold = max(max(x) - min(x), max(y) - min(y)) / 80
    print(f"Threshold: {threshold}")
    #label_point(x, y, ids, ax[0])
    tree = KDTree(np.column_stack((x, y)))
    
    def onclick(event):
        """Event handler for mouse click."""
        if event.inaxes == ax[0]:
            dist, i = tree.query([event.xdata, event.ydata])
            if dist < threshold:
                ax[1].imshow(img_list[int(ids[i])], cmap='gray')
                id_img_title = ids[i] if len(filtered_ids) == 0 else filtered_ids[i]
                ax[1].set_title(f'Image {id_img_title} (color: {colors[i]})') if colors is not None else ax[1].set_title(f'Image {id_img_title}')
                ax[2].imshow(reshape_to_square(emb_list[int(ids[i])]), cmap='gray')
                ax[2].set_title(f'Embedding {id_img_title}')
                ax[3].imshow(true_img_list[int(ids[i])], cmap='gray')
                ax[3].set_title(f'True Image {id_img_title}')
                
                ax[1].axis('off')
                ax[2].axis('off')
                ax[3].axis('off')
        
    fig.tight_layout()
    if action == 'click':
        fig.canvas.mpl_connect('button_press_event', onclick)
    elif action == 'hover':
        fig.canvas.mpl_connect('motion_notify_event', onclick)
    
    plt.show()
    
def plot_2_clusters(embeddings_tsne, cluster_labels, cluster_labels_2d):
    fig, ax = plt.subplots(1, 2, figsize=(10, 5))
    ax[0].scatter(embeddings_tsne[:, 0], embeddings_tsne[:, 1], c=cluster_labels, s=1)
    ax[0].set_title('Clustering Results (t-SNE embedding)')

    ax[1].scatter(embeddings_tsne[:, 0], embeddings_tsne[:, 1], c=cluster_labels_2d, s=1)
    ax[1].set_title('Clustering Results (t-SNE embedding 2D)')

    plt.show()

## 3) SIAMESE NETWORK

In [None]:
path_data = "/data/falcetta/brain_data"
save_embeddings_path = os.path.join(path_data, f"embeddings_VDISNET") 

dataset_name = 'CAS'

save_embeddings_path_SOTA = os.path.join(save_embeddings_path,'SOTA')
print(f"Save embeddings in {save_embeddings_path_SOTA}")
mk_dir(save_embeddings_path_SOTA)

In [6]:
%matplotlib inline

def imshow(img,text=None,should_save=False):
    npimg = img.numpy()
    plt.figure()
    plt.axis("off")
    if text:
        plt.text(75, 8, text, style='italic',fontweight='bold',
            bbox={'facecolor':'white', 'alpha':0.8, 'pad':10})
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()    

def show_plot(iteration,loss, title=''):
    plt.figure()
    plt.plot(iteration,loss)
    plt.title(f"{title} Loss")
    plt.show()

In [7]:
def random_augmentation(img_np):
    # Random horizontal flip
    if np.random.rand() > 0.3:
        img_np = np.flip(img_np, axis=1)
    # Random vertical flip
    if np.random.rand() > 0.3:
        img_np = np.flip(img_np, axis=0)
    # Random rotation
    if np.random.rand() > 0.3:
        img_np = np.rot90(img_np, k=np.random.randint(1, 4))
    return img_np

In [8]:
import numpy as np
from PIL import Image

from torch.utils.data import Dataset
import torch
import random
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
import torch
import random

import torch
from torchvision.transforms import functional as FT
from torchvision.transforms.functional import to_pil_image
import numpy as np
from PIL import Image

# Define a function to create the transformation matrix
def get_transformation_matrix(transform_params):
    matrix = np.eye(3)
    
    if 'flip' in transform_params:
        matrix = np.dot(matrix, np.array([[-1, 0, 0], [0, 1, 0], [0, 0, 1]]))
    
    if 'rotation' in transform_params:
        angle = np.deg2rad(transform_params['rotation'])
        rotation_matrix = np.array([
            [np.cos(angle), -np.sin(angle), 0],
            [np.sin(angle), np.cos(angle), 0],
            [0, 0, 1]
        ])
        matrix = np.dot(matrix, rotation_matrix)
    
    if 'scale' in transform_params:
        scale_matrix = np.array([
            [transform_params['scale'], 0, 0],
            [0, transform_params['scale'], 0],
            [0, 0, 1]
        ])
        matrix = np.dot(matrix, scale_matrix)
        #print(f"Scale Matrix: {scale_matrix.shape}") # 3x3
        assert scale_matrix.shape == (3, 3)
    
    return matrix

# Updated AETDataset class
class AETDataset(Dataset):
    
    def __init__(self, images_array, transform=None, metadata=None, test_mode=False):
        self.images_array = images_array
        self.transform = transform
        self.test_mode = test_mode
        
        if metadata is None:
            self.compute_metadata()
        else:
            self.mean, self.std = metadata 
        
    def __getitem__(self, index):
        idx_0 = int(index)
        img0 = self.images_array[idx_0]
        img0 = Image.fromarray(img0)
        
        # Apply transformations if provided
        if self.transform is not None:
            img0 = self.transform(img0)

        # Normalize images
        img0 = (img0 - self.mean) / self.std    
        # Standardize to 0-1
        img0 = (img0 - img0.min()) / (img0.max() - img0.min())
        
        if self.test_mode:
            return img0
        
        # Apply a random transformation to the image and get the transformation matrix
        img1, t_true = self.random_augmentation_with_matrix(img0)
        
        # assert type of img0 and img1 is tensor
        assert isinstance(img0, torch.Tensor)
        assert isinstance(img1, torch.Tensor)
        
        #Convert tensor img0 and img1 to float32
        img0 = img0.to(torch.float32)
        img1 = img1.to(torch.float32)
        
        return img0, img1, t_true
    
    def random_augmentation_with_matrix(self, img):
        transform_params = {}
        original_size = img.size()
        #img from tensor to PIL
        img = to_pil_image(img)
        
        if random.random() > 0.5:
            img = FT.hflip(img)
            transform_params['flip'] = True
        
        angle = random.uniform(-30, 30)
        img = FT.rotate(img, angle)
        transform_params['rotation'] = angle
        
        scale_factor = random.uniform(0.8, 1.0)
        img = FT.resize(img, [int(img.size[0] * scale_factor), int(img.size[1] * scale_factor)])
        transform_params['scale'] = scale_factor
        
        # Resize the image back to its original size
        img = FT.resize(img, original_size[1:])
        img = FT.to_tensor(img)
        
        t_true = get_transformation_matrix(transform_params)
        t_true = torch.tensor(t_true, dtype=torch.float32).view(-1)  # Flatten to match earlier implementation shape
        
        return img, t_true
        
    def compute_metadata(self):
        flattened_data = self.images_array.reshape(self.images_array.shape[0], -1)
        means = np.mean(flattened_data, axis=0)
        stds = np.std(flattened_data, axis=0)
        
        # Compute the single mean and std
        mean = np.mean(means)
        std = np.mean(stds)
        
        # Convert to PyTorch tensor
        self.mean = torch.tensor(mean)
        self.std = torch.tensor(std)
        
        print(f"Mean: {self.mean} - Std: {self.std}")

    def get_metadata(self):
        return self.mean, self.std
    
    def __len__(self):
        return len(self.images_array)


In [9]:
import numpy as np

def train_test_split_arrays(*arrays, test_size=0.2, random_state=None):
    """
    Split numpy arrays along the first axis into random train and test subsets.

    Parameters:
    *arrays : array-like
        Arrays to be split. All arrays must have the same size along the first axis.
    test_size : float or int, default=0.2
        If float, should be between 0.0 and 1.0 and represent the proportion of the dataset to include in the test split.
        If int, represents the absolute number of test samples.
    random_state : int or RandomState instance, default=None
        Controls the randomness of the training and testing indices.

    Returns:
    tuple of arrays
        Tuple containing train-test split of input arrays.
    """
    # Check if all arrays have the same size along the first axis
    first_axis_lengths = [arr.shape[0] for arr in arrays]
    if len(set(first_axis_lengths)) != 1:
        raise ValueError("All input arrays must have the same size along the first axis.")

    assert first_axis_lengths[0] > 0, "The size of the first axis should be greater than 0."
    
    # Determine the size of the test set
    if isinstance(test_size, float):
        test_size = int(test_size * first_axis_lengths[0])
    elif isinstance(test_size, int):
        if test_size < 0 or test_size > first_axis_lengths[0]:
            raise ValueError("test_size should be a positive integer less than or equal to the size of the first axis.")
    else:
        raise ValueError("test_size should be either float or int.")
    
    test_size = test_size if test_size > 0 else 1
    
    # Generate random indices for the test set
    rng = np.random.RandomState(random_state)
    indices = np.arange(first_axis_lengths[0])
    rng.shuffle(indices)
    test_indices = indices[:test_size]
    train_indices = indices[test_size:]

    # Split arrays
    train_arrays = tuple(arr[train_indices] for arr in arrays)
    test_arrays = tuple(arr[test_indices] for arr in arrays)

    return train_arrays, test_arrays

In [None]:
X_CAS = np.load(os.path.join(save_embeddings_path,f'X_test_CAS_all.npy'))
X_CAS_mask = np.load(os.path.join(save_embeddings_path,f'X_test_mask_CAS_all.npy'))

X_CAS_empty = np.load(os.path.join(save_embeddings_path,f'X_test_empty_CAS_all.npy'))
X_CAS_mask_empty = np.load(os.path.join(save_embeddings_path,f'X_test_empty_mask_CAS_all.npy'))

print(f"Data loaded")

print(f"X_test shape: {X_CAS.shape}")
print(f"X_test_mask shape: {X_CAS_mask.shape}")

print(f"X_test_empty shape: {X_CAS_empty.shape}")
print(f"X_test_empty_mask shape: {X_CAS_mask_empty.shape}")


In [None]:
X_CAS_tot = np.concatenate((X_CAS, X_CAS_empty), axis=0)
X_CAS_mask_tot = np.concatenate((X_CAS_mask, X_CAS_mask_empty), axis=0)

print(f"X_test_tot shape: {X_CAS_tot.shape}")
print(f"X_test_mask_tot shape: {X_CAS_mask_tot.shape}")


In [None]:
X_CAS = X_CAS_tot
X_CAS_mask = X_CAS_mask_tot

print(f"X_test shape: {X_CAS.shape}")
print(f"X_test_mask shape: {X_CAS_mask.shape}")

### Load

In [None]:
X_train = np.load(os.path.join(save_embeddings_path_SOTA,f'X_train_CAS_SOTA.npy'))
X_train_mask = np.load(os.path.join(save_embeddings_path_SOTA,f'X_train_mask_CAS_SOTA.npy'))

X_val = np.load(os.path.join(save_embeddings_path_SOTA,f'X_val_CAS_SOTA.npy'))
X_val_mask = np.load(os.path.join(save_embeddings_path_SOTA,f'X_val_mask_CAS_SOTA.npy'))

X_train.shape, X_train_mask.shape, X_val.shape, X_val_mask.shape


In [None]:
if not just_sampling:
    import torchvision.transforms as transforms

    siamese_dataset = AETDataset(images_array=X_train,
                                            transform=transforms.ToTensor(),)

    metadata = siamese_dataset.get_metadata()
    np.save(os.path.join(save_embeddings_path,f'metadata_CAS.npy'), metadata)

    siamese_val_dataset = AETDataset(images_array=X_val,
                                                transform=transforms.ToTensor(),
                                                metadata=metadata)

### Load metadata

In [None]:
metadata = np.load(os.path.join(save_embeddings_path,f'metadata_CAS.npy'))
print(metadata)

### skip

In [None]:
from torch.utils.data import DataLoader
import torchvision

if not just_sampling:
    vis_dataloader = DataLoader(siamese_dataset,
                            shuffle=True,
                            num_workers=0,
                            batch_size=1)

    print(f"Plotting examples")

    plt.close('all')
    plt.figure()
    for i,example_batch in enumerate(vis_dataloader):
        print(example_batch[0].shape, example_batch[1].shape, example_batch[2].shape)
        
        concatenated = torch.cat((example_batch[0],example_batch[1]),0) # 8,1,100,100
        imshow(torchvision.utils.make_grid(concatenated))
        if i==3:
            break
        
          

### Load model

In [18]:
import torch
from torchvision import transforms
import torchvision.transforms.functional as FT
from PIL import Image
import numpy as np

class Mapper:
    def __init__(self):
        # Define the sequence of transformations to apply
        self.transforms = transforms.Compose([
            transforms.RandomHorizontalFlip(),  # Randomly flip the image horizontally
            transforms.RandomRotation(30),     # Randomly rotate the image by up to 30 degrees
            transforms.RandomResizedCrop(32, scale=(0.8, 1.0))  # Randomly crop and resize to 32x32
        ])
    
    def __call__(self, x):
        """
        Apply the defined transformations to each image in the batch.

        Parameters:
        x (torch.Tensor): A batch of images with shape (batch_size, 1, 32, 32).

        Returns:
        tuple: Transformed batch of images and their corresponding transformation matrices.
        """
        transformed_images = []
        transformation_matrices = []

        for i in range(x.size(0)):  # Iterate over the batch dimension
            img = FT.to_pil_image(x[i].cpu())  # Convert tensor to PIL Image

            # Initialize transformation matrix as an identity matrix
            matrix = np.eye(3)

            # Apply Random Horizontal Flip
            if torch.rand(1).item() > 0.5:
                img = FT.hflip(img)
                matrix = np.dot(matrix, np.array([[-1, 0, 32], [0, 1, 0], [0, 0, 1]]))  # Horizontal flip matrix

            # Apply Random Rotation
            angle = torch.empty(1).uniform_(-30, 30).item()
            img = FT.rotate(img, angle)
            rad = np.deg2rad(angle)
            rotation_matrix = np.array([
                [np.cos(rad), -np.sin(rad), 0],
                [np.sin(rad), np.cos(rad), 0],
                [0, 0, 1]
            ])
            matrix = np.dot(matrix, rotation_matrix)

            # Apply Random Resized Crop
            i, j, h, w = transforms.RandomResizedCrop.get_params(img, scale=(0.8, 1.0), ratio=(1.0, 1.0))
            img = FT.resized_crop(img, i, j, h, w, (32, 32))
            crop_matrix = np.array([
                [32/w, 0, -32*j/w],
                [0, 32/h, -32*i/h],
                [0, 0, 1]
            ])
            matrix = np.dot(matrix, crop_matrix)

            img = FT.to_tensor(img)  # Convert back to tensor
            transformed_images.append(img)
            transformation_matrices.append(matrix)
            
        #flatten the list of transformation matrices
        transformation_matrices = [matrix.flatten() for matrix in transformation_matrices]
        #convert to float32
        transformation_matrices = np.array(transformation_matrices, dtype=np.float32)

        return torch.stack(transformed_images).cuda(), torch.tensor(transformation_matrices).cuda()

# Example usage:
# mapper = Mapper()
# transformed_batch, transformation_matrices = mapper(input_batch)


In [19]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms

# Define the Encoder
class Encoder(nn.Module):
    def __init__(self, latent_dim=512):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
        self.fc = nn.Linear(256*4*4, latent_dim)
        
        # Dropout layers
        self.dropout = nn.Dropout(p=0.5)
    
    def forward(self, x):
        #print(f"1- shape: {x.shape}") # bs, 1, 32, 32
        x = F.relu(self.conv1(x))
        #print(f"2- shape: {x.shape}") # bs, 64, 32, 32
        x = F.max_pool2d(x, 2)
        x = self.dropout(x)
        #print(f"3- shape: {x.shape}") # bs, 64, 16, 16
        x = F.relu(self.conv2(x))
        #print(f"4- shape: {x.shape}") # bs, 128, 16, 16
        x = F.max_pool2d(x, 2)
        x = self.dropout(x)
        #print(f"5- shape: {x.shape}") # bs, 128, 8, 8
        x = F.relu(self.conv3(x))
        #print(f"6- shape: {x.shape}") # bs, 256, 8, 8
        x = F.max_pool2d(x, 2)
        x = self.dropout(x)
        #print(f"7- shape: {x.shape}") # bs, 256, 4, 4
        x = x.view(x.size(0), -1)
        #print(f"8- shape: {x.shape}") # bs, 256*4*4=4096
        x = self.fc(x)
        x = self.dropout(x)
        
        return x

# Define the Decoder
class Decoder(nn.Module):
    def __init__(self, latent_dim=512):
        super(Decoder, self).__init__()
        n=latent_dim
        self.fc1 = nn.Linear(2*n, n)
        self.fc2 = nn.Linear(n, n//2)
        self.fc3 = nn.Linear(n//2, 3*3)  # Assuming the transformation matrix is 3x3
    
    def forward(self, z, z_t):
        x = torch.cat((z, z_t), dim=1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Define the full AET Model
class AETModel(nn.Module):
    def __init__(self, latent_dim=512):
        super(AETModel, self).__init__()
        self.encoder = Encoder(latent_dim=latent_dim)
        self.decoder = Decoder(latent_dim=latent_dim)
        self.mapper = Mapper()
    
    def forward(self, x, return_latent=False):
        z = self.encoder(x)
        if return_latent:
            return z
        x_t, true_t = self.mapper(x)
        z_t = self.encoder(x_t)
        t_pred = self.decoder(z, z_t)
        return t_pred, x_t, true_t


In [20]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class TransformationLoss(nn.Module):
    """
    Transformation loss function for AET model.
    """

    def __init__(self):
        super(TransformationLoss, self).__init__()
        self.loss_accumulator = 0.0
        self.num_samples = 0

    def forward(self, t_pred, t_true):
        # Mean Squared Error loss
        loss_transformation = F.mse_loss(t_pred, t_true)
        
        # Accumulate the loss
        self.loss_accumulator += loss_transformation.item()
        self.num_samples += 1
        return loss_transformation

    def get_accumulated_loss(self):
        return self.loss_accumulator

    def get_mean_loss(self):
        if self.num_samples == 0:
            return 0.0
        mean_loss = self.loss_accumulator / self.num_samples
        self.reset()
        return mean_loss
        
    def reset(self):
        self.loss_accumulator = 0.0
        self.num_samples = 0


### skip

In [21]:
if not just_sampling:
    train_dataloader = DataLoader(siamese_dataset,
                            shuffle=True,
                            num_workers=8,
                            batch_size=128)

    val_dataloader = DataLoader(siamese_val_dataset,
                            shuffle=False,
                            num_workers=8,
                            batch_size=256)

### Training (skip)

In [22]:
load_model = False
start_epoch=0
latent_dim=128

In [23]:
if not just_sampling:
    net = AETModel(latent_dim).cuda()

    if load_model:
        checkpoint_path = os.path.join(save_embeddings_path_SOTA, 'best_model_AET.pt')
        print(f"Loading model from {checkpoint_path}")
        # Load state dictionary into model
        net.load_state_dict(torch.load(checkpoint_path))
        start_epoch= 0


    criterion = TransformationLoss()

    best_loss = 9999
    best_loss_epoch = 0
    cumul_epochs = 0

    mean_loss_contrastive_list = []
    best_loss_contrastive_list = []
    validation_loss_list = []
    continue_training = False

In [None]:
from torch import optim

if not just_sampling:
    n_epochs = 500
    early_stopping_tolerance = 50
    optimizer = optim.Adam(net.parameters(),lr = 0.05)

    if continue_training:
        mean_loss_contrastive_list = mean_loss_contrastive_list.tolist()
        best_loss_contrastive_list = best_loss_contrastive_list.tolist()
        validation_loss_list = validation_loss_list.tolist()

    print(f"Starting round of training from epoch {cumul_epochs}")



    for epoch in tqdm(range(start_epoch, n_epochs), desc='Epochs'): 
    # Training loop
        net.train()
        for data in tqdm(train_dataloader, desc='Batches', leave=False):
            img0, img1, matrix = data
            img0, img1, matrix = img0.cuda(), img1.cuda(), matrix.cuda()

            optimizer.zero_grad()
            t_pred, x_t, matrix = net(img0)
            #print(f"t_pred shape: {t_pred.shape}, x_t shape: {x_t.shape}, matrix shape: {matrix.shape}")
            loss_transformation = criterion(t_pred, matrix)
            loss_transformation.backward()
            optimizer.step()

        # Calculate mean loss for transformation loss during training
        mean_loss_transformation = criterion.get_mean_loss()
        mean_loss_contrastive_list.append(mean_loss_transformation)
        
        # Validation loop
        net.eval()  # Set the model to evaluation mode
        val_loss = 0.0
        with torch.no_grad():
            for val_data in tqdm(val_dataloader, desc='Validation', leave=False):
                val_img0, val_img1, val_matrix = val_data
                val_img0, val_img1, val_matrix = val_img0.cuda(), val_img1.cuda(), val_matrix.cuda()
                val_t_pred, val_x_t, val_matrix = net(val_img0)
                val_loss += criterion(val_t_pred, val_matrix).item()

        val_loss /= len(val_dataloader)
        validation_loss_list.append(val_loss)

        # Check if current loss is the best so far
        if val_loss < best_loss:
            print(f"Epoch number {cumul_epochs} --- Best loss {val_loss}")
            best_loss = val_loss
            best_loss_epoch = cumul_epochs
            torch.save(net.state_dict(), os.path.join(save_embeddings_path_SOTA, 'best_model_AET.pt'))
        else:
            if cumul_epochs - best_loss_epoch > early_stopping_tolerance:
                print(f"Early stopping at epoch {cumul_epochs}")
                best_loss_contrastive_list.append(best_loss)
                cumul_epochs += 1
                break    

        best_loss_contrastive_list.append(best_loss)
        cumul_epochs += 1

In [25]:
if not just_sampling:
    np.save(os.path.join(save_embeddings_path_SOTA, f'mean_loss_contrastive_list_CAS_AET.npy'), mean_loss_contrastive_list)
    np.save(os.path.join(save_embeddings_path_SOTA, f'best_loss_contrastive_list_CAS_AET.npy'), best_loss_contrastive_list)
    np.save(os.path.join(save_embeddings_path_SOTA, f'validation_loss_list_CAS_AET.npy'), validation_loss_list)


### Plot losses

In [None]:

mean_loss_contrastive_list = np.load(os.path.join(save_embeddings_path_SOTA, f'mean_loss_contrastive_list_CAS_AET.npy'))
best_loss_contrastive_list = np.load(os.path.join(save_embeddings_path_SOTA, f'best_loss_contrastive_list_CAS_AET.npy'))
validation_loss_list = np.load(os.path.join(save_embeddings_path_SOTA, f'validation_loss_list_CAS_AET.npy'))

continue_training=True

%matplotlib inline
plt.ioff()
plt.close('all')
show_plot(range(0,len(mean_loss_contrastive_list)),mean_loss_contrastive_list, title='Training Loss')
show_plot(range(0,len(best_loss_contrastive_list)),validation_loss_list, title='Val Loss')
show_plot(range(0,len(best_loss_contrastive_list)),best_loss_contrastive_list, title='Best Val Loss')


### LOAD MODEL and Test (skip)

In [27]:
import torch.nn.functional as F
# LOAD BEST MODEL
if not just_sampling:
    net = AETModel(latent_dim).cuda()
    checkpoint_path = os.path.join(save_embeddings_path_SOTA, 'best_model_AET.pt')

    # Load state dictionary into model
    net.load_state_dict(torch.load(checkpoint_path))
    net.eval()

### Skip

In [None]:
from torch.autograd import Variable 
from tqdm import tqdm

if not just_sampling:
    siamese_testset = AETDataset(images_array=X_CAS,
                                            transform=transforms.ToTensor(),
                                            test_mode=True,
                                            metadata=metadata)
    
    test_dataloader = DataLoader(siamese_testset,
                            shuffle=False,
                            batch_size=1)


    # Get the total number of images
    num_images = len(siamese_testset)

    # Assuming output1 has a fixed size, get the latent vector size
    # Here, we get a dummy batch to determine the size. Replace this with actual size if known
    dummy_img = next(iter(test_dataloader)).cuda()
    latent_vector_size = net(dummy_img, return_latent=True).cpu().detach().numpy().shape
    print(f"Latent vector size: {latent_vector_size}")

    # Preallocate the numpy array for embeddings
    img_embeddings = np.empty((num_images, latent_vector_size[1]))

    # Process each image and store the embeddings
    for i, img in enumerate(tqdm(test_dataloader)):
        img0 = img.cuda()
        output1 = net(img0, return_latent=True).cpu().detach().numpy()
        img_embeddings[i] = output1.flatten()
    

    np.save(os.path.join(save_embeddings_path_SOTA, f'img_embeddings_CAS_AET.npy'), img_embeddings)
    print(f"Image Embeddings saved (shape {img_embeddings.shape})")

### Load img embeddings

In [None]:
img_embeddings = np.load(os.path.join(save_embeddings_path_SOTA, f'img_embeddings_CAS_AET.npy'))

print(f"img_embeddings loaded from {save_embeddings_path_SOTA}")
print(f'img_embeddings shape: {img_embeddings.shape}')

#Convert as float32
print(f"Unique coordinates: {len(np.unique(img_embeddings, axis=0))} (Over {img_embeddings.shape[0]} total points)")
img_embeddings = img_embeddings.astype(np.float32)
print(f"Unique coordinates: {len(np.unique(img_embeddings, axis=0))} (Over {img_embeddings.shape[0]} total points)")
print(f'img_embeddings type: {type(img_embeddings)}, dtype: {img_embeddings.dtype}')

In [None]:
def plot_random_vectors(array):
    """
    Selects two random 128-dimensional vectors from the given array and plots them side by side.
    
    Parameters:
    array (np.ndarray): A NumPy array of shape (n, 128).
    """
    # Ensure the array has the correct shape
    if array.shape[1] != 128:
        raise ValueError("The input array must have 128 columns.")
    
    # Select 2 random indices
    indices = np.random.choice(array.shape[0], size=2, replace=False)

    # Extract the 2 vectors
    vector1 = array[indices[0]]
    vector2 = array[indices[1]]

    # Plot the vectors side by side
    fig, axs = plt.subplots(1, 2, figsize=(12, 4))

    axs[0].plot(vector1)
    axs[0].set_title(f'Vector 1 (Index: {indices[0]})')

    axs[1].plot(vector2)
    axs[1].set_title(f'Vector 2 (Index: {indices[1]})')

    plt.show()
    

plot_random_vectors(img_embeddings)

### PLot IMG Embeddings

### T-SNE (skip)

In [31]:
assert img_embeddings[0,:].all() == img_embeddings[1,:].all()

In [None]:
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

if not just_sampling:
    # Perform t-SNE clustering
    #Peplexity 500
    #early_exaggeration=40
    tsne = TSNE(n_components=2, random_state=42)
    img_embeddings_tsne = tsne.fit_transform(img_embeddings)

    # save the embeddings
    np.save(os.path.join(save_embeddings_path_SOTA,f'img_embeddings_tsne_CAS_AET.npy'), img_embeddings_tsne)
    print(f"TSNE img_embeddings saved in {os.path.join(save_embeddings_path_SOTA,f'img_embeddings_tsne_CAS_AET.npy')}")
    

### Load

In [None]:
img_embeddings_tsne = np.load(os.path.join(save_embeddings_path_SOTA,f'img_embeddings_tsne_CAS_AET.npy'))
print(f"TSNE img_embeddings loaded")
print(f"TSNE img_embeddings shape: {img_embeddings_tsne.shape}")

In [None]:
print(f"Unique coordinates: {len(np.unique(img_embeddings_tsne, axis=0))} (Over {img_embeddings_tsne.shape[0]} total points)")
print(f"Percentage of unique coordinates: {len(np.unique(img_embeddings_tsne, axis=0))/img_embeddings_tsne.shape[0]*100:.4f}%")

In [35]:
if not just_sampling:
    plt.figure()
    plt.scatter(img_embeddings_tsne[:, 0], img_embeddings_tsne[:, 1], s=1)
    plt.title('Image Clustering Results (t-SNE img_embedding)')

In [None]:
%matplotlib widget
interactive_plot(img_embeddings_tsne[:, 0], img_embeddings_tsne[:, 1], colors=None, action='click', img_list=X_CAS_mask, emb_list=img_embeddings, true_img_list=X_CAS)

In [None]:
plot_n_patches_overlap(X_CAS_mask, X_CAS, indexes=[5597,24,25,26,23])

In [38]:
n_clusters_img = 50

### K-MEANS T-SNE (skip)

In [None]:
from sklearn.cluster import KMeans 

if not just_sampling:
    # Create a KMeans object with the desired number of clusters
    kmeans = KMeans(n_clusters=n_clusters_img, random_state=42, n_init='auto')

    # Fit the KMeans model to the img_embeddings
    kmeans.fit(img_embeddings_tsne)

    # Get the cluster labels for each data point
    img_cluster_labels_2d = kmeans.labels_
    print(f"Classes: {set(img_cluster_labels_2d)}")
    np.save(os.path.join(save_embeddings_path_SOTA,f'img_cluster_labels_2d_CAS_AET.npy'), img_cluster_labels_2d)


    # Plot histogram of cluster labels
    plt.figure(figsize=(10, 5))
    plt.hist(img_cluster_labels_2d)
    plt.xlabel('Cluster Label')
    plt.ylabel('Count')
    plt.title('Histogram of Cluster Labels')
    plt.show()

### load

In [None]:
img_cluster_labels_2d = np.load(os.path.join(save_embeddings_path_SOTA,f'img_cluster_labels_2d_CAS_AET.npy'))
print(f"Cluster labels loaded")
print(f"Cluster labels shape: {img_cluster_labels_2d.shape}")

In [None]:
interactive_plot(img_embeddings_tsne[:, 0], img_embeddings_tsne[:, 1], colors=img_cluster_labels_2d, action='click', img_list=X_CAS_mask, emb_list=img_embeddings, true_img_list=X_CAS)

### K-MEANS with embedding code (skip)

In [None]:
from sklearn.cluster import KMeans
if not just_sampling:
    # Create a KMeans object with the desired number of clusters
    kmeans = KMeans(n_clusters=n_clusters_img, random_state=42, n_init='auto')

    # Fit the KMeans model to the img_embeddings
    kmeans.fit(img_embeddings)

    # Get the cluster labels for each data point
    img_cluster_labels = kmeans.labels_
    print(f"Classes: {set(img_cluster_labels)}")
    np.save(os.path.join(save_embeddings_path_SOTA,f'img_cluster_labels_CAS_AET.npy'), img_cluster_labels)

### load

In [None]:
img_cluster_labels = np.load(os.path.join(save_embeddings_path_SOTA,f'img_cluster_labels_CAS_AET.npy'))
print(f"Cluster labels loaded")
print(f"Cluster labels shape: {img_cluster_labels.shape}")

In [None]:
plot_2_clusters(img_embeddings_tsne, img_cluster_labels, img_cluster_labels_2d) 

### Sampling: no more from here

In [45]:
import numpy as np
from sklearn.neighbors import NearestNeighbors
from sklearn.preprocessing import normalize

def calculate_density(features, k=5):
    nbrs = NearestNeighbors(n_neighbors=k).fit(features)
    distances, _ = nbrs.kneighbors(features)
    densities = np.sum(distances, axis=1)
    return densities

def select_representative_samples(densities, num_samples):
    print(f"Selecting {num_samples} samples from {len(densities)} total samples")
    indices = np.argsort(densities)[:num_samples]
    return indices

def one_shot_AET(features, num_samples, k=5):
    
    if num_samples < 1:
        print(f"WARNING: num_samples parameter is <1 ({num_samples})")
        print(f"Taking the {num_samples*100}% of the num_samples ({len(features)}) ==> {int(num_samples * len(features))}")
        num_samples = int(num_samples * len(features))

    # Normalize features
    features = normalize(features)
    
    # Step 2: Calculate density of samples in feature space
    densities = calculate_density(features, k)
    
    # Step 3: Select samples with higher local density
    selected_indices = select_representative_samples(densities, num_samples)
    
    return selected_indices

def filter_by_index(idx, *arrays):
    return (arr[idx] for arr in arrays)

def AET_sampling(features, *embedding_lists, n_size=1000, random_seed=42):
    np.random.seed(random_seed)
    selected_indices = one_shot_AET(features, n_size)
    return filter_by_index(selected_indices, *embedding_lists)


In [63]:
n_size = 5/100

In [None]:
X_test_AET, X_test_mask_AET, img_embeddings_AET, img_embeddings_tsne_AET, img_cluster_labels_AET, img_cluster_labels_2d_AET = AET_sampling(img_embeddings, X_CAS, X_CAS_mask, img_embeddings, img_embeddings_tsne, img_cluster_labels, img_cluster_labels_2d, n_size=n_size, random_seed=42)

print(f"Selected samples filtered")
print(f"X_test_AET shape: {X_test_AET.shape}")
print(f"X_test_mask_AET shape: {X_test_mask_AET.shape}")
print(f"img_embeddings_AET shape: {img_embeddings_AET.shape}")
print(f"img_embeddings_tsne_AET shape: {img_embeddings_tsne_AET.shape}")
print(f"img_cluster_labels_AET shape: {img_cluster_labels_AET.shape}")
print(f"img_cluster_labels_2d_AET shape: {img_cluster_labels_2d_AET.shape}")
