# Spherical Fourier Neural Operators

A simple notebook to showcase spherical Fourier Neural Operators operating on MERRA2 data


## Preparation

In [1]:
%matplotlib ipympl
import torch
import hydra
from typing import Dict
from fmod.base.util.dates import date_list
from fmod.base.util.config import configure, cfg, start_date,  cfg2args
# from torch_harmonics.examples.sfno import SphericalFourierNeuralOperatorNet as SFNO
from fmod.model.sfno.network import SphericalFourierNeuralOperatorNet as SFNO, sfno_network_parms
from fmod.plot.training_results import ResultsPlotter
from fmod.controller.trainer import ModelTrainer
from data.merra2 import MERRA2Dataset

In [2]:
hydra.initialize(version_base=None, config_path="../config")
configure('merra2-sr')

# set device
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
    torch.cuda.set_device(device.index)

### Training data
to train our geometric FNOs, we require training data. To this end let us prepare a Dataloader which computes results on the fly:

In [6]:
dataset = MERRA2Dataset( train_dates=date_list( start_date( cfg().task ), cfg().task.max_steps ), vres="high" )
trainer = ModelTrainer( dataset )
sfno_args: Dict = cfg2args( "model", sfno_network_parms )
model = SFNO(  img_size=trainer.grid_shape, **sfno_args ).to(device)

In [7]:
# pointwise model for sanity checking
# class MLP(nn.Module):
#     def __init__(self,
#                  input_dim = 3,
#                  output_dim = 3,
#                  num_layers = 2,
#                  hidden_dim = 32,
#                  activation_function = nn.ReLU,
#                  bias = False):
#         super().__init__()
    
#         current_dim = input_dim
#         layers = []
#         for l in range(num_layers-1):
#             fc = nn.Conv2d(current_dim, hidden_dim, 1, bias=True)
#             # initialize the weights correctly
#             scale = sqrt(2. / current_dim)
#             nn.init.normal_(fc.weight, mean=0., std=scale)
#             if fc.bias is not None:
#                 nn.init.constant_(fc.bias, 0.0)
#             layers.append(fc)
#             layers.append(activation_function())
#             current_dim = hidden_dim
#         fc = nn.Conv2d(current_dim, output_dim, 1, bias=False)
#         scale = sqrt(1. / current_dim)
#         nn.init.normal_(fc.weight, mean=0., std=scale)
#         if fc.bias is not None:
#             nn.init.constant_(fc.bias, 0.0)
#         layers.append(fc)
#         self.mlp = nn.Sequential(*layers)

#     def forward(self, x):
#         return self.mlp(x)

# model = MLP(num_layers=10).to(device)

## Training the model

In [10]:
trainer.train( model )
inputs, targets, predictions = trainer.inference()

In [None]:
plotter = ResultsPlotter( dataset, targets, predictions )
plotter.plot()

In [13]:
# s = 0; ch = 0
# print( f'Input shape: {inp.shape}, Output shape: {out.shape}, type = {type(inp)} ')
# input_image = inp[s, ch]
# vmin, vmax = gridops.color_range(input_image,2.0)
# 
# fig = plt.figure()
# im = gridops.plot_griddata(input_image, fig, projection='3d', title='input', vmin=vmin, vmax=vmax )
# plt.colorbar(im)
# plt.show()
# 
# fig = plt.figure()
# im = gridops.plot_griddata(out[s, ch], fig, projection='3d', title='prediction', vmin=vmin, vmax=vmax )
# plt.colorbar(im)
# plt.show()
# 
# fig = plt.figure()
# im = gridops.plot_griddata(tar[s, ch], fig, projection='3d', title='target', vmin=vmin, vmax=vmax )
# plt.colorbar(im)
# plt.show()
# 
# fig = plt.figure()
# im = gridops.plot_griddata((tar-out)[s, ch], fig, projection='3d', title='error', vmin=vmin, vmax=vmax )
# plt.colorbar(im)
# plt.show()