In [1]:
import numpy as np
import pyart
import os
from sklearn.mixture import GaussianMixture
import pickle
from netCDF4 import num2date, date2num
import math
from multiprocessing import Pool
import traceback


## You are using the Python ARM Radar Toolkit (Py-ART), an open source
## library for working with weather radar data. Py-ART is partly
## supported by the U.S. Department of Energy as part of the Atmospheric
## Radiation Measurement (ARM) Climate Research Facility, an Office of
## Science user facility.
##
## If you use this software to prepare a publication, please cite:
##
##     JJ Helmus and SM Collis, JORS 2016, doi: 10.5334/jors.119



In [10]:
### Set key parameters
dirr = './raw_data/Mt Bolton/radar_hdf'
#Set True to test (on single volume)
test = False

model_fn = './models/Mt_Bolton_altitude_4_12.gmm'

firename = 'MtBolton'
name_of_plot_run = 'GMM_MtBolton_height'

gfilt = True

tilt = 0

#Exclude values above this SNR value
snr_cutoff = 2
#Exclude values inside these values
VRADH_inside = None #[-0.5,0.5]
#Exclude values outside these values
VRADH_outside = None #[-10,10]
#Apply despeckle filter on VRADH with this minimum no of pixels
depseck_size = None #10

axislabs=('Distance from UQ-XPOL (km)', 'Height (km)')

RHIxlims, RHIylims = [5,28],[0, 7]
downadjust = 0.3

minref = 0
maxref = 35

tz =11

fields = ['DBZH','ZDR','RHOHV','WRADH','KDP','ALT']

In [14]:
### Define all functions

def load_plot_scan(x):
    vol_fn = fls[x]
    try:
        myradar = pyart.aux_io.read_odim_h5('/'.join([dirr,vol_fn]), file_field_names=True)
    except Exception as e:
        msg = 'Could not open file: ' + vol_fn
        return msg
    for k in range(len(models)):
        model_name = 'GMM_n' + str(models[k].n_components)
        outloc = '/'.join(['.','plots',name_of_plot_run,model_name])
        myradar.add_field(model_name,predict_labels(myradar,models[k],filt=gfilt))
        plot_k_and_ref(myradar,tilt,outloc,model_name,models[k].n_components)

    msg = 'Finished '+ vol_fn
    return msg


def give_gatefilter(myradar):

    try:
        myradar.check_field_exists('SNR')
    except:
        myradar.add_field('SNR',
                          pyart.retrieve.calculate_snr_from_reflectivity(myradar,
                                                                         refl_field='DBZH',
                                                                         snr_field=None,
                                                                         toa=25000.0))

    # Set gatefilters
    gatefilter = pyart.correct.GateFilter(myradar)
    if VRADH_inside is not None:
        gatefilter.exclude_inside('VRADH',VRADH_inside[0],VRADH_inside[1])
    if VRADH_outside is not None:
        gatefilter.exclude_outside('VRADH',VRADH_outside[0],VRADH_outside[1])
    if snr_cutoff is not None:
        gatefilter.exclude_below('SNR',snr_cutoff)
    if depseck_size is not None:
        gatefilter = pyart.correct.despeckle.despeckle_field(myradar,
                                                     'VRADH',
                                                     gatefilter=gatefilter,
                                                     size = depseck_size)
    return gatefilter

def predict_labels(myradar,model,filt=True):

    from pyart.config import get_metadata

    #models[4]
    #m =4

    #fields = ['DBZH','ZDR','RHOHV','WRADH','KDP','ALT']

    gatefilter = None
    pred_data = np.zeros((myradar.fields['DBZH']['data'][:].flatten().size,6))
    orig_shape = np.shape(myradar.fields['DBZH']['data'][:])

    for f in range(len(fields)):
        if fields[f] == 'ALT':
            returns = myradar.gate_z['data'][:].flatten()
            pred_data[:,f] = returns
        else:
            returns = myradar.fields[fields[f]]['data'][:].flatten()
            pred_data[:,f] = returns

        ### Do classification here, assume 1D output

    labels = np.reshape(model.predict(pred_data),orig_shape)
    ['GMM_n' + str(model.n_components)]
    #predict_labels(myradar,model[])

    GMM_field = get_metadata(myradar)
    GMM_field['data'] = labels
    GMM_field['units'] = 'NA'
    GMM_field['standard_name'] = 'GMM_n' + str(model.n_components)
    GMM_field['long_name'] = 'Labels as predict by Gaussian Mixture Model where k = ' + str(model.n_components)
    return GMM_field

