```
This software is a part of GPU Ocean.

Copyright (C) 2019  SINTEF Digital

In this notebook we carry out prototyping for developing a new 
ensemble class that can be used for reading observations from file
and still work in the current Data Assimilation structure.

This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program.  If not, see <http://www.gnu.org/licenses/>.
```

# Development of a new ensemble class based on files

In this notebook we carry out prototyping for developing a new 
ensemble class that can be used for reading observations from file
and still work in the current Data Assimilation structure.


## Set environment

In [None]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import numpy as np
import matplotlib
from matplotlib import pyplot as plt
from matplotlib import animation, rc

import pycuda.driver as cuda
import os
import sys
import datetime

from importlib import reload
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), '../../')))

#Set large figure sizes
rc('figure', figsize=(16.0, 12.0))
rc('animation', html='html5')

#Import our simulator
from SWESimulators import IPythonMagic, CDKLM16, EnsembleFromFiles

from SWESimulators import BaseOceanStateEnsemble, SimReader, Observation
from SWESimulators import DataAssimilationUtils as dautils


In [None]:
%cuda_context_handler gpu_ctx

In [None]:
#Create output directory for images
#imgdir = 'double_jet'
#filename_prefix = imgdir + "/" + datetime.datetime.now().strftime("%Y_%m_%d-%H_%M_%S") + "_"
#os.makedirs(imgdir, exist_ok=True)
#print("Saving images to " + imgdir)

### Define functions for plotting

In [None]:
def imshow(im, interpolation="None", title=None, figsize=(4,4), interior=False):
    fig = plt.figure(figsize=figsize)
    
    if interior:
        im = plt.imshow(im[2:-2,2:-2], interpolation=interpolation, origin='lower')
    else:
        im = plt.imshow(im, interpolation=interpolation, origin='lower')
    
    plt.colorbar()
    if title is not None:
        plt.title(title)
        
def imshow3(eta, hu, hv, interpolation="None", title=None, figsize=(12,3), 
            interior=False, color_bar_from_zero=False):
    fig, axs = plt.subplots(1,3, figsize=figsize)
    
    eta_max = np.max(np.abs(eta))
    huv_max = max(np.max(np.abs(hu)), np.max(np.abs(hv)))
    eta_min = -eta_max
    huv_min = -huv_max
    if color_bar_from_zero:
        eta_min, huv_min = 0, 0
    
    if interior:
        eta_im = axs[0].imshow(eta[2:-2,2:-2], interpolation=interpolation, origin='lower', vmin=eta_min, vmax=eta_max)
    else:
        eta_im = axs[0].imshow(eta, interpolation=interpolation, origin='lower', vmin=eta_min, vmax=eta_max)
    axs[0].set_title("$\eta$")
    plt.colorbar(eta_im, ax=axs[0])
    
    if interior:
        hu_im = axs[1].imshow(hu[2:-2,2:-2], interpolation=interpolation, origin='lower', vmin=huv_min, vmax=huv_max)
    else:
        hu_im = axs[1].imshow(hu, interpolation=interpolation, origin='lower', vmin=huv_min, vmax=huv_max)
    axs[1].set_title("$hu$")
    plt.colorbar(hu_im, ax=axs[1])

    if interior:
        hv_im = axs[2].imshow(hv[2:-2,2:-2], interpolation=interpolation, origin='lower', vmin=huv_min, vmax=huv_max)
    else:
        hv_im = axs[2].imshow(hv, interpolation=interpolation, origin='lower', vmin=huv_min, vmax=huv_max)
    axs[2].set_title("$hv$")
    plt.colorbar(hv_im, ax=axs[2])

    if title is not None:
        plt.suptitle(title)
    plt.tight_layout()
    

def days_to_sec(days):
    return days*24*60*60

def truth_time_step(t):
    t = t - days_to_sec(3)
    return int(t/(60*60))

# The new class

We base in on the OceanStateEnsemble.

In [None]:
%%time
if 'ensemble' in globals():
    ensemble.cleanUp()
    del ensemble
    
reload(BaseOceanStateEnsemble)
reload(Observation)
reload(EnsembleFromFiles)

ensemble_init_path = os.path.abspath('double_jet_ensemble_init/')
truth_path = os.path.abspath('double_jet_truth/')
#ensemble = EnsembleFromFile(

print(os.path.isdir(ensemble_init_path))
print(os.path.isdir(truth_path))

ensemble_nc_gen = (os.path.join(ensemble_init_path, file)  for file in os.listdir(ensemble_init_path) if file.endswith('.nc'))
ensemble_nc_files = list(ensemble_nc_gen)
print(type(ensemble_nc_files))
print(len(ensemble_nc_files))
print(ensemble_nc_files[10])
print(type(ensemble_nc_files[10]))
print()
print()


ensemble_size = 5
observation_variance = 1

ensemble = EnsembleFromFiles.EnsembleFromFiles(gpu_ctx, ensemble_size, 
                                               ensemble_init_path, truth_path,
                                               observation_variance, use_lcg=True)
ensemble.configureObservations(drifterSet=[2, 10, 18], observationInterval=6)

In [None]:
print(ensemble.observations.get_num_drifters(), ensemble.getNumDrifters())


In [None]:
obs_time = (3*24+10)*60*60
depth = 230
obs = ensemble.observations.get_observation(obs_time, ensemble.mean_depth)
print(obs)


In [None]:
%%time
ensemble.stepToObservation(obs_time)

In [None]:
print(obs)
observedTrueState = ensemble.observeTrueState()
print(observedTrueState)
print(ensemble.observeTrueDrifters())

In [None]:
observed_particles = ensemble.observeParticles()
print(observed_particles)
print(obs[:,2:])

In [None]:
for p in range(ensemble.getNumParticles()):
    print(observedTrueState[:,2:] - observed_particles[p,:,:])
    
print(ensemble.getInnovations())
print(ensemble.getInnovationNorms())

In [None]:
print(ensemble.getGaussianWeight())
print(ensemble.getGaussianWeight(normalize=True))

In [None]:
for p in range(ensemble.getNumParticles()):
    eta, hu, hv = ensemble.downloadParticleOceanState(p)
    imshow3(eta, hu, hv, title='Particle ' + str(p))

In [None]:
eta, hu, hv = ensemble.downloadTrueOceanState()
imshow3(eta, hu, hv, title='Truth at time ' + str(t))
