In [None]:
from IPython.display import clear_output
from google.colab import drive
drive.mount('/content/gdrive')
results_path = '/content/gdrive/MyDrive/LieStationaryKernel_plots/'

In [None]:
!pip install git+https://github.com/imbirik/LieStationaryKernels.git
!pip install backends
!pip install git+https://github.com/vdutor/SphericalHarmonics.git
clear_output()

In [None]:
import sys
import itertools
import torch
import gc
import matplotlib.pyplot as plt
import random
import seaborn as sns
from torch.autograd.functional import _vmap as vmap
import numpy as np
import pandas as pd
from tqdm import tqdm_notebook as tqdm
from math import sqrt
from collections import defaultdict
from itertools import islice
from lie_stationary_kernels.spectral_kernel import EigenbasisSumKernel
from lie_stationary_kernels.spectral_measure import SqExpSpectralMeasure, MaternSpectralMeasure
from lie_stationary_kernels.prior_approximation import RandomPhaseApproximation

from lie_stationary_kernels.space import TranslatedCharactersBasis

from lie_stationary_kernels.spaces import SO, SU, Stiefel
from lie_stationary_kernels.spaces.sphere import Sphere

device = 'cuda' if torch.cuda.is_available() else 'cpu'

torch.set_printoptions(precision=6, sci_mode=False, linewidth=160, edgeitems=15)
dtype = torch.float64

print(device)

In [None]:
!cp '/content/gdrive/MyDrive/precomputed_characters.json' '/usr/local/lib/python3.10/dist-packages/lie_stationary_kernels/spaces/precomputed_characters.json'

In [None]:
# auxilarly functions
def compute_l2_diff(x, y):
        return torch.pow(torch.mean(torch.pow(x-y, 2)), 0.5)

def raw_character(chi, average_order, gammas_x_h):
        chi_x_h = chi(gammas_x_h).reshape(-1, average_order)
        return chi_x_h

def get_order(n):
  new_order = [0]
  for i in range(2, n+1):
    for j in range(1,i):
        new_order.append(i+n*(j-1)-1)
    for j in range(0, i):
        new_order.append((i-1)*n + j)
  return new_order

def pairwise_embed(self, x, y):
        """For arrays of form x_iH, y_jH computes embedding corresponding to x_i, y_j
        i.e. flattened array of form G.embed(h_m^{-1}x_i^{-1}y_jh_k)"""
        x_y_ = self.pairwise_diff(x, y)
        embed = self.g.pairwise_embed(self.g.inv(self.g.pairwise_diff(self.g.inv(x_y_), self.h_samples)), self.h_samples)

        embed = embed.reshape(len(x)*len(y), self.average_order*self.average_order, -1)[:, get_order(self.average_order), :].reshape(-1, embed.shape[-1])

        return embed

In [None]:
def stiefel_kernel_approx_error(measure_class, measure_kwargs, measure_name, manifold_class, manifold_kwargs, manifold_name):
        num_tries = 20
        all_errors = []
        for i in tqdm(range(num_tries)):

                torch.manual_seed(i)
                np.random.seed(i)
                random.seed(i)

                average_order_ = 150
                average_order = average_order_ ** 2
                manifold = manifold_class(average_order=average_order_, **manifold_kwargs)
                dim = manifold.dim
                measure = measure_class(dim, **measure_kwargs)
                space_kernel = EigenbasisSumKernel(measure=measure, manifold=manifold)

                x = manifold.rand(20)
                y = x
                x_y_embed = pairwise_embed(manifold, x, y)
                values = [torch.zeros(len(x), len(y), dtype=dtype, device=device) for _ in range(average_order)]

                for eigenspace in tqdm(manifold.lb_eigenspaces):
                        lmd = eigenspace.lb_eigenvalue
                        f = eigenspace.phase_function.chi
                        chis_x_y_h = raw_character(f, average_order, x_y_embed)

                        for i in range(average_order-1):
                                f_x_y = torch.mean(chis_x_y_h[:, :i+1], dim=-1)
                                values[i+1] += measure(lmd) * f_x_y.view(x.size()[0], y.size()[0]).real/space_kernel.normalizer
                errors = []
                for i in range(average_order):
                        errors.append(compute_l2_diff(values[-1], values[i]).detach().cpu().numpy())

                all_errors.append(errors)
        all_errors = np.array(all_errors)
        pd.DataFrame(data=all_errors).to_csv(results_path + f'{manifold_name}_{measure_name}_2.csv')
        mean_error = np.mean(all_errors, axis=0)
        quantile_error_25 = np.quantile(all_errors, 0.25, axis=0)
        quantile_error_75 = np.quantile(all_errors, 0.75, axis=0)

        plt.plot(mean_error[:1500], c='blue', label='mean')
        plt.plot(quantile_error_25[:1500], c='red', label='q25')
        plt.plot(quantile_error_75[:1500], c='green', label='q75')
        plt.legend()
        plt.savefig(results_path + f'{manifold_name}_{measure_name}_2.pdf')
        plt.show()

