# get S2CNN to work on weatherbench and ERA5 data

In [None]:
import numpy as np

datadir = '/gpfs/work/nonnenma/data/forecast_predictability/weatherbench/5_625deg/'
res_dir = '/gpfs/work/nonnenma/results/forecast_predictability/weatherbench/5_625deg/'

### load ERA5 states here

In [None]:
import xarray as xr

fn = 'test_temperature_JAN_2016_MARS_T19'
x19 = xr.open_dataset(datadir + fn + '.grib', engine='cfgrib')


fn = 'test_temperature_JAN_2016_MARS_T63'
x63 = xr.open_dataset(datadir + fn + '.grib', engine='cfgrib')


fn = 'test_temperature_JAN_2016_MARS_T179'
x179 = xr.open_dataset(datadir + fn + '.grib', engine='cfgrib')

fn = 'test_temperature_JAN_2016_MARS' # this is T639
x639 = xr.open_dataset(datadir + fn + '.grib', engine='cfgrib')

# let's check output shapes
x19.t.shape, x179.t.shape, x639.t.shape

# indexing  translations between ERA5 ans S2CNN

In [None]:
# this function reads in spherical coefficients from ERA5 in 'values' and returns them ordered as for s2cnn 
def translate_idx(T, Tnew=None):
    
    Tnew = T if Tnew is None else Tnew    
    T_, Tnew_ = T+1, Tnew+1

    # get indices of lower triangular matrix
    idx_in_i, idx_in_j = np.where(np.triu(np.ones((T_,T_)))) 

    # get indices to read out elements of full coefficient matrix
    idx_out_i, idx_out_j = np.zeros(Tnew_**2,dtype=np.int), np.zeros(Tnew_**2, dtype=np.int)
    for l in range(Tnew_):
        # first read indices m = -l : -1 from lower triangle, then  m = 0 : l from upper
        idx_out_i[l**2 : l**2 + 2*l + 1] = np.asarray(np.concatenate([l*np.ones(l), np.arange(l+1)]), dtype=np.int)
        # note there's a shift on the column indices in lower triangle due to Mc[1:,-1] = M[1:, 1:] above
        idx_out_j[l**2 : l**2 + 2*l + 1] = np.asarray(np.concatenate([np.arange(l)[::-1], l**np.ones(l+1)]), dtype=np.int)
    
    return idx_in_i, idx_in_j, idx_out_i, idx_out_j
    
def cohmp(T, values, idcs=None, Tnew=None): # quick & dirty function to put values into (T+1)x(T+1) upper triangular matrix.

    Tnew = T if Tnew is None else Tnew    
    T_, Tnew_ = T+1, Tnew+1
    idx_in_i, idx_in_j, idx_out_i, idx_out_j = translate_idx(T,Tnew) if idcs is None else idcs
    
    # read stored spherical coefficients into upper triangular matrix 
    M = np.zeros((T_, T_, 2))
    for i in range(2):
        M[idx_in_i,idx_in_j,i] = values[i::2]
    M = M[:Tnew_, :Tnew_]

    # retrieve and store spherical coefficients for negative m
    ms = (np.arange(1,Tnew_).reshape(-1,1)+np.array([0,1]))[:,None,:] 
    M[1:,:-1,:] += (M[1:, 1:,:] / (-1)**ms).transpose(1,0,2) # (-1)^m on the *columns* (i.e. along 'm' axis)
        
    # now read out elements of M in correct order ...
    out = np.vstack([M[idx_out_i, idx_out_j,i] for i in range(2)]).T.reshape(-1,1,2)

    return out

# redefining S2CNN classes for grids with Nx2N grid point
- CUDA version not even started

In [None]:
from s2cnn.soft.s2_fft import _setup_wigner

