```
This notebook sets up and runs a set of benchmarks to compare
different numerical discretizations of the SWEs

Copyright (C) 2016  SINTEF ICT

This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program.  If not, see <http://www.gnu.org/licenses/>.
```

# MPI Skeleton for basic particle filter with SIR




## Set environment

In [None]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import numpy as np
import matplotlib
from matplotlib import pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib import animation, rc
from scipy.special import lambertw

import os
import sys

sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), '../')))

#Set large figure sizes
rc('figure', figsize=(16.0, 12.0))
rc('animation', html='html5')
matplotlib.rcParams['contour.negative_linestyle'] = 'solid'

#Import our simulator
from SWESimulators import CDKLM16, PlotHelper, Common, IPythonMagic

from SWESimulators import BathymetryAndICs as BC
from SWESimulators import OceanStateNoise
from SWESimulators import OceanNoiseEnsemble
from SWESimulators import BaseOceanStateEnsemble
from SWESimulators import GPUDrifterCollection
from SWESimulators import DataAssimilationUtils as dautils


In [None]:
%cuda_context_handler gpu_ctx

# Ensemble

Some basic assumptions and notes on what has to be redesigned in the future:


### Local vs global ensemble
The local ensemble is the ensemble owned by one process. The global ensemble is the combined ensemble from all processes. Only the "master" (rank 0) will care about the global ensemble in this skeleton.

### Code sections requiring MPI is clearly marked
Clearly marked with 
```
# MPI START -----------------------------
if self.rank == 0:
    ...
# MPI END -----------------------------

```
The keyword `pass` is used where there currently would be an empty block (or just a comment), so that it is valid python code.

### Initial states

The ensemble is initialized based on one simulator instance, and the ocean state in this simulator instance is completely random. The initialization method therefore requires some exchange of data so that all MPI proccesses has the same initial data to create a "local" ensemble from. 

My suggestion is that rank 0 will share its initialization input (eta, hu, hv) with the other proccesses before entering the loop that creates the ensemble members. 

**In the future** we will start the ensemble based on predefined set of initial states (e.g. large ensemble created from a small ensemble of ROMS models), and therefore the initialization will need to be different in the future.


### Initialization of the ensemble
This has been a mess in the class OceanStateEnsemble class. In the below skeleton I've tried to clean up this, so that the standard constructor is the only required function.

I've removed the option to choose between different observation operators, etc. This can be re-introduced later.

### Short cuts where there should have been some more options
We assume 
* Direct observation of underlying flow at drifter positions only
* CDKLM simulator only
* periodic boundary conditions only
* Total num particles **MUST** be a multiple of number of MPI processes (in the below version)


### Syntethic truth
The truth should only be a model realization in the master process. Currently all processes has a truth model, so that as much of the old code can be reused. For all processes with rank > 0, the simulator in self.particles[self.obs_index] should never be used.


### Too much communication
Because I want to be able to copy-paste as much code as possible, this version will most likely contain too much communication (or communication of information that has already been communicated before). This can be fixed in the future, since the most important thing in short term is correctness.

### Synchronous sends only
All communication showed below assumes synchronous sends. Nothing else happens before a send/receive is complete.


### No resampling of drifter positions 
The only drifters that are used for anything are the drifters in the syntethic truth. We keep the drifters in all the particles also (in fear of breaking something), but we ignore them. There are therefore no reason to why we should resample drifter positions, and therefore we don't.  

# How to follow the code?

There are very few functions of MPIOceanEnsemble that are used from the outside.

**First of all**, it is **the constructor** of course. It uses no other functions

**Step** is used to run the ensemble forward in time.

**And then**, the functions **getGaussianWeights** and **resample** is called **via the function DataAssimilationUtilities.residualSampling** function.

That's it.

