## SHTools Demo

This notbook is used to play around with a few things from the pyshtools library


In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import pyshtools
from pyshtools import spectralanalysis
from pyshtools import shio
from pyshtools import expand

from s2cnn import S2Convolution
from s2cnn import s2_fft
from s2cnn.utils.complex import as_complex

from training_set import TrainingSet
from data_source import DataSource
from visualize import Visualize
from sphere import Sphere
from model import Model

%matplotlib inline
%load_ext autoreload
%autoreload 2

## Load input
First, load the current input feature set from the disk

In [None]:
restore = False
bandwidth = 100
dataset_path = "/media/scratch/berlukas/spherical/"
#dataset_path = "/home/berlukas/data/arche_low_res/"

n_test_data = 20
n_test_cache = n_test_data
ds_test = DataSource(dataset_path, n_test_cache, -1)
ds_test.load(n_test_data)
n_test_data = len(ds_test.anchors)
test_set = TrainingSet(restore, bandwidth)
test_set.generateAll(ds_test)
n_test_set = len(test_set)
print("Total size: ", n_test_set)

## Perform a S2 transform of the features


In [None]:
loader = torch.utils.data.DataLoader(test_set, batch_size=1, shuffle=False, num_workers=1, pin_memory=True, drop_last=False)

coeff = None
for batch_idx, data in enumerate(loader): 
    print(batch_idx)
    #s2_fft.S2_fft_real.apply(data[0], 50)
    a = data[0].float()[0,2,:,:]
    p = data[1].float()[0,2,:,:]
    n = data[2].float()[0,2,:,:]
    print(a.shape)
    print(as_complex(a).shape)
    #A = s2_fft.s2_fft(as_complex(a))
    #print(A.shape)
    #print(A[0,:])
    if batch_idx == 0:
        break


In [None]:
a_t = torch.transpose(a, 0,1)
#grid = as_complex(torch.transpose(a, 0,1))
#grid = torch.reshape(grid, (grid.size(2), grid.size(0), grid.size(1)))
#a_grid = pyshtools.expand.MakeGridDH(grid, sampling=1)
a_coeffs = pyshtools.expand.SHExpandDH(a, sampling=1)
power_per_l = pyshtools.spectralanalysis.spectrum(a_coeffs)
degrees = np.arange(a_coeffs.shape[1])

fig, ax = plt.subplots(1, 1)
ax.plot(degrees, power_per_l)
ax.set(yscale='log', xscale='log', xlabel='Spherical harmonic degree', ylabel='Power')
ax.grid()

In [None]:
p_t = torch.transpose(p, 0,1)
#grid = as_complex(torch.transpose(p, 0,1))
#grid = torch.reshape(grid, (grid.size(2), grid.size(0), grid.size(1)))
#p_grid = pyshtools.expand.MakeGridDH(grid, sampling=1)
p_coeffs = pyshtools.expand.SHExpandDH(p, sampling=1)
power_per_l = spectralanalysis.cross_spectrum(a_coeffs, p_coeffs, normalization='schmidt', convention='energy')
degrees = np.arange(a_coeffs.shape[1])

fig, ax = plt.subplots(1, 1)
ax.plot(degrees, power_per_l)
ax.set(yscale='log', xscale='log', xlabel='Spherical harmonic degree', ylabel='Power')
ax.grid()

admit, error, corr = spectralanalysis.SHAdmitCorr(a_coeffs, p_coeffs)
for i in range(0, 100):
    prob = spectralanalysis.SHConfidence(i, corr[i])
    if (prob < 1.0):
        print(f'Probability of being correlated at {i} is {prob}')

In [None]:
n_t = torch.transpose(n, 0,1)
#grid = as_complex(torch.transpose(n, 0,1))
#grid = torch.reshape(grid, (grid.size(2), grid.size(0), grid.size(1)))
#n_grid = pyshtools.expand.MakeGridDH(grid, sampling=1)
n_coeffs = pyshtools.expand.SHExpandDH(n, sampling=1)
power_per_l = spectralanalysis.cross_spectrum(a_coeffs, n_coeffs, normalization='schmidt', convention='energy')
degrees = np.arange(a_coeffs.shape[1])

fig, ax = plt.subplots(1, 1)
ax.plot(degrees, power_per_l)
ax.set(yscale='log', xscale='log', xlabel='Spherical harmonic degree', ylabel='Power')
ax.grid()


admit, error, corr = spectralanalysis.SHAdmitCorr(a_coeffs, n_coeffs)
for i in range(0, 100):
    prob = spectralanalysis.SHConfidence(i, corr[i])
    if (prob < 1.0):
        print(f'Probability of being correlated at {i} is {prob}')