In [None]:
torch.cuda.empty_cache()
gc.collect()

measures = [
    (MaternSpectralMeasure, {'lengthscale': 0.4, 'nu': 0.5}, 'matern12'),
    (MaternSpectralMeasure, {'lengthscale': 0.6, 'nu': 1.5}, 'matern32'),
    (MaternSpectralMeasure, {'lengthscale': 0.7, 'nu': 2.5}, 'matern52'),
    (SqExpSpectralMeasure, {'lengthscale': 0.95}, 'heat')]

groups = [(Stiefel, {'n': 5, 'm':2, 'order': 20}, 'stiefel52'), (Stiefel, {'n': 5, 'm':3, 'order': 20}, 'stiefel53')]

for measure in measures:
  for group in groups:
    print(measure, group)
    stiefel_kernel_approx_error(measure[0], measure[1], measure[2], group[0], group[1], group[2])
    torch.cuda.empty_cache()
    gc.collect()


In [None]:
def sphere_kernel_approx_error(measure_class, measure_kwargs, measure_name, manifold_class, manifold_kwargs, manifold_name):
        num_tries = 20
        all_errors = []
        for i in tqdm(range(num_tries)):
                torch.manual_seed(i)
                np.random.seed(i)
                random.seed(i)

                average_order_ = 150
                average_order = average_order_ ** 2
                manifold = manifold_class(average_order=average_order_, **manifold_kwargs)
                dim = manifold.dim
                measure = measure_class(dim, **measure_kwargs)
                space_kernel = EigenbasisSumKernel(measure=measure, manifold=manifold)

                x = manifold.rand(20)
                y = x
                x_y_embed = pairwise_embed(manifold, x, y)
                values = [torch.zeros(len(x), len(y), dtype=dtype, device=device) for _ in range(average_order)]

                n = manifold.n
                sphere = Sphere(n=n-1, order=10)

                sphere_kernel = EigenbasisSumKernel(measure=measure, manifold=sphere)
                cov = sphere_kernel(x.reshape(-1,n), y.reshape(-1,n))
                if n > 3:
                  lb_eigenspaces = [eig for eig in manifold.lb_eigenspaces if eig.index[1] == 0][:10]
                else:
                  lb_eigenspaces = manifold.lb_eigenspaces[:10]
                for eigenspace in tqdm(lb_eigenspaces):
                        lmd = eigenspace.lb_eigenvalue
                        f = eigenspace.phase_function.chi
                        full_result = raw_character(f, average_order, x_y_embed)
                        for i in range(average_order-1):
                                f_x_y = torch.mean(full_result[:, :i+1], dim=-1)
                                values[i+1] += measure(lmd) * f_x_y.view(x.size()[0], y.size()[0]).real / sphere_kernel.normalizer
                errors = []
                for i in range(average_order):
                        errors.append(compute_l2_diff(cov, values[i]).detach().cpu().numpy())
                all_errors.append(errors)
        all_errors = np.array(all_errors)
        pd.DataFrame(data=all_errors).to_csv(results_path + f'{manifold_name}_{measure_name}_2.csv')

        mean_error = np.mean(all_errors, axis=0)
        quantile_error_25 = np.quantile(all_errors, 0.25, axis=0)
        quantile_error_75 = np.quantile(all_errors, 0.75, axis=0)

        plt.plot(mean_error[100:2500], c='blue', label='mean')
        plt.plot(quantile_error_25[100:2500], c='red', label='q25')
        plt.plot(quantile_error_75[100:2500], c='green', label='q75')
        plt.legend()
        plt.savefig(results_path + f'{manifold_name}_{measure_name}_2.pdf')
        plt.show()

In [None]:
torch.cuda.empty_cache()
gc.collect()

measures = [
    (MaternSpectralMeasure, {'lengthscale': 0.4, 'nu': 0.5}, 'matern12'),
    (MaternSpectralMeasure, {'lengthscale': 0.6, 'nu': 1.5}, 'matern32'),
    (MaternSpectralMeasure, {'lengthscale': 0.7, 'nu': 2.5}, 'matern52'),
    (SqExpSpectralMeasure, {'lengthscale': 0.95}, 'heat')]

groups = [(Stiefel, {'n': 5, 'm':1, 'order': 100}, 'stiefel51'), (Stiefel, {'n': 3, 'm':1, 'order': 100}, 'stiefel31')]

for measure in measures:
  for group in groups:
    print(measure, group)
    sphere_kernel_approx_error(measure[0], measure[1], measure[2], group[0], group[1], group[2])
    torch.cuda.empty_cache()
    gc.collect()