In [None]:
import pickle
import numpy as np
import pandas as pd 
import scipy.interpolate as spi

import xarray as xr
import cartopy.crs as ccrs
from cftime import num2date, date2num
from foam.ocean import ocean 
from foam.atmosphere import atmosphere
from foam.ionosphere import ionosphere 
from foam.sky import sky 
from foam.solver import solver, bin_observations
from foam.spacecraft import spacecraft, make_smap, make_aquarius, strings_to_epochs, revisit_time, angle_conversion
from foam.utils import reader

%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt 

matplotlib.rcParams['figure.dpi'] = 150 


# Retrieve ocean salinity

This notebook demonstrates how to use the FOAM package to simulate retrievals of sea surface salinity over an arbitrary region/time window. Different workflows are included

## Set 1: Simulate spacecraft measurements, download ancillary data 

### 1. Create spacecraft 

In [None]:
# Select your time window
date_range = ('02-01-2020', '2-04-2020')
eps = strings_to_epochs('2020 FEB 01 12:00 UTC', '2020 FEB 04 12:00 UTC') 

pick_spacecraft = 'SMAP'

# Choose a spacecraft 
# epoch_res determines fidelity of orbit position interpolation, not sample rate
if pick_spacecraft == 'SMAP': 
    craft = make_smap(*eps, epoch_res = 10, sc_number=0)
elif pick_spacecraft == 'Aquarius': 
    craft = make_aquarius(*eps, epoch_res = 10, sc_number=0)
else: 
    # Custom (this one is CIMR) 
    height = 817  # km 
    look_angle = sc_utils.angle_conversion(height, 55, in_angle_type='incidence')
    inclination = np.radians(98.7)
    raan = 18 / 24 * 2 * np.pi
    tle_epoch = eps[0]
    craft = spacecraft.spacecraft(sc_number=0)
    elems = craft.get_manual_elems(inclination=inclination, raan=raan, 
                                   height=height * 1e3, tle_epoch=start_epoch)
    craft.write_tle_kernels(elems=elems, tle_epoch=tle_epoch, start_epoch=eps[0],
                                     end_epoch=eps[1], epoch_res=10)
    craft.write_radiometer_ck(look_angle, 'Y', 7.8, 'X')  # Look angle and RPM 


### 2. Make spacecraft observation grid 

In [None]:
bounds = ((-90, 90), (-180, 180))  # Latitude and longitude bounds 
sample_time = 100e-3  # seconds 
grid_res = 0.5  # degrees 
grid_mode = 'linear'  # Options are linear, ease, cosine; see docs.
make_plots = False

# Main wrapped if you want to run this in parallel 
if __name__ == '__main__': 
    grid, lon_bins, lat_bins, obs_dict = revisit_time(craft, *eps, sample_time, plots=make_plots, grid_res=grid_res, 
                                                      grid_mode=grid_mode, bounds=bounds, 
                                                      parallel=True, nproc=6, ndiv=6
                                                     )
    # Save your work 
    # pickle.dump(obs_dict, open('obs_dict.p', 'wb'))

### 2.5 (Optional) Pre-download ancillary data 

In [None]:
# Ocean 
## SST 
rdr = reader.GHRSSTReader(date_range)
ds = rdr.get_dataset()
ds.to_netcdf('ghrsst_data.nc')

## SSS 
rdr = reader.OISSSReader(date_range)
ds = rdr.get_dataset()
ds.to_netcdf('oisss_data.nc')

# Atmosphere 
rdr = reader.NCEPReader(date_range)
ds = rdr.get_dataset()
ds.to_netcdf('ncep_data.nc')

# Ionosphere
rdr = reader.IONEXReader(date_range)
ds = rdr.get_dataset()
ds.to_netcdf('ionex_data.nc')

### 3. Option 1: Create modules and download data

In [None]:
date_range = ('02-01-2020', '2-04-2020')

# Ocean is using empirical model functions for surface roughening
oc = ocean(date_range, mode='rough', online=True, 
           sst_reader=reader.GHRSSTReader, 
           sss_reader=reader.OISSSReader)  # You can also use reader.HYCOMReader, will take longer to download

