## 0 - Import libraries

In [1]:
just_sampling = False

In [2]:
import os
import random
import numpy as np
import torch
from torch.utils.data import DataLoader

torch.backends.cudnn.benchmark = True

from tqdm import tqdm
import imageio.v2 as imageio

from utils import set_random_seed, mk_dir
from importlib import import_module

from torchvision.utils import save_image

In [3]:
import matplotlib.pyplot as plt
plt.ion() 

def save_img(tensor, name, norm, n_rows=16, scale_each=False):
    save_image(tensor, name, nrow=n_rows, padding=5, normalize=norm, pad_value=1, scale_each=scale_each)
    
def plot_10_patches(img_np, true_vessel_np=False, indexes=None):
    
    random_indexes = np.random.randint(0, img_np.shape[0], size=5) if indexes is None else indexes

    n = len(random_indexes)
    
    fig, ax = plt.subplots(n, 2, figsize=(6, 3*n))
    
    for i, index in enumerate(random_indexes):
        
        if index == img_np.shape[0]:
            index -= 1
        if index == img_np.shape[0]-1:
            index -= 2
            
        #print(f"Shape: {img_np[index,:,:].shape}, Max: {img_np[index,:,:].max()}, Min: {img_np[index,:,:].min()}")
        ax[i, 0].set_title(f'True Vessel Slice {index}')
        ax[i, 0].imshow(img_np[index,:,:], cmap='gray', vmin=0, vmax=1)
        ax[i,0].axis('off')
        
        if true_vessel_np is not False:
            print(f"Shape: {true_vessel_np[index,:,:].shape}, Max: {true_vessel_np[index,:,:].max()}, Min: {true_vessel_np[index,:,:].min()}")
            ax[i, 1].imshow(true_vessel_np[index,:,:], cmap='gray')
            ax[i, 1].set_title(f'Image Slice {index}')
        else:
            ax[i, 1].set_title(f'Image Slice {index+1}')
            ax[i, 1].imshow(img_np[index+1,:,:], cmap='gray', vmin=0, vmax=1)
        
        ax[i,1].axis('off')
            
    
    plt.tight_layout()
    plt.show()

def plot_n_patches_overlap(img_np, true_vessel_np=False, indexes=None, selected_class=None, add_title='', m=5, alpha=0.5, save_dir='clusters_imgs'):
    
    save_dir = os.path.join(save_embeddings_path,save_dir)
    mk_dir(save_dir)
    plt.ioff()
    n = len(indexes) if indexes is not None else 5
    # m is the number of images to plot in each row (2m is the number of columns)
    n_rows = int(np.ceil(n/m))
    n_cols = m
    
    fig, ax = plt.subplots(n_rows, n_cols, figsize=(4*n_cols, 4*n_rows))
    ax = ax.flatten()
    
    for i, index in enumerate(indexes):
        ax[i].set_title(f'Overlay for Slice {index}')
        masked_image = np.ma.masked_where(img_np[index,:,:] == 0, true_vessel_np[index,:,:])
        # Overlay the red image on top of true_vessel_np[index,:,:]
        ax[i].imshow(true_vessel_np[index,:,:], cmap='gray', interpolation='none')
        ax[i].imshow(masked_image, cmap='Reds', alpha=alpha)
        
        ax[i].axis('off')
    
    try:
        plt.tight_layout()
    except:
        pass
    
    if add_title!='':
        print(f"Save Fig to {save_dir}")
        plt.savefig(os.path.join(save_dir,f'{add_title}_class_{selected_class}_patches.png'))
        plt.close(fig)
        
    else:
        print("Plotting..")
        plt.show()
    plt.ion()

    
def plot_n_patches(img_np, true_vessel_np=False, indexes=None, selected_class=None, add_title='', m=5):
    plt.ioff()
    n = len(indexes) if indexes is not None else 5
    if n > 1000:
        print(f'WARNING: YOU ARE TRYING TO PLOT {n} images')
    # m is the number of images to plot in each row (2m is the number of columns)
    n_rows = int(np.ceil(n/m))
    n_cols = 2*m
    
    fig, ax = plt.subplots(n_rows, n_cols, figsize=(3*n_cols, 3*n_rows))
    ax = ax.flatten()
    
    for i, index in enumerate(indexes):
        
        if index == img_np.shape[0]:
            index -= 1
        if index == img_np.shape[0]-1:
            index -= 2
            
        ax[2*i].set_title(f'Image Slice {index}')
        ax[2*i].imshow(img_np[index,:,:], cmap='gray', vmin=0, vmax=1)
        ax[2*i].axis('off')
        print(f"Shape: {img_np[index,:,:].shape}, Max: {img_np[index,:,:].max()}, Min: {img_np[index,:,:].min()}")
        if true_vessel_np is not False:
            ax[2*i + 1].imshow(true_vessel_np[index,:,:], cmap='gray')
            ax[2*i + 1].set_title(f'True Vessel Slice {index}')
            
        else:
            ax[2*i + 1].set_title(f'Image Slice {index+1}')
            ax[2*i + 1].imshow(img_np[index+1,:,:], cmap='gray', vmin=0, vmax=1)
        
        ax[2*i+1].axis('off')
            
    
    plt.tight_layout()
    if add_title!='':
        print("Save Fig")
        plt.savefig(f'{add_title}_class_{selected_class}_patches.png')
        plt.close(fig)
        
    else:
        print("Plotting..")
        plt.show()
    plt.ion()
    
