In [None]:
from IPython.display import clear_output
from google.colab import drive
drive.mount('/content/gdrive')
results_path = '/content/gdrive/MyDrive/LieStationaryKernel_plots/'

In [None]:
!pip install git+https://github.com/imbirik/LieStationaryKernels.git
!pip install backends
!pip install git+https://github.com/vdutor/SphericalHarmonics.git
clear_output()

In [None]:
import sys
import itertools
import torch
import gc
import matplotlib.pyplot as plt
import random
import seaborn as sns
from torch.autograd.functional import _vmap as vmap
import numpy as np
import pandas as pd
from tqdm import tqdm_notebook as tqdm
from math import sqrt
from collections import defaultdict
from itertools import islice
from lie_stationary_kernels.spectral_kernel import EigenbasisSumKernel
from lie_stationary_kernels.spectral_measure import SqExpSpectralMeasure, MaternSpectralMeasure
from lie_stationary_kernels.prior_approximation import RandomPhaseApproximation

from lie_stationary_kernels.space import TranslatedCharactersBasis

from lie_stationary_kernels.spaces import SO, SU, Stiefel
from lie_stationary_kernels.spaces.sphere import Sphere

device = 'cuda' if torch.cuda.is_available() else 'cpu'

torch.set_printoptions(precision=6, sci_mode=False, linewidth=160, edgeitems=15)
dtype = torch.float64

In [None]:
!cp '/content/gdrive/MyDrive/precomputed_characters.json' '/usr/local/lib/python3.10/dist-packages/lie_stationary_kernels/spaces/precomputed_characters.json'

In [None]:
# auxilarly functions
def compute_l2_diff(x, y):
        return torch.pow(torch.mean(torch.pow(x-y, 2)), 0.5)

def make_raw_embedding(self, x):
        embeddings = []
        phases = self.phases  # [num_phase, ...]
        # left multiplication
        phase_x_inv = self.kernel.manifold.pairwise_embed(phases, x)  # [len(x), num_phase, ...]
        for i, eigenspace in enumerate(islice(self.kernel.manifold.lb_eigenspaces, self.approx_order)):
            lmd = eigenspace.lb_eigenvalue
            f = eigenspace.phase_function
            eigen_embedding = f(phase_x_inv).real.view(self.phase_order, x.size()[0]).T/eigenspace.dimension
            embeddings.append(eigen_embedding)
        return embeddings

In [None]:
def plot_sampler_error_normalized(measure_class, measure_kwargs, measure_name, manifold_class, manifold_kwargs, manifold_name):
      sampled_errors = []
      for i in tqdm(range(20)):
        phase_order = 10 ** 4
        manifold = manifold_class(**manifold_kwargs)
        dim = manifold.dim
        measure = measure_class(dim,**measure_kwargs)

        kernel = EigenbasisSumKernel(measure=measure, manifold=manifold)
        sampler = RandomPhaseApproximation(kernel=kernel, phase_order=phase_order)
        x = manifold.rand(50)
        y = x
        x_y = manifold.pairwise_embed(x, y)
        embedding_x = make_raw_embedding(sampler, x)
        embedding_y = make_raw_embedding(sampler, y)

        cov_x_y = kernel(x, y)
        values = [torch.zeros_like(cov_x_y) for _ in range(phase_order+1)]

        for j in tqdm(range(phase_order)):
                for i, eigenspace in enumerate(manifold.lb_eigenspaces):
                        cov = (embedding_x[i][:, :j+1] @  torch.conj(embedding_y[i][:, :j+1].T) / (j+1)).real
                        diag = torch.sqrt(torch.diagonal(cov))
                        cov = cov/diag[:,None]/diag[None,:] * (eigenspace.dimension * eigenspace.dimension * kernel.measure(eigenspace.lb_eigenvalue)/kernel.normalizer).real
                        values[j+1] += cov.clone()
        errors = [None for _ in range(phase_order)]
        for j in range(phase_order):
            errors[j] = compute_l2_diff(cov_x_y, values[j]).item()
        sampled_errors.append(errors)

      sampled_errors = np.array(sampled_errors)
      pd.DataFrame(data=sampled_errors).to_csv(results_path + f'{manifold_name}_{measure_name}_3_normalized.csv')

      mean_error = np.mean(sampled_errors, axis=0)
      quantile_error_25 = np.quantile(sampled_errors, 0.25, axis=0)
      quantile_error_75 = np.quantile(sampled_errors, 0.75, axis=0)

      plt.plot(list(range(phase_order))[:10**4], mean_error[:10**4], c='blue', label='mean')
      plt.plot(list(range(phase_order))[:10**4], quantile_error_25[:10**4], c='red', label='q25')
      plt.plot(list(range(phase_order))[:10**4], quantile_error_75[:10**4], c='green', label='q75')

      plt.axhline(0)
      plt.savefig(results_path + f'{manifold_name}_{measure_name}_3_normalized.pdf')

      plt.show()