def discrete_cmap(N, base_cmap=None):
    import matplotlib.pyplot as plt
    """Create an N-bin discrete colormap from the specified input map"""

    #Source: https://gist.github.com/jakevdp/91077b0cae40f8f8244a 
    
    # Note that if base_cmap is a string or None, you can simply do
    #    return plt.cm.get_cmap(base_cmap, N)
    # The following works for string, None, or a colormap instance:

    base = plt.cm.get_cmap(base_cmap)
    color_list = base(np.linspace(0, 1, N))
    cmap_name = base.name + str(N)
    return base.from_list(cmap_name, color_list, N)

def plot_k_and_ref(myradar,tilt,outloc,model_name,N_cmap):
    import matplotlib.pyplot as plt
    
#    if myradar.scan_type is 'rhi':
#        if np.round(myradar.azimuth['data'][0]) == 185.0:
#            if myradar.elevation['data'][-1] < myradar.elevation['data'][0]:
#                myradar.elevation['data'] = (myradar.elevation['data'] + 
#                                             np.full_like(myradar.elevation['data'],downadjust))


        
        
    if myradar.scan_type is 'rhi':
        
        display = pyart.graph.RadarDisplay(myradar)

        fig = plt.figure(figsize = [10,8])
        dts = num2date(myradar.time['data'][0] + tz* 60.*60., myradar.time['units'])

        if gfilt:
            gatefilter = give_gatefilter(myradar)
        else:
            gatefilter = None
        
        plt.subplot(2,1,1)
        display = pyart.graph.RadarMapDisplay(myradar)
        
        display.plot('DBZH', 0,
                     vmin=minref, vmax=maxref,
                     title_flag=False,
                     cmap = pyart.graph.cm.RRate11,
                     gatefilter = gatefilter,
                     colorbar_flag = True,
                     axislabels = axislabs,
                    )
            
        display.set_limits(RHIxlims,RHIylims)
        
            
        plt.subplot(2,1,2)
        display = pyart.graph.RadarMapDisplay(myradar)
        
        display.plot(model_name,0,
                     title_flag=False,
                     cmap = discrete_cmap(N_cmap,'pyart_NWSRef'),
                     gatefilter = gatefilter,
                     colorbar_flag = True,
                     axislabels = axislabs,
                     ticks=range(1,N_cmap+1))
        
        display.set_limits(RHIxlims,RHIylims)

        fn = firename +'_' + dts.strftime('%H%M%S')

        plt.savefig(outloc + '/' +
                    fn +
                    '_tilt_' +
                    '.png',
                    dpi=150)

        plt.close()

    return


In [15]:
if __name__ == '__main__':
    fls = os.listdir(dirr)
    fls.sort()

    if test:
        fls = [fls[0]]

    #Load GMM models
    models = pickle.load(open(model_fn, 'rb'))

    #Check and create plot directories if needed
    for k in range(len(models)):
        model_name = 'GMM_n' + str(models[k].n_components)
        outloc = '/'.join(['.','plots',name_of_plot_run,model_name])

        if not os.path.isdir(outloc):
            
            print('Directory does not exist yet, adding plot run directory: ' + outloc)
            if not os.path.isdir('/'.join(['.','plots',name_of_plot_run])):
                os.mkdir('/'.join(['.','plots',name_of_plot_run]))
            os.mkdir(outloc)
                

    print('Begin processing')
    #[load_plot_scan(valid_file) for valid_file in valid_files]

    #pool = multiprocessing.Pool()
    #pool.map(load_plot_scan, valid_files)

    p = Pool(3)
    outcomes = []
    try:
        outcomes = p.map(load_plot_scan, range(len(fls)))
    except Exception as e:
        print(e)
        traceback.print_exc()
        p.terminate()
        

Begin processing


  full_angle_rad = np.arccos(dot_product)
  edges[edges < 0] += 360     # range from [-180, 180] to [0, 360]
  reverse_xaxis = np.all(R < 1.)


In [39]:
load_plot_scan(0)

'Finished uq-xpol_rhi_20160223_033611.h5'

In [16]:
p.terminate()

In [37]:
range(1,14+1)

[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]