def reshape_to_square(vector):
    # Calculate the nearest square number greater than or equal to the length of the vector
    n = int(np.ceil(np.sqrt(len(vector))))
    
    # Calculate the number of elements to pad with zeros
    num_zeros = n*n - len(vector)
    
    # Pad the vector with zeros if necessary
    vector_padded = np.pad(vector, (0, num_zeros), mode='constant')
    
    # Reshape the padded vector into a square matrix
    square_matrix = vector_padded.reshape((n, n))
    
    return square_matrix

In [4]:
import matplotlib.pyplot as plt
import numpy as np
from scipy.spatial import KDTree

def label_point(x, y, ids, ax):
    """Annotate points on plot with their IDs."""
    for i, txt in enumerate(ids):
        ax.annotate(txt, (x[i], y[i]))

def interactive_plot(x, y, ids=None, colors=None, action='click', img_list=None, emb_list=None, true_img_list=None ,zoom=False, filtered_ids=[]):
    """Identify the ID of a point by clicking on it."""
    if ids is None:
        ids = [str(i) for i in range(len(x))]
    fig, ax = plt.subplots(2, 3, figsize=(12, 8)) if zoom else plt.subplots(1, 4, figsize=(16,4))
    ax = ax.flatten()
    
    ax[1].axis('off')
    ax[2].axis('off')
    ax[3].axis('off')
    
    ax[0].scatter(x, y, s=1, c=colors) if colors is not None else ax[0].scatter(x, y, s=1 if zoom else 0.1)
    # Set limit to the plot
    if zoom:
        ax[4].scatter(x, y, s=1, c=colors) if colors is not None else ax[4].scatter(x, y, s=1)
        ax[4].set_xlim([-100, 100])
        ax[4].set_ylim([-100, 100])
    # Set a threshold based on max values of x and y
    threshold = max(max(x) - min(x), max(y) - min(y)) / 80
    print(f"Threshold: {threshold}")
    #label_point(x, y, ids, ax[0])
    tree = KDTree(np.column_stack((x, y)))
    
    def onclick(event):
        """Event handler for mouse click."""
        if event.inaxes == ax[0]:
            dist, i = tree.query([event.xdata, event.ydata])
            if dist < threshold:
                ax[1].imshow(img_list[int(ids[i])], cmap='gray')
                id_img_title = ids[i] if len(filtered_ids) == 0 else filtered_ids[i]
                ax[1].set_title(f'Image {id_img_title} (color: {colors[i]})') if colors is not None else ax[1].set_title(f'Image {id_img_title}')
                ax[2].imshow(reshape_to_square(emb_list[int(ids[i])]), cmap='gray')
                ax[2].set_title(f'Embedding {id_img_title}')
                ax[3].imshow(true_img_list[int(ids[i])], cmap='gray')
                ax[3].set_title(f'True Image {id_img_title}')
                
                ax[1].axis('off')
                ax[2].axis('off')
                ax[3].axis('off')
        
    fig.tight_layout()
    if action == 'click':
        fig.canvas.mpl_connect('button_press_event', onclick)
    elif action == 'hover':
        fig.canvas.mpl_connect('motion_notify_event', onclick)
    
    plt.show()
    
def plot_2_clusters(embeddings_tsne, cluster_labels, cluster_labels_2d):
    fig, ax = plt.subplots(1, 2, figsize=(10, 5))
    ax[0].scatter(embeddings_tsne[:, 0], embeddings_tsne[:, 1], c=cluster_labels, s=1)
    ax[0].set_title('Clustering Results (t-SNE embedding)')

    ax[1].scatter(embeddings_tsne[:, 0], embeddings_tsne[:, 1], c=cluster_labels_2d, s=1)
    ax[1].set_title('Clustering Results (t-SNE embedding 2D)')

    plt.show()

## 3) VAE NETWORK

In [None]:
path_data = "/data/falcetta/brain_data"
save_embeddings_path = os.path.join(path_data, f"embeddings_VDISNET") 

dataset_name = 'CAS'

save_embeddings_path_SOTA = os.path.join(save_embeddings_path,'SOTA')
print(f"Save embeddings in {save_embeddings_path_SOTA}")
mk_dir(save_embeddings_path_SOTA)


In [6]:
%matplotlib inline

def imshow(img,text=None,should_save=False):
    npimg = img.numpy()
    plt.figure()
    plt.axis("off")
    if text:
        plt.text(75, 8, text, style='italic',fontweight='bold',
            bbox={'facecolor':'white', 'alpha':0.8, 'pad':10})
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()    

def show_plot(iteration,loss, title=''):
    plt.figure()
    plt.plot(iteration,loss)
    plt.title(f"{title} Loss")
    plt.show()

In [7]:
import numpy as np
from PIL import Image

from torch.utils.data import Dataset
import torch
import random
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
import torch
import random

import torch
from torch.utils.data import Dataset
import numpy as np
from PIL import Image
import random

