# Spherical Fourier Neural Operators

A simple notebook to showcase spherical Fourier Neural Operators operating on MERRA2 data


## Preparation

In [1]:
%matplotlib ipympl
import logging, torch
from fmod.base.util.logging import lgm, exception_handled, log_timing
from fmod.pipeline.downscale import Downscaler
import hydra, os, time
from typing import Any, Dict, List, Tuple, Type, Optional, Union, Sequence, Mapping
from fmod.base.util.dates import date_list
from fmod.base.util.config import cfg, start_date,  cfg2args, pp
from fmod.models.sfno.downscale.network import SphericalFourierNeuralOperatorNet as SFNO, sfno_network_parms
from fmod.plot.training_results import ResultsPlotter
from fmod.pipeline.trainer import DualModelTrainer
from fmod.pipeline.merra2 import MERRA2Dataset

In [2]:
hydra.initialize(version_base=None, config_path="../config")
configure('merra2-sr1')
# lgm().set_level( logging.DEBUG )

load_state = True
save_state = True 
input_res = "low"
target_res = "high"
etype = "l2" # "spectral-l2"

# set device
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
    torch.cuda.set_device(device.index)

### Training data
to train our geometric FNOs, we require training data. To this end let us prepare a Dataloader which computes results on the fly:

In [6]:
input_dataset  = MERRA2Dataset( train_dates=date_list( start_date( cfg().task ), cfg().task.max_days ), vres=input_res, load_inputs=True, load_base=True )
target_dataset = MERRA2Dataset( train_dates=date_list( start_date( cfg().task ), cfg().task.max_days ), vres=target_res, load_targets=True )
trainer = DualModelTrainer( input_dataset, target_dataset, device )
sfno_args: Dict = cfg2args( "model", sfno_network_parms )
model = SFNO( in_shape=trainer.input_grid, out_shape=trainer.output_grid, **sfno_args ).to(device)

In [7]:
# pointwise model for sanity checking
# class MLP(nn.Module):
#     def __init__(self,
#                  input_dim = 3,
#                  output_dim = 3,
#                  num_layers = 2,
#                  hidden_dim = 32,
#                  activation_function = nn.ReLU,
#                  bias = False):
#         super().__init__()
    
#         current_dim = input_dim
#         layers = []
#         for l in range(num_layers-1):
#             fc = nn.Conv2d(current_dim, hidden_dim, 1, bias=True)
#             # initialize the weights correctly
#             scale = sqrt(2. / current_dim)
#             nn.init.normal_(fc.weight, mean=0., std=scale)
#             if fc.bias is not None:
#                 nn.init.constant_(fc.bias, 0.0)
#             layers.append(fc)
#             layers.append(activation_function())
#             current_dim = hidden_dim
#         fc = nn.Conv2d(current_dim, output_dim, 1, bias=False)
#         scale = sqrt(1. / current_dim)
#         nn.init.normal_(fc.weight, mean=0., std=scale)
#         if fc.bias is not None:
#             nn.init.constant_(fc.bias, 0.0)
#         layers.append(fc)
#         self.mlp = nn.Sequential(*layers)

#     def forward(self, x):
#         return self.mlp(x)

# model = MLP(num_layers=10).to(device)

## Training the model

In [10]:
trainer.train( model, load_state=load_state, save_state=save_state )
inputs, targets, predictions, interpolates = trainer.inference( etype=etype )

In [None]:
plotter = ResultsPlotter( inputs, targets, predictions, interpolates, chanids=input_dataset.channel_ids('target') )
plotter.plot()

In [13]:
# s = 0; ch = 0
# print( f'Input shape: {inp.shape}, Output shape: {out.shape}, type = {type(inp)} ')
# input_image = inp[s, ch]
# vmin, vmax = gridops.color_range(input_image,2.0)
# 
# fig = plt.figure()
# im = gridops.plot_griddata(input_image, fig, projection='3d', title='input', vmin=vmin, vmax=vmax )
# plt.colorbar(im)
# plt.show()
# 
# fig = plt.figure()
# im = gridops.plot_griddata(out[s, ch], fig, projection='3d', title='prediction', vmin=vmin, vmax=vmax )
# plt.colorbar(im)
# plt.show()
# 
# fig = plt.figure()
# im = gridops.plot_griddata(tar[s, ch], fig, projection='3d', title='target', vmin=vmin, vmax=vmax )
# plt.colorbar(im)
# plt.show()
# 
# fig = plt.figure()
# im = gridops.plot_griddata((tar-out)[s, ch], fig, projection='3d', title='error', vmin=vmin, vmax=vmax )
# plt.colorbar(im)
# plt.show()