def s2_fft_half_height(x, for_grad=False, b_out=None):
    '''
    :param x: [..., beta, alpha, complex]
    :return:  [l * m, ..., complex]
    '''
    print('x.shape', x.shape)    
    assert x.size(-1) == 2
    b_in_m = x.size(-3)  # vertical resolution (latittudes)
    b_in_l = 2 * b_in    # horizontal resolution (longitudes)
    assert x.size(-2) == b_in_l
    if b_out is None:
        b_out = b_in
    assert b_out <= b_in_l
    batch_size = x.size()[:-3]

    x = x.view(-1, b_in_m, b_in_l, 2)  # [batch, beta, alpha, complex]

    '''
    :param x: [batch, beta, alpha, complex] (nbatch, b_in, 2 * b_in, 2)
    :return: [l * m, batch, complex] (b_out**2, nbatch, 2)
    '''
    nspec = b_out ** 2
    nbatch = x.size(0)

    wigner = _setup_wigner(b_in_m, nl=b_out, weighted=not for_grad, device=x.device)
    wigner = wigner.view(b_in_m, -1)  # [beta, l * m] (b_in, nspec)

    x = torch.fft(x, 1)  # [batch, beta, m, complex]

    output = x.new_empty((nspec, nbatch, 2))
    if x.is_cuda and x.dtype == torch.float32:
        import s2cnn.utils.cuda as cuda_utils
        cuda_kernel = _setup_s2fft_cuda_kernel(b=b_in, nspec=nspec, nbatch=nbatch, device=x.device.index)
        stream = cuda_utils.Stream(ptr=torch.cuda.current_stream().cuda_stream)
        cuda_kernel(block=(1024, 1, 1),
                    grid=(cuda_utils.get_blocks(nspec * nbatch, 1024), 1, 1),
                    args=[x.contiguous().data_ptr(), wigner.contiguous().data_ptr(), output.data_ptr()],
                    stream=stream)
        # [l * m, batch, complex]
    else:
        for l in range(b_out):
            s = slice(l ** 2, l ** 2 + 2 * l + 1)
            xx = torch.cat((x[:, :, -l:], x[:, :, :l + 1]), dim=2) if l > 0 else x[:, :, :1]
            output[s] = torch.einsum("bm,zbmc->mzc", (wigner[:, s], xx))

    output = output.view(-1, *batch_size, 2)  # [l * m, ..., complex] (nspec, ..., 2)
    return output


def s2_ifft_half_height(x, for_grad=False, b_out=None):
    '''
    :param x: [l * m, ..., complex]
    '''
    assert x.size(-1) == 2
    nspec = x.size(0)
    b_in = round(nspec ** 0.5)
    assert nspec == b_in ** 2
    if b_out is None:
        b_out = b_in   # horizontal resolution (longitudes)
    b_out_m = b_out//2 # vertical resolution (latittudes)
    assert b_out >= b_in
    batch_size = x.size()[1:-1]

    x = x.view(nspec, -1, 2)  # [l * m, batch, complex] (nspec, nbatch, 2)

    '''
    :param x: [l * m, batch, complex] (b_in**2, nbatch, 2)
    :return: [batch, beta, alpha, complex] (nbatch, 2 b_out, 2 * b_out, 2)
    '''
    nbatch = x.size(1)

    wigner = _setup_wigner(b_out_m, nl=b_in, weighted=for_grad, device=x.device)
    wigner = wigner.view(2 * b_out_m, -1)  # [beta, l * m] (b_out, nspec)

    if x.is_cuda and x.dtype == torch.float32:
        import s2cnn.utils.cuda as cuda_utils
        cuda_kernel = _setup_s2ifft_cuda_kernel(b=b_out, nl=b_in, nbatch=nbatch, device=x.device.index)
        stream = cuda_utils.Stream(ptr=torch.cuda.current_stream().cuda_stream)
        output = x.new_empty((nbatch, 2 * b_out, 2 * b_out, 2))
        cuda_kernel(block=(1024, 1, 1),
                    grid=(cuda_utils.get_blocks(nbatch * (2 * b_out) ** 2, 1024), 1, 1),
                    args=[x.data_ptr(), wigner.data_ptr(), output.data_ptr()],
                    stream=stream)
        # [batch, beta, m, complex] (nbatch, 2 * b_out, 2 * b_out, 2)
    else:
        output = x.new_zeros((nbatch, 2 * b_out_m, 2 * b_out, 2))
        for l in range(b_in):
            s = slice(l ** 2, l ** 2 + 2 * l + 1)
            out = torch.einsum("mzc,bm->zbmc", (x[s], wigner[:, s]))
            output[:, :, :l + 1] += out[:, :, -l - 1:]
            if l > 0:
                output[:, :, -l:] += out[:, :, :l]

    output = torch.ifft(output, 1) * output.size(-3)  # [batch, beta, alpha, complex]
    output = output.view(*batch_size, 2 * b_out_m, 2 * b_out, 2)
    return output


