# Prepare climate modes data

In [None]:
from dask_jobqueue import PBSCluster
from dask.distributed import Client

In [None]:
# One node on Gadi has 48 cores - try and use up a full node before going to multiple nodes (jobs)

walltime = '01:00:00'
cores = 4
memory = str(4 * cores)
memory = memory + 'GB'

cluster = PBSCluster(walltime=str(walltime), cores=cores, memory=str(memory),
                     job_extra=['-l ncpus='+str(cores),
                                '-l mem='+str(memory),
                                '-P xv83',
                                '-l storage=gdata/xv83+gdata/rt52+scratch/xv83'],
                     header_skip=["select"])

In [None]:
cluster.scale(jobs=1)
client = Client(cluster)

In [None]:
client

In [None]:
import xarray as xr
import numpy as np
import pandas as pd

import matplotlib
import matplotlib.pyplot as plt

import cartopy.crs as ccrs
import cartopy
cartopy.config['pre_existing_data_dir'] = '/g/data/xv83/dr6273/work/data/cartopy-data'
cartopy.config['data_dir'] = '/g/data/xv83/dr6273/work/data/cartopy-data'

# MJO

### Adames index

1. Daily anomalies of $\chi_{150} - \chi_{850}$
2. Area weight using 
2. PCA. First two PCs and EOFs.
3. Amplitude: $\sqrt{\mathrm{PC}_{1}^{2} + \mathrm{PC}_{2}^{2}}$.
4. Phase: $\arctan\mathrm{2}(\mathrm{PC}_{2}, \mathrm{PC}_{1})$. Note the order of PCs here!
5. Check signal propagates eastwards. If it doesn't, change the sign of one of the PCs.

In [None]:
def write_and_read(ds, filename):
    """
    Write to zarr and read back in.
    """
    if isinstance(ds, xr.DataArray):
        is_DataArray = True
        name = ds.name
        ds = ds.to_dataset(name=name)
    else:
        is_DataArray = False
            
    ds.to_zarr(filename, consolidated=True, mode='w')
    ds = xr.open_zarr(filename, consolidated=True)
    
    return ds[name] if is_DataArray else ds

### SVD takes ~1.5 hours to compute (for two data sets - unadjusted and detrended vpot)

In [None]:
compute = False

In [None]:
%%time
if compute:
    client.close()
    cluster.close()
    
    walltime = '02:00:00'
    cores = 24
    memory = str(4 * cores)
    memory = memory + 'GB'

    cluster = PBSCluster(processes=1,
                         walltime=str(walltime), cores=cores, memory=str(memory),
                         job_extra=['-l ncpus='+str(cores),
                                    '-l mem='+str(memory),
                                    '-P xv83',
                                    '-l storage=gdata/xv83+gdata/rt52+scratch/xv83'],
                         header_skip=["select"])
    cluster.scale(jobs=1)
    client = Client(cluster)
    
    # Load vpot data, take difference, anomalise, rechunk to single chunk and then write and read back in
    vpot = xr.open_zarr('/g/data/xv83/dr6273/work/data/era5/vpot/vpot_era5_daily_2x2.zarr',
                    consolidated=True)
    vpot_diff = vpot.vpot.sel(level=150) - vpot.vpot.sel(level=850)
    vpot_anoms = vpot_diff.groupby('time.dayofyear').apply(lambda x: x - x.mean('time'))
    vpot_anoms = vpot_anoms.chunk({'time': -1, 'latitude': -1, 'longitude': -1})
    vpot_anoms = write_and_read(vpot_anoms,
                            '/g/data/xv83/dr6273/work/data/era5/vpot/derived/vpot_150-850_diff_anoms_era5_daily_2x2_single_chunk.zarr')
    # Detrend vpot
    vpot_anoms_dt = fn.detrend_dim(vpot_anoms, 'time')
    vpot_anoms_dt.name = 'vpot'
    vpot_anoms_dt = write_and_read(vpot_anoms_dt,
                                '/g/data/xv83/dr6273/work/data/era5/vpot/derived/vpot_150-850_diff_anoms_era5_daily_2x2_single_chunk_detrended.zarr')
    # PCA
    import xeof
    xeof.core.LAT_NAME = 'latitude'
    
    vpot_eof = xeof.eof(vpot_anoms,
                        sensor_dims=['latitude', 'longitude'],
                        sample_dim='time',
                        weight='sqrt_cos_lat',
                        n_modes=5,
                        norm_PCs=False)
    vpot_eof = vpot_eof.compute()
    vpot_eof.to_zarr('/g/data/xv83/dr6273/work/data/era5/vpot/derived/vpot_150-850_diff_anoms_era5_daily_2x2_PCA.zarr',
                     mode='w',
                     consolidated=True)
    
    vpot_eof_dt = xeof.eof(vpot_anoms_dt,
                            sensor_dims=['latitude', 'longitude'],
                            sample_dim='time',
                            weight='sqrt_cos_lat',
                            n_modes=5,
                            norm_PCs=False)
    vpot_eof_dt = vpot_eof_dt.compute()
    vpot_eof_dt.to_zarr('/g/data/xv83/dr6273/work/data/era5/vpot/derived/vpot_150-850_diff_anoms_era5_daily_2x2_PCA_detrended.zarr',
                         mode='w',
                         consolidated=True)
    
    client.close()
    cluster.close()