# Atmosphere is using empirical model functions for transmissivity
atm = atmosphere(date_range, mode='simple', online=True, 
                 atm_reader=reader.NCEPReader) # You can also use reader.MERRAReader, will take longer to download

# Ionosphere 
ion = ionosphere(datetime='2015-01-01', online=True,
                 tec_reader=reader.IONEXReader) 

# Sky 
sk = sky(scattered_galaxy=True)

### 3. Option 2: Create modules and use pre-downloaded data

In [None]:
date_range = ('02-01-2020', '2-04-2020')

# Ocean is using empirical model functions for surface roughening
oc = ocean(date_range, mode='rough', online=False, 
           sst_reader=reader.GHRSSTReader, 
           sst_file='ghrsst_data.nc',
           sss_reader=reader.OISSSReader, 
           sss_file='oisss_data.nc')  # You can also use reader.HYCOMReader, will take longer to download

# Atmosphere is using empirical model functions for transmissivity
atm = atmosphere(date_range, mode='simple', online=False, 
                 reader=reader.NCEPReader, 
                 file='ncep_data.nc') # You can also use reader.MERRAReader, will take longer to download

# Ionosphere 
ion = ionosphere(date_range, online=False,
                 tec_reader=reader.IONEXReader, tec_reader_kwargs={'from_dataset': True},
                 tec_file='ionex_data.nc') 

# Sky 
sk = sky(scattered_galaxy=True)

### 4. Make forward model 

In [None]:
# Load obs_dict if you saved it 
obs_dict = pickle.load(open('obs_dict.p', 'rb'))

frequency = np.array([1.4e3])
bandwidth = np.array([24]) 
noise_figure = 2. 
int_time = 50e-3

sol = solver(ocean=oc, atmosphere=atm, 
             ionosphere=ion, sky=sk)


TB, anc_pack = sol.compute_spacecraft_TB(frequency, obs_dict)

# Forward model TBs are 'exact', noise is added prior to the retrieval stage. 
# This may change in the future 

# And plots 
lat_bins = np.arange(-90, 90, 1)
lon_bins = np.arange(-180, 180, 1)
average_time = 24 * 3600  # in seconds
epoch_bins = np.arange(np.min(obs_dict['epoch']), np.max(obs_dict['epoch']), average_time)
mean_TB, std_TB = bin_observations(TB[0].ravel(), obs_dict, lat_bins, lon_bins)

plt.pcolormesh(lon_bins, lat_bins, mean_TB.unstack().values)
plt.colorbar()

### 5. Retrieve salinity 

In [None]:
# Main wrapped if you want to run this in parallel 
if __name__ == '__main__': 
    outputs = sol.retrieval(TB, anc_pack, frequency=frequency, bandwidth=bandwidth, 
                            retrieve=['sss'], 
                            noise_figure=noise_figure, int_time=int_time)
    
    # Or in parallel 
    # outputs = sol.parallel_retrieval(TB, anc_pack, frequency=frequency, bandwidth=bandwidth, 
    #                                  retrieve=['sss', 'windspd'], 
    #                                   noise_figure=noise_figure, int_time=int_time, nproc=4, ndiv=4)
    
    out_dict, unc_dict, anc_dict = outputs

### 6. Make some plots 

In [None]:
lat_bins = np.arange(-90, 90, 0.5)
lon_bins = np.arange(-180, 180, 0.5)
average_time = 24 * 3600  # in seconds
epoch_bins = np.arange(np.min(obs_dict['epoch']), np.max(obs_dict['epoch']), average_time)
mean_sss, _ = bin_observations(out_dict['sss'], obs_dict, lat_bins, lon_bins, epoch_bins)
std_sss, _ = bin_observations(unc_dict['sss'], obs_dict, lat_bins, lon_bins, epoch_bins)

plt.figure()
ax = fig.add_subplot(121, projection=ccrs.PlateCarree())
ax.coastlines()
im = ax.pcolormesh(lon_bins, lat_bins, mean_sss.values.reshape(len(lat_bins), len(lon_bins)), 
                   shading='auto', cmap='turbo', transform=ccrs.PlateCarree())
fig.colorbar(im, ax=ax, fraction=0.03, pad=0.04, label='Retrieved SSS')