class S2_fft_hh_real(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, b_out=None):  # pylint: disable=W
        from s2cnn.utils.complex import as_complex
        ctx.b_out = b_out
        ctx.b_in = x.size(-1) // 2
        return s2_fft_half_height(as_complex(x), b_out=ctx.b_out)

    @staticmethod
    def backward(ctx, grad_output):  # pylint: disable=W
        return s2_ifft_half_height(grad_output, for_grad=True, b_out=ctx.b_in)[..., 0], None


class S2_ifft_hh_real(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, b_out=None):  # pylint: disable=W
        nspec = x.size(0)
        ctx.b_out = b_out
        ctx.b_in = round(nspec ** 0.5)
        return s2_ifft_half_height(x, b_out=ctx.b_out)[..., 0]

    @staticmethod
    def backward(ctx, grad_output):  # pylint: disable=W
        from s2cnn.utils.complex import as_complex
        return s2_fft_half_height(as_complex(grad_output), for_grad=True, b_out=ctx.b_in), None


from s2cnn import S2Convolution

class S2Convolution_hh(S2Convhh_olution):

    def __init__(self, nfeature_in, nfeature_out, b_in, b_out, grid):
        '''
        :param nfeature_in: number of input fearures
        :param nfeature_out: number of output features
        :param b_in: input bandwidth (precision of the input SOFT grid)
        :param b_out: output bandwidth
        :param grid: points of the sphere defining the kernel, tuple of (alpha, beta)'s
        '''
        super(S2Convolution_hh, self).__init__(nfeature_in, nfeature_out, b_in, b_out, grid)

    def forward(self, x):  # pylint: disable=W
        '''
        :x:      [batch, feature_in,  beta, alpha]
        :return: [batch, feature_out, beta, alpha, gamma]
        '''
        assert x.size(1) == self.nfeature_in
        assert x.size(2) == self.b_in
        assert x.size(3) == 2 * self.b_in
        x = S2_fft_hh_real.apply(x, self.b_out)  # [l * m, batch, feature_in, complex]
        y = s2_rft(self.kernel * self.scaling, self.b_out, self.grid)  # [l * m, feature_in, feature_out, complex]

        print('(x,y)', (x.shape,y.shape))

        z = s2_mm(x, y)  # [l * m * n, batch, feature_out, complex]
        z = SO3_ifft_real.apply(z)  # [batch, feature_out, beta, alpha, gamma]

        z = z + self.bias

        return z

### SO3 rewrite (unfinished)
- CUDA version not even started

In [None]:

def so3_rfft_half_height(x, for_grad=False, b_out=None):
    '''
    :param x: [..., beta, alpha, gamma]
    :return: [l * m * n, ..., complex]
    '''
    pass

def so3_ifft_half_height(x, for_grad=False, b_out=None):
    '''
    :param x: [l * m * n, ..., complex]
    '''
    pass


def so3_rifft_half_height(x, for_grad=False, b_out=None):
    '''
    :param x: [l * m * n, ..., complex]
    '''
    pass


class SO3_fft_hh_real(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, b_out=None):  # pylint: disable=W
        ctx.b_out = b_out
        ctx.b_in = x.size(-1) // 2
        return so3_rfft_half_height(x, b_out=ctx.b_out)

    @staticmethod
    def backward(self, grad_output):  # pylint: disable=W
        # ifft of grad_output is not necessarily real, therefore we cannot use rifft
        return so3_ifft_half_height(grad_output, for_grad=True, b_out=self.b_in)[..., 0], None


class SO3_ifft_hh_real(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, b_out=None):  # pylint: disable=W
        nspec = x.size(0)
        ctx.b_out = b_out
        ctx.b_in = round((3 / 4 * nspec) ** (1 / 3))
        return so3_rifft_half_height(x, b_out=ctx.b_out)

    @staticmethod
    def backward(ctx, grad_output):  # pylint: disable=W
        return so3_rfft_half_height(grad_output, for_grad=True, b_out=ctx.b_in), None


from s2cnn import SO3Convolution