else:
    vpot_eof = xr.open_zarr('/g/data/xv83/dr6273/work/data/era5/vpot/derived/vpot_150-850_diff_anoms_era5_daily_2x2_PCA.zarr',
                            consolidated=True)
    vpot_eof_dt = xr.open_zarr('/g/data/xv83/dr6273/work/data/era5/vpot/derived/vpot_150-850_diff_anoms_era5_daily_2x2_PCA_detrended.zarr',
                                consolidated=True)

In [None]:
vpot_eof = vpot_eof.compute()

In [None]:
vpot_eof_dt = vpot_eof_dt.compute()

### Check EOFs and PCs look right

In [None]:
with plt.rc_context(plt_params):
    fig = plt.figure(figsize=(6.9,3), dpi=150)
    
    for mode in range(2):
        
        # EOFs
        pos = mode * 2 + 1
        ax = fig.add_subplot(2, 2, pos, projection=ccrs.PlateCarree(central_longitude=180))
        ax.coastlines(lw=0.5)
        vpot_eof.eof.sel(mode=mode+1).plot(ax=ax,
                                            transform=ccrs.PlateCarree(),
                                            cbar_kwargs={'location': 'left',
                                                         'label': r'EOF$_{'+str(mode+1)+'}$'})
        ax.set_title('')
#         ax.text(0.01, 0.15, str(np.round(vpot_eof.explained_var.sel(mode=mode+1).values * 100, 2))+'%',
#                 transform=ax.transAxes)
        
        # PCs
        pos = mode * 2 + 2
        ax = fig.add_subplot(2, 2, pos)
        vpot_eof.pc.sel(mode=mode+1, time='2020').plot(ax=ax)
        ax.set_ylim(-1e9, 1e9)
        ax.set_title('')
        if mode < 1:
            ax.set_xticklabels('')
            ax.set_xlabel('')
        ax.set_ylabel(r'PC$_{'+str(mode+1)+'}$')
        ax.yaxis.tick_right()
        ax.yaxis.set_label_position("right")
        
    plt.subplots_adjust(hspace=0.18, wspace=0.05)
    
    plt.savefig('./figures/vpot_eofs.pdf', format='pdf', dpi=400, bbox_inches='tight')

### Amplitude $A$ and phase $\theta$

In [None]:
def calc_mjo(da, mode_1=1, mode_2=2, transform_mode_1=False, transform_mode_2=False):
    """
    Calculate the phase and amplitude of the MJO
    """
    m1 = da.sel(mode=mode_1)
    m2 = da.sel(mode=mode_2)
    
    if transform_mode_1:
        m1 = m1 * -1
    if transform_mode_2:
        m2 = m2 * -1
        
    amplitude = xr.ufuncs.sqrt(xr.ufuncs.square(m1) + xr.ufuncs.square(m2))
    phase = xr.ufuncs.arctan2(m2, m1)
    return amplitude, phase