In [None]:
class MPIOceanEnsemble():
    
    """
    Initilization of ensemble - lots of boiler plate code to read parameters etc...
    """
    def __init__(self, gpu_ctx, numParticles, sim, \
                 driftersPerOceanModel=1, \
                 observation_variance = None, \
                 small_scale_perturbation_amplitude = 0.0, \
                 initialization_variance_factor_ocean_field = 0.0):
    
        self.gpu_ctx = gpu_ctx
        self.numParticles = numParticles
        self.driftersPerOceanModel = driftersPerOceanModel
                
        # MPI START -----------------------------
        # Obtain my rank and number of other processes
        self.rank = 0
        self.num_procs = 1
        
        self.local_numParticles = self.numParticles/self.num_procs 
        # What if num_procs is not a factor of numParticles? 
        
        # MPI END -----------------------------
        
        self.particles = [None]*(self.numParticles + 1)
        self.obs_index = self.numParticles
        
        self.simType = 'CDKLM16'
        
        self.t = 0.0
    
        self.observation_type = dautils.ObservationType.DirectUnderlyingFlow
        self.prev_observation = None
        
        
        #-------------------------------------------------        
        ### ----- Stochastic parameters
        #-------------------------------------------------
        
        self.observation_variance = observation_variance
        if self.observation_variance is None:
            # Just setting something related to drifter velocity
            self.observation_variance = 0.01**2
        
        # Build observation covariance matrix:
        self.observation_cov = None
        self.observation_cov_inverse = None
        if np.isscalar(self.observation_variance):
            self.observation_cov = np.eye(2)*self.observation_variance
            self.observation_cov_inverse = np.eye(2)*(1.0/self.observation_variance)
        else:
            # Assume that we have a correctly shaped matrix here
            self.observation_cov = self.observation_variance
            self.observation_cov_inverse = np.linalg.inv(self.observation_cov)
         
        self.small_scale_perturbation_amplitude = small_scale_perturbation_amplitude
    
        # When initializing an ensemble, each member should be perturbed so that they 
        # have slightly different starting point.
        # This factor should be multiplied to the small_scale_perturbation_amplitude for that 
        # perturbation
        self.initialization_variance_factor_ocean_field = initialization_variance_factor_ocean_field
        
        
        
        #-------------------------------------------------        
        ### ---- -Read parameters from sim into self.*
        #-------------------------------------------------
        
        self.nx = sim.nx
        self.ny = sim.ny
        self.dx = sim.dx
        self.dy = sim.dy
        self.dt = sim.dt
        self.g = sim.g
        self.f = sim.f
        self.beta = sim.coriolis_beta
        self.r = sim.r
        self.wind = sim.wind_stress
        self.boundaryConditions = sim.boundary_conditions
        self.ghostCells = np.array([2,2,2,2])
        
        self.dataShape =  ( self.ny + self.ghostCells[0] + self.ghostCells[2], 
                            self.nx + self.ghostCells[1] + self.ghostCells[3]  )
        
        self.base_eta, self.base_hu, self.base_hv = sim.download(interior_domain_only=False)
        self.base_H = sim.downloadBathymetry()[0]
        
        # MPI START -----------------------------
        # sim's ocean state is random, and we must make rank 0 send its valid state to
        # all the other processes.
        if self.rank == 0:
            for p in range(1, self.num_procs):
                #mpisend self.base_eta of size self.dataShape to rank p
                #mpisend self.base_hu  of size self.dataShape to rank p
                #mpisend self.base_hv  of size self.dataShape to rank p
                pass
        else:
            #mpirecieve self.base_eta of size self.dataShape from rank 0
            #mpirecieve self.base_hu  of size self.dataShape from rank 0
            #mpirecieve self.base_hv  of size self.dataShape from rank 0
            pass
        # MPI END -----------------------------
        
        
        
        #-------------------------------------------------
        ### ---- Create the local ensemble
        #-------------------------------------------------
        
        # MPI START -----------------------------
        # Mapping of particle indices from local ensemble to global ensemble
        self.global_particle_indices = np.array(range(self.rank, \
                                                      self.rank + self.local_numParticles))
        # MPI END -----------------------------
        
        # Define mid-points for the different drifters 
        # Decompose the domain, so that we spread the drifters as much as possible
        sub_domains_y = np.int(np.round(np.sqrt(self.driftersPerOceanModel)))
        sub_domains_x = np.int(np.ceil(1.0*self.driftersPerOceanModel/sub_domains_y))
        self.midPoints = np.empty((driftersPerOceanModel, 2))
        for sub_y in range(sub_domains_y):
            for sub_x in range(sub_domains_x):
                drifter_id = sub_y*sub_domains_x + sub_x
                if drifter_id >= self.driftersPerOceanModel:
                    break
                self.midPoints[drifter_id, 0]  = (sub_x + 0.5)*self.nx*self.dx/sub_domains_x
                self.midPoints[drifter_id, 1]  = (sub_y + 0.5)*self.ny*self.dy/sub_domains_y
              
        
        for i in range(self.local_numParticles+1):
            self.particles[i] = CDKLM16.CDKLM16(self.gpu_ctx, \
                                                self.base_eta, self.base_hu, self.base_hv, \
                                                self.base_H, \
                                                self.nx, self.ny, self.dx, self.dy, self.dt, \
                                                self.g, self.f, self.r, \
                                                boundary_conditions=self.boundaryConditions, \
                                                write_netcdf=False, \
                                                small_scale_perturbation=True, \
                                                small_scale_perturbation_amplitude=self.small_scale_perturbation_amplitude)
            
            if self.initialization_variance_factor_ocean_field != 0.0:
                self.particles[i].perturbState(q0_scale=self.initialization_variance_factor_ocean_field)
            
            drifters = GPUDrifterCollection.GPUDrifterCollection(self.gpu_ctx, driftersPerOceanModel,
                                                                 observation_variance=self.observation_variance,
                                                                 boundaryConditions=self.boundaryConditions,
                                                                 domain_size_x=self.nx*self.dx, domain_size_y=self.ny*self.dy)
            
            drifters.setDrifterPositions(self.midPoints)
            self.particles[i].attachDrifters(drifters)
   
    def cleanUp(self):
        for oceanState in self.particles:
            if oceanState is not None:
                oceanState.cleanUp()
    
    def step(self, sub_t):
        """
        Function which makes all particles step until time t.
        apply_stochastic_term: Boolean value for whether the stochastic
            perturbation (if any) should be applied.
        """
        for p in self.particles:
            self.t = p.step(sub_t)
        return self.t
    
    
    def observeTrueState(self):
        """
        Applying the observation operator on the syntetic true state.

        Returns a numpy array with D drifter positions and drifter velocities
        [[x_1, y_1, u_1, v_1], ... , [x_D, y_D, u_D, v_D]]
        
        Only rank 0 obtains the true state and spreads them to the other ranks
        
        MPI: All processes MUST call this function similtaneously 
        """

        # MPI START -----------------------------
        if self.rank == 0:
            trueDrifterPositions = self.particles[self.obs_index].drifters.getDrifterPositions()
            
            trueState = np.empty((self.driftersPerOceanModel, 4))
            
            for d in range(self.driftersPerOceanModel):
                x = trueDrifterPositions[d,0]
                y = trueDrifterPositions[d,1]
                id_x = np.int(np.floor(x/self.dx))
                id_y = np.int(np.floor(y/self.dy))

                # Skipping interpolation
                depth = self.particles[self.obs_index].downloadBathymetry()[1][id_y, id_x]

                # Downloading ocean state without ghost cells
                eta, hu, hv = self.particles[self.obs_index].download(interior_domain_only=True)
                u = hu[id_y, id_x]/(depth + eta[id_y, id_x])
                v = hv[id_y, id_x]/(depth + eta[id_y, id_x])

                trueState[d,:] = np.array([x, y, u, v])
                
            # Share true state with other processes:
            # mpisend trueState, size (driftersPerOceanModel, 4) to rank p 
        
        else:
            # trueState = mpirecieve trueState, size (driftersPerOceanModel, 4) from rank 0
            pass
        
        return trueState    
        
        
        
    def observeParticles(self):
        """
        Applying the observation operator on each particle.

        Structure on the output:
        [
        particle 1:  [u_1, v_1], ... , [u_D, v_D],
        particle 2:  [u_1, v_1], ... , [u_D, v_D],
        particle Ne: [u_1, v_1], ... , [u_D, v_D]
        ]
        numpy array with dimensions (particles, drifters, 2)
        
        MPI: All processes MUST call this function similtaneously
        """
        local_observedState = np.empty((self.local_numParticles, \
                                        self.driftersPerOceanModel, 2))

        trueState = self.observeTrueState()
        # trueState structure: [[x1, y1, u1, v1], ..., [xD, yD, uD, vD]]

        for p in range(self.local_numParticles):
            # Downloading ocean state without ghost cells
            Hi = self.particles[p].downloadBathymetry()[1]
            eta, hu, hv = self.particles[p].download(interior_domain_only=True)

            for d in range(self.driftersPerOceanModel):
                id_x = np.int(np.floor(trueState[d,0]/self.dx))
                id_y = np.int(np.floor(trueState[d,1]/self.dy))

                depth = Hi[id_y, id_x]
                local_observedState[p,d,0] = hu[id_y, id_x]/(depth + eta[id_y, id_x])
                local_observedState[p,d,1] = hv[id_y, id_x]/(depth + eta[id_y, id_x])
        
        
        # MPI START -----------------------------
        # Gather all observed states on rank 0
        
        if self.rank == 0:
            observedState = np.empty((self.numParticles, \
                                      self.driftersPerOceanModel, 2))

            observedState[0:self.local_numParticles, :, :] = local_observedState
            for p in range(1, self.num_procs):
                # remote_observedState = mpireceive local_observedState, size (see above) from rank p 
                observedState[p*self.local_numParticles:(p+1)*self.local_numParticles, :, :] = remote_observedState
            
            return observedState
            
        else:
            # mpisend local_observedState, size (see above) to rank 0
            
            return local_drifterPositions
        
        ## CHALLENGE!!! All processes needs to enter functions that are structured as this one.
        #  But what should all other processes than rank 0 return???
        #  Does it matter what they return?
        
        # MPI END -----------------------------
        
    def getInnovations(self, obs=None):
        """
        Obtaining the innovation vectors, y^m - H(\psi_i^m)

        Returns a numpy array with dimensions (particles, drifters, 2)

        MPI: All processors must enter this function (even though we only care about the output from rank 0)!!!
        """
        if obs is None:
            trueState = self.observeTrueState()[:, 2:]
            # Only select the velocities, not the positions.
            
        innovations = trueState - self.observeParticles()
        return innovations
    
    def getInnovationNorms(self, obs=None):
        
        # Innovations have the structure 
        # [ particle: [drifter: [x, y] ] ], or
        # [ particle: [drifter: [u, v] ] ]
        # We simply gather find the norm for each particle:
        innovations = self.getInnovations(obs=obs)
        return np.linalg.norm(np.linalg.norm(innovations, axis=2), axis=1)
    
    
    def getGaussianWeight(self, innovations=None, normalize=True):
        """
        Calculates a weight associated to every particle, based on its innovation vector, using 
        Gaussian uncertainty for the observation.
        
        MPI: All processors must enter this function (even though we only care about the output from rank 0)!!!
        
        """

        if innovations is None:
            innovations = self.getInnovations()
            
            # MPI: Now, rank 0 will have all innovations, and rank > 0 will have their local innovations
            
        # MPI START -----------------------------
        
        ## Suggestion: We only care about rank 0, and return some non-invalid (but bogus) weights for rank > 0 
        
        if self.rank == 0:
            observationVariance = self.getObservationVariance()
            Rinv = None

            weights = np.zeros(innovations.shape[0])
            if len(innovations.shape) == 1:
                weights = (1.0/np.sqrt(2*np.pi*observationVariance))* \
                        np.exp(- (innovations**2/(2*observationVariance)))

            else:
                Ne = self.getNumParticles()
                Nd = innovations.shape[1] # number of drifters per particle
                Ny = innovations.shape[2]

                Rinv = self.observation_cov_inverse
                R = self.observation_cov

                for i in range(Ne):
                    w = 0.0
                    for d in range(Nd):
                        inn = innovations[i,d,:]
                        w += np.dot(inn, np.dot(Rinv, inn.transpose()))

                    ## TODO: Restructure to do the normalization before applying
                    # the exponential function. The current version is sensitive to overflows.
                    weights[i] = (1.0/((2*np.pi)**Nd*np.linalg.det(R)**(Nd/2.0)))*np.exp(-0.5*w)
            if normalize:
                return weights/np.sum(weights)
            return weights
        
        else: # if rank > 0
            return np.ones(innovations.shape[0])/(1.0*self.local_numParticles)
        
        # MPI END  -----------------------------

        
    def resample(self, newSampleIndices, reinitialization_variance):
        """
        Resampling the particles given by the newSampleIndicies input array.
        Here, the reinitialization_variance input is ignored, meaning that exact
        copies only are resampled.
        
        
        
        MPI: MUST be called by all processes.
        newSampleIndices is a valid input for rank 0 only. For rank > 0 it is only bogus.
        
        MPI: It is very important that we don't overwrite particles that still needs to be copied.
        Here is a highly stupid, brute force resampling scheme, but it should be safe as long as we
        don't run out of memory on a single node...
        """
        newOceanStates = [None]*self.getNumParticles()
                
        # MPI START -----------------------------
        
        # Create an array containing the process id in charge of each global particle
        particle_owner = [None]*self.getNumParticles()
        for p in range(self.getNumParticles()):
            particle_owner[p] = p//self.local_numParticles  # integer division

            
        # Share the global ID's to all processes
        if self.rank == 0:
            for p in range(self.num_procs):
                # mpisend newSampleIndices size (numParticles) to rank p 
                pass
        else:
            #mpirecv newSampleIndices size (numParticles) from rank 0
            pass
        
        
        # Resample the all particles onto rank 0
        for i in range(self.getNumParticles()):
            index = newSampleIndices[i]
            owner = particle_owner[index]
            
            # Send state index to rank 0 from rank owner:
            if self.rank == 0:
                if owner == 0:
                    eta0, hu0, hv0 = self.particles[index].download()
                    eta1, hu1, hv1 = self.particles[index].downloadPrevTimestep()
                    newOceanStates[i] = (eta0, hu0, hv0, eta1, hu1, hv1)
                else:
                    #mpi_recv eta0, hu0, hv0, eta1, hu1, hv1 from rank owner
                    # all these arrays should have size self.dataShape (interior + ghost cells)
                    newOceanStates[i] = (eta0, hu0, hv0, eta1, hu1, hv1)
                    pass
                
            
            elif self.rank == owner:
                # Download index's ocean state:
                eta0, hu0, hv0 = self.particles[index].download()
                eta1, hu1, hv1 = self.particles[index].downloadPrevTimestep()
                
                #mpi_send eta0, hu0, hv0, eta1, hu1, hv1 to rank 0
                # all these arrays should have size self.dataShape (interior + ghost cells)

                
                
        # New loop for transferring the correct ocean states back up to its owner, and to the GPU
        for i in range(self.getNumParticles()):
            owner = particle_owner[i]
            
            if self.rank == 0:
                if owner == 0:
                    self.particles[i].upload(newOceanStates[i][0],
                                             newOceanStates[i][1],
                                             newOceanStates[i][2],
                                             newOceanStates[i][3],
                                             newOceanStates[i][4],
                                             newOceanStates[i][5])
                else:
                    # mpi_send newOceanStates[i][0:5] to rank owner. 
                    # Each array size self.dataShape
                    pass
            elif self.rank == owner:
                    # mpi_recv eta0, hu0, hv0, eta1, hu1, hv1 from rank 0.
                    # Each array size self.dataShape
                    local_index = i - self.rank*self.local_numParticles
                    self.particles[local_index].upload(eta0, hu0, hv0,
                                                       eta1, hu1, hv1)
  

        # MPI END -----------------------------

                    
    def getDomainSizeX(self):
        return self.nx*self.dx
    def getDomainSizeY(self):
        return self.ny*self.dy
    def getObservationVariance(self):
        return self.observation_variance
    def getNumParticles(self):
        return self.numParticles
    
    
    def plotDistanceInfo(self, title=None):
        """
        MPI: Only rank 0 creates a figure. The others processors just help out with producing 
        required information
        """
        
        # All processes helps out with gathering info
        innovations = self.getInnovationNorms()
        obs_var = self.getObservationVariance()
        range_x = np.sqrt(obs_var)*20
        x = np.linspace(0, range_x, num=100)
        gauss_pdf = self.getGaussianWeight(x, normalize=False)
        gaussWeights = self.getGaussianWeight()
        
        # Only rank 0 creates a figure:
        fig = None
        if self.rank==0:
            plotRows = 2
            fig = plt.figure(figsize=(10, 6))
            gridspec.GridSpec(plotRows, 3)


            # PLOT DISCTRIBUTION OF PARTICLE DISTANCES AND THEORETIC OBSERVATION PDF
            ax0 = plt.subplot2grid((plotRows,3), (0,0), colspan=3)
            range_x = np.sqrt(obs_var)*20

            # With observation 
            x = np.linspace(0, range_x, num=100)
            plt.plot(x, gauss_pdf, 'g', label="pdf directly from innovations")
            plt.legend()
            plt.title("Distribution of particle innovations")

            #hisograms:
            ax1 = ax0.twinx()
            ax1.hist(innovations, bins=30, \
                     range=(0, range_x),\
                     normed=True, label="particle innovations (norm)")

            # PLOT SORTED DISTANCES FROM OBSERVATION
            ax0 = plt.subplot2grid((plotRows,3), (1,0), colspan=3)
            indices_sorted_by_observation = innovations.argsort()
            ax0.plot(gaussWeights[indices_sorted_by_observation]/np.max(gaussWeights),\
                     'g', label="Weight directly from innovations")
            ax0.set_ylabel('Weights directly from innovations', color='g')
            ax0.grid()
            ax0.set_ylim(0,1.4)
            #plt.legend(loc=7)
            ax0.set_xlabel('Particle ID')

            ax1 = ax0.twinx()
            ax1.plot(innovations[indices_sorted_by_observation], label="innovations")
            ax1.set_ylabel('Innovations', color='b')

            plt.title("Sorted distances from observation")

            if title is not None:
                plt.suptitle(title, fontsize=16)
            #plt.tight_layout()
            
        return fig