class SO3Convolution_hh(SO3Convolution):
    def __init__(self, nfeature_in, nfeature_out, b_in, b_out, grid):
        '''
        :param nfeature_in: number of input fearures
        :param nfeature_out: number of output features
        :param b_in: input bandwidth (precision of the input SOFT grid)
        :param b_out: output bandwidth
        :param grid: points of the SO(3) group defining the kernel, tuple of (alpha, beta, gamma)'s
        '''
        super(SO3Convolution_hh, self).__init__(nfeature_in, nfeature_out, b_in, b_out, grid)

    def forward(self, x):  # pylint: disable=W
        '''
        :x:      [batch, feature_in,  beta, alpha, gamma]
        :return: [batch, feature_out, beta, alpha, gamma]
        '''
        assert x.size(1) == self.nfeature_in
        assert x.size(2) == 2 * self.b_in
        assert x.size(3) == 2 * self.b_in
        assert x.size(4) == 2 * self.b_in

        x = SO3_fft_hh_real.apply(x, self.b_out)  # [l * m * n, batch, feature_in, complex]
        y = so3_rft(self.kernel * self.scaling, self.b_out, self.grid)  # [l * m * n, feature_in, feature_out, complex]
        assert x.size(0) == y.size(0)
        assert x.size(2) == y.size(1)
        z = so3_mm(x, y)  # [l * m * n, batch, feature_out, complex]
        assert z.size(0) == x.size(0)
        assert z.size(1) == x.size(1)
        assert z.size(2) == y.size(2)
        z = SO3_ifft_hh_real.apply(z)  # [batch, feature_out, beta, alpha, gamma]

        z = z + self.bias

        return z

# test some FTs and IFTs

In [None]:
import torch
import matplotlib.pyplot as plt
from s2cnn.soft.s2_fft import s2_ifft

t = 736 # time index (in h)
lvl = 0 # level index (indexing [500, 850] hPa)

idcs63 = translate_idx(T=639, Tnew=63)
ft = cohmp(T=639, Tnew=63, values=x639.t.values[t,lvl,:], idcs=idcs63)

print('fourier transform shape: ', ft.shape)

x = s2_ifft(torch.tensor(ft, dtype=torch.float32))
plt.figure(figsize=(16,7))
titles = ['real part', 'imaginary part']
for i in range(2):
    plt.subplot(1,2,i+1)
    plt.imshow(x[0,:,:,i], aspect='auto')
    plt.colorbar()
    plt.title(titles[i])
plt.show()


x = s2_ifft_half_height(torch.tensor(ft, dtype=torch.float32))
plt.figure(figsize=(16,7))
titles = ['real part', 'imaginary part']
for i in range(2):
    plt.subplot(1,2,i+1)
    plt.imshow(x[0,:,:,i], aspect='auto')
    plt.colorbar()
    plt.title(titles[i])
plt.show()


x = s2_ifft_half_height(s2_fft_half_height(x[:,:,:,:]))
plt.figure(figsize=(16,7))
titles = ['real part', 'imaginary part']
for i in range(2):
    plt.subplot(1,2,i+1)
    plt.imshow(x[0,:,:,i], aspect='auto')
    plt.colorbar()
    plt.title(titles[i])
plt.show()

# plug into (custom) convolutional layer for S2CNN
- S2Convolution() assumes input is already in grid-space, but here we got spherical coordinates

In [None]:
import torch
from torch.nn.parameter import Parameter
import time

from s2cnn.soft.so3_fft import SO3_ifft_real
from s2cnn import s2_mm
from s2cnn import s2_rft

from s2cnn import s2_near_identity_grid
from s2cnn import S2Convolution


class FTConvolution(S2Convolution):
    
    def __init__(self, nfeature_in, nfeature_out, b_in, b_out, grid):
        '''
        :param nfeature_in: number of input fearures
        :param nfeature_out: number of output features
        :param b_in: input bandwidth (precision of the input SOFT grid)
        :param b_out: output bandwidth
        :param grid: points of the sphere defining the kernel, tuple of (alpha, beta)'s
        '''
        super(FTConvolution, self).__init__(nfeature_in, nfeature_out, b_in, b_out, grid)
        
    def forward(self, x):  # pylint: disable=W
        ''' We rewrite the S2 convolution to start from x already Fourier-transformed
        :x:      [b_in**2, batch, feature_in,  complex]
        :return: [batch, feature_out, beta, alpha, gamma]
        '''
        assert x.size(0) == self.b_in**2
        assert x.size(2) == self.nfeature_in
        assert x.size(3) == 2

        y = s2_rft(self.kernel * self.scaling, self.b_out, self.grid)  # [l * m, feature_in, feature_out, complex]

        y = torch.as_tensor(y, dtype=torch.float32) # because reasons...
                    
        print('(x,y)', (x.shape, y.shape))
            
        z = s2_mm(x, y)  # [l * m * n, batch, feature_out, complex]
        
        print('z.shape', z.shape)
        z = SO3_ifft_real.apply(z)  # [batch, feature_out, beta, alpha, gamma]
        z = z + self.bias

        return z

