# Preview

## Issues

todo
- try to see if i can nn the tot, expecting large variance for that unknown tail

ideas
- transform of outputs
- pca outputs, to reduce dimensionality as needed
- pca inputs, to reduce complexity
- test for signal
  - 1 vs 2 nearest neighbor outputs
  - chl-a retrieval algorithms
- dealing with unbalanced data (are they unbalanced?)
- try classification only

## Imports

In [None]:
import importlib
import os
import warnings
import datetime as dt
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'

from IPython.display import Markdown
from scipy.stats import zscore
import holoviews as hv
import hvplot.xarray
import numpy as np
import pandas as pd
import panel as pn
import param as p
import xarray as xr
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_probability as tfp

#from re_nobm_pcc import preprocess
from re_nobm_pcc import DATA_DIR, WAVELENGTH, TAXA
from re_nobm_pcc import kit

warnings.filterwarnings(action='ignore', category=FutureWarning)
hv.opts.defaults(
    hv.opts.Bars(active_tools=[]),
    hv.opts.Curve(active_tools=[]),
    hv.opts.Image(active_tools=[]),
    hv.opts.Scatter(active_tools=[]),
    hv.opts.HexTiles(active_tools=[], tools=['hover']),
)
hv.extension('bokeh')

## Raw Data

The OASIM model requires absorption and backscattering for each phytoplankton group.

In [None]:
ds = []
for item in ['dia', 'chl', 'cya', 'coc', 'pha', 'din']:
    path = f'../data/oasim_param/{item}1.txt'
    df = pd.read_table(path, sep='\t', dtype={0: int})
    df.columns = ('wavelength', 'absorption', 'scattering')
    da = df.set_index('wavelength').to_xarray().expand_dims('component')
    da['component'] = [item]
    ds.append(da)
ds = xr.concat(ds, 'component')
(
    ds.hvplot.line(x='wavelength', y='absorption', by='component')
    + ds.hvplot.line(x='wavelength', y='scattering', by='component')
).cols(1)

The NOBM data provided by Cecile contains the ocean constituents that are sufficient inputs for the OASIM Fortran library to calculte Rrs.

Below 350nm however, there is no phytoplankton absorption data so those Rrs values should be ignored.

In [None]:
paths = [f'../data/rrs_day/rrs{1998+i}{1+j:02}.nc' for i in range(23, 24) for j in range(4, 5)]
ds = xr.open_mfdataset(paths)