class ConvVAEDataset(Dataset):
    def __init__(self, images_array, transform=None, metadata=None, device='cuda'):
        self.images_array = images_array
        self.device = device
        self.transform = transform
        
        if metadata is None:
            self.compute_metadata()
        else:
            self.mean, self.std = metadata 
            
    def __getitem__(self, index):
        idx_0 = int(index)
        img0 = self.images_array[idx_0]
        img0 = Image.fromarray(img0)
        if self.transform is not None:
            img0 = self.transform(img0)
        img0 = (img0 - self.mean) / self.std    
        img0 = (img0 - img0.min()) / (img0.max() - img0.min())
        return img0
    
    def compute_metadata(self):
        with torch.no_grad():
            flattened_data = self.images_array.reshape(self.images_array.shape[0], -1)
            means = np.mean(flattened_data, axis=0)
            stds = np.std(flattened_data, axis=0)
            
            mean = np.mean(means)
            std = np.mean(stds)
            
            self.mean = torch.tensor(mean)
            self.std = torch.tensor(std)
        
        print(f"Mean: {self.mean} - Std: {self.std}")

    def get_metadata(self):
        return self.mean, self.std
    
    def __len__(self):
        return len(self.images_array)


In [None]:
X_CAS = np.load(os.path.join(save_embeddings_path,f'X_test_CAS_all.npy'))
X_CAS_mask = np.load(os.path.join(save_embeddings_path,f'X_test_mask_CAS_all.npy'))

X_CAS_empty = np.load(os.path.join(save_embeddings_path,f'X_test_empty_CAS_all.npy'))
X_CAS_mask_empty = np.load(os.path.join(save_embeddings_path,f'X_test_empty_mask_CAS_all.npy'))

print(f"Data loaded")

print(f"X_test shape: {X_CAS.shape}")
print(f"X_test_mask shape: {X_CAS_mask.shape}")

print(f"X_test_empty shape: {X_CAS_empty.shape}")
print(f"X_test_empty_mask shape: {X_CAS_mask_empty.shape}")




In [None]:
X_CAS_tot = np.concatenate((X_CAS, X_CAS_empty), axis=0)
X_CAS_mask_tot = np.concatenate((X_CAS_mask, X_CAS_mask_empty), axis=0)

print(f"X_test_tot shape: {X_CAS_tot.shape}")
print(f"X_test_mask_tot shape: {X_CAS_mask_tot.shape}")


In [None]:
X_CAS = X_CAS_tot
X_CAS_mask = X_CAS_mask_tot

print(f"X_test shape: {X_CAS.shape}")
print(f"X_test_mask shape: {X_CAS_mask.shape}")

### Load

In [None]:
X_train = np.load(os.path.join(save_embeddings_path_SOTA,f'X_train_CAS_SOTA.npy'))
X_train_mask = np.load(os.path.join(save_embeddings_path_SOTA,f'X_train_mask_CAS_SOTA.npy'))

X_val = np.load(os.path.join(save_embeddings_path_SOTA,f'X_val_CAS_SOTA.npy'))
X_val_mask = np.load(os.path.join(save_embeddings_path_SOTA,f'X_val_mask_CAS_SOTA.npy'))

X_train.shape, X_train_mask.shape, X_val.shape, X_val_mask.shape


In [None]:
if not just_sampling:
    import torchvision.transforms as transforms

    RA_dataset = ConvVAEDataset(images_array=X_train,
                                            transform=transforms.ToTensor(),)

    metadata = RA_dataset.get_metadata()
    np.save(os.path.join(save_embeddings_path,f'metadata_CAS.npy'), metadata)

    RA_val_dataset = ConvVAEDataset(images_array=X_val,
                                                transform=transforms.ToTensor(),
                                                metadata=metadata)

### Load metadata

In [None]:
metadata = np.load(os.path.join(save_embeddings_path,f'metadata_CAS.npy'))
print(metadata)

### skip

In [None]:
from torch.utils.data import DataLoader
import torchvision

if not just_sampling:
    vis_dataloader = DataLoader(RA_dataset,
                            shuffle=True,
                            num_workers=0,
                            batch_size=1)

    plot_same = True
    plot_diff = True

    print(f"Plotting examples")

    plt.close('all')
    plt.figure()
    for en,example_batch in enumerate(vis_dataloader):
        plt.figure()
        plt.imshow(example_batch[0][0], cmap='gray')
        plt.axis('off')
        plt.show()
        
        if en > 5:
            break

### CONV-VAE

In [15]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