## Create the ensemble:

1. Set parameters
2. Create the simulator from which all ensemble members should be based on
3. Create the ensemble 
4. Define observation times and run simulation


In [None]:
#-------------------------------------------------------------------------------
# 1) Set parameters
#-------------------------------------------------------------------------------

nx = 40
ny = 40

dx = 4.0
dy = 4.0

dt = 0.05
g = 9.81
r = 0.0

f = 0.05
beta = 0.0

ensemble_size = 40
drifters = 3

#-------------------------------------------------------------------------------
# 2) Create the simulator from which all ensemble members should be based on
#-------------------------------------------------------------------------------

ghosts = np.array([2,2,2,2]) # north, east, south, west
validDomain = np.array([2,2,2,2])
boundaryConditions = Common.BoundaryConditions(2,2,2,2)

# Define which cell index which has lower left corner as position (0,0)
x_zero_ref = 2
y_zero_ref = 2

dataShape = (ny + ghosts[0]+ghosts[2], 
             nx + ghosts[1]+ghosts[3])
dataShapeHi = (ny + ghosts[0]+ghosts[2]+1, 
             nx + ghosts[1]+ghosts[3]+1)

eta0 = np.zeros(dataShape, dtype=np.float32, order='C');
hv0 = np.zeros(dataShape, dtype=np.float32, order='C');
hu0 = np.zeros(dataShape, dtype=np.float32, order='C');
waterDepth = 10.0
Hi = np.ones(dataShapeHi, dtype=np.float32, order='C')*waterDepth


