In [1]:
import xarray as xr
import pandas as pd
import numpy as np
import os 
import timeit
import math
import time
import sparse
import dask
import dask.array as da
from dask.diagnostics import ProgressBar
from glob import glob
from os import path
import pickle
from tqdm import tqdm

from ll_Balltree import *
%run -i 'll_Balltree.py'

In [2]:
outputDir = 'data/posterior_computation_data/'
gom_masks = xr.open_dataset(outputDir + 'gom_masks.nc')

# GLOBAL CONSTANTS
MIN_LON = np.min(gom_masks['lon'].values)
MAX_LON = np.max(gom_masks['lon'].values)
MIN_LAT = np.min(gom_masks['lat'].values)
MAX_LAT = np.max(gom_masks['lat'].values)

#domain width and height (cell counts)
LAT_SIZE = gom_masks.dims['lat']
LON_SIZE = gom_masks.dims['lon']

#cell size
D_LON = gom_masks["lon"][1].values - gom_masks["lon"][0].values
D_LAT = gom_masks["lat"][1].values - gom_masks["lat"][0].values

BIN_CELL_LATS = gom_masks.bin_cell_lats.values
BIN_CELL_LONS = gom_masks.bin_cell_lons.values

MIN_LON, MAX_LON, MIN_LAT, MAX_LAT,LAT_SIZE,LON_SIZE, D_LON,D_LAT

(-97.98001098632812,
 -76.45999145507812,
 18.140000343322754,
 31.899998664855957,
 345,
 539,
 0.03997802734375,
 0.03999900817871094)

In [3]:
# Load domain_cell_tree
fileObj = open(outputDir + 'output_dict.obj', 'rb')
output = pickle.load(fileObj)
fileObj.close()

In [4]:
n_cell_beaching = output['n_cell_beaching']
n_cell_source = output['n_cell_source']
n_window_beaching = output['n_window_beaching'] 
n_window_source = output['n_window_source']

particle_count = output['particle_count']

beaching_cells = output['beaching_cells'] 
beaching_cell_tree = output['beaching_cell_tree']

source_cell_mask = output['source_cell_mask']
source_cells = output['source_cells']
source_cell_tree = output['source_cell_tree']

beaching_windows = output['beaching_windows']
source_windows = output['source_windows']
d = output['d']
beaching_ym_mat = output['beaching_ym_mat']
source_ym_mat = output['source_ym_mat']
n_cell_beaching,n_cell_source,n_window_beaching, n_window_source

(188, 114024, 30, 36)

Compute Posterior Quantities

In [5]:
ds = xr.open_zarr(outputDir + 'nn_post.zarr')

In [6]:
# Compute normalized posteriors
l_n_post = np.zeros(((n_cell_source*(d+1)), (n_cell_beaching*n_window_beaching)))
r_n_post = np.zeros(((n_cell_source*(d+1)), (n_cell_beaching*n_window_beaching)))
f_n_post = np.zeros(((n_cell_source*(d+1)), (n_cell_beaching*n_window_beaching)))

for beaching_time in tqdm(range(n_window_beaching)):
    
    possible_source_time = np.arange((beaching_time),(beaching_time+(d+1)))

    post_beaching_idx = slice(((beaching_time)*n_cell_beaching), ((beaching_time+1)*n_cell_beaching))

    for j in range((d+1)):
        
        post_source_idx = slice( ((j)*n_cell_source), ((j+1)*n_cell_source))
        ll_source_idx = slice((n_cell_source*(possible_source_time[j])), (n_cell_source*(possible_source_time[j] +1))) 
        ll_beaching_idx = slice(((j)*n_cell_beaching), ((j+1)*n_cell_beaching))
        
        l_n_post[post_source_idx, post_beaching_idx] =  ds.l_nn_post[ll_source_idx,ll_beaching_idx]
        r_n_post[post_source_idx, post_beaching_idx] =  ds.r_nn_post[ll_source_idx,ll_beaching_idx]
        f_n_post[post_source_idx, post_beaching_idx] =  ds.f_nn_post[ll_source_idx,ll_beaching_idx]

        
# n_post_norm = np.sum(l_n_post, axis = 0) +  np.sum(r_n_post, axis = 0) + np.sum(f_n_post, axis = 0)

# l_n_post  = np.nan_to_num(l_n_post / n_post_norm)
# r_n_post  = np.nan_to_num(r_n_post / n_post_norm)
# f_n_post  = np.nan_to_num(f_n_post / n_post_norm)

100%|██████████| 30/30 [01:55<00:00,  3.86s/it]


In [7]:
n_post = xr.Dataset(
    {
        'l_n_post': (['dsource_windows_cells', 'beaching_window_cells' ], l_n_post),
        'r_n_post': (['dsource_windows_cells', 'beaching_window_cells' ], r_n_post),
        'f_n_post': (['dsource_windows_cells', 'beaching_window_cells' ], f_n_post),
    },
    coords={
        'dsource_windows_cells': np.arange(n_cell_source*(d+1)),
        'beaching_window_cells': np.arange((n_cell_beaching*n_window_beaching)),
    },
)

In [8]:
chunksize = {'dsource_windows_cells': n_cell_source,
        'beaching_window_cells': n_cell_beaching}

print('re-chunking')
tic=time.time()
n_post = n_post.chunk(chunksize)
print('   done in',time.time()-tic)

re-chunking
   done in 114.46145153045654


In [9]:
delayed_obj = n_post.to_zarr(outputDir + "nn_post_reshaped.zarr", compute=False)

with ProgressBar():
    results = delayed_obj.compute()

[########################################] | 100% Completed | 83.05 s


In [10]:
# import shutil
# shutil.rmtree(outputDir + 'nn_post_reshaped.zarr', ignore_errors=True)