In [None]:
# FIXME move to re_nobm_pcc.simulate.py once oasim is fixed
ds = ds.roll({'lon': ds.sizes['lon'] // 2})
coords = xr.Dataset(
    coords={
        'lon': np.linspace(-180, 180, ds.sizes['lon']),
        'lat': np.linspace(-84, 71.4, ds.sizes['lat'])
    },
)
ds = xr.merge((ds.drop_vars(('lon', 'lat')), coords))
ds = ds.sel({'wavelength': slice(WAVELENGTH[0], WAVELENGTH[-1])})
(ds['date'].min().data, ds['date'].max().data)

In [None]:
class Dashboard(p.Parameterized):
    
    # part of the GUI
    date = p.Date(dt.date(2021, 5, 8))
    h2o = p.Selector(['tot', 'dtc', 'pic', 'cdc', 't', 's'], label='Ocean Property Variable')
    phy = p.Selector(ds['component'].values.tolist(), label='Phytoplankton Group')
    # needed as dependencies, not part of the GUI
    data = p.ClassSelector(xr.Dataset)
    stream = hv.streams.Tap(x=0, y=0)
    
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.tap_rrs = xr.DataArray(
            np.empty((ds.sizes['wavelength'], 0), dtype=ds['rrs'].dtype),
            dims=('wavelength', 'tap'),
            name='rrs',
        )
        self.tap_phy = xr.DataArray(
            np.empty((ds.sizes['component'], 0), dtype=ds['phy'].dtype),
            dims=('component', 'tap'),
            name='phy',
        )
    
    @p.depends('date', watch=True, on_init=True)
    def _load_date(self):
        self.data = ds.sel({'date': np.datetime64(self.date)}).load()
        
    @p.depends('data', 'h2o')
    def plt_h2o(self):
        da = self.data[self.h2o]
        return da.hvplot.image(x='lon', y='lat', clabel=self.h2o, title='')
    
    @p.depends('data', 'phy')
    def plt_phy(self):
        da = self.data['phy'].sel({'component': self.phy})
        plt = da.hvplot.image(x='lon', y='lat', clabel=self.phy, title='')
        self.stream.source = plt
        return plt

    @p.depends('stream.x', 'stream.y')
    def plt_rrs(self):
        da = self.data['rrs'].sel(
            {'lon': self.stream.x, 'lat': self.stream.y},
            method='nearest',
        )
        da = da.expand_dims('tap')
        self.tap_rrs = xr.concat((self.tap_rrs, da), dim='tap')
        return self.tap_rrs.hvplot(x='wavelength', by='tap', title='', color='black', fontscale=1.4)

    @p.depends('stream.x', 'stream.y')
    def plt_phy_bar(self):
        da = self.data['phy'].sel(
            {'lon': self.stream.x, 'lat': self.stream.y},
            method='nearest',
        )
        da = da.expand_dims('tap')
        da = xr.concat((self.tap_phy, da), dim='tap')
        self.tap_phy = da.drop_vars(('lon', 'lat', 'date'))
        return self.tap_phy.hvplot.bar(x='component', by='tap', color='component', title='', fontscale=1.4)


dash = Dashboard(name='NOBM Variables and Computed Rrs')
pn.Column(
    pn.panel(dash.param, parameters=['date', 'h2o', 'phy'], widgets={'date': pn.widgets.DatePicker}),
    dash.plt_h2o,
    dash.plt_phy,
    dash.plt_rrs,
    dash.plt_phy_bar,
)

In [None]:
phy = TAXA[5]
da = ds['phy'].sel({'component': phy, 'date': '2021-05-08'})
plt = da.hvplot.image(x='lon', y='lat', title=phy, clabel='chl-a', aspect=288/234)
plt.options(fontscale=1.4)

## Chl-a OC4 Algorithm

OC4 (SeaWiFS) from https://oceancolor.gsfc.nasa.gov/atbd/chlor_a/

In [None]:
a = [0.32814, -3.20725, 3.22969, -1.36769, -0.81739]
blue = [443, 489, 510]
green = 555

dim = 'wavelength'
da = xr.DataArray(
    np.arange(len(WAVELENGTH)),
    coords={dim: ds[dim].loc[WAVELENGTH[0]:WAVELENGTH[-1]]},
)
blue = da.sel({dim: blue}, method='nearest').values.tolist()
green = da.sel({dim: green}, method='nearest').values.item()

a = tf.expand_dims(tf.constant(a), 1)
blue, green

In [None]:
@tf.function
def log_blue_green_ratio(x, y):
    return (
        tf.expand_dims(tf.experimental.numpy.log10(
            tf.reduce_max(tf.gather(x, blue, axis=1), axis=1) / x[:, green]
        ), axis=1),
        tf.expand_dims(tf.experimental.numpy.log10(
            tf.reduce_sum(y, axis=1)
        ), axis=1),
    )

batch_size = 2 ** 10
tfds_rrs_day = tfds.builder('rrs_day_tfds', data_dir=DATA_DIR)
train, test = tfds_rrs_day.as_dataset(split=['split[7:8%]', 'split[9%:10%]'], as_supervised=True)
train_size = train.cardinality()
test_size = test.cardinality()
train = train.batch(batch_size).cache()
test = test.batch(batch_size).cache()

### no retraining

In [None]:
log_x, log_y = test.map(log_blue_green_ratio).rebatch(test_size).get_single_element()
log_y = log_y[:, 0].numpy()
log_y_hat = (log_x ** tf.range(5, dtype=np.float32) @ a)[:, 0].numpy()

In [None]:
R2 = 1 - ((log_y - log_y_hat)**2).sum()/((log_y - log_y.mean())**2).sum()
print(f'R2: {R2}')

In [None]:
pred = hv.Dimension('prediction', range=(-2.5, 2))
true = hv.Dimension('truth', range=(-16, 3))
count = hv.Dimension('count', range=(1, 10**6))
plt = (
    hv.HexTiles((log_y_hat, log_y), kdims=[pred, true], vdims=count)
    * hv.Slope(1, 0)
)
plt = plt.options('HexTiles', logz=True, colorbar=True, aspect=1, fontscale=1.4)
plt = plt.options('Slope', color='red', line_width=2)
plt

### re-trained coeffs but still OC4

In [None]:
network = tf.keras.Sequential([
    tf.keras.layers.Lambda(lambda x: x ** tf.range(1, 5, dtype=np.float32)),
    tf.keras.layers.Dense(1),
])
network.compile(
    optimizer=tf.optimizers.Adam(learning_rate=3e-4),
    loss=tf.keras.losses.MeanSquaredError(),
)

In [None]:
fit = network.fit(
    train.map(log_blue_green_ratio).cache().shuffle(batch_size * 4),
    epochs=10,
)

In [None]:
log_x, log_y = test.map(log_blue_green_ratio).rebatch(test_size).get_single_element()
log_y_hat = network(log_x)[:, 0].numpy()
log_y = log_y[:, 0].numpy()

In [None]:
R2 = 1 - ((log_y - log_y_hat)**2).sum()/((log_y - log_y.mean())**2).sum()
print(f'R2: {R2}')

In [None]:
plt = (
    hv.HexTiles((log_y_hat, log_y), kdims=[pred, true], vdims=cbar)
    * hv.Slope(1, 0)
)
plt = plt.options('HexTiles', logz=True, colorbar=True, aspect=1, fontscale=1.4)
plt = plt.options('Slope', color='red', line_width=2)
plt

### mlp

In [None]:
@tf.function
def four_wavelengths(x, y):
    return (
        tf.experimental.numpy.log10(
            tf.concat(
                (tf.gather(x, blue, axis=1), tf.gather(x, [green], axis=1)),
                axis=1,
            )
        ),
        tf.experimental.numpy.log10(
            tf.reduce_sum(y, axis=1)
        ),
    )    

network = tf.keras.Sequential([
    tf.keras.layers.Dense(64, 'relu'),
    tf.keras.layers.Dense(64, 'relu'),
    tf.keras.layers.Dense(1),
])
network.compile(
    optimizer=tf.optimizers.Adam(learning_rate=3e-4),
    loss=tf.keras.losses.MeanSquaredError(),
)
train_cached = train.map(four_wavelengths).cache()
for _ in train_cached: pass

In [None]:
fit = network.fit(
    train_cache.shuffle(batch_size * 4),
    epochs=100,
)

In [None]:
log_x, log_y = test.map(four_wavelengths).rebatch(test_size).get_single_element()
log_y_hat = network(log_x)[:, 0].numpy()
log_y = log_y.numpy()

In [None]:
R2 = 1 - ((log_y - log_y_hat)**2).sum()/((log_y - log_y.mean())**2).sum()
print(f'R2: {R2}')

In [None]:
pred = hv.Dimension('prediction', range=(-6, 3))
plt = (
    hv.HexTiles((log_y_hat, log_y), kdims=[pred, true], vdims=count)
    * hv.Slope(1, 0)
)
plt = plt.options('HexTiles', logz=True, colorbar=True, aspect=1, fontscale=1.4)
plt = plt.options('Slope', color='red', line_width=2)
plt

### mlp, loc scale out

In [None]:
network = tf.keras.Sequential([
    tf.keras.layers.Dense(64, 'relu'),
    tf.keras.layers.Dense(64, 'relu'),
    # tf.keras.layers.Dense(64, 'relu'),
    tf.keras.layers.Dense(2),
    tfp.layers.IndependentNormal(1),    
])
network.compile(
    optimizer=tf.optimizers.Adam(learning_rate=3e-4),
    loss=lambda y, model: tf.reduce_sum(-model.log_prob(y)),
)
train_cache = train.map(four_wavelengths).cache()
for _ in train_cache: pass

In [None]:
fit = network.fit(
    train_cache.shuffle(batch_size * 4),
    epochs=100,
)

In [None]:
log_x, log_y = test.map(four_wavelengths).rebatch(test_size).get_single_element()
log_y_model = network(log_x)
log_y_hat = log_y_model.mean()[:, 0].numpy()
log_y = log_y.numpy()

In [None]:
R2 = 1 - ((log_y - log_y_hat)**2).sum()/((log_y - log_y.mean())**2).sum()
print(f'R2: {R2}')

In [None]:
plt = (
#    hv.HexTiles((log_y_hat, log_y), kdims=['pred', 'true'], vdims='cbar')
    hv.HexTiles((log_y_hat, log_y), kdims=[pred, true], vdims=cbar)
    * hv.Slope(1, 0)
)
plt = plt.options('HexTiles', logz=True, colorbar=True, aspect=1, fontscale=1.4)
plt = plt.options('Slope', color='red', line_width=2)
plt

In [None]:
idx = (log_y_model.stddev() / tf.abs(log_y_model.mean()) < 0.5)[:, 0].numpy()
idx.sum() / idx.size

In [None]:
R2 = 1 - ((log_y[idx] - log_y_hat[idx])**2).sum()/((log_y[idx] - log_y[idx].mean())**2).sum()
print(f'R2: {R2}')

In [None]:
plt = (
    hv.HexTiles((log_y_hat[idx], log_y[idx]), kdims=[pred, true], vdims=cbar)
    * hv.Slope(1, 0)
)
plt = plt.options('HexTiles', logz=True, colorbar=True, aspect=1, fontscale=1.4)
plt = plt.options('Slope', color='red', line_width=2)
plt

In [None]:
std_norm = (log_y - log_y_model.mean()[:, 0].numpy()) / log_y_model.stddev()[:, 0].numpy()
ecdf = kit.ecdf(std_norm)
by = 1000
x = np.linspace(-5, 9, 80)
y = tfp.distributions.Normal(0, 1).cdf(x)

In [None]:
kdims = r'$$\text{log}(\textit{chl-a}\ [\mathit{mg}\ m^{-3}])$$'
vdims = r'$$\text{eCDF}$$'
plt = (
    hv.Scatter(
        (std_norm[::by], ecdf[::by]),
        kdims=kdims,
        vdims=vdims,
    ).options(color='black')
    * hv.Curve((x, y)).options(color='red')
)
plt.options(fontscale=1.4, aspect=1)

## Spectum with Taxa

In [None]:
x, y = next(train.shuffle(32).take(4).as_numpy_iterator())
(hv.Curve(x) + hv.Bars(y)).opts(shared_axes=False)

---

## Outdated Below

## Preprocessed Data

The features and labels are both model output from NASA GMAO using the [NOBM and OASIM](https://gmao.gsfc.nasa.gov/gmaoftp/NOBM) models. The labels are four phytoplankton chlorophyll densities output by NOBM. The features are normalized water leaving radiances output by OASIM, using the NOBM model as input.

### Features

One NetCDF file contains all the predictor data. Note that the `FillValue` attribute is not set to `9.99e11` in the netCDF file (Cecile will fix in next version). There are no explicit coordinates given; they are documented as attributes.

## Labels

Each of twelve NetCDF files contain a month of NOBM model output. The first is representative. Unlike the HyperLwn file, this one contains coordinates.

The `PhytoChl` xarray.Dataset includes the different phytoplankton groups as variables.

## Plot your Data

## Features

The radiances currently make a nice map, but the data should be more sparsely sampled.

A few "typical" hyperspectral radiances.

Mean centered radiances and corresponding phytoplankton abundances.

SVD to reduce the wavelength dimension to `k` vectors accounting for the most variation in the features. The singular values are:

The corresponding vectors:

A matrix of univariate (diagonal) and bivariate (off-diagonal) histograms of the `scores`, or coefficients generating each wavelength by linear combination of the `vectors` above.

## Labels

A map of the phytoplankton labels in `PhytoChl` at one month.

The distribution of the four phytoplankton groups.