In [66]:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torchvision

import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns

from tqdm import tqdm

import math

In [67]:
mnist_train = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('data', train=True, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize((0.1307,), (0.3081,)),
                                   torchvision.transforms.Lambda(lambda x: x.view(-1))
                               ])),
    batch_size=512, shuffle=True)

mnist_test = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('data', train=False, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize((0.1307,), (0.3081,)),
                                   torchvision.transforms.Lambda(lambda x: x.view(-1))
                               ])),
    batch_size=512, shuffle=True)

kmnist_train = torch.utils.data.DataLoader(
    torchvision.datasets.KMNIST('data', train=True, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize((0.1307,), (0.3081,)),
                                   torchvision.transforms.Lambda(lambda x: x.view(-1))
                               ])),
    batch_size=512, shuffle=True)

kmnist_test = torch.utils.data.DataLoader(
    torchvision.datasets.KMNIST('data', train=False, download=True,
                               transform=torchvision.transforms.Compose([
                                    torchvision.transforms.ToTensor(),
                                    torchvision.transforms.Normalize((0.1307,), (0.3081,)),
                                    torchvision.transforms.Lambda(lambda x: x.view(-1))
                               ])),
    batch_size=512, shuffle=True)

fashion_train = torch.utils.data.DataLoader(
    torchvision.datasets.FashionMNIST('data', train=True, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize((0.1307,), (0.3081,)),
                                   torchvision.transforms.Lambda(lambda x: x.view(-1))
                               ])),
    batch_size=512, shuffle=True)

fashion_test = torch.utils.data.DataLoader(
    torchvision.datasets.FashionMNIST('data', train=False, download=True,
                               transform=torchvision.transforms.Compose([
                                    torchvision.transforms.ToTensor(),
                                    torchvision.transforms.Normalize((0.1307,), (0.3081,)),
                                    torchvision.transforms.Lambda(lambda x: x.view(-1))
                               ])),
    batch_size=512, shuffle=True)


cifar10_train = torch.utils.data.DataLoader(
    torchvision.datasets.CIFAR10('data', train=True, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize((0.1307,), (0.3081,)),
                                   torchvision.transforms.Lambda(lambda x: x.view(-1))
                               ])),
    batch_size=512, shuffle=True)

cifar10_test = torch.utils.data.DataLoader(
    torchvision.datasets.CIFAR10('data', train=False, download=True,
                               transform=torchvision.transforms.Compose([
                                    torchvision.transforms.ToTensor(),
                                    torchvision.transforms.Normalize((0.1307,), (0.3081,)),
                                    torchvision.transforms.Lambda(lambda x: x.view(-1))
                               ])),
    batch_size=512, shuffle=True)


datasets = {
    'mnist': (mnist_train, mnist_test, 28*28),
    'kmnist': (kmnist_train, kmnist_test, 28*28),
    'fashion': (fashion_train, fashion_test, 28*28),
    'cifar10': (cifar10_train, cifar10_test, 32*32*3)
}

Files already downloaded and verified
Files already downloaded and verified


In [68]:
PLOTS_PER_ROW = 4

ONLY_PDF = True
CREATE_PDF = True
TEST_MODELS = False

USE_UMAP = False
USE_FILTER_TDA = False

DATASET = "mnist"

In [69]:
from ff_mod.goodness import L2_Goodness, L1_Goodness, Norm_goodness, L2_Goodness_SQRT
from ff_mod.probability import SigmoidProbability, SymmetricFFAProbability

from ff_mod.network.base_ffa import FFANetwork, FFALayer 

from ff_mod.overlay import AppendToEndOverlay, CornerOverlay
from ff_mod.loss import BCELoss

NUM_CLASSES = 10
DIM = 1000
LEARNING_RATE = 0.0001

activations = {
    'ReLU': torch.nn.ReLU(),
    'Sigmoid': torch.nn.Sigmoid(),
    'Tanh': torch.nn.Tanh()
}

