In [1]:
from getpass import getuser # Libaray to copy things
from pathlib import Path # Object oriented libary to deal with paths
import os
from tempfile import NamedTemporaryFile, TemporaryDirectory # Creating temporary Files/Dirs
from subprocess import run, PIPE
import sys
 
import dask # Distributed data libary
from dask_jobqueue import SLURMCluster # Setting up distributed memories via slurm
from distributed import Client, progress, wait # Libaray to orchestrate distributed resources

import xarray as xr # Libary to work with labeled n-dimensional data and dask
import numpy as np
import skimage.util as sutil
import matplotlib.pyplot as plt

# sys.path.insert(0, os.path.abspath('/home/mpim/m300414/phd/'))
from org_metrics import Pairs, gen_regionprops_objects_all, gen_shapely_objects_all, gen_tuplelist
from org_metrics import radar_organisation_metric, avg_area, lower_rom_limit

In [2]:
import warnings
warnings.filterwarnings(action='ignore')

In [3]:
# Set some user specific variables
scratch_dir = Path('/scratch') / getuser()[0] / getuser() # Define the users scratch dir

# Create a temp directory where the output of distributed cluster will be written to, after this notebook
# is closed the temp directory will be closed
dask_tmp_dir = TemporaryDirectory(dir=scratch_dir, prefix='rome_')
cluster = SLURMCluster(memory='500GiB',
                       cores=72,
                       project='mh0731',
                       walltime='01:15:00',
                       queue='gpu',
                       name='rome',
                       scheduler_options={'dashboard_address': ':12435'},
                       local_directory='/home/mpim/m300414/phd/Notebooks/',
                       job_extra=[f'-J rome', 
                                  f'-D /home/mpim/m300414/phd/Notebooks/',
                                  f'--begin=now',
                                  f'--output={dask_tmp_dir.name}/LOG_cluster.%j.o',
                                  f'--output={dask_tmp_dir.name}/LOG_cluster.%j.o'
                                 ],
                       interface='ib0')

cluster.scale(jobs=2) # requests whole nodes
dask_client = Client(cluster)
dask_client.wait_for_workers(18) # gpu-partition has 9 workers per node

In [4]:
data_path = Path('/work/mh0731/m300414/DyWinter_b10/Fake_Steiner/')
glob_pattern_2d = 'bool_*[0-9]_14mmhour.nc'
 
# Collect all file names with pathlib's rglob and list compressions 
file_names = sorted([str(f) for f in data_path.rglob(f'{glob_pattern_2d}')])
file_names