ax = fig.add_subplot(122, projection=ccrs.PlateCarree())
ax.coastlines()
im = ax.pcolormesh(lon_bins, lat_bins, std_sss.values.reshape(len(lat_bins), len(lon_bins)), 
                   shading='auto', cmap='turbo', transform=ccrs.PlateCarree())
fig.colorbar(im, ax=ax, fraction=0.03, pad=0.04, label='Retrieved SSS')


## Set 2 - Use package defaults, ignore timing, gridded map
### 1. Create modules

In [None]:
# Ocean 
oc = ocean(mode='rough', online=False)

# Atmosphere
atm = atmosphere(mode='simple', online=False)

# Ionosphere 
ion = ionosphere(online=False)

# Sky 
sk = sky(scattered_galaxy=True)

### 2. Make forward model 

In [None]:
frequency = np.array([1.4e3])
bandwidth = np.array([24]) 
noise_figure = 2. 
int_time = 50e-3

incidence_angle = 40. 

lat = np.arange(-90, 90, 1.)
lon = np.arange(-180, 180, 1.)
lon_grid, lat_grid = np.meshgrid(lon, lat)
original_shape = np.shape(lon_grid)
lon_grid = lon_grid.flatten()
lat_grid = lat_grid.flatten()
times = np.zeros(len(lon_grid))  # Placeholder 
theta = incidence_angle * np.ones(len(lon_grid))
phi = np.zeros(len(lon_grid))
sun_flag = np.zeros(len(lon_grid)).astype(bool)
moon_flag = np.zeros(len(lon_grid)).astype(bool)


sol = solver(ocean=oc, atmosphere=atm, 
             ionosphere=ion, sky=sk)

TB = sol.compute_TB(frequency, times, lat_grid, lon_grid, theta, phi, 
                    ra=lon_grid, dec=lat_grid, sun_flag=sun_flag, moon_flag=moon_flag, use_time=False)
anc_pack = sol.ancillary_pack(times, lat_grid, lon_grid, theta, phi, 
                              ra=lon_grid, dec=lat_grid, use_time=False)

# And plots 
lat_bins = np.arange(-90, 90, 1)
lon_bins = np.arange(-180, 180, 1)

plt.pcolormesh(lon_bins, lat_bins, TB[0].reshape((len(lat_bins), len(lon_bins))))
plt.colorbar()


### 3. Retrieve salinity 

In [None]:
# Main wrapped if you want to run this in parallel 
if __name__ == '__main__': 
    outputs = sol.retrieval(TB, anc_pack, frequency=frequency, bandwidth=bandwidth, 
                            retrieve=['sss'], 
                            noise_figure=noise_figure, int_time=int_time)
    # Or in parallel 
    # outputs = sol.parallel_retrieval(TB, anc_pack, frequency=frequency, bandwidth=bandwidth, 
    #                                  retrieve=['sss'], 
    #                                   noise_figure=noise_figure, int_time=int_time, nproc=4, ndiv=4)
    
    out_dict, unc_dict, anc_dict = outputs

### 4. Make plots 

In [None]:
mean_sss, _ = bin_observations(out_dict['sss'], anc_dict, lat, lon)
std_sss, _ = bin_observations(unc_dict['sss'], anc_dict, lat, lon)

fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(211, projection=ccrs.PlateCarree())
ax.coastlines()
im = ax.pcolormesh(lon, lat, mean_sss.values.reshape(len(lat), len(lon)), 
                   shading='auto', cmap='turbo', transform=ccrs.PlateCarree())
fig.colorbar(im, ax=ax, fraction=0.03, pad=0.04, label='Retrieved SSS')

ax = fig.add_subplot(212, projection=ccrs.PlateCarree())
ax.coastlines()
im = ax.pcolormesh(lon, lat, std_sss.values.reshape(len(lat), len(lon)), 
                   shading='auto', cmap='turbo', transform=ccrs.PlateCarree(), vmin=0, vmax=10)
fig.colorbar(im, ax=ax, fraction=0.03, pad=0.04, label='Retrieved SSS')

In [None]:
from foam.utils.retrieval_plots import plot_snapshot_dsss
plot_snapshot_dsss(out_dict, anc_dict)