class ConvVAE(nn.Module):
    def __init__(self, latent_dim=20):
        super(ConvVAE, self).__init__()
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=4, stride=2, padding=1), # [bs, 32, 16, 16]
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1), # [bs, 64, 8, 8]
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), # [bs, 128, 4, 4]
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1), # [bs, 256, 2, 2]
            nn.ReLU(),
        )
        
        dim_last_conv = 256 * 2 * 2
        self.fc_mu = nn.Linear(dim_last_conv, latent_dim) # [bs, latent_dim]
        self.fc_logvar = nn.Linear(dim_last_conv, latent_dim) # [bs, latent_dim]
        self.fc_decode = nn.Linear(latent_dim, dim_last_conv) # [bs, 256 * 2 * 2]
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1), # [bs, 128, 4, 4]
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), # [bs, 64, 8, 8]
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1), # [bs, 32, 16, 16]
            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, kernel_size=4, stride=2, padding=1), # [bs, 1, 32, 32]
            nn.Sigmoid(),
        )
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def forward(self, x, return_latent=False):
        #print(f"Input shape: {x.shape}") # [bs, 1, 32, 32]
        h = self.encoder(x)
        #print(f"Encoded shape: {h.shape}") # [bs, 256, 2, 2]
        h = h.view(h.size(0), -1)
        #print(f"Flattened shape: {h.shape}") # [bs, 256 * 2 * 2]
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        z_lat = self.reparameterize(mu, logvar)
        z = self.fc_decode(z_lat)
        #print(f"Decoded shape: {z.shape}") # [bs, 256 * 2 * 2]
        
        z = z.view(z.size(0), 256, 2, 2)
        #z = z.view(z.size(0), 128, 4, 4)
        #z = z.view(z.size(0), 64, 8, 8)
        
        #print(f"Reshaped shape: {z.shape}") # [bs, 256, 2, 2]
        #print(f"Output shape: {self.decoder(z).shape}") # [bs, 1, 32, 32]
        if return_latent:
            return z_lat
        
        return self.decoder(z), mu, logvar


In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class VAELoss(nn.Module):
    """
    VAE loss function.
    Combines the reconstruction loss (Binary Cross-Entropy) and the Kullback-Leibler Divergence (KLD) loss.
    """
    
    def __init__(self):
        super(VAELoss, self).__init__()
        self.loss_accumulator = 0.0
        self.num_samples = 0

    def forward(self, recon_x, x, mu, logvar, message='TR'):

        x = x.to(recon_x.device)
        mu = mu.to(recon_x.device)
        logvar = logvar.to(recon_x.device)
        
        # Reconstruction loss
        BCE = F.binary_cross_entropy(recon_x, x, reduction='mean')
        
        # Kullback-Leibler divergence
        KLD = torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        #print(f"BCE: {BCE}, KLD: {KLD}")
        # Total loss
        beta = 0 # weight for KLD loss (if alpha=0, the model is a simple autoencoder)
        total_loss = BCE - beta * KLD
        if random.random() < 0.001:
            print(f"{message} SINGLE LOSSES ==> BCE: {BCE}, W-KLD: {-beta * KLD}, KLD: {-KLD}") 
        
        # Accumulate loss
        self.loss_accumulator += total_loss.item()
        self.num_samples += 1
        
        return total_loss

    def get_accumulated_loss(self):
        return self.loss_accumulator

    def get_mean_loss(self):
        mean_loss = self.loss_accumulator / self.num_samples
        self.reset()
        return mean_loss
        
    def reset(self):
        self.loss_accumulator = 0.0
        self.num_samples = 0


### skip

In [17]:
if not just_sampling:
    train_dataloader = DataLoader(RA_dataset,
                            shuffle=True,
                            num_workers=8,
                            batch_size=128)

    val_dataloader = DataLoader(RA_val_dataset,
                            shuffle=False,
                            num_workers=8,
                            batch_size=256)

### Training (skip)

In [18]:
plt.close('all')

In [19]:
load_model = False
start_epoch=0
latent_dim = 32

In [20]:
if not just_sampling:
    net = ConvVAE(latent_dim=latent_dim).cuda()

    if load_model:
        checkpoint_path = os.path.join(save_embeddings_path_SOTA, 'best_model_RA.pt')
        print(f"Loading model from {checkpoint_path}")
        # Load state dictionary into model
        net.load_state_dict(torch.load(checkpoint_path))
        start_epoch= 36


    criterion = VAELoss()

    best_loss = 999999999
    best_loss_epoch = 0
    cumul_epochs = 0

    mean_loss_contrastive_list = []
    best_loss_contrastive_list = []
    validation_loss_list = []
    continue_training = False

In [None]:
from torch import optim
import torch.nn.functional as F

