# Ocean macronutrients

---

## Overview
The availability of several macronutrients controls production in most of the ocean: nitrate, phosphate, and silicate. Here we take a look at maps and depth profiles of these nutrients, and compare them to an observational dataset.

1. General setup
2. Subsetting
3. Transforming from monthly to annual data
4. Compare to World Ocean Atlas data
5. Make depth profiles


## Prerequisites

| Concepts | Importance | Notes |
| --- | --- | --- |
| [Dask Cookbook](https://projectpythia.org/dask-cookbook/README.html) | Helpful | |
| [Intro to Xarray](https://foundations.projectpythia.org/core/xarray.html) | Helpful | |
| [Matplotlib](https://foundations.projectpythia.org/core/matplotlib.html) | Necessary | |
| [Intro to Cartopy](https://foundations.projectpythia.org/core/cartopy/cartopy.html) | Necessary | |


- **Time to learn**: 30 min


---

## Imports

In [None]:
import xarray as xr
import glob
import numpy as np
import matplotlib.pyplot as plt
import cartopy
import cartopy.crs as ccrs
import pop_tools
from dask.distributed import LocalCluster
import s3fs
import netCDF4

## General setup (see intro notebooks for explanations)

### Connect to cluster

In [None]:
cluster = LocalCluster()
client = cluster.get_client()

In [None]:
cluster.scale(20)

In [None]:
client

### Bring in POP grid utilities

In [None]:
ds_grid = pop_tools.get_grid('POP_gx1v7')
lons = ds_grid.TLONG
lats = ds_grid.TLAT
depths = ds_grid.z_t * 0.01

In [None]:
def adjust_pop_grid(tlon,tlat,field):
    nj = tlon.shape[0]
    ni = tlon.shape[1]
    xL = int(ni/2 - 1)
    xR = int(xL + ni)

    tlon = np.where(np.greater_equal(tlon,min(tlon[:,0])),tlon-360.,tlon)
    lon  = np.concatenate((tlon,tlon+360.),1)
    lon = lon[:,xL:xR]

    if ni == 320:
        lon[367:-3,0] = lon[367:-3,0]+360.
    lon = lon - 360.
    lon = np.hstack((lon,lon[:,0:1]+360.))
    if ni == 320:
        lon[367:,-1] = lon[367:,-1] - 360.

    # Trick cartopy into doing the right thing:
    # it gets confused when the cyclic coords are identical
    lon[:,0] = lon[:,0]-1e-8
    
    # Periodicity
    lat  = np.concatenate((tlat,tlat),1)
    lat = lat[:,xL:xR]
    lat = np.hstack((lat,lat[:,0:1]))

    field = np.ma.concatenate((field,field),1)
    field = field[:,xL:xR]
    field = np.ma.hstack((field,field[:,0:1]))
    return lon,lat,field

### Load the data

In [None]:
jetstream_url = 'https://js2.jetstream-cloud.org:8001/'

s3 = s3fs.S3FileSystem(anon=True, client_kwargs=dict(endpoint_url=jetstream_url))

# Generate a list of all files in the bucket
s3path = 's3://pythia/ocean-bgc/cesm/g.e22.GOMIPECOIAF_JRA-1p4-2018.TL319_g17.4p2z.002branch/ocn/proc/tseries/month_1/*'
remote_files = s3.glob(s3path)

# Open all files from bucket
fileset = [s3.open(file) for file in remote_files]

# Open with xarray
ds = xr.open_mfdataset(fileset, data_vars="minimal", coords='minimal', compat="override", parallel=True,
                       drop_variables=["transport_components", "transport_regions", 'moc_components'], decode_times=True)

ds

## Subsetting
Make our dataset smaller so it has just a couple of macronutrient variables we're interested in.

In [None]:
variables =['PO4','NO3','SiO3']

In [None]:
keep_vars=['z_t','z_t_150m','dz','time_bound','time','TAREA','TLAT','TLONG'] + variables
ds = ds.drop_vars([v for v in ds.variables if v not in keep_vars])

Let's take a quick look at nitrate to make sure that things look okay...

In [None]:
ds.NO3.isel(time=0,z_t=0).plot(cmap="viridis")

## Transforming from monthly to annual data
We can't just use xarray's regular `mean()` function because months have different numbers of days in them, so we have to weight by that to ensure the annual mean is accurate. See this [ESDS blog post](https://ncar.github.io/esds/posts/2021/yearly-averages-xarray/) for a more detailed explanation with examples!

In [None]:
def year_mean(ds):
    """
    Properly convert monthly data to annual means, taking into account month lengths.
    Source: https://ncar.github.io/esds/posts/2021/yearly-averages-xarray/
    """
    
    # Make a DataArray with the number of days in each month, size = len(time)
    month_length = ds.time.dt.days_in_month

    # Calculate the weights by grouping by 'time.season'
    weights = (
        month_length.groupby("time.year") / month_length.groupby("time.year").sum()
    )

    # Test that the sum of the weights for each season is 1.0
    np.testing.assert_allclose(weights.groupby("time.year").sum().values, np.ones((len(ds.groupby("time.year")), )))

    # Calculate the weighted average
    return (ds * weights).groupby("time.year").sum(dim="time")

In [None]:
ds_annual = year_mean(ds)
ds_annual

Note that our time coordinate is now called `year` instead, and has only years now. We can select specific years to plot:

In [None]:
ds_annual['NO3'].sel(year=2010).isel(z_t=0).plot()

---

## Summary
You've learned how to plot and evaluate the distribution of some key ocean nutrients in CESM output.

## Resources and references

- [Converting from monthly to annual data](https://ncar.github.io/esds/posts/2021/yearly-averages-xarray/)
- [About World Ocean Atlas data](https://www.ncei.noaa.gov/products/world-ocean-atlas)
- [World Ocean Atlas data location](https://www.ncei.noaa.gov/access/world-ocean-atlas-2018/)