goodnesses = {
    'L2M_TK15': L2_Goodness(use_mean=True, topk_units=15),
    'L2M': L2_Goodness(use_mean=True),
    'L1M': L1_Goodness(use_mean=True),
    'L1M_TK15': L1_Goodness(use_mean=True, topk_units=15),
    'L2M_Split': L2_Goodness(positive_split=DIM//2, use_mean=True),
    'L2M_TK15_Split': L2_Goodness(positive_split=DIM//2, use_mean=True, topk_units=15),
    'L1M_Split': L1_Goodness(positive_split=DIM//2, use_mean=True),
    'L1M_TK15_Split': L1_Goodness(positive_split=DIM//2, use_mean=True, topk_units=15),
    'L2S': L2_Goodness(use_mean=False),
    'L2S_TK15': L2_Goodness(use_mean=False, topk_units=15),
    'L1S': L1_Goodness(use_mean=False),
    'L1S_TK15': L1_Goodness(use_mean=False, topk_units=15),
    'L2S_Split': L2_Goodness(positive_split=DIM//2, use_mean=False),
    'L2S_TK15_Split': L2_Goodness(positive_split=DIM//2, use_mean=False, topk_units=15),
    'L1S_Split': L1_Goodness(positive_split=DIM//2, use_mean=False),
    'L1S_TK15_Split': L1_Goodness(positive_split=DIM//2, use_mean=False, topk_units=15),
    'L2Sq' : L2_Goodness_SQRT(use_mean=False),
    'L2Sq_TK15' : L2_Goodness_SQRT(use_mean=False, topk_units=15),
    'L2Sq_Split' : L2_Goodness_SQRT(positive_split=DIM//2, use_mean=False),
    'L2Sq_TK15_Split' : L2_Goodness_SQRT(positive_split=DIM//2, use_mean=False, topk_units=15),
    'L2Mq' :  L2_Goodness_SQRT(use_mean=True),
    'L2Mq_TK15' :  L2_Goodness_SQRT(use_mean=True, topk_units=15),
    'L2Mq_Split' :  L2_Goodness_SQRT(positive_split=DIM//2, use_mean=True),
    'L2Mq_TK15_Split' :  L2_Goodness_SQRT(positive_split=DIM//2, use_mean=True, topk_units=15)
}

goodness_map = {
    "L2_Goodness": "L2",
    "L1_Goodness": "L1",
    "L2_Goodness_Positive_Split_250": "L2_Split",
    "L1_Goodness_Positive_Split_250": "L1_Split",
}

probabilities = {
    'SigmoidProbability_Theta_0': SigmoidProbability(theta=0),
    'SigmoidProbability_Theta_2': SigmoidProbability(theta=2),
    'SymmetricFFAProbability': SymmetricFFAProbability()
}

def create_network(goodness, activation, probability, input_size=784):
    overlay = AppendToEndOverlay(pattern_size=100, num_classes=NUM_CLASSES, p=0.1)

    network = FFANetwork(overlay)
    
    loss = BCELoss(probability_function=probabilities[probability])
    
    network.add_layer(FFALayer(input_size + 100, DIM, goodnesses[goodness], loss, activations[activation], learning_rate=LEARNING_RATE))
    network.add_layer(FFALayer(DIM, DIM, goodnesses[goodness], loss, activations[activation], learning_rate=LEARNING_RATE))
    
    return network

In [70]:
from ff_mod.trainer import Trainer

trainer = Trainer()
trainer.set_dataloader(datasets[DATASET][0], datasets[DATASET][1])

In [71]:
# Traverse all folders and create networks
import os
import json 

MIN_ACCURATION = 0.00

EXPERIMENT_FOLDER = f'experiments_train/{DATASET}/'

USE_GOODNESS_MAP = False

current_models = {}
accuracies = {}



for i, folder in tqdm(enumerate(os.listdir(EXPERIMENT_FOLDER)), leave=False):
    
    # Read json file config in folder
    with open(os.path.join(EXPERIMENT_FOLDER, folder, 'config.json'), 'r') as f:
        config = json.load(f)
        network = create_network(config['goodness'], config['activation'], config['probability'], input_size=datasets[DATASET][2])
        network.load_network(EXPERIMENT_FOLDER+'/' + folder + '/best_model')

        if USE_GOODNESS_MAP:
            config_str = f"{config['activation']}_{goodness_map[config['goodness']]}_{config['probability']}"
        else:
            config_str = f"{config['activation']}_{config['goodness']}_{config['probability']}"
                
        if config_str in current_models:
            act_ind = 1
            while config_str + f"_{act_ind}" in current_models:
                act_ind += 1
            config_str = config_str + f"_{act_ind}"
        
        if TEST_MODELS:
            trainer.set_network(network)
            acc = trainer.test_epoch(verbose=0)
            
            print(f"{config_str} - {acc}")
            
            if acc > MIN_ACCURATION:
                current_models[config_str] = network
                accuracies[config_str] = acc
        else:
            current_models[config_str] = network
            
print(f"Total models: {len(current_models)}")

0it [00:00, ?it/s]

                        

Total models: 108




In [72]:
def normalize_over_mean(latents, batch_size = 512, class_size = 10, dim = 1000):   

    total_batches = latents.shape[0] // (batch_size * class_size)
    
    for batch in range(total_batches):
        skip = batch * batch_size * class_size
        for i in range(batch_size):
            mean_state = np.zeros((dim))
            
            for j in range(class_size):
                #print(f"Spahes: {latents[skip + i + j * batch_size].shape} - {mean_state.shape}")
                mean_state += latents[skip + i + j * batch_size]
            
            mean_state /= class_size
            
            for j in range(10):
                latents[skip + i + j * 512] -= mean_state
    
    return latents

def get_latents(network, total_batches, layer = 1, use_train = True, normalize = False, remove_zeros = False, normalize_mean = False):
    global trainer
    
    latents = []
    labels = []
    positiveness = []
    
    loader = trainer.train_loader if use_train else trainer.test_loader
    
    for i, (data, target) in enumerate(loader):
        data, target = data.to(trainer.device), target.to(trainer.device)
        if i >= total_batches:
            break
        
        for l in range(10):
            latent_t = network.get_latent(data, (target+l)%10, layer)
            target_t = target.clone().detach()
            
            if remove_zeros:
                target_t = target[torch.norm(latent_t, dim=1) > 0.01]
                latent_t = latent_t[torch.norm(latent_t, dim=1) > 0.01]
            
            
            
            if normalize:
                latent_t = latent_t / (torch.norm(latent_t, dim=1, keepdim=True) + 0.00001)
            
            latents.append(latent_t.detach().cpu().numpy())
            labels.append(target_t.detach().cpu().numpy() * np.ones(latent_t.shape[0]))
            
            if l == 0:
                positiveness.append(np.ones(latent_t.shape[0]))
            else:
                positiveness.append(np.zeros(latent_t.shape[0]))
        
    latents = np.concatenate(latents)
    labels = np.concatenate(labels)
    positiveness = np.concatenate(positiveness)
    
    if normalize_mean:
        latent_t = normalize_over_mean(latents)
        
    return latents, labels, positiveness

In [73]:
def get_all_latents(all_models, use_train = True, normalize_mean = False):
    all_latents = {}
    
    for model in all_models.keys():
        remove_zeros = "SymmetricFFAProbability" in model
        
        all_latents[model] = get_latents(all_models[model], 2, use_train = use_train, remove_zeros=remove_zeros, normalize_mean=normalize_mean)
        
        if "SymmetricFFAProbability" in model:
            all_latents[model + "_Normalized"] = get_latents(current_models[model], 2, use_train = use_train, normalize = True, remove_zeros=remove_zeros)
            
    return all_latents

In [74]:
if not ONLY_PDF:
    all_latents = get_all_latents(current_models)

In [75]:
from sklearn.manifold import TSNE
import umap


def get_tsne(all_latents, verbose = 1, limit = 4000, plot = False, plot_class = False, plot_possetive = False, use_umap = False):
    """ Recieves a dictionary of latents and plots them in matplotlib in subplots"""
    
    total_plots = len(all_latents.keys())
    
    assert total_plots > 0
    
    if limit is None:
        for model in all_latents.keys():
            if all_latents[model][0].shape[0] > 4000:
                raise ValueError(f'Latent {model} is too big with {all_latents[model][0].shape[0]} samples.')
    
    if plot or plot_class or plot_possetive:
        if total_plots > PLOTS_PER_ROW:
            rows = math.ceil(total_plots / PLOTS_PER_ROW)
            fig, ax = plt.subplots(rows, PLOTS_PER_ROW, figsize=(5 * PLOTS_PER_ROW, 5 * rows))
        else:
            fig, ax = plt.subplots(1, total_plots, figsize=(5* total_plots, 5))
    
    fig.suptitle('TSNE of latents')
    
    tsne_latents = {}
    
    for i, model in tqdm(enumerate(all_latents.keys()), leave=False, total=total_plots):
        latents, labels, positiveness = all_latents[model]

        # Random samples
        if limit is not None and latents.shape[0] > limit:
            rand_indices = np.random.choice(latents.shape[0], limit, replace=False)
        else:
            rand_indices = np.arange(latents.shape[0])
        
        if not use_umap:
            tsne = TSNE(n_components=2, verbose=verbose, n_iter=700)
        else:
            tsne = umap.UMAP(n_components=2, verbose=verbose)

        tsne_results = tsne.fit_transform(latents[rand_indices])
        
        tsne_latents[model] = tsne_results
        
        if total_plots > PLOTS_PER_ROW:
            if plot_class:
                ax[i//PLOTS_PER_ROW, i%PLOTS_PER_ROW].scatter(tsne_results[:,0], tsne_results[:,1], c=labels[rand_indices], cmap='Set3', s = 1)
                ax[i//PLOTS_PER_ROW, i%PLOTS_PER_ROW].set_title(model)
            elif plot_possetive:
                ax[i//PLOTS_PER_ROW, i%PLOTS_PER_ROW].scatter(tsne_results[:,0], tsne_results[:,1], c=positiveness[rand_indices], cmap='bwr', s = 1)
                ax[i//PLOTS_PER_ROW, i%PLOTS_PER_ROW].set_title(model)
            elif plot:
                ax[i//PLOTS_PER_ROW, i%PLOTS_PER_ROW].scatter(tsne_results[:,0], tsne_results[:,1], alpha=0.5)
                ax[i//PLOTS_PER_ROW, i%PLOTS_PER_ROW].set_title(model)
        else:
            if plot_class:
                ax[i].scatter(tsne_results[:,0], tsne_results[:,1], c=labels[rand_indices], cmap='Set3', s = 1)
                ax[i].set_title(model)
            elif plot_possetive:
                ax[i].scatter(tsne_results[:,0], tsne_results[:,1], c=positiveness[rand_indices], cmap='bwr', s = 1)
                ax[i].set_title(model)
            elif plot:
                ax[i].scatter(tsne_results[:,0], tsne_results[:,1], alpha=0.5)
                ax[i].set_title(model)

    return tsne_latents

In [76]:
if not ONLY_PDF:
    _ = get_tsne(all_latents, limit = 1800, verbose=0, plot_possetive=True)

In [77]:
def plot_latents(all_latents, amount = 1000, range = (400, 600)):
    total_plots = len(all_latents.keys())
    
    if len(all_latents.keys()) > PLOTS_PER_ROW:
        rows = math.ceil(total_plots / PLOTS_PER_ROW)
        fig, ax = plt.subplots(rows, PLOTS_PER_ROW, figsize=(5 * PLOTS_PER_ROW, 5 * rows))
    else:
        fig, ax = plt.subplots(1, len(all_latents.keys()), figsize=(5* len(all_latents.keys()), 5))
    
    
    fig.suptitle('Plot of latents')
    
    for i, model in enumerate(all_latents.keys()):
        latents, labels, positiveness = all_latents[model]
        
        #random_indices = np.random.choice(latents.shape[0], amount, replace=False)
        normal_latents = latents / (np.linalg.norm(latents, axis=1)[:, None]+0.000001)
        if len(all_latents.keys()) > PLOTS_PER_ROW:
            ax[i//PLOTS_PER_ROW, i%PLOTS_PER_ROW].imshow(normal_latents[:amount, range[0]:range[1]], cmap='gray', aspect='auto')
            ax[i//PLOTS_PER_ROW, i%PLOTS_PER_ROW].set_title(model)
        else:
            ax[i].imshow(normal_latents[:amount, range[0]:range[1]], cmap='gray', aspect='auto')
            ax[i].set_title(model)

In [78]:
if not ONLY_PDF:
    plot_latents(all_latents, amount = 1000, range=(0, DIM))

In [79]:
def plot_positives(all_latents, eps = 0.05):
    total_plots = len(all_latents.keys())
    
    mean_positives = {}
    
    if len(all_latents.keys()) > PLOTS_PER_ROW:
        rows = rows = math.ceil(total_plots / PLOTS_PER_ROW)
        fig, ax = plt.subplots(rows, PLOTS_PER_ROW, figsize=(5 * PLOTS_PER_ROW, 5 * rows))
    else:
        fig, ax = plt.subplots(1, len(all_latents.keys()), figsize=(5* len(all_latents.keys()), 5))
    
    fig.suptitle('Use of neurons in latents')
    
    for i, model in enumerate(all_latents.keys()):
        latents, labels, positiveness = all_latents[model]
        
        pos_counts = np.sum(latents[positiveness == 1] > eps, axis=1)
        neg_counts = np.sum(latents[positiveness == 0] > eps, axis=1)
        
        mean_positives[model] = pos_counts.mean() + neg_counts.mean() * 1/9
        
        max_bins = min(max(np.max(pos_counts), np.max(neg_counts), 2), 20)
        
        if len(all_latents.keys()) > PLOTS_PER_ROW:
            ax[i//PLOTS_PER_ROW, i%PLOTS_PER_ROW].hist(neg_counts, bins=max_bins, color='b', alpha=0.7, weights=np.ones(len(neg_counts))/len(neg_counts))
            ax[i//PLOTS_PER_ROW, i%PLOTS_PER_ROW].hist(pos_counts, bins=max_bins, color='r', alpha=0.7, weights=np.ones(len(pos_counts))/len(pos_counts))
            
            ax[i//PLOTS_PER_ROW, i%PLOTS_PER_ROW].set_title(model)
        else:
            ax[i].hist(neg_counts, bins=max_bins, color='b', alpha=0.7, weights=np.ones(len(neg_counts))/len(neg_counts))
            ax[i].hist(pos_counts, bins=max_bins, color='r', alpha=0.7, weights=np.ones(len(pos_counts))/len(pos_counts))
            
            ax[i].set_title(model)
    
    return mean_positives

In [80]:
if not ONLY_PDF:
    mean_positives = plot_positives(all_latents)

In [81]:
def get_hoyer_distribution(all_latents, eps = 1e-6):
    hoyer_mean = {}
    
    for i, model in enumerate(all_latents.keys()):
        latents, labels, positiveness = all_latents[model]
        
        hoyer = np.linalg.norm(latents, ord=1, axis=1) / (np.linalg.norm(latents, ord=2, axis=1)+eps)
        hoyer = (np.sqrt(DIM) - hoyer) / (np.sqrt(DIM) - 1)
        
        hoyer_mean[model] = hoyer.mean()
    
    return hoyer_mean

def plot_hoyer_distribution(all_latents, eps = 1e-6):
    total_plots = len(all_latents.keys())
    
    hoyer_mean = {}
    
    if len(all_latents.keys()) > PLOTS_PER_ROW:
        rows = rows = math.ceil(total_plots / PLOTS_PER_ROW)
        fig, ax = plt.subplots(rows, PLOTS_PER_ROW, figsize=(5 * PLOTS_PER_ROW, 5 * rows))
    else:
        fig, ax = plt.subplots(1, len(all_latents.keys()), figsize=(5* len(all_latents.keys()), 5))
    
    fig.suptitle('Hoyer Sparsity of latents')
    
    for i, model in enumerate(all_latents.keys()):
        latents, labels, positiveness = all_latents[model]
        
        hoyer = np.linalg.norm(latents, ord=1, axis=1) / (np.linalg.norm(latents, ord=2, axis=1)+eps)
        
        hoyer = (np.sqrt(DIM) - hoyer) / (np.sqrt(DIM) - 1)
        
        hoyer_mean[model] = hoyer.mean()
        
        if len(all_latents.keys()) > PLOTS_PER_ROW:
            ax[i//PLOTS_PER_ROW, i%PLOTS_PER_ROW].hist(hoyer, bins=50)
            ax[i//PLOTS_PER_ROW, i%PLOTS_PER_ROW].set_title(model)
            ax[i//PLOTS_PER_ROW, i%PLOTS_PER_ROW].set_xlim(0, 1)
        else:
            ax[i].hist(hoyer, bins=50)
            ax[i].set_title(model)
            ax[i].set_xlim(0, 1)
    
    return hoyer_mean

In [82]:
if not ONLY_PDF:
    hoyer_mean = get_hoyer_distribution(all_latents)

In [83]:
if not ONLY_PDF:
    plot_positives(all_latents, eps = 0.05)

In [84]:
from sklearn.cluster import DBSCAN
from scipy.spatial.distance import cdist

def get_best_eps(latents, limit = 1000, percentage = 0.9, steps = 10):
    # Do a binary search to find the best eps
    latents = latents[:limit]
    
    best_eps = 0
    min_eps, max_eps = 0.01, np.max(cdist(latents, latents))
    
    for i in range(steps):
        eps = (min_eps + max_eps) / 2
        db = DBSCAN(eps=eps, min_samples=7).fit(latents)
        
        if np.sum(db.labels_ != -1) / limit < percentage:
            min_eps = eps
        else:
            max_eps = eps
    
    best_eps = (min_eps + max_eps) / 2
    
    return best_eps
    

def get_filtered_data(all_latents, min_samples = 7, eps = 2, percentage = 0.9, batch_size = 3000):
    filtered_latents = {}
    
    for i, model in tqdm(enumerate(all_latents.keys()), leave=False):
        latents, labels, positiveness = all_latents[model]
        
        best_eps = get_best_eps(latents, percentage = percentage)
        
        # Since the dataset may be to big, we will use a batched approach
        latents_result = []
        labels_result = []
        positiveness_result = []
        
        for i in range(0, len(latents), batch_size):
            db = DBSCAN(eps=best_eps, min_samples=min_samples).fit(latents[i:i+batch_size])
            
            latents_result.append(latents[i:i+batch_size][db.labels_ != -1])
            labels_result.append(labels[i:i+batch_size][db.labels_ != -1])
            positiveness_result.append(positiveness[i:i+batch_size][db.labels_ != -1])
            
        latents = np.concatenate(latents_result)
        labels = np.concatenate(labels_result)
        positiveness = np.concatenate(positiveness_result)
        
        filtered_latents[model] = (latents, labels, positiveness)
    
    return filtered_latents

In [85]:
import numpy as np
from ripser import ripser
from persim import plot_diagrams
from ripser import Rips

# Required for a LaTeX error
import matplotlib as mpl
mpl.rcParams.update(mpl.rcParamsDefault)

def get_ph_diagrams(all_latents, max_dim = 1, verbose = True, limit = 1700):
    
    total_plots = len(all_latents.keys())
    
    assert total_plots > 0
    
    rips = Rips(maxdim=max_dim, verbose=verbose)
    
    if total_plots > PLOTS_PER_ROW:
        rows = math.ceil(total_plots / PLOTS_PER_ROW)
        
        fig, ax = plt.subplots(2 * rows, PLOTS_PER_ROW, figsize=(5 * PLOTS_PER_ROW, 2 * 5 * rows), gridspec_kw={'hspace': 0.2})
    else:
        fig, ax = plt.subplots(2, len(all_latents.keys()), figsize=(5 * total_plots, 2 * 5), gridspec_kw={'hspace': 0.2})
    
    fig.suptitle('pi 0 of latents')
    
    diagrams_coords = {}
    
    for i, model in tqdm(enumerate(all_latents.keys()), leave=False, total=total_plots):
        if limit is not None and all_latents[model][0].shape[0] > limit:
            rand_indices = np.random.choice(all_latents[model][0].shape[0], limit, replace=False)
        else:
            rand_indices = np.arange(all_latents[model][0].shape[0])
            
        diagrams = rips.fit_transform(all_latents[model][0][rand_indices])
        
        if total_plots > PLOTS_PER_ROW:
            plot_diagrams(diagrams, ax=ax[2 * (i // PLOTS_PER_ROW), i % PLOTS_PER_ROW], lifetime=True, size=5, title=model)
            ax[2 * (i // PLOTS_PER_ROW) + 1, i % PLOTS_PER_ROW].hist(diagrams[0][:, 1][:-1], bins=100)
        else:
            plot_diagrams(diagrams, ax=ax[0, i], lifetime=True, size=5, title=model)
            ax[1,i].hist(diagrams[0][:, 1][:-1], bins=100)

        diagrams_coords[model] = diagrams
        
    return diagrams_coords

In [86]:
if not ONLY_PDF:
    _ = get_ph_diagrams(all_latents, max_dim = 0, verbose = False, limit = 1001)

In [87]:
from sklearn.cluster import DBSCAN
from scipy.spatial.distance import cdist
def compute_separability(all_latents, kfactor = 5, batch_size = 32, limit = 4000):
    """ Compute the separability of the latents by using the separability index"""
    
    figure = plt.figure(figsize=(11.69,8.27))
    figure.suptitle('Separability')
    
    separability = {}
    
    for i, model in enumerate(all_latents.keys()):
        latents_pre, labels_pre, positiveness_pre = all_latents[model]
        
        random_indices = np.random.choice(latents_pre.shape[0], limit, replace=False)
        latents, labels, positiveness = latents_pre[random_indices], labels_pre[random_indices], positiveness_pre[random_indices]
        
        pos_sum, neg_sum = 0, 0

        for i in tqdm(range(math.ceil(latents.shape[0]/batch_size)), leave=False):
            
            # For each batch, compute the indexes of the kfactor nearest neighbors
            indexes = np.argsort(cdist(latents[batch_size * i : batch_size * i + batch_size], latents), axis=1)[:, 1 : 1 + kfactor]
            
            current_positives = positiveness[batch_size * i : batch_size * i + batch_size] == 1
            current_negatives = positiveness[batch_size * i : batch_size * i + batch_size] == 0
            
            pos_sum += np.sum(positiveness[indexes[current_positives]] == 1)
            neg_sum += np.sum(positiveness[indexes[current_negatives]] == 0)
            
        total_pos = np.sum(positiveness == 1) * kfactor
        total_neg = np.sum(positiveness == 0) * kfactor
        
        separability[model] = ((pos_sum / total_pos) + (neg_sum / total_neg))/2
    
    plt.axis('off')
    plt.table(cellText=[[k, v] for k, v in separability.items()], colLabels=["Model", "Separability"], loc='center')
    
    return separability

In [88]:
if not ONLY_PDF:
    separability_mean = compute_separability(all_latents, batch_size=128, limit=1000)

## Create a PDF with all of them 

In [89]:
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages

import os

def create_report(filename = "report.pdf", use_train = False, normalize_over_mean = False):

    all_latents = get_all_latents(current_models, use_train=use_train, normalize_mean=normalize_over_mean)

    os.makedirs(f'reports/{DATASET}/', exist_ok=True)

    with PdfPages(f'reports/{DATASET}/'+filename) as pdf:
        # Create and save each plot in a new page
        _ = get_tsne(all_latents, limit = 1500, verbose=0, plot_possetive=True)
        pdf.savefig()
        plt.close()
        
        if USE_UMAP:
            _ = get_tsne(all_latents, limit = 1800, verbose=0, plot_possetive=True, use_umap=True)
            pdf.savefig()
            plt.close()
        
        _ = plot_latents(all_latents, amount = 512 * 2, range=(0, DIM))
        pdf.savefig()
        plt.close()
        
        plot_positives(all_latents)
        pdf.savefig()
        plt.close()
        
        plot_hoyer_distribution(all_latents)
        pdf.savefig()
        plt.close()

        _ = get_ph_diagrams(all_latents, max_dim = 0, verbose = False, limit = 5000)
        pdf.savefig()
        plt.close()
        
        #_ = compute_separability(all_latents, batch_size=128*2, limit=512*10)
        #pdf.savefig()
        #plt.close()
        
        if USE_FILTER_TDA:
            FILTER_PERCENTAGE = 0.95
            all_filtered_latents = get_filtered_data(all_latents, percentage=FILTER_PERCENTAGE)
            
            _ = get_ph_diagrams(all_filtered_latents, max_dim = 0, verbose = False, limit = 6500)
            plt.title(f"Persistence Diagrams Filtered ({FILTER_PERCENTAGE})")
            pdf.savefig()
            plt.close()
        
    print(f'PDF file "{filename}" created successfully.')

if CREATE_PDF:
    create_report(filename = f"report_{DATASET}_3_bigtest.pdf", use_train = False)
    #create_report(filename = f"report_{DATASET}_train_3_final.pdf", use_train = True)
    #create_report(filename = f"report_{DATASET}_train_meaned_3_final.pdf", use_train = True, normalize_over_mean=True)

    raise Exception("Finished")

                                                 

PDF file "report_mnist_3_bigtest.pdf" created successfully.


Exception: Finished