if not just_sampling:
    n_epochs = 500
    early_stopping_tolerance = 50
    optimizer = optim.Adam(net.parameters(), betas=(0.9, 0.999), lr=1e-4)

    if continue_training:
        mean_loss_contrastive_list = mean_loss_contrastive_list.tolist()
        best_loss_contrastive_list = best_loss_contrastive_list.tolist()
        validation_loss_list = validation_loss_list.tolist()

    print(f"Starting round of training from epoch {cumul_epochs} (Best loss: {best_loss:.2f})")

    for epoch in tqdm(range(start_epoch, n_epochs), desc='Epochs'): 
        # Training loop
        net.train()
        
        for data in tqdm(train_dataloader, desc='Batches', leave=False):
            img0 = data.cuda()
            optimizer.zero_grad()
            recon_batch, mu, logvar = net(img0)
            loss_contrastive = criterion(recon_batch, img0, mu, logvar)
            #show an example of the reconstruction (Just 4 images)
            loss_contrastive.backward()
            optimizer.step()
        
        if epoch % 10 == 0 and cumul_epochs % 10 == 0:
                print(f"Recon shape: {recon_batch.shape}, Mu shape: {mu.shape}, Logvar shape: {logvar.shape}")
                imshow(torchvision.utils.make_grid(recon_batch.cpu().detach()[0:4]), 'TR - Reconstruction')
                imshow(torchvision.utils.make_grid(img0.cpu().detach()[0:4]), 'Original')
            
        # Calculate mean loss for contrastive loss during training
        mean_loss_contrastive = criterion.get_mean_loss()
        mean_loss_contrastive_list.append(mean_loss_contrastive)
        
        # Validation loop
        net.eval()  # Set the model to evaluation mode
        val_loss = 0.0
        with torch.no_grad():
            for val_data in tqdm(val_dataloader, desc='Validation', leave=False):
                val_img0 = val_data.cuda()
                val_recon_batch, val_mu, val_logvar = net(val_img0)
                val_loss += criterion(val_recon_batch, val_img0, val_mu, val_logvar,"val").item()

        val_loss /= len(val_dataloader)
        validation_loss_list.append(val_loss)
        #print(f"Validation Loss: {val_loss:.2f}")
        if epoch % 10 == 0 and cumul_epochs % 10 == 0:
            print(f"Recon shape: {val_recon_batch.shape}, Mu shape: {val_mu.shape}, Logvar shape: {val_logvar.shape}")
            imshow(torchvision.utils.make_grid(val_recon_batch.cpu().detach()[0:4]), 'VAL - Reconstruction')
            imshow(torchvision.utils.make_grid(val_img0.cpu().detach()[0:4]), 'Original')
        # Check if current loss is the best so far
        if val_loss < best_loss:
            print(f"Epoch number {cumul_epochs} --- NEW Best loss {val_loss}")
            best_loss = val_loss
            best_loss_epoch = cumul_epochs
            torch.save(net.state_dict(), os.path.join(save_embeddings_path_SOTA, 'best_model_RA.pt'))
        elif cumul_epochs - best_loss_epoch > 10:
            print(f"No improvement in the last 10 epochs. Validation loss: {val_loss:.2f}")
        else:
            if cumul_epochs - best_loss_epoch > early_stopping_tolerance:
                print(f"Early stopping at epoch {cumul_epochs}")
                best_loss_contrastive_list.append(best_loss)
                cumul_epochs +=1
                break    
        
        best_loss_contrastive_list.append(best_loss)
        cumul_epochs +=1

In [22]:
if not just_sampling:
    np.save(os.path.join(save_embeddings_path_SOTA, f'mean_loss_contrastive_list_CAS_RA.npy'), mean_loss_contrastive_list)
    np.save(os.path.join(save_embeddings_path_SOTA, f'best_loss_contrastive_list_CAS_RA.npy'), best_loss_contrastive_list)
    np.save(os.path.join(save_embeddings_path_SOTA, f'validation_loss_list_CAS_RA.npy'), validation_loss_list)


### Plot losses

In [None]:
if not just_sampling:
    mean_loss_contrastive_list = np.load(os.path.join(save_embeddings_path_SOTA, f'mean_loss_contrastive_list_CAS_RA.npy'))
    best_loss_contrastive_list = np.load(os.path.join(save_embeddings_path_SOTA, f'best_loss_contrastive_list_CAS_RA.npy'))
    validation_loss_list = np.load(os.path.join(save_embeddings_path_SOTA, f'validation_loss_list_CAS_RA.npy'))

    continue_training=True

    %matplotlib inline
    plt.ioff()
    plt.close('all')
    show_plot(range(0,len(mean_loss_contrastive_list)),mean_loss_contrastive_list, title='Training Loss')
    show_plot(range(0,len(best_loss_contrastive_list)),validation_loss_list, title='Val Loss')
    show_plot(range(0,len(best_loss_contrastive_list)),best_loss_contrastive_list, title='Best Val Loss')


### LOAD MODEL and Test (skip)

In [24]:
import torch.nn.functional as F
# LOAD BEST MODEL
if not just_sampling:
    net = ConvVAE(latent_dim=latent_dim).cuda()
    checkpoint_path = os.path.join(save_embeddings_path_SOTA, 'best_model_RA.pt')

    # Load state dictionary into model
    net.load_state_dict(torch.load(checkpoint_path))
    net.eval()

### Skip

In [None]:
from torch.autograd import Variable 
from tqdm import tqdm

if not just_sampling:
    siamese_testset = ConvVAEDataset(images_array=X_CAS,
                                            transform=transforms.ToTensor(),
                                            metadata=metadata,)

    test_dataloader = DataLoader(siamese_testset,
                            shuffle=False,
                            batch_size=1)


    # Get the total number of images
    num_images = len(siamese_testset)

    # Assuming output1 has a fixed size, get the latent vector size
    # Here, we get a dummy batch to determine the size. Replace this with actual size if known
    dummy_img = next(iter(test_dataloader)).cuda()
    latent_vector_size = net(dummy_img, return_latent=True).cpu().detach().numpy().shape
    print(f"Latent vector size: {latent_vector_size}")

    # Preallocate the numpy array for embeddings
    img_embeddings = np.zeros((num_images, latent_dim))

    # Process each image and store the embeddings
    for i, img in enumerate(tqdm(test_dataloader)):
        img0 = img.cuda()
        output1 = net(img0, return_latent=True).cpu().detach().numpy()
        img_embeddings[i] = output1.flatten()
    
    np.save(os.path.join(save_embeddings_path_SOTA, f'img_embeddings_CAS_RA.npy'), img_embeddings)
    print(f"Image Embeddings saved (shape {img_embeddings.shape})")