['/work/mh0731/m300414/DyWinter_b10/Fake_Steiner/bool_20200131T0000_14mmhour.nc',
 '/work/mh0731/m300414/DyWinter_b10/Fake_Steiner/bool_20200201T0000_14mmhour.nc',
 '/work/mh0731/m300414/DyWinter_b10/Fake_Steiner/bool_20200202T0000_14mmhour.nc',
 '/work/mh0731/m300414/DyWinter_b10/Fake_Steiner/bool_20200203T0000_14mmhour.nc',
 '/work/mh0731/m300414/DyWinter_b10/Fake_Steiner/bool_20200204T0000_14mmhour.nc',
 '/work/mh0731/m300414/DyWinter_b10/Fake_Steiner/bool_20200205T0000_14mmhour.nc',
 '/work/mh0731/m300414/DyWinter_b10/Fake_Steiner/bool_20200206T0000_14mmhour.nc',
 '/work/mh0731/m300414/DyWinter_b10/Fake_Steiner/bool_20200207T0000_14mmhour.nc',
 '/work/mh0731/m300414/DyWinter_b10/Fake_Steiner/bool_20200208T0000_14mmhour.nc',
 '/work/mh0731/m300414/DyWinter_b10/Fake_Steiner/bool_20200209T0000_14mmhour.nc',
 '/work/mh0731/m300414/DyWinter_b10/Fake_Steiner/bool_20200210T0000_14mmhour.nc',
 '/work/mh0731/m300414/DyWinter_b10/Fake_Steiner/bool_20200211T0000_14mmhour.nc',
 '/work/mh0731/m

In [5]:
def rome_per_domain(domain):
    
    # the generators expect time-dimension to loop over. We don't have it, so put list around domain.
    # the generators return a list each time, so only take [0]-element to avoid single list in list.
    objects_as_regionprop = list(gen_regionprops_objects_all([domain]))[0]
    objects_as_shapely    = list(gen_shapely_objects_all    ([domain]))[0]
    
    
    pairs_regionprop = Pairs(
        pairlist=list(gen_tuplelist(objects_as_regionprop))
    )
    
    pairs_shapely    = Pairs(
        pairlist=list(gen_tuplelist(objects_as_shapely))
    )

    return radar_organisation_metric(s_pairs=pairs_shapely, r_pairs=pairs_regionprop)

In [6]:
def area_number_per_domain(domain):
    
    # the generators expect time-dimension to loop over. We don't have it, so put list around domain.
    # the generators return a list each time, so only take [0]-element to avoid single list in list.
    objects_as_regionprop = list(gen_regionprops_objects_all([domain]))[0]

    return avg_area(objects_as_regionprop), len(objects_as_regionprop)

In [7]:
def low_rome_limit_per_domain(domain):
    
    # the generators expect time-dimension to loop over. We don't have it, so put list around domain.
    # the generators return a list each time, so only take [0]-element to avoid single list in list.
    objects_as_regionprop = list(gen_regionprops_objects_all([domain]))[0]

    return lower_rom_limit(objects_as_regionprop)

In [8]:
@dask.delayed
def slide_domain_over_tropics(classifier):
    

    domain_size = (117, 117)
    assert domain_size[0]     == domain_size[1] # domain is quadratic
    assert domain_size[0] % 2 == 1              # number of pixels is not even
    stride_between_domains = domain_size[0] // 2 + 1
    
    radar_domains = sutil.view_as_windows(
        np.array(classifier),
        window_shape=domain_size,
        step=stride_between_domains
    )
    
    # define the array to contain ROME-values
    map_shape = radar_domains.shape[:2]
    mid_point = domain_size[0] // 2
    latitude  = classifier['lat'][mid_point :: stride_between_domains][:map_shape[0]]
    longitude = classifier['lon'][mid_point :: stride_between_domains][:map_shape[1]]
    metric_map = xr.DataArray(
        np.zeros(shape=map_shape),
        coords={'lat': latitude, 'lon': longitude},
        dims=('lat', 'lon')
    )
    
    metric_map_2 = xr.zeros_like(metric_map)
    
    for i in range(map_shape[0]):
        for j in range(map_shape[1]):
            # metric_map[i, j] = rome_per_domain( radar_domains[i, j, :, :] )
            # mean_object_area, object_count = area_number_per_domain( radar_domains[i, j, :, :] )
            # metric_map  [i, j] = mean_object_area
            # metric_map_2[i, j] = object_count
            metric_map[i, j] = low_rome_limit_per_domain( radar_domains[i, j, :, :] )
                    
    return metric_map #, metric_map_2

In [9]:
fakesteiner = xr.open_mfdataset(file_names)['conv_rain_class']
    
# parallelisation on time level
map_singletime = []
for t in fakesteiner.time:
    map_singletime.append( slide_domain_over_tropics(fakesteiner.sel(time=str(t.values)) ))

In [10]:
jobs = dask.persist(map_singletime)
progress(jobs, notebook=False)

[########################################] | 100% Completed | 27min 10.5s

In [11]:
rom_low = xr.concat(dask.compute(*map_singletime), dim=fakesteiner.time)

In [21]:
rom_low.name = 'r_ni'
rom_low *= 6.25
rom_low.attrs['units'] = 'km^2'
rom_low.attrs['long_name'] = 'Non-interacting ROME across (117*2.5)x(117*2.5) km domain.'
rom_low.attrs['convective_threshold'] = 'Convective pixels > 14 mm/hour.'

In [12]:
# tuplelist = dask.compute(*map_singletime)
# map1_singletime, map2_singletime = list(zip(*tuplelist))
# area = xr.concat(map1_singletime, dim=fakesteiner.time)
# number = xr.concat(map2_singletime, dim=fakesteiner.time)

In [13]:
# area.name='o_area'
# area *= 6.25
# area.attrs['units'] = 'km^2'
# area.attrs['long_name'] = 'Object mean area (117*2.5)x(117*2.5) km domain.'
# area.to_netcdf('/work/mh0731/m300414/DyWinter_b10/o_area_14mmhour.nc')

In [14]:
# number.name='o_number'
# number.attrs['units'] = '1'
# number.attrs['long_name'] = 'Number of objects in (117*2.5)x(117*2.5) km domain.'
# number.to_netcdf('/work/mh0731/m300414/DyWinter_b10/o_number_14mmhour.nc')

In [15]:
# rome = xr.concat(dask.compute(*map_singletime), dim=fakesteiner.time)

In [16]:
# rome.name = 'rome'
# rome *= 6.25
# rome.attrs['units'] = 'km^2'
# rome.attrs['long_name'] = 'ROME across (117*2.5)x(117*2.5) km.'
# rome.attrs['convective_threshold'] = 'Convective pixels > 7 mm/hour.'

In [17]:
# rome.to_netcdf('/work/mh0731/m300414/DyWinter_b10/ROME/rome_7mmhour.nc')