### To propagate eastwards, we need to ensure $d\theta / dt$ is positive.

- The figure below shows that in our case it is negative, so we multiply an EOF and the corresponding PC by -1

In [None]:
calc_mjo(vpot_eof.pc.sel(time='2020'))[1].plot()

### Multiply EOF 2 by -1

In [None]:
mjo_amplitude, mjo_phase = calc_mjo(vpot_eof.pc, transform_mode_2=True)

mjo = mjo_phase.to_dataset(name='phase')
mjo['amplitude'] = mjo_amplitude

In [None]:
mjo_amplitude_dt, mjo_phase_dt = calc_mjo(vpot_eof_dt.pc, transform_mode_2=True)

mjo_dt = mjo_phase_dt.to_dataset(name='phase_detrended')
mjo_dt['amplitude_detrended'] = mjo_amplitude_dt

### We also want to label the phase according to its geographical location

When the phase, arctan2(PC2, PC1), is zero, PC2 must be zero, and PC1 must be positive. From looking at EOF1, we now know that phase of zero corresponds to upper-level convergence over the Maritime Continent, and an active MJO in the western Hemisphere/Africa.

When PC2 = 0 and PC1 is negative, the phase is $\pi$ and the MJO is active over the Maritime Continent.

We label the phases from 1 to 8 following standard terminology:
1. W Hem. & Africa
2. Indian Ocean
3. Indian Ocean
4. Maritime Continent
5. Maritime Continent
6. W Pacific
7. W Pacific
8. W Hem. & Africa

In [None]:
def discretise_phase(da):
    """
    Discretise MJO phase into octants
    """
    bins = np.linspace(-np.pi, np.pi, 9) 
    phase = xr.apply_ufunc(np.digitize, da, bins, dask='allowed')
    # Relabel so that octants refer to correct phase e.g. 0 < phase < pi/4 should be labelled "1"
    phase = phase.where(phase > 4, phase + 8)
    return phase - 4

In [None]:
mjo['phase_ID'] = discretise_phase(mjo.phase)

In [None]:
mjo.to_zarr('/g/data/xv83/dr6273/work/data/era5/climate_modes/mjo_daily_1979-2020.zarr',
            mode='w',
            consolidated=True)

In [None]:
mjo_dt['phase_ID_detrended'] = discretise_phase(mjo_dt.phase_detrended)

In [None]:
mjo_dt.to_zarr('/g/data/xv83/dr6273/work/data/era5/climate_modes/mjo_daily_1979-2020_detrended.zarr',
               mode='w',
               consolidated=True)

### Check polar representation of EOFs

In [None]:
def plot_polar(eofs):
    """
    Plot MJO in polar representation
    """
    # Change lons to 0-360
    mjo_shift = eofs.copy()
    mjo_shift['longitude'] = (mjo_shift['longitude'] + 360) % 360
    mjo_shift = mjo_shift.sortby(mjo_shift['longitude'])
    
    shifted_eof_amplitude, shifted_eof_phase = calc_mjo(mjo_shift.eof.mean('latitude'), transform_mode_2=True)
    
    fig, ax = plt.subplots(1,1,dpi=100)
    (mjo_shift.eof.sel(mode=1).mean('latitude') * 1).plot()
    (mjo_shift.eof.sel(mode=2).mean('latitude') * -1).plot()
    shifted_eof_amplitude.plot(color='k')
    ax2 = ax.twinx()
    shifted_eof_phase.plot(ax=ax2, color='gray')

In [None]:
plot_polar(vpot_eof)

In [None]:
plot_polar(vpot_eof_dt)

### Compare with Bureau of Meteorology figures
- http://www.bom.gov.au/climate/mjo/
- We multiply PC2 by -1, which gives a rotated version of the BoM figures.
- To align with orientation of BoM plots, multiply PC1, not PC2, by -1