### Load img embeddings

In [None]:
img_embeddings = np.load(os.path.join(save_embeddings_path_SOTA, f'img_embeddings_CAS_RA.npy'))

print(f"img_embeddings loaded from {save_embeddings_path_SOTA}")
print(f'img_embeddings shape: {img_embeddings.shape}') #128*4*4 = 2048

img_embeddings = img_embeddings.astype(np.float32)
print(f'img_embeddings dtype: {img_embeddings.dtype}')

### PLot IMG Embeddings

### T-SNE (skip)

In [None]:
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
if not just_sampling:
    # Perform t-SNE clustering
    #Peplexity 500
    #early_exaggeration=40
    tsne = TSNE(n_components=2, random_state=42)
    img_embeddings_tsne = tsne.fit_transform(img_embeddings)

    plt.figure()
    plt.scatter(img_embeddings_tsne[:, 0], img_embeddings_tsne[:, 1], s=1)
    plt.title('Image Clustering Results (t-SNE img_embedding)')

    # save the embeddings
    np.save(os.path.join(save_embeddings_path_SOTA,f'img_embeddings_tsne_CAS_RA.npy'), img_embeddings_tsne)
    print(f"TSNE img_embeddings saved in {os.path.join(save_embeddings_path_SOTA,f'img_embeddings_tsne_CAS_RA.npy')}")

### Load

In [None]:
img_embeddings_tsne = np.load(os.path.join(save_embeddings_path_SOTA,f'img_embeddings_tsne_CAS_RA.npy'))
print(f"TSNE img_embeddings loaded")
print(f"TSNE img_embeddings shape: {img_embeddings_tsne.shape}")

In [None]:
%matplotlib widget
interactive_plot(img_embeddings_tsne[:, 0], img_embeddings_tsne[:, 1], colors=None, action='click', img_list=X_CAS_mask, emb_list=img_embeddings, true_img_list=X_CAS)

In [None]:
plot_n_patches_overlap(X_CAS_mask, X_CAS, indexes=[5597,24,25,26,23])

In [31]:
n_clusters_img = 50

### K-MEANS T-SNE (skip)

In [None]:
from sklearn.cluster import KMeans 

if not just_sampling:
    # Create a KMeans object with the desired number of clusters
    kmeans = KMeans(n_clusters=n_clusters_img, random_state=42, n_init='auto')

    # Fit the KMeans model to the img_embeddings
    kmeans.fit(img_embeddings_tsne)

    # Get the cluster labels for each data point
    img_cluster_labels_2d = kmeans.labels_
    print(f"Classes: {set(img_cluster_labels_2d)}")
    np.save(os.path.join(save_embeddings_path_SOTA,f'img_cluster_labels_2d_CAS_RA.npy'), img_cluster_labels_2d)


    # Plot histogram of cluster labels
    plt.figure(figsize=(10, 5))
    plt.hist(img_cluster_labels_2d)
    plt.xlabel('Cluster Label')
    plt.ylabel('Count')
    plt.title('Histogram of Cluster Labels')
    plt.show()

### load

In [None]:
img_cluster_labels_2d = np.load(os.path.join(save_embeddings_path_SOTA,f'img_cluster_labels_2d_CAS_RA.npy'))
print(f"Cluster labels loaded")
print(f"Cluster labels shape: {img_cluster_labels_2d.shape}")

In [None]:
interactive_plot(img_embeddings_tsne[:, 0], img_embeddings_tsne[:, 1], colors=img_cluster_labels_2d, action='click', img_list=X_CAS_mask, emb_list=img_embeddings, true_img_list=X_CAS)

### K-MEANS with embedding code (skip)

In [None]:
from sklearn.cluster import KMeans
if not just_sampling:
    # Create a KMeans object with the desired number of clusters
    kmeans = KMeans(n_clusters=n_clusters_img, random_state=42, n_init='auto')

    # Fit the KMeans model to the img_embeddings
    kmeans.fit(img_embeddings)

    # Get the cluster labels for each data point
    img_cluster_labels = kmeans.labels_
    print(f"Classes: {set(img_cluster_labels)}")
    np.save(os.path.join(save_embeddings_path_SOTA,f'img_cluster_labels_CAS_RA.npy'), img_cluster_labels)

### load

In [None]:
img_cluster_labels = np.load(os.path.join(save_embeddings_path_SOTA,f'img_cluster_labels_CAS_RA.npy'))
print(f"Cluster labels loaded")
print(f"Cluster labels shape: {img_cluster_labels.shape}")

In [None]:
plot_2_clusters(img_embeddings_tsne, img_cluster_labels, img_cluster_labels_2d) 

### Sampling ==> AGGLOMERATIVE CLUSTERING (TOO MUCH COMPUTATION NEEDED)

In [None]:
# import numpy as np
# from scipy.cluster.hierarchy import linkage, dendrogram
# from joblib import Parallel, delayed