if 'sim' in globals():
    sim.cleanUp()
if 'ensemble' in globals():
    ensemble.cleanUp()

# Choose a suitable amplitude for the model error.
# This expression does not make sense (dimensionwise), but it gives a number
# that fits well with all the other numbers (:
q0 = 0.5*dt*f/(g*waterDepth)

sim = CDKLM16.CDKLM16(gpu_ctx, eta0, hu0, hv0, Hi, \
                      nx, ny, dx, dy, dt, g, f, r, \
                      boundary_conditions=boundaryConditions, \
                      write_netcdf=False, \
                      small_scale_perturbation=True, \
                      small_scale_perturbation_amplitude=q0)

# Create a random initial state 
sim.perturbState(q0_scale=100)


#-------------------------------------------------------------------------------
# 3) Create the ensemble
#-------------------------------------------------------------------------------

ensemble = MPIOceanEnsemble(gpu_ctx, ensemble_size, sim, \
                            driftersPerOceanModel=drifters, \
                            observation_variance = 0.02**2, \
                            small_scale_perturbation_amplitude=q0, \
                            initialization_variance_factor_ocean_field=50)

#print "ensemble.observeTrueState()", ensemble.observeTrueState()
#print "ensemble.observeParticles()", ensemble.observeParticles()
#print "ensemble.getInnovations()", ensemble.getInnovations()
#print "ensemble.getGaussianWeight()", ensemble.getGaussianWeight()