# define convolutional layer 

T = 63
b_in = T+1
b_out = T+1
grid_s2 = s2_near_identity_grid()
        
conv1 = FTConvolution(nfeature_in=2, nfeature_out=4, b_in=b_in, b_out=b_out, grid=grid_s2)

# load and translate input data
ts, lvls, ft = [0,1,2], [0,1], []
idcs63 = translate_idx(T=639, Tnew=63)

# stacking inputs (cohmp() not yet parallelized)
for t in ts:
    ft.append(np.concatenate(
        [cohmp(T=639, Tnew=63, values=x639.t.values[t,lvl,:], idcs=idcs63) for lvl in lvls],
        axis=1
    ))
ft = np.stack(ft, axis=1)
print('ft.shape', ft.shape)
              
print('\n beginning convolution ! \n')
t = time.time()
out = conv1.forward(torch.tensor(ft, requires_grad=False, dtype=torch.float32))
print(f'- finished in {time.time()-t}s -')
out.shape

In [None]:
from s2cnn.soft.so3_fft import SO3_fft_real, SO3_ifft_real
x = SO3_fft_real.apply(torch.as_tensor(out, dtype=torch.float32), b_out)  # [l * m * n, batch, feature_in, complex]
x.shape

# plug into spherical CNN

In [None]:
import torch.nn as nn
import torch.nn.functional as F

from s2cnn import so3_near_identity_grid
from s2cnn import s2_near_identity_grid

from s2cnn import SO3Convolution
from s2cnn import S2Convolution

class S2ConvNet_original(nn.Module):

    def __init__(self, f1=20, f2=40, f_output=1, b_in=64, b_l1=10, b_l2=6, nfeature_in=1):
        super(S2ConvNet_original, self).__init__()

        grid_s2 = s2_near_identity_grid()
        grid_so3 = so3_near_identity_grid()

        self.conv1 = S2Convolution_hh(
            nfeature_in=nfeature_in,
            nfeature_out=f1,
            b_in=b_in,
            b_out=b_l1,
            grid=grid_s2)

        self.conv2 = SO3Convolution(
            nfeature_in=f1,
            nfeature_out=f2,
            b_in=b_l1,
            b_out=b_l2,
            grid=grid_so3)

        self.out_layer = torch.nn.Conv2d(
            in_channels=2 * b_l2, 
            out_channels=f_output, 
            kernel_size=(1,1))        
        
    def forward(self, x):
        
        print('x.shape', x.shape)
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)

        print('x.shape', x.shape)
        N, C = x.size(0), x.size(1)
        x = x.view(N*C, x.size(2), x.size(3), x.size(4))
        print('x.shape', x.shape)
        x = self.out_layer(x)
        print('x.shape', x.shape)
        x = x.view(N, C, x.size(2), x.size(3))
        print('x.shape', x.shape)
        
        return x
    
    
model = S2ConvNet_original(f1=20, f2=1, f_output=1, b_in=T+1, b_l1=T+1, b_l2=(T+1)//2, nfeature_in=2)

# stacking inputs (cohmp() not yet parallelized)
ft = []
for t in ts:
    ft.append(np.concatenate(
        [cohmp(T=639, Tnew=63, values=x639.t.values[t,lvl,:], idcs=idcs63) for lvl in lvls],
        axis=1
    ))
ft = np.stack(ft, axis=1)
print('ft.shape', ft.shape)
x = s2_ifft_half_height(torch.tensor(ft, dtype=torch.float32))[:,:,:,:,0]

out = model.forward(torch.tensor(x, requires_grad=False, dtype=torch.float32))

x.shape, out.shape

In [None]:
plt.imshow(out[2,0,:,:].detach().numpy())
plt.show()

# have a look into grids for kernel

In [None]:
"""
def s2_soft_grid(b):
    beta = (np.arange(2 * b) + 0.5) / (2 * b) * np.pi
    alpha = np.linspace(start=0, stop=2 * np.pi, num=2 * b, endpoint=False)
    B, A = np.meshgrid(beta, alpha, indexing='ij')
    B = B.flatten()
    A = A.flatten()
    grid = np.stack((B, A), axis=1)
    return tuple(tuple(ba) for ba in grid)
"""

# debug