In [1]:
import numpy as np
import scipy
import pickle as pkl
from matplotlib import pyplot as plt

from utils import *
from nystrom import *

import os

In [3]:
inv_ = np.vectorize(lambda x: 1/x if np.abs(x) > 1e-9 else 0)

def analyse_sampling_method(kernel, 
                            sampling_method=sample_landmarks_random, 
                            n_landmarks=1024, 
                            n_eigenvals=[128,]):
    
    subkernel, ixs = sampling_method(kernel, n_landmarks, return_ixs=True)
    
    eigenvals, eigenvecs = scipy.linalg.eigh(subkernel)
    order = np.flip(np.argsort(np.abs(eigenvals)))
    eigenvals = eigenvals[order]
    eigenvecs = eigenvecs[:,order]
    
    masking = lambda x: np.delete(np.delete(x,ixs,axis=0),ixs,axis=1)
    clear_diag = lambda x: x-np.diag(np.diag(x))
    
    reconstructed_kernel = kernel[:,ixs] @ scipy.linalg.pinvh(subkernel) @ kernel[ixs]
    r_squared = 1-clear_diag(np.square(masking(kernel-reconstructed_kernel))).sum()/clear_diag(np.square(masking(kernel)-np.mean(masking(kernel)))).sum()
    
    truncation_r_squared = {}
    for n_eigs in n_eigenvals:
        trunc_reconstructed_kernel = kernel[:,ixs] @ (eigenvecs[:,:n_eigs] @ np.diag(inv_(eigenvals[:n_eigs])) @ eigenvecs[:,:n_eigs].T) @ kernel[ixs]
        r_squared_truncated = 1-clear_diag(np.square(masking(kernel-trunc_reconstructed_kernel))).sum()/clear_diag(np.square(masking(kernel)-np.mean(masking(kernel)))).sum()
        truncation_r_squared[n_eigs] = r_squared_truncated
        print(f'{n_eigs}: {r_squared_truncated}')
        
    return r_squared, truncation_r_squared


    

In [4]:
def analyse_kernel_sampling(dataset,distance=False,iterated=False,show=False,results=None):
      if results is None:
            kernel = load_kernel_matrix(dataset,distance,iterated)
            
            kernel_shuffle = np.random.permutation(kernel.shape[0])
            kernel = kernel[kernel_shuffle][:,kernel_shuffle]
            results = {}
            for n_landmarks in [2,3,4,6,8,11,16,23,32,45,64,91,128,181,256,
                              362,512,724,1024,1448,2048,2896,4096,5793,8192]:
                  print(f'{n_landmarks=}')
                  n_eigenvals = [x for x in \
                                          [2,3,4,6,8,11,16,23,32,45,64,91,128,181,256,
                                          362,512,724,1024,1448,2048,2896,4096,5793,8192] \
                                    if x<=n_landmarks]
                  results[n_landmarks] = analyse_sampling_method(kernel, n_landmarks=n_landmarks,n_eigenvals=n_eigenvals)
            print()
      
      plot_x,plot_y,plot_c = [],[],[]
      argmaxs = []

      for n_landmarks in results:
            maxval = -np.inf
            argmax = None
            for n_eigenvals, r_sq in results[n_landmarks][1].items():
                  n_landmarks = np.sqrt(2)**np.round(np.log(n_landmarks)/np.log(np.sqrt(2)))
                  n_eigenvals = np.sqrt(2)**np.round(np.log(n_eigenvals)/np.log(np.sqrt(2)))
                  plot_x.append(n_landmarks)
                  plot_y.append(n_eigenvals)
                  plot_c.append(r_sq)
                  if r_sq > maxval:
                        maxval = r_sq
                        argmax = n_eigenvals
            argmaxs.append((n_landmarks,argmax))

      plt.figure(figsize=(12,10))
      plt.rcParams.update({'font.size': 13})
      plt.gca().set_xscale('log')
      plt.gca().set_yscale('log')
      plt.scatter(plot_x,plot_y,s=575,c=plot_c, marker=',',edgecolors='none',cmap='magma',vmin=.6,vmax=1)
      plt.colorbar()

      plt.plot(*zip(*argmaxs),color='lightslategray',lw=3, label='optimal truncation')
      plt.legend(loc='upper left')
      plt.xlabel('# landmarks (log-scale)')
      plt.ylabel('# eigenvals (log-scale)')
      plt.xticks([2,4,8,16,32,64,128,256,512,1024,2048,4096,8192],[2,4,8,16,32,64,128,256,512,1024,2048,4096,8192])
      plt.yticks([2,4,8,16,32,64,128,256,512,1024,2048,4096,8192],[2,4,8,16,32,64,128,256,512,1024,2048,4096,8192])
      plt.minorticks_off()
      if distance:
            plt.title(f'$R^2$ for the truncated indefinite Nyström approximation for $\\delta$ on {DATASET_MAP[dataset]}')
      else:
            plt.title(f'$R^2$ for the truncated indefinite Nyström approximation for $K$ on {DATASET_MAP[dataset]}')
      plt.tight_layout()
      plt.savefig(f'../plots/nystrom_truncation_{dataset_string(dataset,distance,iterated)}.png',format='png',dpi=300)
      if show:
            plt.show()
      plt.close()
      
      return results

In [5]:
for dataset in ['16_10000_15','8_10000_25',]:
    for distance, iterated in [(False,False), (True,False)]:
        print('##################################################')
        print(f'Analysing {dataset_string(dataset,distance,iterated)}')
        print('##################################################')
        print()
        
        if os.path.isfile(f'../data/nystrom_analysis_results_{dataset_string(dataset,distance,iterated)}.pkl'):
            with open(f'../data/nystrom_analysis_results_{dataset_string(dataset,distance,iterated)}.pkl','rb') as f:
                results = pkl.load(f)
        else:
            results = None
        
        analyse_kernel_sampling(dataset,distance,iterated,False,results)
        


##################################################
Analysing 16_10000_15
##################################################

##################################################
Analysing 16_10000_15_distance
##################################################

##################################################
Analysing 8_10000_25
##################################################

##################################################
Analysing 8_10000_25_distance
##################################################