In [None]:
pc1 = vpot_eof.pc.sel(mode=1)
pc2 = vpot_eof.pc.sel(mode=2) * -1

In [None]:
def plot_mjo(pc1, pc2, time_str, z_dim='time', phase=None, filename=None, save_fig=False):
    """
    Plot MJO progression
    """
    with plt.rc_context(plt_params):
        fig, ax = plt.subplots(1, 1, figsize=(4.9,4), dpi=150)

        ax.set_xlim(-1.7e9, 1.7e9)
        ax.set_ylim(-1.7e9, 1.7e9)
        ax.axvline(0, color='lightgray', zorder=0)
        ax.axhline(0, color='lightgray', zorder=0)
        ax.plot((-1.7e9, 1.7e9), (-1.7e9, 1.7e9), color='lightgray', zorder=0)
        ax.plot((-1.7e9, 1.7e9), (1.7e9, -1.7e9), color='lightgray', zorder=0)

        x = pc1.sel(time=time_str).values
        y = pc2.sel(time=time_str).values
        if z_dim == 'time':
            z = pd.to_datetime(pc1.sel(time=time_str).time.values).strftime('%Y-%m-%d')
            c = range(len(z))
        elif z_dim == 'phase':
            z = phase.sel(time=time_str).values
            c = z
        else:
            raise ValueError("z_dim should be 'time' or 'phase'")

        ax.plot(x, y, color='k', zorder=1)
        if z_dim == 'time':
            sc = ax.scatter(x, y, c=c, zorder=2, cmap='viridis')
        else:
            norm = matplotlib.colors.Normalize(vmin=-3.2, vmax=3.2)
            sc = ax.scatter(x, y, c=c, zorder=2, cmap='RdBu', norm=norm)
            
        ax.text(x[0]+0.1e9, y[0]+0.0e9, 'Start')
        ax.text(x[-1]-0.2e9, y[-1]+0.1e9, 'End')

        ax.set_xlabel(r'PC$_{1}$')
        ax.set_ylabel(r'PC$_{2}$')

        # The following octant labels were determined by comparing with figures from www.bom.gov.au/climate/mjo/
        ax.text(0.96, 0.75, '1', fontsize=plt_params['font.size']+7, ha='center', va='center', transform=ax.transAxes)
        ax.text(0.75, 0.95, '2', fontsize=plt_params['font.size']+7, ha='center', va='center', transform=ax.transAxes)
        ax.text(0.25, 0.95, '3', fontsize=plt_params['font.size']+7, ha='center', va='center', transform=ax.transAxes)
        ax.text(0.04, 0.75, '4', fontsize=plt_params['font.size']+7, ha='center', va='center', transform=ax.transAxes)
        ax.text(0.04, 0.25, '5', fontsize=plt_params['font.size']+7, ha='center', va='center', transform=ax.transAxes)
        ax.text(0.25, 0.05, '6', fontsize=plt_params['font.size']+7, ha='center', va='center', transform=ax.transAxes)
        ax.text(0.75, 0.05, '7', fontsize=plt_params['font.size']+7, ha='center', va='center', transform=ax.transAxes)
        ax.text(0.96, 0.25, '8', fontsize=plt_params['font.size']+7, ha='center', va='center', transform=ax.transAxes)

        ax.text(0.5, 0.96, 'Indian Ocean', ha='center', va='center', transform=ax.transAxes)
        ax.text(0.04, 0.5, 'Maritime Continent', rotation=90, ha='center', va='center', transform=ax.transAxes)
        ax.text(0.5, 0.04, 'Western Pacific', ha='center', va='center', transform=ax.transAxes)
        ax.text(0.96, 0.5, 'W. Hem. and Africa', rotation=270, ha='center', va='center', transform=ax.transAxes)

        cb = fig.colorbar(sc)
        if z_dim == 'phase':
            cb.set_ticks(np.arange(-np.pi, np.pi+0.01, np.pi/4))
            cb.set_ticklabels([r'$-\pi$', r'$-3\pi/4$', r'$-\pi/2$', r'$-\pi/4$', '$0$',
                              r'$\pi/4$', r'$\pi/2$', r'$3\pi/4$', r'$\pi$'])
            cb.set_label(r'$\theta$')
        else:
            cb.set_ticks(c[::3])
            cb.set_ticklabels(z[::3])
            
        if save_fig:
            plt.savefig('./figures/'+filename, format='pdf', dpi=400, bbox_inches='tight')

