* 12 May 2024, Julian Mak (some tidy up of conda usage)
* 30 Apr 2022, Fei Er Yan + Julian Mak (whatever with copyright, do what you want with this)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import xarray as xr
from sklearn import preprocessing
from sklearn.model_selection import train_test_split

In [None]:
# NOTE (JM 20 May): various things that could fix, some fixes below, (un)comment as required

# A) if using Colab, will need a separate upload of the data and mount on colab
#
# 1) go to https://drive.google.com/drive/folders/1JJ0cpshu6-JE8wp93UsHuqy6V33rQy7s?usp=sharing
# 2) download the folder
# 3) upload that to your own instance of Colab
# 4) load as something like below:

# from google.colab import drive
# drive.mount('/content/drive')
#data = xr.open_zarr("/path/to/folder/GLOB_HOMOGENEOUS_variables.zarr/")

# test to see if data actually exists (should return some folders)
#ls "/content/drive/MyDrive/Colab Notebooks/GLOB_HOMOGENEOUS_variables.zarr/" including BRV2

# B) engine problems, e.g. unrecognized engine zarr must be one of: ['h5netcdf', 'scipy', 'store']
#
# brute force fix is to install complete
# !pip install xarray[complete] # xarray[io] # zarr
#
# then need to force a kernel restart ("Runtime -> Restart Session")

---------------------------

# loose introduction to Machine learning