# def parallel_linkage(data, method='ward', n_jobs=-1):
#     # Split data into chunks for parallel processing
#     chunks = np.array_split(data, n_jobs)
#     results = Parallel(n_jobs=n_jobs)(delayed(linkage)(chunk, method=method) for chunk in chunks)
    
#     # Combine results
#     combined = np.vstack(results)
#     return combined

# def determine_num_clusters(data, max_clusters=30, n_jobs=12):
#     print(f"DETERMINE NUM CLUSTERS IN PARALLEL")
#     data_flat = data.reshape(data.shape[0], -1)
    
#     # Perform hierarchical clustering using 'ward' method in parallel
#     linked = parallel_linkage(data_flat, method='ward', n_jobs=n_jobs)
    
#     # Create the dendrogram with no plotting and truncation to max_clusters
#     dendro = dendrogram(linked, truncate_mode='lastp', p=max_clusters, no_plot=True)
    
#     # Get the number of clusters from the dendrogram
#     num_clusters = len(dendro['leaves'])
    
#     return num_clusters

In [None]:
import numpy as np
from fastcluster import linkage as fast_linkage
from scipy.cluster.hierarchy import dendrogram

def determine_num_clusters(data, max_clusters=30, n_jobs=12):
    # Efficiently flatten the data
    print(f"DETERMINE NUM CLUSTERS WITH FASTCLUSTER")
    data_flat = data.reshape(data.shape[0], -1)
    
    # Perform hierarchical clustering using 'ward' method with fastcluster
    linked = fast_linkage(data_flat, method='ward', preserve_input=True)
    
    # Create the dendrogram with no plotting and truncation to max_clusters
    dendro = dendrogram(linked, truncate_mode='lastp', p=max_clusters, no_plot=True)
    
    # Get the number of clusters from the dendrogram
    num_clusters = len(dendro['leaves'])
    
    return num_clusters

In [None]:
import numpy as np
from sklearn.cluster import AgglomerativeClustering
from sklearn.metrics import pairwise_distances

def select_representatives(cluster_points, cluster_ids, num_representatives):
    cluster_flat = cluster_points.reshape(cluster_points.shape[0], -1)
    representatives = []
    representatives_idx = []
    
    # Initialize by selecting a random point as the first representative
    initial_idx = np.random.choice(len(cluster_flat))
    representatives.append(cluster_flat[initial_idx])
    representatives_idx.append(cluster_ids[initial_idx])
    remaining_indices = set(range(len(cluster_flat))) - {initial_idx}
    
    while len(representatives) < num_representatives:
        max_min_dist = -1
        next_representative = None
        
        for idx in remaining_indices:
            min_dist = min(np.linalg.norm(cluster_flat[idx] - rep) for rep in representatives)
            if min_dist > max_min_dist:
                max_min_dist = min_dist
                next_representative = idx
                
        representatives.append(cluster_flat[next_representative])
        representatives_idx.append(cluster_ids[next_representative])
        remaining_indices.remove(next_representative)
    
    return np.array(representatives).reshape(-1, *cluster_points.shape[1:]), representatives_idx


def compute_pairwise_distances(data):
    return pairwise_distances(data)

# def compute_distance_matrix(data, metric='euclidean'):
#     # Compute the full distance matrix using pairwise_distances
#     return pairwise_distances(data, metric=metric)

def compute_distance_matrix(data, metric='euclidean'):
    return pairwise_distances(data, metric=metric, n_jobs=-1)

def compute_distance_matrix(data, metric='euclidean', chunk_size=100):
    n_samples = data.shape[0]
    distance_matrix = np.zeros((n_samples, n_samples), dtype=np.float32)
    
    for i in tqdm(range(0, n_samples, chunk_size)):
        for j in range(0, n_samples, chunk_size):
            i_end = min(i + chunk_size, n_samples)
            j_end = min(j + chunk_size, n_samples)
            distance_matrix[i:i_end, j:j_end] = pairwise_distances(data[i:i_end], data[j:j_end], metric=metric, n_jobs=-1)
    
    return distance_matrix

