In [None]:
import mlflow
import xarray as xr
import matplotlib.pyplot as plt
import os,sys
sys.path.insert(1, os.path.join(os.getcwd()  , '../../src/gz21_ocean_momentum'))
from utils import select_experiment, select_run

%env MLFLOW_TRACKING_URI /scratch/ag7531/mlruns

In [None]:
%env MLFLOW_TRACKING_URI

In [None]:
exp_id, _ = select_experiment()

In [None]:
runs=mlflow.search_runs(experiment_ids=(exp_id,))

In [None]:
runs

In [None]:
import os
os.listdir(runs.iloc[7]['artifact_uri'])

In [None]:
def load_data_from_run(i_run: int):
    run = runs.iloc[i_run]
    filenames = os.listdir(run['artifact_uri'])
    datasets = []
    for fn in filenames:
        print(f'Loading {fn}')
        name = fn.split('_')[0]
        datasets.append(xr.open_dataset(os.path.join(run['artifact_uri'], fn)))
    return datasets

In [None]:
low_rez_datas = [load_data_from_run(i) for i in (9, 2, 6, 7)]
data_h = load_data_from_run(0)

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sun Dec  6 13:21:28 2020

@author: arthur
"""
from scipy.ndimage import gaussian_filter
import numpy as np

def coarsen(data, factor):
    data = xr.apply_ufunc(lambda x: gaussian_filter(x, factor / 2), data,
                          input_core_dims=[['x', 'y']],
                         output_cor_dims=[['x', 'y']])
    data = data.coarsen(factor)
    return data.mean()

def kinetic_energy(u: np.ndarray, v: np.ndarray, model):
    if u.shape != v.shape:
        u = model.h2mat(model.IuT.dot(u.flatten()))
        v = model.h2mat(model.IvT.dot(v.flatten()))
    return np.mean(u**2 + v**2)

def uv2Tgrid(u: np.ndarray, v: np.ndarray, model):
    if u.ndim == 3:
        uv = [uv2Tgrid(u[i, ...], v[i, ...], model)
              for i in range(u.shape[0])]
        u = np.stack((uv[i][0] for i in range(u.shape[0])))
        v = np.stack((uv[i][1] for i in range(u.shape[0])))
        return u, v
    u = model.h2mat(model.IuT.dot(u.flatten()))
    v = model.h2mat(model.IvT.dot(v.flatten()))
    return u, v

def get_kinetic_energy_ts(u: np.ndarray, v: np.ndarray, model):
    n = u.shape[0]
    kE_time_series = np.zeros(n)
    for i in range(n):
        kE_time_series[i] = kinetic_energy(u[i, ...], v[i, ...], model)
    return kE_time_series

def stream_function(u: np.ndarray, v: np.ndarray, model):
    if u.ndim == 3:
        s = [stream_function(u[i, ...], v[i, ...], model)
             for i in range(u.shape[0])]
        return np.stack(s, 0)
    if u.shape != v.shape:
        u = model.h2mat(model.IuT.dot(u.flatten()))
        v = model.h2mat(model.IvT.dot(v.flatten()))
    minus_v = -v
    s = np.cumsum(minus_v, axis=1)
    s += np.cumsum(u, axis=0)
    return s

def cum_mean(data: np.ndarray, axis=0):
    n = data.shape[axis]
    ns = np.arange(n).reshape((n, 1, 1)) + 1
    return 1 / ns * np.cumsum(data, axis)

def eke_spec_avg(u,v,dx,dy, model):
    """ Computes a wavenumber-frequency plot for 3D (t,x,y) data via radial (k = sqrt(kx**2 + ky**2)) integration. TODO: correct normalisation, so that the integral in normal space corresponds to the integral in Fourier space.
    """
    if u.shape != v.shape:
        u, v = uv2Tgrid(u, v, model)
    
    nt,ny,nx = np.shape(u)
    kx = (1/(dx))*np.hstack((np.arange(0,(nx+1)/2.),np.arange(-nx/2.+1,0)))/float(nx)
    ky = (1/(dy))*np.hstack((np.arange(0,(ny+1)/2.),np.arange(-ny/2.+1,0)))/float(ny)

    kxx,kyy = np.meshgrid(kx,ky)
    # radial distance from kx,ky = 0
    kk = np.sqrt(kxx**2 + kyy**2) 

    if nx >= ny: #kill negative wavenumbers
        k  = kx[:int(nx/2)+1]
    else:
        k  = ky[:int(ny/2)+1]

    dk = k[1] - k[0]

    # 2D FFT average
    p_eke = np.empty((nt,ny,nx))
    nxy2 = nx**2*ny**2

    for i in range(nt):
        pu = abs(np.fft.fft2(u[i,:,:]))**2/nxy2
        pv = abs(np.fft.fft2(v[i,:,:]))**2/nxy2
        p_eke[i,:,:] = pu+pv
        if ((i+1)/nt*100 % 5) < (i/nt*100 % 5):
            print(str(int((i+1)/nt*100.))+'%')
    
    p_eke_avg = .5*p_eke.mean(axis=0)

    # create radial coordinates, associated with k[i]
    rcoords = []
    for i in range(len(k)):
        rcoords.append(np.where(kk<k[i]))

    # mulitply by dk to have the corresponding integral
    eke_spec = np.zeros(len(k))
    for i in range(len(k)):
        eke_spec[i] = np.sum(p_eke_avg[rcoords[i][0],rcoords[i][1]])
    
    eke_spec = np.diff(eke_spec) / dk
    k = (k[:-1] + k[1:])/2.

    return k,eke_spec

In [None]:
import sys
sys.path.append('/home/ag7531/code/swe_stochastic_param/')
from shallowwater import ShallowWaterModel

In [None]:
size = 3840
model_l = ShallowWaterModel(Nx=size // 10 // 4, Ny=size // 10 // 4, Lx=size * 1e3, Ly = size * 1e3)
model_h = ShallowWaterModel(Nx=size // 10 // 1, Ny=size // 10 // 1, Lx=size * 1e3, Ly = size * 1e3)

In [None]:
import cmocean
def my_plot(data):
    plt.imshow(data, vmin=-1, vmax=1, origin='lower', cmap=cmocean.cm.delta)


In [None]:
# Low rez no param
new_low_rez_datas = []
for data in low_rez_datas:
    u = data[0]['u'].values
    v = data[4]['v'].values
    eta = data[3]['eta']
    u, v = uv2Tgrid(u, v, model_l)
    da_u = xr.DataArray(u, dims=eta.dims, coords=eta.coords)
    da_v = xr.DataArray(v, dims=eta.dims, coords=eta.coords)
    dataset_l = xr.Dataset(dict(u=da_u, v=da_v, eta=eta))
    dataset_l['kE'] = 1/2 * (dataset_l['u']**2 + dataset_l['v']**2)
    dataset_l = dataset_l.rename(dict(t='time'))
    new_low_rez_datas.append(dataset_l)
    print('ok')
low_rez_datas = new_low_rez_datas

In [None]:
def coarsen(data, factor):
    data = xr.apply_ufunc(lambda x: gaussian_filter(x, factor / 2), data, input_core_dims=[['time']], 
                         output_core_dims=[['time']])
    data = data.coarsen(dict(x=factor, y=factor))
    return data.mean()

# high rez 
u = data_h[0]['u'].values
v = data_h[4]['v'].values
eta = data_h[3]['eta']
u, v = uv2Tgrid(u, v, model_h)
da_u = xr.DataArray(u, dims=eta.dims, coords=eta.coords)
da_v = xr.DataArray(v, dims=eta.dims, coords=eta.coords)
dataset_h = xr.Dataset(dict(u=da_u, v=da_v, eta=eta))
dataset_h = dataset_h.rename(dict(t='time'))
dataset_h = coarsen(dataset_h, 4)

In [None]:
dataset_h['kE'] = 1/2 * (dataset_h['u']**2 + dataset_h['v']**2)

In [None]:
dataset_h['time'] = data.time

In [None]:
dataset_h

In [None]:
%matplotlib notebook
plt.rcParams["figure.figsize"] = (4 * 2, 4 * 2 / 1.618)

plt.figure()
for data in low_rez_datas:
    plt.plot(data['kE'].mean(dim=('x', 'y')))
plt.plot(dataset_h['kE'].mean(dim=('x', 'y')))
plt.ylabel(r'$m^2/s^2$')
plt.xlabel('day')

In [None]:
plt.savefig('online_kE.jpg', dpi=400)

In [None]:
kE_dataset = xr.Dataset()

In [None]:
for i, data in enumerate(low_rez_datas):
    kE_dataset['low_rez_' + str(i)] = data['kE']
kE_dataset['high_rez'] = dataset_h['kE']

In [None]:
kE_dataset

In [None]:
kE_dataset['low_rez_2'].nbytes / 1e9

In [None]:
kE_dataset.to_zarr('/scratch/ag7531/paper_plots_data/' + var + '_' + func + '_l', mode='w')

In [None]:
dataset_h = dataset_h.isel(time=slice(1000, None))
for i in range(4):
    low_rez_datas[i] = low_rez_datas[i].isel(time=slice(1000, None))

In [None]:
dataset_h

In [None]:
import numpy as np
import matplotlib
var = 'eta'
func = 'std'
cmaps = dict(mean=cmocean.cm.delta, std=cmocean.cm.matter)
args = dict(mean=dict(), std=dict(norm=matplotlib.colors.LogNorm()))
vmins=dict(mean=-1.96, std=0.5)
vmaxs=dict(mean=1.96, std=3)
vmins2=dict(mean=0, std=0.0001)
extent = (0, 3840, 0, 3840)

In [None]:
fig = plt.figure()
# Determine limits
std_h = getattr(dataset_h[var], func)(dim='time').std()
for i in range(3):
    plt.subplot(1, 3, i + 1)
    if i < 2:
        im = plt.imshow(getattr(low_rez_datas[i][var], func)(dim='time'), cmap=cmaps[func], **args[func], 
                                                          vmin=std_h*vmins[func] + vmins2[func],
                                                          vmax=std_h*vmaxs[func], origin='lower',
                                                          extent=extent)
    else:
        im = plt.imshow(getattr(dataset_h[var], func)(dim='time'), cmap=cmaps[func], **args[func],
                                              vmin=std_h*vmins[func] + vmins2[func],
                                                          vmax=std_h*vmaxs[func], origin='lower',
                                                            extent=extent)
    if i > 0:
        im.axes.set_yticks([])
    if i == 0:
        im.axes.set_xlabel('km')
        im.axes.set_ylabel('km')
fig.subplots_adjust(right=0.8)
cbar_ax = fig.add_axes([0.85, 0.3, 0.025, 0.4])
fig.colorbar(im, cax=cbar_ax, label='m')

In [None]:
plt.savefig(var + '_' + func + '_l.jpg', dpi=400)

In [None]:
plt.figure()
getattr(dataset_h[var], func)(dim='time').plot(cmap=cmaps[func], **args[func],
                                              vmin=std_h*vmins[func] + vmins2[func],
                                                          vmax=std_h*vmaxs[func])


In [None]:
plt.savefig(var + '_' + func + '_h.jpg', dpi=400)

In [None]:
low_rez_datas[0]

In [None]:
plt.figure()
colors=['b', 'g', 'r', 'c']
for i in range(2):
    # low rez
    m = low_rez_datas[i]['kE'].mean()
    s = low_rez_datas[i]['kE'].std()
    low_rez_datas[i]['kE'].plot.hist(bins=np.linspace(0.01, 1, 99), density=True, log=True, histtype='step', color=colors[i], linewidth=2)
    plt.axvline(m, color=colors[i], linewidth=3)
    plt.axvline(s, color=colors[i], linestyle='--', linewidth=3)
# high-rez
m = dataset_h['kE'].mean()
s = dataset_h['kE'].std()
dataset_h['kE'].plot.hist(bins=np.linspace(0.01, 1, 99), density=True, log=True, histtype='step', color='m', linewidth=2)
plt.axvline(m, color='m', linewidth=3)
plt.axvline(s, color='m', linestyle='--', linewidth=3)

In [None]:
plt.savefig('kE_hist_2.jpg', dpi=400)

In [None]:
plt.figure()
_ = dataset_h['kE'].plot.hist(bins=np.linspace(0, 1, 100), density=True, log=True)
m = dataset_h['kE'].mean()
s = dataset_h['kE'].std()
plt.axvline(m, color='r')
plt.axvline(s, color='g')

In [None]:
plt.savefig('kE_hist_h.jpg', dpi=400)