In [None]:
import gc
torch.cuda.empty_cache()
gc.collect()

measures = [
    (MaternSpectralMeasure, {'lengthscale': 0.5, 'nu': 0.5}, 'matern12'),
    (MaternSpectralMeasure, {'lengthscale': 0.6, 'nu': 1.5}, 'matern32'),
    (MaternSpectralMeasure, {'lengthscale': 0.7, 'nu': 2.5}, 'matern52'),
    (SqExpSpectralMeasure, {'lengthscale': 0.95}, 'heat')
    ]

groups = [
    (SO, {'n': 3, 'order': 20}, 'so3'),
    (SO, {'n': 5, 'order': 20}, 'so5')]

for group in groups:
  for measure in measures:
    print(measure)
    plot_sampler_error_normalized(measure[0], measure[1], measure[2], group[0], group[1], group[2])
    torch.cuda.empty_cache()
    gc.collect()

In [None]:
def plot_sampler_error(measure_class, measure_kwargs, measure_name, manifold_class, manifold_kwargs, manifold_name):
      sampled_errors = []
      for i in tqdm(range(20)):
        phase_order = 10 ** 4
        manifold = manifold_class(**manifold_kwargs)
        dim = manifold.dim
        measure = measure_class(dim,**measure_kwargs)

        kernel = EigenbasisSumKernel(measure=measure, manifold=manifold)
        sampler = RandomPhaseApproximation(kernel=kernel, phase_order=phase_order)
        x = manifold.rand(50)
        y = x
        x_y = manifold.pairwise_embed(x, y)
        embedding_x = make_raw_embedding(sampler, x)
        embedding_y = make_raw_embedding(sampler, y)

        cov_x_y = kernel(x, y)
        values = [torch.zeros_like(cov_x_y) for _ in range(phase_order+1)]

        for j in tqdm(range(phase_order)):
                for i, eigenspace in enumerate(manifold.lb_eigenspaces):
                        cov = (embedding_x[i][:, :j+1] @  torch.conj(embedding_y[i][:, :j+1].T) / (j+1)).real
                        cov = cov * (eigenspace.dimension * eigenspace.dimension * kernel.measure(eigenspace.lb_eigenvalue)/kernel.normalizer).real
                        values[j+1] += cov.clone()
        errors = [None for _ in range(phase_order)]
        for j in range(phase_order):
            errors[j] = compute_l2_diff(cov_x_y, values[j]).item()
        sampled_errors.append(errors)

      sampled_errors = np.array(sampled_errors)
      pd.DataFrame(data=sampled_errors).to_csv(results_path + f'{manifold_name}_{measure_name}_3.csv')

      mean_error = np.mean(sampled_errors, axis=0)
      quantile_error_25 = np.quantile(sampled_errors, 0.25, axis=0)
      quantile_error_75 = np.quantile(sampled_errors, 0.75, axis=0)

      plt.plot(list(range(phase_order))[:10**4], mean_error[:10**4], c='blue', label='mean')
      plt.plot(list(range(phase_order))[:10**4], quantile_error_25[:10**4], c='red', label='q25')
      plt.plot(list(range(phase_order))[:10**4], quantile_error_75[:10**4], c='green', label='q75')

      plt.axhline(0)
      plt.savefig(results_path + f'{manifold_name}_{measure_name}_3.pdf')

      plt.show()

In [None]:
import gc
torch.cuda.empty_cache()
gc.collect()

measures = [
    (MaternSpectralMeasure, {'lengthscale': 0.5, 'nu': 0.5}, 'matern12'),
    (MaternSpectralMeasure, {'lengthscale': 0.6, 'nu': 1.5}, 'matern32'),
    (MaternSpectralMeasure, {'lengthscale': 0.7, 'nu': 2.5}, 'matern52'),
    (SqExpSpectralMeasure, {'lengthscale': 0.95}, 'heat')
    ]

groups = [
    (SO, {'n': 3, 'order': 20}, 'so3'),
    (SO, {'n': 5, 'order': 20}, 'so5')]

for group in groups:
  for measure in measures:
    print(measure)
    plot_sampler_error(measure[0], measure[1], measure[2], group[0], group[1], group[2])
    torch.cuda.empty_cache()
    gc.collect()