In [None]:
plot_mjo(pc1, pc2, '2020-01')

In [None]:
plot_mjo(pc1, pc2, '2020-01', z_dim='phase', phase=mjo.phase, filename='mjo_phase_2020-01.pdf', save_fig=True)

# Compute monthly MJO statistics

If we aggregate to monthly data then we can use existing codes to aggregate further to seasonal data needed for the coffee analysis.

Compute these statistics:
- Number of days in each phase per month
- Average magnitude of all days in each phase each month

In [None]:
def mjo_stats(mjo_ds, detrended=False):
    """
    Compute monthly statistics of MJO
    """
    if detrended:
        phase_ID_name = 'phase_ID_detrended'
        amplitude_name = 'amplitude_detrended'
    else:
        phase_ID_name = 'phase_ID'
        amplitude_name = 'amplitude'
    
    mjo_monthly = xr.Dataset()

    phase_dpm_list = []
    mean_amplitude_list = []

    for phase in range(1,9):

        phase_dpm_list.append(xr.where(mjo_ds[phase_ID_name] == phase, True, False) \
                              .resample(time='1MS').sum() \
                              .expand_dims({'phase_ID': [phase]}) \
                              .assign_coords({'phase_ID' : [phase]})
                      )
        mean_amplitude_list.append(mjo_ds[amplitude_name].where(mjo_ds[phase_ID_name] == phase) \
                                   .resample(time='1MS').mean() \
                                   .expand_dims({'phase_ID': [phase]}) \
                                   .assign_coords({'phase_ID': [phase]})
                                  )

    if detrended:
        mjo_monthly['mjo_days_per_month_detrended'] = xr.concat(phase_dpm_list, dim='phase_ID')
        mjo_monthly['mjo_mean_amplitude_detrended'] = xr.concat(mean_amplitude_list, dim='phase_ID')
    else:
        mjo_monthly['mjo_days_per_month'] = xr.concat(phase_dpm_list, dim='phase_ID')
        mjo_monthly['mjo_mean_amplitude'] = xr.concat(mean_amplitude_list, dim='phase_ID')
    
    return mjo_monthly

In [None]:
mjo_monthly = mjo_stats(mjo)

In [None]:
mjo_monthly.to_zarr('/g/data/xv83/dr6273/work/data/era5/climate_modes/mjo_monthly_1979-2020.zarr',
                   mode='w', consolidated=True)

In [None]:
mjo_monthly_dt = mjo_stats(mjo_dt, detrended=True)

In [None]:
mjo_monthly_dt.to_zarr('/g/data/xv83/dr6273/work/data/era5/climate_modes/mjo_monthly_1979-2020_detrended.zarr',
                   mode='w', consolidated=True)

# Also calculate anomalies

In [None]:
mjo_monthly_anoms = mjo_monthly.groupby('time.month').apply(lambda x: x - x.mean('time'))

In [None]:
mjo_monthly_anoms.to_zarr('/g/data/xv83/dr6273/work/data/era5/climate_modes/mjo_monthly_anoms_1979-2020.zarr',
                           mode='w', consolidated=True)

In [None]:
mjo_monthly_dt_anoms = mjo_monthly_dt.groupby('time.month').apply(lambda x: x - x.mean('time'))

In [None]:
mjo_monthly_dt_anoms.to_zarr('/g/data/xv83/dr6273/work/data/era5/climate_modes/mjo_monthly_anoms_1979-2020_detrended.zarr',
                               mode='w', consolidated=True)

# Close cluster

In [None]:
client.close()
cluster.close()