#-------------------------------------------------------------------------------
# 4) Define observation times and run simulation
#    Here, we store a plot before and after each observation time step
#    in order to inspect the results.
#-------------------------------------------------------------------------------

T = 50
sub_t = 10*dt
resampling_points = range(5, 100, 10)
print "Will resample at iterations: ", resampling_points
infoPlots = []

# Run particle filter:
for it in range(T):
    t = ensemble.step(sub_t)
    
    # Check if we are at an observation
    for rp in resampling_points:
        if it == rp:
            print "resampling at iteration " + str(it)
            infoFig = ensemble.plotDistanceInfo(title="it = " + str(it) + " before resampling")
            if ensemble.rank == 0:
                plt.close(infoFig)
                infoPlots.append(infoFig)
            
            dautils.residualSampling(ensemble)
            
            infoFig = ensemble.plotDistanceInfo(title="it = " + str(it) + " post resampling")
            if ensemble.rank == 0:
                plt.close(infoFig)
                infoPlots.append(infoFig)
    
    if (it%10 == 0):
        print "{:03.0f}".format(100*it / T) + " % => t=" + str(t) 

print "Done"

In [None]:
# Inspect results
def show_figures(figs):
    for f in figs:
        dummy = plt.figure()
        new_manager = dummy.canvas.manager
        new_manager.canvas.figure = f
        f.set_canvas(new_manager.canvas)
        filename= f._suptitle.get_text().replace(" ", "_").replace("=_", "") + ".png"
        #plt.savefig(filename)
        
if ensemble.rank == 0:
    show_figures(infoPlots)
    fig = ensemble.plotDistanceInfo(title="Final ensemble")