[Machine Learning](https://en.wikipedia.org/wiki/Machine_learning) has found a ton of applications in multiple disciplines, notably in:

* image and facial recognition
* natural language processing (e.g. predictive text, translations)
* classical literature (e.g. [Kuzushiji](https://nips2018creativity.github.io/doc/deep_learning_for_classical_japanese_literature.pdf))
* self-driving cars (e.g. [HKUST efforts](https://seng.hkust.edu.hk/news/20200402/autonomous-vehicles-developed-hkust-engineering-professor-serve-community-during-covid-19-outbreak-mainland-china))
* Artificial Intelligence (e.g. the famous [AlphaGo](https://en.wikipedia.org/wiki/AlphaGo))

There has also been applications in multiple aspects of environmental sciences, including oceanography. In this notebook we are going to briefly touch on the more elementary applications. The focus here will be on demonstrating some use of machine learning and the relevant Python syntax, and less on the understanding of the algorithms themselves (they can get quite technical and mathematical). The explanations do exist, but we will largely use them like a *black box*, and just assume we have made some deal/sacrifice with say [Hermaeus Mora](https://elderscrolls.fandom.com/wiki/Hermaeus_Mora) in exchange for some answers.

<img src="https://i.imgur.com/ZFxVaBm.jpg" width="400" alt='Hermeowus Mora'>

(Disciple of Hermaeus Mora, Hermeowus Mora)

> NOTE: Again, these algorithm are just tools, and they have their own limitations, and should not be regarded as silver bullets that will solve all your problems. One criticism I have with these things is that these tools can work really well, but you might not understand why they work so well, and that is something that generally makes me a little uneasy.

Machine learning can briefly be split into **unsupervised** and **supervised** learning. Unsupervised is when you let the algorithms find features for you, while supervised is when the data itself is already tagged, and a model is *trained up* to try and reproduce the target data.

PCA would be an example of unsupervised learning, where you feed in the data, and the algorithm returns features that capture certain amounts of variance. Linear regression would be an example of supervised learning I guess, where given some outputs and inputs, you want to find some sort of model that minimises the mismatch between outputs and predictions using inputs. More on this later.

---------------------------

## a) Argo data

For our demonstration here we are going to be using data from the [argo observation system](https://argo.ucsd.edu/). Argo is a system of autonomous floats that are put into the ocean, floating around with the currents, and periodically does vertical sections to take in-situ measurements of things like temperature, salinity, pressure, and so forth down to about 2000 m depth; see below for the schematic. There are increasing interest in [BGC-Argo](https://biogeochemical-argo.org/) that measure quantities relevant to biogeochemistry, and [deep Argo](https://argo.ucsd.edu/expansion/deep-argo-mission/) that go down to 4000 m.

<img src="https://argo.ucsd.edu/wp-content/uploads/sites/361/2020/06/float_cycle_1-768x424.png" width="600" alt='Argo'>

> NOTE: The namesake of argo is related to the [JASON](https://en.wikipedia.org/wiki/Jason-1) satellites if you know your Greek mythology.

There are multiple products and file formats that one could get (see later). Instead of the more standard **gridded** reanalyses products, we are going to be dealing with the float profile data directly. The data is collated in the folder `GLOB_HOMOGENEOUS_variables.zarr`, which can be opened with the [`zarr`](https://zarr.readthedocs.io/en/stable/) plugin, a so far experimental data format that promises to give better performance particularly when parallel processing is concerned. In this case we happen to open it using the `zarr` plugin through `xarray`.

> NOTE: If you are on your own computer, you probably want to do `conda install -c conda-forge zarr` in your environment. It might be available already on Google Colab, but if not, try `!pip install xarray[io]`.

In [None]:
# load the data (good luck...)
#
#data = xr.open_zarr("/path/to/folder/GLOB_HOMOGENEOUS_variables.zarr/")

data = xr.open_zarr("./GLOB_HOMOGENEOUS_variables.zarr/")
data

The files themselves exists as separate binary files, which is then collated into an xarray object here. The data is in dimensions of **depth** and **profile number**. There are various variables, but we are only going to be using `PSAL` (practical salinity) and `TEMP` (*in-situ* temperature). These are subsetted out in the code below, with a little bit of tidying up:

1) dropping all the entries with salinity that are outside an expected range (here anything outside of the 25 to 40 g/kg interval)

2) profiles that have a NaN entry

In [None]:
# subset some data out

da_al = data[['PSAL','TEMP']] #(DEPTH: 302, N_PROF: 128910)
da_s  = da_al.where((da_al.PSAL <40.) & (da_al.PSAL>25.), drop= True)

# NOTE (JM Apr 15): if it complains about no indexing by booleans, try adding the .compute() bit in as below
# da_s  = da_al.where(((da_al.PSAL <40.) & (da_al.PSAL>25.)).compute(), drop= True)

da_s  = da_s.dropna('N_PROF')
da_s

### Visualising the raw profile data as a function of geographical co-ordinates

As per tradition, we follow the number -1 step of data analysis and plot out what the data looks like.

Each profile has associated with it a longitude and latitude, so in the below we do a scatter plot, and each profile is marked on as a dot. We colour the dots by the magnitude of the data, so this ends up looking a bit like a contour/pcolor graph.

> <span style="color:red">**Q.**</span> You can try and make plots through Cartopy in the stuff below.

In [None]:
# plot out what the observation data actually looks like

nl = 20 # change this index to plot different depths (as an index entry)

fig = plt.figure(figsize=(14, 4))

# temperature
ax = plt.subplot(1, 2, 1)
cs = ax.scatter(da_s.LONGITUDE, da_s.LATITUDE, 10, da_s.TEMP[:,nl], 
                cmap=plt.cm.get_cmap('Spectral_r'), zorder=3)
ax.set_xlabel(r"lon ($^\circ$)")
ax.set_ylabel(r"lat ($^\circ$)")
plt.colorbar(cs)
ax.grid(lw=0.5, zorder=0)
ax.set_title(f"Temp at {da_s.DEPTH[nl].values} m")

# salinity
ax = plt.subplot(1, 2, 2)
cs = ax.scatter(da_s.LONGITUDE, da_s.LATITUDE, 2, da_s.PSAL[:,nl], 
                alpha=.5, cmap=plt.cm.get_cmap('viridis', 5), zorder=3)
ax.set_xlabel(r"lon ($^\circ$)")
ax.set_ylabel(r"lat ($^\circ$)")
plt.colorbar(cs)
ax.grid(lw=0.5, zorder=0)
ax.set_title(f"Salinity at {da_s.DEPTH[nl].values} m")

We consider subsetting the data a bit more into different regions. Below are some of the ones we chose in light of the clustering analysis we will be doing later, using xarray conditionals. We then plot out the data to see the geographical distribution of the data we subset out at a fixed depth.

In [None]:
# subsetting and showing different locations

# North Atlantic
da_na=da_s.where(  (da_s['LATITUDE']  >   0.) 
                 & (da_s['LATITUDE']  <= 50.) 
                 & (da_s['LONGITUDE'] > -78.) 
                 & (da_s['LONGITUDE'] <  31.),
                 drop=True)

# higher latitude, split into Atlantic and Pacific sector
da_ao=da_s.where(da_s['LATITUDE'] > 50., drop=True)
da_ao1=da_s.where((da_s['LATITUDE'] > 50.) & (da_s['LONGITUDE'] >-78.) & (da_s['LONGITUDE'] < 31), drop=True)
da_ao2=da_ao.where((da_ao['LONGITUDE'] >= 100.) | (da_ao['LONGITUDE'] <= -100), drop=True)

# Southern Ocean, split into a few sectors
da_so=da_s.where(da_s['LATITUDE'] <=  -56., drop=True)
da_so1=da_s.where((da_s['LATITUDE'] > -56.) & (da_s['LATITUDE'] <= -40.) , drop=True)

In [None]:
# plot out the locations that have been subsetted out

fig = plt.figure(figsize=(10, 6))
ax = plt.axes()
ax.plot(da_s.LONGITUDE  , da_s.LATITUDE,   "o", markersize=2, label="total")
ax.plot(da_na.LONGITUDE , da_na.LATITUDE,  "o", markersize=2, label='North Atlantic')
ax.plot(da_ao1.LONGITUDE, da_ao1.LATITUDE, "o", markersize=2, label='Arctic Ocean')
ax.plot(da_so.LONGITUDE , da_so.LATITUDE,  "o", markersize=2, label='Southern Ocean')
ax.plot(da_so1.LONGITUDE, da_so1.LATITUDE, "o", markersize=2, label='Southern Ocean 1')
ax.plot(da_ao2.LONGITUDE, da_ao2.LATITUDE, "o", markersize=2, label='Arctic Ocean 2')
ax.set_xlabel(r"lon ($^\circ$)")
ax.set_ylabel(r"lat ($^\circ$)")
plt.grid()
ax.legend()
ax.set_title(f"Geographical locations of subsets")

### Visualising the raw data in $TS$-space

The diagram below show the above labelled data in a **$TS$-diagram** (i.e. data but in temperature-salinity space) at some chosen depths, which provides another visualisation of the data that is perhaps more in line with the **watermass properties**. Notice for example the North Atlantic data tends to be clustered in a certain region in $TS$-space, with the distinguishing feature of it being generally quite salty; this is consistent with the observations that the Atlantic waters tend to be more salty because of the known physical oceanic processes at play (e.g. lec 5 of OCES 2003; `09/10_fun_with_maps` data).

In [None]:
# TS-diagrams at different depths
nl = 0

fig = plt.figure(figsize=(10, 6))
ax = plt.subplot(1, 2, 1)
ax.plot(da_s.PSAL  [:, nl], da_s.TEMP  [:, nl], "o", markersize=2, label="total")
ax.plot(da_na.PSAL [:, nl], da_na.TEMP [:, nl], "o", markersize=2, label='North Atlantic')
ax.plot(da_ao1.PSAL[:, nl], da_ao1.TEMP[:, nl], "o", markersize=2, label='Arctic Ocean')
ax.plot(da_so.PSAL [:, nl], da_so.TEMP [:, nl], "o", markersize=2, label='Southern Ocean')
ax.plot(da_so1.PSAL[:, nl], da_so1.TEMP[:, nl], "o", markersize=2, label='Southern Ocean 1')
ax.plot(da_ao2.PSAL[:, nl], da_ao2.TEMP[:, nl], "o", markersize=2, label='Arctic Ocean 2')
ax.grid()
plt.legend()
ax.set_ylabel(r'Temperature ($^\circ\ \mathrm{C}$)')
ax.set_xlabel(r'Salinity ($\mathrm{g}/\mathrm{kg}$)')
ax.set_title(f"TS diagram at {da_s.DEPTH[nl].values} m")

nl = 20

ax = plt.subplot(1, 2, 2)
ax.plot(da_s.PSAL  [:, nl], da_s.TEMP  [:, nl]  , "o", markersize=2, label="total")
ax.plot(da_na.PSAL [:, nl], da_na.TEMP [:, nl] , "o", markersize=2, label='North Atlantic')
ax.plot(da_ao1.PSAL[:, nl], da_ao1.TEMP[:, nl], "o", markersize=2, label='Arctic Ocean')
ax.plot(da_so.PSAL [:, nl], da_so.TEMP [:, nl] , "o", markersize=2, label='Southern Ocean')
ax.plot(da_so1.PSAL[:, nl], da_so1.TEMP[:, nl], "o", markersize=2, label='Southern Ocean 1')
ax.plot(da_ao2.PSAL[:, nl], da_ao2.TEMP[:, nl], "o", markersize=2, label='Arctic Ocean 2')
ax.grid()
plt.legend()
ax.set_ylabel(r'Temperature ($^\circ\ \mathrm{C}$)')
ax.set_xlabel(r'Salinity ($\mathrm{g}/\mathrm{kg}$)')
ax.set_title(f"TS diagram at {da_s.DEPTH[nl].values} m")

### Visualising the data as meridional sections

For demonstration purposes we are going to focus on the Atlantic here. We also plot out the raw profile data as a scatter plot as above to show the distribution of temperature and salinity.

In [None]:
# select Atlantic sector

da_aw = da_s.where((da_s['LONGITUDE'] > -75.) & (da_s['LONGITUDE'] < 17), drop=True)
da_aw

In [None]:
# plot temperature and salnity at fixed depth only in Atlantic sector

nl = -20

fig = plt.figure(figsize=(16, 8))
ax = plt.subplot(1, 2, 1)
cs = ax.scatter(da_aw.LONGITUDE, da_aw.LATITUDE, 10, da_aw.TEMP[:, nl],
                cmap = plt.cm.get_cmap('Spectral_r', 10))
ax.set_xlabel(r"lon ($^\circ$)")
ax.set_ylabel(r"lat ($^\circ$)")
ax.set_title(f"Temperature at {da_s.DEPTH[nl].values} m")
plt.colorbar(cs)
ax.grid()

ax = plt.subplot(1, 2, 2)
cs = ax.scatter(da_aw.LONGITUDE, da_aw.LATITUDE, 10, da_aw.PSAL[:, nl],
                cmap = plt.cm.get_cmap('viridis', 10))
ax.set_xlabel(r"lon ($^\circ$)")
ax.set_ylabel(r"lat ($^\circ$)")
ax.set_title(f"Salinity at {da_s.DEPTH[nl].values} m")
plt.colorbar(cs)
ax.grid()

Since the data here is not gridded and each profile has its own longitude and latitude, it is not immediately possible to do meridional sections, and further processing is required. 

> <span style="color:red">**Q.**</span> One way is to **interpolate** the data. This could be done through e.g. `scipy.interpolate`, since you basically have a collection of co-ordinates with associated data, or through xarray (which as far as I can tell leverages the `scipy.interpolate` anyway). The other leverages the xarray functionality `.groupby('LATITUDE')`, and then taking averages (skipping NaN values). This procedure results in an averaging over profiles, longitude and time, so is really a meridional section of the time and zonally averaged data.

---------------------------

## b) Example of unsupervised learning: cluster analysis

As advertised above, unsupervised learning is where you let the algorithms pick out the data features of interest. PCA (and by corollary EOF analysis) is one example of this, which we have encountered already in `04_regression`. These kind of algorithms have found uses in image recognition and reconstruction; the example below shows PCA applied to pictures of cats and dogs (right panels are the first 4 PCs of data; from Fig 10 of [Brunton, Brunton, Proctor & Kutz (2013)](https://arxiv.org/pdf/1310.4217.pdf)) encountered in `04_regression`.

<img src="https://i.imgur.com/D5TJanm.png" width="800" alt='brunton_et_al_13_fig10'>

Unsupervised learning is useful for data exploration. In the argo data of interest to us here, we know from theory and observations already that different water masses clusters in a different way, so can machine learning pick those out for us? In the case below, we are going demonstrate the use of the **$k$-means** algorithm, which is one possible way of identifying clusters. $k$-means very loosely considers partitions of the data, computing the means of the data associated with each partition, and iterating on the choice of partition such that there is a minimisation of the deviations of the partitioned data from the means. The algorithm is available in `scikit-learn`, with syntax demonstrated below.

> NOTE: Sounds familiar? A lot of machine learning could (and should?) be framed as an optimisation problem where we want to minimise some cost functional (often called the **loss function** in machine learning), and because it is an optimisation problem, we have some liberty in choosing the choice of mismatch and/or regularisations, which may help with finding "better" solutions depending on the context.

For the case here we want to tell the algorithm what are the features of interest. We are going to stack the temperature and salinity data together, so the clustering tries to find clusters given both the temperature and salinity as an input feature.

> NOTE: We are going to get rid of the deeper parts of the data as a choice of pre-processing.

In [None]:
# only demonstrating one, could try others (uncomment accordingly)

from sklearn.cluster import KMeans
# from sklearn.mixture import GaussianMixture
# from sklearn.cluster import DBSCAN
# from sklearn.cluster import OPTICS

fx = np.stack((da_s.TEMP[:, :-30].values, da_s.PSAL[:, :-30].values), axis=2)
fx.shape

In the code below, I am going to fix a depth level indexed by `nl` and consider a clustering with `nc` features. The model is the fitted with the input data (with a seed specified to make sure the initial guess is fixed and so results are exactly reproducible), and a prediction is made. We then plot it out the resulting clusters' distribution geographically as well as in $TS$-space.

In [None]:
# Fit to clustering model according to temp and salinity characteristics
nc = 5  # number of features
nl = 20 # level of data to be used

seed = 3315088937
model = KMeans(n_clusters=nc, random_state=seed)
# model = GaussianMixture(n_components=nc, random_state=seed)
# model = DBSCAN(eps=0.3, min_samples=1000)
# model = OPTICS(min_samples=100)

# fit data and use clustering for prediction
model.fit(fx[:, nl])
cluster_idx = model.predict(fx[:, nl])

In [None]:
# plot out the predictions
fig = plt.figure(figsize=(14, 6))

# horizontally varying
ax = plt.subplot2grid((1, 3), (0, 0), colspan=2)
cs = ax.scatter(da_s.LONGITUDE, da_s.LATITUDE, nc, cluster_idx, cmap = plt.cm.get_cmap('viridis', nc))
ax.set_xlabel(r"lon ($^\circ$)")
ax.set_ylabel(r"lat ($^\circ$)")
ax.set_title(f'Clusters at {da_s.DEPTH[nl].values} m')
ax.grid()

# on TS diagram
ax = plt.subplot2grid((1, 3), (0, 2), colspan=1)
cs = ax.scatter(fx[:, nl, 1], fx[:, nl, 0], nc, cluster_idx, cmap = plt.cm.get_cmap('viridis', nc))
ax.set_ylabel(r'Temperature ($^\circ\ \mathrm{C}$)')
ax.set_xlabel(r'Salinity ($\mathrm{g}/\mathrm{kg}$)')
ax.set_title(f'Clusters on TS diagram at {da_s.DEPTH[nl].values} m')
ax.grid()
cax = plt.colorbar(cs)
cax.set_ticks(np.arange(nc))

Notice that the clusters we got here are not that similar to the manual ones we specified in the exploratory plots above. However, there are some physical rationalisations here:

* the Eastern boundary water looks like it is indexed by cluster 0
* the Southern Antarctic water is indexed by cluster 1, highlighting the water that are generally cold, and relatively fresh
* there is a delineation between the Antarctic waters, between the colder and relatively fresh waters and what might be classified as the ACC waters that forms part of cluster 4
* cluster 4 is picking out the polar waters
* there is a suggestive pattern for the subpolar gyres given by cluster 2 and 3, which are generally warmer

There have been some papers using similar techniques to identify watermass properties (e.g. [Jones et al., 2019](https://agupubs.onlinelibrary.wiley.com/doi/full/10.1029/2018JC014629) for the Southern Ocean, using Gaussian Mixture Model).

> <span style="color:red">**Q.**</span> I haven't standardised the data here, but you should try and see if it makes a difference (hint: it does quite a bit). Remember to invert the transform if you are going to plot the data in a $TS$-diagram. See below in the neutral network part for some related code.

> <span style="color:red">**Q.**</span> Look up and/or try the other clustering models that have been commented out above.

> <span style="color:red">**Q.**</span> Try doing the clustering analysis for smaller regions (e.g. the Southern Ocean region; cf. [Jones et al., 2019](https://agupubs.onlinelibrary.wiley.com/doi/full/10.1029/2018JC014629)).

---------------------------

## c) Example of supervised learning: neural networks

If we want to predict things we might want to employ (semi-)supervised learning instead. Ultimately we have

\begin{equation*}
    Y = f(X),
\end{equation*}

where $X$ is the input, $y$ is the output, and $f$ is the model. Generically, we talk about **training/fitting** a model $f$ via exposing the associated algorithm to some 

* **training data** ($X_{\rm train}, Y_{\rm train}$) to minimise misfits encapsulated in some **loss function** (usually some sort of square of the mismatch)

* **validation data** ($X_{\rm val}, Y_{\rm val}$) for tuning model **hyperparameters** and/or selecting from a collection of models trained up

The acid test for the performance of the trained model is then examined through

* **test data** ($X_{\rm test}, Y_{\rm test}$) that the model has ***not*** seen before via some measure of mismatch between the "truth" data $Y_{\rm test}$ and prediction $f(X_{\rm test})$, with mismatch to be defined accordingly.

(Multi-)Linear regression are examples of the more basic supervised learning algorithms in this regard, where for the $L^2$ misfit (or loss function) we don't strictly have to distinguish training, validation or test data (though we might consider having training and test data at least). More sophisticated nonlinear algorithms such as **neural networks** normally do want a splitting, and we will demonstrate the procedure below for the argo data.

> NOTE: We are going to be using neural networks for supervised learning, but you could in principle use them for unsupervised as well as reinforcement learning. See for example the [wikipedia entry](https://en.wikipedia.org/wiki/Artificial_neural_network) on related procedures (you might want to do a Google search if you find the description on the Wikipedia page to abstract).

The goal here is to ***predict salinity from temperature*** (more for demonstration rather than scientific reasons). In this case we will do a Z-score standardisation of all the data, and a minor tidy up.

In [None]:
# standardise the data

scaler = preprocessing.StandardScaler()
scaler.fit(da_s.PSAL)
data_salt=scaler.transform(da_s.PSAL)

scaler.fit(da_s.TEMP)
data_temp=scaler.transform(da_s.TEMP)

# input is temperature, output is salinity
#   don't include the stuff at the lower depths
xx = data_temp[:,:-30]
yy = data_salt[:,:-30]

# original size = (N_PROF: 128910, DEPTH: 302)
print(f"# of profiles = {xx.shape[0]} of size {xx.shape[1]}") 

We have over 100,000 profiles as a function of depth, and the input will be some vertical profile of temperature, while the output we want out of this is a vertical profile of salinity.

In anticipation of demonstrating the neural network algorithm, we will in this case (somewhat randomly) divide the data up into a test, validation and training set, using the `train_test_split` sub-function from `scikit-learn`. The code below does the following:

1) first split out 20% of the total data as a test set (so now we have 80% of the data left)

2) we split out the remaining 25% of the data (25% of 80% = 20% of 100%) into a validation set

3) what is left over (60% of the original data set) is the training set

In [None]:
# split into test data first, then validation, and the remaining are training data
# X are the inputs, y are the outputs
indices = np.arange(xx.shape[0])

seed = 42

# split out test data (20% of 100%)
X_tr, X_test, y_tr, y_test, x_ind, ind_test = train_test_split(xx, yy, indices,
                                                               test_size=0.2, 
                                                               random_state=seed, 
                                                               shuffle=True)

# from the remaining, split out the 80% data into 20% validation (hence the 0.25) and remaining to be training
X_train, X_val, y_train, y_val, ind_train, ind_val = train_test_split(X_tr, y_tr, x_ind,
                                                                      test_size=0.25, 
                                                                      random_state=seed, 
                                                                      shuffle=True)

print(f"number of test       data = {X_test.shape[0]}")
print(f"number of validation data = {X_val.shape[0]}")
print(f"number of training   data = {X_train.shape[0]}")

So for the case here we will be using about 80% of the data to train up the neural network, and testing it against the 20% data that the network has not been previously exposed to. The exact splitting is flexible, but a 90:10 or 80:20 split is fairly common.

> NOTE: For neural networks it is generally considered good to use as much training data as possible.

The below graph plots the geographical distribution of the train/validation/test datasets.

In [None]:
# plot out where these are
fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(14, 3), facecolor='w', edgecolor='k', sharey='row')

X = da_s
ax[0].plot(X.LONGITUDE.values[ind_train], X.LATITUDE.values[ind_train], 'C0.', markersize=0.5)
ax[1].plot(X.LONGITUDE.values[ind_val],   X.LATITUDE.values[ind_val],   'C1.', markersize=0.5)
ax[2].plot(X.LONGITUDE.values[ind_test],  X.LATITUDE.values[ind_test],  'C2.', markersize=0.5)

ax[0].set_title("Train data (60%)")
ax[1].set_title("Validation data (20%)")
ax[2].set_title("Test data (20%)")

ax[0].set_xlabel(r"lon ($^\circ$)")
ax[0].set_ylabel(r"lat ($^\circ$)")
ax[1].set_xlabel(r"lon ($^\circ$)")
ax[2].set_xlabel(r"lon ($^\circ$)")

### First try: linear regression

By the principle of Occam's razor we should probably at least try the simpler linear regression case first. We are going to use `scikit-learn` and train up a linear model using the (normalised) training dataset. We then plot the $L^2$ mismatch or the **root-mean-squared (RMS) loss** between $y_{\rm test}$ and $f(X_{\rm test})$. In this case I didn't bother undoing the scaling, so a RMS loss larger than 1 is pretty bad.

> NOTE: Linear regression basically doesn't work, which is perhaps not a huge surprise, but lets demonstrate this explicitly. I am not going to be that careful about doing diagnostics for this case.

In [None]:
from sklearn import linear_model

# do a linear regression (note data here has already by Z-scored)

ols = linear_model.LinearRegression()
model = ols.fit(X_train, y_train)
model.score(X_train, y_train)

y_pred = model.predict(X_test)

In [None]:
# plot the squared mismatches per index of prediction for SCALED data

fig = plt.figure(figsize=(10, 4))
ax = plt.axes()
ax.plot((y_pred - y_test).flatten()**2, "x")
ax.set_xlabel("index")
ax.set_ylabel(r"$(y - y_{\mathrm{data}})^2$")
ax.grid()
ax.set_title(r"RMS mismatch for SCALED data")

Note the array has been flattened, so each cross here is a prediction at some location and at some *fixed* depth. So over quite a few predictions has very large RMS loss values, indicating linear regression has failed pretty hard here. This is not entirely a surprise, given we do not expect there to be a linear relation between temperature and salinity, as can be seen from the $TS$-diagram. Increasing the number of inputs might help, so you could try this in the extended exercise later.

### Neural networks 

A neural network is a network with a schematic like the one below (diagram taken from [Wikipedia](https://en.wikipedia.org/wiki/Artificial_neural_network)).

<img src="https://upload.wikimedia.org/wikipedia/commons/thumb/4/46/Colored_neural_network.svg/800px-Colored_neural_network.svg.png" width="200" alt='network schematic'>

The idea here is that each node (the blobs) is some feature, and each link is a connection with some **weights** (which could be deterministic or probabilitistic in principle) leading to a transition, which is a recipe for transforming some input into some intermediate output. Given some input, the model splits it into multiple features, pass it through the network each with some transitional probabilities, and eventually leads to a collection of outputs that is assembled to give you an output. For a simple case where only the weights are varied, for each choice of weights there is an associated mismatch between the prediction and the provided "truth", and the goal is to iterate on the weights such that the eventual associated mismatch (or the loss function) is minimised, or at least approaches some asymptote. Again, it might be helpful to think of these as optimisation problems (which is certainly what I tend to do, because I am more familiar with optimisation problems).

Depending on problem and available computational resources, various **hyperparameters** (e.g. the number of features, loss function decrease threshold, regularisation, number of hidden layers, model training parameters) might need/want to be varied. There are cases where one could vary the features themselves during the iterations, employ other algorithms (e.g. **convolution neural networks (CNN)**, **generative adversarial networks (GAN)** etc.), but this is well beyond the scope here.

For our problem we are going to keep it simple and just use a standard neural network. We are going to be using the `keras` package, and the `tensorflow` backend (another possibility is `pytorch`).

> NOTE: If you are using this notebook locally through Anaconda, you probably want to do `conda install -c conda-forge tensorflow` and `conda install -c conda-forge keras`. Tensorflow and keras should be available through Google Colab as is.

> NOTE: The present problem is small so I am going to not bother with GPU capabilities. Look up the internet on how to get things working with GPUs.

In [None]:
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Dropout
from keras import regularizers
from keras.models import Model
from keras import optimizers

The model is initialised here through `Sequential()`, and two hidden layers with 400 features are added in. The `Dropout` command is to drop some features (in this case 20% in each layer), which acts as a regulariser and reduces chances of overfitting.

In [None]:
model = Sequential()
model.add(Dense(400, input_shape=(X_train.shape[1],), kernel_initializer='normal', activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(400, activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(y_test.shape[1], activation='linear'))
model.summary()

Having defined the properties of neural network model, we proceed to train it with the testing data, and validate it using the validation data. Here we are training the model based on the RMS loss (specified via the `loss` keyword), with the optimisation done through the [adam](https://arxiv.org/abs/1412.6980) algorithm (which is an first order gradient based stochastic optimization). The model is going to train for 30 epochs (cf. full iterations).

> NOTE: If the optional keyword `validation_data` is not specified I assume the model will just pick out some data from the provided training set to serve as validation data. It is specified here to force the model to be somewhat reproducible.

> NOTE: "adam" is the name of the algorithm (it's not an acronym), and it's one of the go-to algorithms that is used in machine learning for the optimisation problem. Google scholar notes the adam paper ([Kingma & Ba, 2014](https://arxiv.org/abs/1412.6980)) has around 100,000 citations to date (checked in May 2022) and was the #1 most cited scientific paper of the past five years in 2020 [link](https://www.natureindex.com/news-blog/google-scholar-reveals-most-influential-papers-research-citations-twenty-twenty), so who says no one cares about numerical quantitative research...

In [None]:
model.compile(loss='mean_squared_error', optimizer='adam')
batch_size = 1280
epochs = 30
history = model.fit(X_train, y_train,
                    batch_size=batch_size,
                    epochs=epochs,
                    verbose=1,
                    validation_data=(X_val, y_val))

The training record is given in this case in the `history` variable. The code below plots out the RMS loss  against the epoch, and note that the loss is gradually decreasing, but is not zero. We probably don't want it zero, because that usually would indicate a model is very overfitted. Remember here the data is scaled and a RMS loss of 1 is pretty bad, so the model is getting a RMS loss below 0.1 for the whole dataset, which might be regarded as reasonable.

In [None]:
# plot out diagnostic relevant training of neural network

fig = plt.figure(figsize=(10, 4))
ax = plt.axes()
ax.plot(history.history['loss'])
ax.plot(history.history['val_loss'])
ax.set_xlabel('Epochs')
ax.set_ylabel('Loss')
ax.set_title('model loss')
ax.grid()
ax.legend(['Train','validation'], loc='upper right')

Now that we have a model, we can proceed to use it. The acid test here is to use the model on the test data to see how the model performs, given the model has not been exposed to the test data at all. We just need to make sure to undo the data standardisation if we want a "real" output. The output in this case is salinity, and we make a subroutine below to undo the standardisation based on salinity data.

In [None]:
# subroutine to return unscaled output

def destd(x):
    scaler1 = preprocessing.StandardScaler()
    scaler1.fit(da_s.PSAL[:,:x.size])
    return scaler1.inverse_transform(x.reshape(1,-1)).reshape(-1)

# make prediction
y_pred = model.predict(X_test)

The code below randomly chooses three profiles in the test dataset and plots out the truth $y_{\rm test}$ and the prediction $f(X_{\rm test})$, and we should be able to see that the model is not perfect, but fairly reasonable in the deeper parts of the profile, with deficiencies in some cases near the ocean surface.

In [None]:
# randomly plot three profiles and the prediction from neural networks

dp = da_s.DEPTH[:xx.shape[1]]
np.random.seed(4167)

fig, ax = plt.subplots(nrows=1, ncols=3, figsize=(14, 6), facecolor='w', edgecolor='k', sharey='row')

for i in range(3):
    ind = np.random.randint(y_pred.shape[0]+1)
    ax[i].plot(destd(y_pred[ind]), dp, label='Prediction')
    ax[i].plot(destd(y_test[ind]), dp, label='Test')
    ax[i].set_title(f'Argo profile # {ind}')
    ax[i].set_xlabel(r'Salinity ($\mathrm{g}/\mathrm{kg}$)')
    ax[i].grid()
    
ax[0].set_ylabel(r'Depth ($\mathrm{m}$)')
ax[1].legend()

To get a more quantitative measure, we compute the RMS loss of all profiles in the three datasets as a function of depth. Here we expect the model to perform reasonably well for the training and validation dataset, with the test data being the "worse", given the model has not seen the test data before. Just from seeing the above plots, we might have an expectation that the model would perform worse in terms of accuracy near the top of the ocean, and do better at depth.

In [None]:
# compute the RMS errors between input data and prediction
from sklearn.metrics import mean_squared_error

def finde(x_e, y_e):
    y_pred = model.predict(x_e)       
    return np.sqrt(mean_squared_error(y_pred, y_e, multioutput='raw_values')) 
  
fig = plt.figure(figsize=(6, 8))
ax = plt.axes()
ax.plot(finde(X_test,  y_test),  dp, "C3", label="Test") 
ax.plot(finde(X_train, y_train), dp, "C0", label="Train")
ax.plot(finde(X_val,   y_val),   dp, "C1", label="Validation")
ax.set_xlabel('RMS error')
ax.set_ylabel(r'Depth ($\mathrm{m}$)')
ax.set_title('Averaged RMS error')
ax.legend()
ax.grid()

So the above plot is largely consistent with our expectations. It is however of interest to see that the RMS error starts increasing below 600 m depth. This could be a numerical or physical artifact, but we don't really have enough information to say thus far.

> NOTE: The temperature here is the in-situ temperature rather than potential temperature, which might make a slight difference. The model behaviour could be arising from this particular realisation of the model training and can be tested by doing some sort of ensemble calculation to get the RMS loss (I do think the behaviour is probably generic, although I haven't looked). It could just be that the neural network is better at getting the bulk values rather than the extremes.

Below we compile the histogram of RMS loss values over *depth*, with an aim to find geographical locations that the model performs particularly bad at.

In [None]:
def finde1(x_e,y_e):
    y_pred = model.predict(x_e)
    err = np.zeros(y_e.shape[0])
    for i in range(y_e.shape[0]):
        err[i] = np.sqrt(((y_pred[i]-y_e[i])**2).mean())
    return err

In [None]:
# histogram of error of SCALED data and associated predictions (>1 is bad basically) aveaged over depths

eh_all = finde1(xx, yy)  # this is SCALED data
eh1_a  = np.where(eh_all < 1, eh_all, 1)

fig = plt.figure(figsize=(6, 4))
ax = plt.axes()
ax.hist(eh1_a, bins=25, color='#0504aa', alpha=0.7, rwidth=0.85)
ax.set_ylabel('Count')
ax.set_xlabel('RMS')
ax.grid()

From the histogram we see the model performs ok in most locations, with a slow-ish decaying tail as we get to the higher RMS losses. There are cases where the model does really badly though, as given the the RMS loss of 1 (again this is the RMS loss for scaled data, so 1 is bad).

The plot below shows the geographical distribution of the RMS loss over depth.

In [None]:
# plot out the geostrophica distribution of errors
eh1_a = np.where(eh_all < 1, eh_all, 1)
X = da_s

fig = plt.figure(figsize=(14, 4))

# plot the raw errors
ax = plt.subplot(1, 2, 1)
cs = ax.scatter(X.LONGITUDE, X.LATITUDE, 2, eh1_a, cmap='cubehelix_r')
ax.set_ylim(-70, 70)
ax.set_xlim(-180, 180)
ax.set_xlabel(r"lon ($^\circ$)")
ax.set_ylabel(r"lat ($^\circ$)")
plt.colorbar(cs)

# make it a bit easier to read and limit the colorbars
eh1_a = np.where(eh_all < 0.4, eh_all, 0.4)
ax = plt.subplot(1, 2, 2)
cs = ax.scatter(X.LONGITUDE, X.LATITUDE, 2, eh1_a, cmap='cubehelix_r')
ax.set_ylim(-70, 70)
ax.set_xlim(-180, 180)
plt.colorbar(cs)
ax.set_xlabel(r"lon ($^\circ$)")

As we can see the places where the model seems to perform badly in the mediterranean sea and the outflow region, which is plausible since this is a region with particularly strong salinity (through high evaporation and low percipitation) that is perhaps not as well correlated with temperature. There are regions around the tip of South Africa and the Bay of Bengal in the Indian Ocean where the errors are also notable, which are also regions known to have salinity anomalies (the Aghulas rings transport salty water, while the Bay of Bengal is rather fresh by comparison).

The below code picks out one of these cases with a larger RMS loss to compare the prediction and given data.

In [None]:
lt_eh = eh_all.tolist()
f_eh  = sorted(i for i in lt_eh if i >= 0.4)
print(f"number of profiles above error threshold = {len(f_eh)}")
print(f"  totalling {len(f_eh)/xx.shape[0]*100:.2f}% of data")

y_predal = model.predict(xx)

In [None]:
# randomly select one of the "bad" profiles and see what predictions look compared with data

np.random.seed(4167)
idx = lt_eh.index(f_eh[np.random.randint(len(f_eh)+1)])

fig, ax = plt.subplots(nrows=1, ncols=2,  facecolor='w',figsize=(7,5), edgecolor='k', sharey='row')

h1 = da_s.isel(N_PROF=idx)

ax[0].plot(destd(y_predal[idx]), dp, label='Prediction')
ax[0].plot(h1.PSAL.values,h1.DEPTH, '.-', label='Data')
ax[0].set_xlabel(r'Salinity ($\mathrm{g}/\mathrm{kg}$)')
ax[0].set_ylabel(r'Depth ($\mathrm{m}$)')
ax[0].legend()
ax[0].grid()

ax[1].plot(h1.TEMP.values,h1.DEPTH,'.-', color='k')
ax[1].grid()
ax[1].set_xlabel('Temp ($^\circ$)')

fig.suptitle(f"""ARGO profile # {da_s.N_PROF[idx].values}
[Lat: {h1.LATITUDE.values:05.2f}°, Lon: {h1.LONGITUDE.values:05.2f}°]""")

For the chosen seed (4167), the model makes a fairly large error in the salinity throughout the depth (salinity unlike temperature doesn't vary that much in raw numerical values in the ocean). Without looking at other "bad" cases, we cannot definitively conclude here why the model fails at these particular cases (which is often an issue with black box models like neural networks...)

> <span style="color:red">**Q.**</span> Try picking out the particularly "bad" profiles (through xarray, error thresholds, or otherwise) and see what the model is actually doing there.

> <span style="color:red">**Q.**</span> Consider doing ensemble type calculations to test for robustness of model behaviour and skill.

### Some personal comments

I personally feel uneasy about black box models because you could be getting right and/or wrong answers for the wrong and/or right reasons, and you don't necessarily know why or have a good way to test it. It is my opinion that black box models need to be used with caution; the last thing I feel people should be doing is pressing buttons and relying blindly on the numbers that come out (too much doing not enough thinking). In the absence of knowing what the model is actually doing, it would be prudent to explore how and when the model works and/or fails.

----------------
# More involved exercises with this notebook

## 1) `argopy`

Less machine learning and more data manipulation + argo data. There is now a package called [`argopy`](https://argopy.readthedocs.io/en/latest/) that you might want to try to use instead of the provided binary files here. Consider swapping out what we have done above that reads a local file for something that reads data off a remote database instead.

> NOTE: Let me know if you end up doing this, as that will release some space on my GitHub LFS quota, and you can get acknowledgment on the code too :)

## 2) Clustering algorithms for categorical data

Try it for some other data (e.g. atmospheric data, biogeochemical data) and apply the clustering algorithms to see what you get.

## 3) Neural networks

In our case we used temperature to predict salinity. Consider trying for example:

* predict temperature from salinity

* include other features such as depth and so forth

* use temperature, salinity and or depth to predict `sigma0` (potential density referenced to sea surface); in this case we know the truth answer given `sigma0` is actually derived from temperature and salinity, so we could see if the neural network is at least able to reproduce the derived data itself

* make up some for yourself

## 4) Random forests

Have a look at the [random forest](https://en.wikipedia.org/wiki/Random_forest) algorithms and see what you get from that. 

Random forest might be more suited to the smaller dataset such as penguin or iris data (see [example](https://medium.com/edviconedu/random-forest-algorithm-to-classify-iris-flower-datasets-in-google-colab-b0652a8a8a66)). Try and do some of the exercises floated around these set of notebooks for some choice of datasets and questions of your own choosing (be your own teacher!)