In [None]:
from IPython.display import clear_output
from google.colab import drive
drive.mount('/content/gdrive')
results_path = '/content/gdrive/MyDrive/LieStationaryKernel_plots/'

In [None]:
!pip install git+https://github.com/imbirik/LieStationaryKernels.git
!pip install backends
!pip install git+https://github.com/vdutor/SphericalHarmonics.git
clear_output()

In [None]:
import sys
import itertools
import torch
import gc
import matplotlib.pyplot as plt
import random
import seaborn as sns
from torch.autograd.functional import _vmap as vmap
import numpy as np
import pandas as pd
from tqdm import tqdm_notebook as tqdm
from math import sqrt
from collections import defaultdict
from itertools import islice
from lie_stationary_kernels.spectral_kernel import EigenbasisSumKernel
from lie_stationary_kernels.spectral_measure import SqExpSpectralMeasure, MaternSpectralMeasure
from lie_stationary_kernels.prior_approximation import RandomPhaseApproximation

from lie_stationary_kernels.space import TranslatedCharactersBasis

from lie_stationary_kernels.spaces import SO, SU, Stiefel
from lie_stationary_kernels.spaces.sphere import Sphere

device = 'cuda' if torch.cuda.is_available() else 'cpu'

torch.set_printoptions(precision=6, sci_mode=False, linewidth=160, edgeitems=15)
dtype = torch.float64

In [None]:
!cp '/content/gdrive/MyDrive/precomputed_characters.json' '/usr/local/lib/python3.10/dist-packages/lie_stationary_kernels/spaces/precomputed_characters.json'

In [None]:
# auxilarly functions
def compute_l2_diff(x, y):
        return torch.pow(torch.mean(torch.pow(x-y, 2)), 0.5)


In [None]:
def plot_approx_error(measure_class, measure_kwargs, measure_name, manifold_class, manifold_kwargs, manifold_name):
        manifold = manifold_class(**manifold_kwargs)
        dim = manifold.dim
        measure = measure_class(dim, **measure_kwargs)

        so_lb_eigenspaces = manifold.lb_eigenspaces
        res = [0]
        vals = defaultdict(list)
        val_at_e = 0
        for id_, lb_eigspace in enumerate(so_lb_eigenspaces):
                dim = lb_eigspace.dimension
                val = lb_eigspace.lb_eigenvalue
                x = measure(val).detach().cpu().numpy()[0]
                val_at_e += x*dim
                res.append(res[-1] + (x ** 2) * dim * dim)
        error = [np.log(np.sqrt((res[-1] - x))/val_at_e) for x in res[:1000]]
        print(manifold_name, measure_name, (res[-1] - res[20])/res[-1])
        pd.DataFrame( error).to_csv(results_path + f'{manifold_name}_{measure_name}_1.csv')
        plt.plot(error, label = manifold_name + "_" + measure_name)


In [None]:
import gc
torch.cuda.empty_cache()
gc.collect()

measures = [
    (MaternSpectralMeasure, {'lengthscale': 0.5, 'nu': 0.5}, 'matern12'),
     (MaternSpectralMeasure, {'lengthscale': 0.6, 'nu': 1.5}, 'matern32'),
    (MaternSpectralMeasure, {'lengthscale': 0.7, 'nu': 2.5}, 'matern52'),
    (SqExpSpectralMeasure, {'lengthscale': 0.95}, 'heat')]

groups = [
    (SO, {'n': 3, 'order': 1000}, 'so3'),
    (SO, {'n': 5, 'order': 1000}, 'so5')]

for group in groups:
  for measure in measures:
    print(measure)
    plot_approx_error(measure[0], measure[1], measure[2], group[0], group[1], group[2])
    torch.cuda.empty_cache()
    gc.collect()
  plt.legend()
  plt.savefig(results_path + f'{group[2]}_1.pdf')
  plt.show()