In [1]:
import numpy as np
import scipy
import pickle as pkl
import matplotlib.pyplot as plt
from nystrom import *
from utils import *
import os

In [3]:
inv_ = np.vectorize(lambda x: 1/x if np.abs(x) > 1e-9 else 0)

def compare_sampling_methods(kernel, 
                             n_candidates= 5793,
                             sampling_methods=[sample_landmarks_kmeans,sample_landmarks_det,sample_landmarks_random], 
                             eval_landmarks=[256,362,512,724,1024,1448,2048,2896,4096],
                             eval_eigenvals=[16,23,32,45,64,91,128,181,256,
                                             362,512,724,1024,1448,2048,2896,4096]):
    
    results = {}
    for sampling_method in sampling_methods:
        print(f'sampling ({sampling_method.__name__})')
        all_subkernel, all_ixs = sampling_method(kernel[:n_candidates][:,:n_candidates], eval_landmarks[-1], return_ixs=True)
        results[sampling_method.__name__] = {}
        for n_landmarks in eval_landmarks:
            print(f'{n_landmarks=}')
            ixs = all_ixs[:n_landmarks]
            subkernel = kernel[ixs][:,ixs]
            eigenvals, eigenvecs = scipy.linalg.eigh(subkernel)
            order = np.flip(np.argsort(np.abs(eigenvals)))
            eigenvals = eigenvals[order]
            eigenvecs = eigenvecs[:,order]
            
            masking = lambda x: np.delete(np.delete(x,ixs,axis=0),ixs,axis=1)
            clear_diag = lambda x: x-np.diag(np.diag(x))
            
            results[sampling_method.__name__][n_landmarks] = {}
            
            for n_eigenvals in eval_eigenvals:        
                if n_eigenvals <= n_landmarks:
                    trunc_reconstructed_kernel = kernel[:,ixs] @ (eigenvecs[:,:n_eigenvals] @ np.diag(inv_(eigenvals[:n_eigenvals])) @ eigenvecs[:,:n_eigenvals].T) @ kernel[ixs]
                    r_squared_truncated = 1-clear_diag(np.square(masking(kernel-trunc_reconstructed_kernel))).sum()/clear_diag(np.square(masking(kernel)-np.mean(masking(kernel)))).sum()
                    print(f'{n_eigenvals=}: {r_squared_truncated}')
                    results[sampling_method.__name__][n_landmarks][n_eigenvals] = r_squared_truncated
            print()
    return results

In [4]:
for dataset in ['8_10000_25','16_10000_15']:
    for distance, iterated in [(False,False), (True,False)]:
        print('##################################################')
        print(f'Comparing sampling methods for {dataset_string(dataset,distance,iterated)}')
        print('##################################################')
        print()
        
        results_path = f'../data/sampling_methods_compare_results_{dataset_string(dataset,distance,iterated)}.pkl'
        
        if os.path.isfile(results_path):
            with open(results_path,'rb') as f:
                results = pkl.load(f)
        else:
            kernel = load_kernel_matrix(dataset,distance,iterated)
            kernel_shuffle = np.random.permutation(kernel.shape[0])
            kernel = kernel[kernel_shuffle][:,kernel_shuffle]
            
            results = compare_sampling_methods(kernel)
            
            with open(results_path,'wb') as f:
                pkl.dump(results,f)

        plot_results = {}

        for method, method_res in results.items():
            method_results = ([],[],[])
            for n_landmarks, landmarks_res in method_res.items():
                r_squared_max = -np.inf
                argmax = None
                for n_eigs, r_squared in landmarks_res.items():
                    if r_squared > r_squared_max:
                        r_squared_max = r_squared
                        argmax = n_eigs
                method_results[0].append(n_landmarks)
                method_results[1].append(r_squared_max)
                method_results[2].append(argmax)
                plot_results[method] = method_results

        plt.figure(figsize=(12,8))
        plt.rcParams.update({'font.size': 13})

        if distance:
            plt.title(f'$R^2$ for different sampling methods (by #landmarks, optimally truncated), for $\\delta$ on {DATASET_MAP[dataset]}')
        else:
            plt.title(f'$R^2$ for different sampling methods (by #landmarks, optimally truncated), for $K{"^{(2)}" if iterated else ""}$ on {DATASET_MAP[dataset]}')
        plt.gca().set_ylabel('$R^2$')
        plt.gca().set_xlabel(f'# landmarks')

        cmap = plt.cm.magma  # define the colormap
        cmaplist = [cmap(i) for i in range(cmap.N)] # extract all colors from the colormap

        first = True

        for i,(method, method_results) in enumerate(plot_results.items()):
            print(method,max(method_results[1]))
            plt.plot(method_results[0],method_results[1], lw=2, marker='v', ms=8, color=cmaplist[(i)*64+64], label=f'{"random" if method=="sample_landmarks_random" else "k-means" if method=="sample_landmarks_kmeans" else "optimal subspace"}')
            if i <= 1:
                shift_x = 25
                shift_y = -.0005
            else:
                shift_x = -100
                shift_y = .00025
            for x in zip(*method_results):
                if first:
                    first = False
                    annotation = f'{x[2]} (optimal truncation)'
                else:
                    annotation = f'{x[2]}'
                plt.annotate(annotation, (x[0]+shift_x,x[1]+shift_y))

        plt.xticks([256,512,724,1024,1448,2048,2896,4096],[256,512,724,1024,1448,2048,2896,4096])
        plt.minorticks_off()
        plt.tight_layout()
        plt.legend()
        plt.savefig(f'../plots/sampling_comparison_{dataset_string(dataset,distance,iterated)}.png',format='png',dpi=300)
        # plt.show()
        plt.close()

##################################################
Comparing sampling methods for 8_10000_25
##################################################

sample_landmarks_kmeans 0.988598507828863
sample_landmarks_det 0.9889041095936378
sample_landmarks_random 0.9850736015755234
##################################################
Comparing sampling methods for 8_10000_25_distance
##################################################

sample_landmarks_kmeans 0.9618390439474204
sample_landmarks_det 0.9620433227999934
sample_landmarks_random 0.9448636727042176
##################################################
Comparing sampling methods for 16_10000_15
##################################################

sample_landmarks_kmeans 0.880030241123318
sample_landmarks_det 0.8807795387757664
sample_landmarks_random 0.8653330336988072
##################################################
Comparing sampling methods for 16_10000_15_distance
##################################################

sample_landmarks_kmeans 