def ClsMC_RS(data, percentage=0.1, num_clusters=0, chunk_size=100):
    if num_clusters == 0:
        print("STEP 0: Determining num clusters")
        # Step 1: Determine the number of clusters using agglomerative clustering and dendrogram
        num_clusters = determine_num_clusters(data)
    print(f"Number of clusters: {num_clusters}")
    
    # Step 2: Perform clustering with the determined number of clusters
    print("STEP 1: Clustering data")
    data_flat = data.reshape(data.shape[0], -1)
    print(f"Data flat shape: {data_flat.shape}, dtype: {data_flat.dtype}")

    # Compute distance matrix
    print("Computing distance matrix")
    if not os.path.exists(os.path.join(save_embeddings_path_SOTA,f'distance_matrix_CAS_RA.npy')):
        distance_matrix = compute_distance_matrix(data_flat, chunk_size=chunk_size)
        np.save(os.path.join(save_embeddings_path_SOTA,f'distance_matrix_CAS_RA.npy'), distance_matrix)
    else:
        print("Loading distance matrix")
        distance_matrix = np.load(os.path.join(save_embeddings_path_SOTA,f'distance_matrix_CAS_RA.npy'))
        
    print("Performing Agglomerative clustering")
    
    if not os.path.exists(os.path.join(save_embeddings_path_SOTA,f'cluster_labels_CAS_RA.npy')):
        clustering = AgglomerativeClustering(n_clusters=num_clusters, affinity='precomputed', linkage='complete')
        cluster_labels = clustering.fit_predict(distance_matrix)
        np.save(os.path.join(save_embeddings_path_SOTA,f'cluster_labels_CAS_RA.npy'), cluster_labels)
        print(f"Different clusters: {set(cluster_labels)}")
    else:
        print("Loading cluster labels")
        cluster_labels = np.load(os.path.join(save_embeddings_path_SOTA,f'cluster_labels_CAS_RA.npy'))
        
    print(f"Different clusters: {set(cluster_labels)}")
    
    # Step 3: Representative selection max coverage sampling
    print("STEP 2: Selecting representatives")
    representatives = []
    representatives_idx = []
    for cluster_idx in range(num_clusters):
        cluster_points = data[cluster_labels == cluster_idx]
        cluster_ids = np.where(cluster_labels == cluster_idx)[0]
        print(f"Cluster {cluster_idx} - Num points: {len(cluster_points)}")
        print(f"Cluster IDs: {cluster_ids}")
        reps, id_reps = select_representatives(cluster_points, cluster_ids, int(len(cluster_points) * percentage))
        representatives.extend(reps)
        representatives_idx.extend(id_reps)
        
        print(f"Adding {len(reps)} representatives (TOTAL: {len(representatives)})")
        print(f"Chosen representatives: {id_reps}")
    
    return np.array(representatives), representatives_idx
# Example usage:
# data = np.random.rand(100, 2)  # Replace with actual data
# representatives, representatives_idx = ClsMC_RS(data)

# Example usage
#num_points = 1000
#latent_space_shape = (1, 128, 4, 4)
#data = np.random.rand(num_points, *latent_space_shape)
#percentage = 0.1
#print(f"Data shape: {data.shape} (percentage: {percentage})")

#representatives, representatives_idx = ClsMC_RS(data, percentage)


#print(f"Selected representatives shape: {representatives.shape}")


In [None]:
img_embeddings_selected = ClsMC_RS(img_embeddings, percentage=0.1, num_clusters=30, chunk_size=15000)

In [None]:


def compute_pairwise_distances2(data):
    return pairwise_distances(data)

#def monte_carlo_sampling(cluster_data, num_samples):
#    samples = []
#    for _ in range(num_samples):
#        sample_indices = np.random.choice(len(cluster_data), size=len(cluster_data) // 2, replace=False)
#        samples.append(cluster_data[sample_indices])
#    return samples



# def ClsMC_RS(data, percentage=0.1, num_clusters=0):
#     if num_clusters == 0:
#         print(f"STEP 0: Determining num clusters")
#         # Step 1: Determine the number of clusters using agglomerative clustering and dendrogram
#         num_clusters = determine_num_clusters(data)
#     print(f"Number of clusters: {num_clusters}")
    
#     # Step 2: Perform clustering with the determined number of clusters
#     print(f"STEP 1: Clustering data")
#     data_flat = data.reshape(data.shape[0], -1)
#     clustering = AgglomerativeClustering(n_clusters=num_clusters, affinity='precomputed', linkage='complete')
#     cluster_labels = clustering.fit_predict(data_flat)
#     print(f"Different clusters: {set(cluster_labels)}")
    
#     # Step 3: Representative selection max coverage sampling
#     print(f"STEP 2: Selecting representatives")
#     representatives = []
#     representatives_idx = []
#     for cluster_idx in range(num_clusters):
#         cluster_points = data[cluster_labels == cluster_idx]
#         cluster_ids = np.where(cluster_labels == cluster_idx)[0]
#         print(f"Cluster {cluster_idx} - Num points: {len(cluster_points)}")
#         print(f"Cluster {len(cluster_ids)}: {cluster_ids}")
#         reps, id_reps = select_representatives(cluster_points, cluster_ids, len(cluster_points)*percentage)
#         representatives.extend(reps)
#         representatives_idx.extend(id_reps)
        
#         print(f"Adding {len(reps)} representatives (TOTAL: {len(representatives)})")
#         print(f"Chosen representatives: {id_reps}")
    
#     return np.array(representatives), representatives_idx



# # Example usage
# num_points = 1000
# latent_space_shape = (1, 128, 4, 4)
# data = np.random.rand(num_points, *latent_space_shape).astype(np.float32)
# percentage = 0.1
# print(f"Data shape: {data.shape} (percentage: {percentage})")

# representatives, representatives_idx = ClsMC_RS(data, percentage)


# print(f"Selected representatives shape: {representatives.shape}")



In [None]:
print(f"Img list shape: {X_CAS.shape}")
print(f"Mask list shape: {X_CAS_mask.shape}\n")

print(f"Img Embeddings shape: {img_embeddings.shape}")
print(f"Img Embeddings TSNE shape {img_embeddings_tsne.shape}")

print(f"Class shape: {img_cluster_labels.shape}")
print(f"Class 2D shape: {img_cluster_labels_2d.shape}")