In [None]:
#import standard libraries
import os
import glob
from datetime import datetime

#import other libraries
import numpy as np
from cpol_processing import processing as cpol_prc
from pyhail import hsda, hdr, mesh, common
import pyart

#turn off warnings
import warnings
warnings.filterwarnings('ignore')


def retrieval_hail(odim_ffn, srtm_ffn, hsda_geotif_ffn, hdr_geotif_ffn, mesh_geotif_ffn, dbzh_offset, zdr_offset, out_alt_m):

    ###########################################################
    # Configuration
    ###########################################################
    
    #field names (used to map to radar object fields)
    fieldn      = {'dbzh':'DBZH',
                   'dbzh_corr':'DBZH_CORR',
                   'zdr':'ZDR',
                   'zdr_corr':'ZDR_CORR',
                   'phi':'PHIDP',
                   'phi_unfold':'PHI_UNF',
                   'phi_bringi':'PHIDP_BRINGI',
                   'kdp':'KDP',
                   'kdp_bringi':'KDP_BRINGI',
                   'rhv':'RHOHV',
                   'ncp':'NCP',
                   'a_dbz':'SPEC_ATT_REFL',
                   'a_zdr':'SPEC_ATT_DIFF',
                   'rhv_corr':'RHOHV_CORR',
                   'temp':'TEMPERATURE',
                   'alt':'HEIGHT',
                   'snr':'SNR',
                   'cbb':'CBB',
                   'hca':'HCA',
                   'hail_ke':'HAIL_KE',
                   'shi':'SHI',
                   'posh':'POSH',
                   'mesh':'MESH',
                   'hdr':'HDR',
                   'hsda':'HSDA',
                   'hca_hsda':'HCA_HSDA'}
    
    sounde_id  =   {'66':'94578',
                    '71':'94776',
                    '02':'94866',
                    '64':'94672'}
    
    #hsda vars
    hca_hail_idx = [9] #list of hail classe(s) indices in HCA
    hca_hsda_idx = [11, 12, 13]
    
    #continous grid
    grid_shape_cts  = (41, 201, 201)
    grid_limits_cts = ((0, 20000), (-100000.0, 100000.0), (-100000.0, 100000.0))
    grid_roi_cts    = 2000
    
    #discreate grid
    grid_shape_dis  = (1, 201, 201)
    grid_limits_dis = ((out_alt_m, out_alt_m), (-100000.0, 100000.0), (-100000.0, 100000.0))
    grid_roi_dis    = 1000
    
    ###########################################################
    # Load file
    ###########################################################
    
    #load radar object
    radar = pyart.aux_io.read_odim_h5(odim_ffn, file_field_names=True)
    radar_id = radar.metadata['source'][6:8]

    #extract date    
    date_str = radar.time['units'][-20:]
    dt       = datetime.strptime(date_str, '%Y-%m-%dT%H:%M:%SZ')    
    
    #calc beam blockage on first file
    cbb_meta = calc_beam_blocking(vol_filelist[0], srtm_ffn)
    radar.add_field(fieldn['cbb'], cbb_meta, replace_existing=True)
    
    ###########################################################
    # Extract sounding from UWYO
    ###########################################################
    #floor radar date to 6 hour block?
    #get radar latlon
    #adj_datenum = datenum(date_vec(1),date_vec(2),date_vec(3),floor(date_vec(4)/6)*6,0,0);
    sonde_ffn = get_sounding_data(adj_date, radar_lat, radar_lon, outffn)
    #if gfs extract failes, pull data from pervious year in eraint
    
    ###########################################################
    # Offsets
    ###########################################################   
    
    #add offsets
    dbzh_adjusted = radar.fields[fieldn['dbzh']]['data'] + dbzh_offset
    zdr_adjusted  = radar.fields[fieldn['zdr']]['data'] + zdr_offset
    #insert back into radar object
    radar.fields[fieldn['dbzh']]['data'] = dbzh_adjusted
    radar.fields[fieldn['zdr']]['data']  = zdr_adjusted
    
    ###########################################################
    # Filtering
    ###########################################################
    
    #rhohv gatefilter
    gatefilter = pyart.filters.GateFilter(radar)
    gatefilter.exclude_below(fieldn['rhv'], 0.7)
    
    #rhohv texture filtering
    #gatefilter = pyart.filters.moment_and_texture_based_gate_filter(
    
    ###########################################################
    # Correction
    ###########################################################
    
    #build temp information
    height, temperature, snr = cpol_prc.radar_codes.snr_and_sounding(radar, sonde_ffn, refl_field_name = fieldn['dbzh'], 
                                                                     temp_field_name = 'temp') #temp from radiosonde nc
    radar.add_field(fieldn['temp'], temperature, replace_existing=True)
    radar.add_field(fieldn['alt'], height, replace_existing=True)
    radar.add_field(fieldn['snr'], snr, replace_existing=True)
    
    #add NCP if it doesn't exist
    try:
        radar.fields[fieldn['ncp']]
        fake_ncp = False
    except KeyError:
        # Creating a fake NCP field.
        ncp = pyart.config.get_metadata('normalized_coherent_power')
        emr2 = np.zeros_like(snr['data'])
        emr2[snr['data'] > 7.5] = 1
        ncp['data'] = emr2
        ncp['description'] = "THIS FIELD IS FAKE. SHOULD BE REMOVED!"
        radar.add_field(fieldn['ncp'], ncp)
        fake_ncp = True
    
    #RHOHV Noise correct
    rho_corr = cpol_prc.radar_codes.correct_rhohv(radar, rhohv_name=fieldn['rhv'], snr_name=fieldn['snr'])
    radar.add_field_like(fieldn['rhv'], fieldn['rhv_corr'], rho_corr, replace_existing=True)
    
    #ZDR Noise Correct
    corr_zdr = cpol_prc.radar_codes.correct_zdr(radar, zdr_name=fieldn['zdr'], snr_name=fieldn['snr'])
    radar.add_field_like(fieldn['zdr'], fieldn['zdr_corr'], corr_zdr, replace_existing=True)
    
    #unfold phidp
    phi_unfold = cpol_prc.phase.unfold_raw_phidp(radar, refl_field=fieldn['dbzh'], ncp_field=fieldn['ncp'], 
                                                 rhv_field=fieldn['rhv_corr'], phi_name=fieldn['phi'])
    radar.add_field(fieldn['phi_unfold'], phi_unfold, replace_existing=True)

    #recalculate phidp
    phimeta, kdpmeta = cpol_prc.phase.phidp_bringi(radar, gatefilter, refl_field=fieldn['dbzh'], ncp_name=fieldn['ncp'], 
                                                   rhohv_name=fieldn['rhv_corr'], unfold_phidp_name=fieldn['phi_unfold'])
    radar.add_field(fieldn['phi_bringi'], phimeta, replace_existing=True)
    radar.add_field(fieldn['kdp_bringi'], kdpmeta, replace_existing=True)
    radar.fields[fieldn['phi_bringi']]['long_name'] = "corrected_differential_phase"
    radar.fields[fieldn['kdp_bringi']]['long_name'] = "corrected_specific_differential_phase"

    ###########################################################
    # Attenuation
    ###########################################################
    
    #ZH attenuation correction
    atten_spec, zh_corr = cpol_prc.attenuation.correct_attenuation_zh_pyart(radar, refl_field=fieldn['dbzh'], ncp_field=fieldn['ncp'], 
                                                                            rhv_field=fieldn['rhv_corr'], phidp_field=fieldn['kdp_bringi'])
    radar.add_field(fieldn['dbzh_corr'], zh_corr, replace_existing=True)
    radar.add_field(fieldn['a_dbz'], atten_spec, replace_existing=True)    
    
    #ZDR attenuation correction
    atten_spec_zdr, zdr_corr = cpol_prc.attenuation.correct_attenuation_zdr(radar, zdr_name=fieldn['zdr_corr'], kdp_name=fieldn['kdp_bringi'], 
                                                                            alpha=0.016)
    radar.add_field_like(fieldn['zdr'], fieldn['zdr_corr'], zdr_corr, replace_existing=True)
    radar.add_field(fieldn['a_zdr'], atten_spec_zdr,
                    replace_existing=True)
    
    ###########################################################
    # Apply filter
    ###########################################################
    
    #apply rhohv filter
    radar.fields[fieldn['dbzh_corr']]['data']   = cpol_prc.filtering.filter_hardcoding(radar.fields[fieldn['dbzh_corr']]['data'], gatefilter)
    radar.fields[fieldn['zdr_corr']]['data']   = cpol_prc.filtering.filter_hardcoding(radar.fields[fieldn['zdr_corr']]['data'], gatefilter)
    radar.fields[fieldn['kdp_bringi']]['data'] = cpol_prc.filtering.filter_hardcoding(radar.fields[fieldn['kdp_bringi']]['data'], gatefilter)
    radar.fields[fieldn['rhv_corr']]['data']      = cpol_prc.filtering.filter_hardcoding(radar.fields[fieldn['rhv_corr']]['data'], gatefilter)
    
    ###########################################################
    # Classifications
    ###########################################################
    
    #CSU HCA
    hydro_class = cpol_prc.hydrometeors.hydrometeor_classification(radar, refl_name=fieldn['dbzh_corr'], zdr_name=fieldn['zdr_corr'], 
                                                                   kdp_name=fieldn['kdp_bringi'], rhohv_name=fieldn['rhv_corr'], 
                                                                   height_name=fieldn['alt'], temperature_name=fieldn['temp'])
    radar.add_field(fieldn['hca'], hydro_class, replace_existing=True)    
    
    #HSDA
    hsda_meta = hsda.main(radar, sonde_ffn, hca_hail_idx, 0, ref_name=fieldn['dbzh_corr'], zdr_name=fieldn['zdr_corr'], 
                                                                rhv_name=fieldn['rhv_corr'], phi_name=fieldn['phi_bringi'],
                                                                snr_name=fieldn['snr'], cbb_name=fieldn['cbb'],
                                                                hca_name=fieldn['hca'])
    radar.add_field(fieldn['hsda'], hsda_meta, replace_existing=True) 
    
    #HDR
    hdr_meta = hdr.main(radar, ref_name=fieldn['dbzh_corr'], zdr_name=fieldn['zdr_corr'])
    radar.add_field(fieldn['hdr'], hdr_meta, replace_existing=True)
    
    
    ###########################################################
    # Merge HCA with HSDA
    ###########################################################
    
    hsda = radar.fields[fieldn['hsda']]['data']
    hca  = radar.fields[fieldn['hca']]['data']
    
    hca_hsda   = hca
    hail_mask1 = np.isin(hsda, 1)
    hail_idx1  = np.where(hail_mask1)
    hca_hsda[hail_mask1] = hca_hsda_idx[0]
    hail_mask2 = np.isin(hsda, 2)
    hail_idx2  = np.where(hail_mask2)
    hca_hsda[hail_mask2] = hca_hsda_idx[1]
    hail_mask3 = np.isin(hsda, 3)
    hail_idx3  = np.where(hail_mask3)
    hca_hsda[hail_mask3] = hca_hsda_idx[2]
    
    the_comments = "1: Drizzle; 2: Rain; 3: Ice Crystals; 4: Aggregates; " +\
                   "5: Wet Snow; 6: Vertical Ice; 7: LD Graupel; 8: HD Graupel; 9: NOT USED; 10: Big Drops" +\
                   "11: Small Hail (< 25 mm); 12: Large Hail (25 - 50 mm); 13: Giant Hail (> 50 mm)"
    hsda_meta   = {'data': hca_hsda, 'units': ' ', 'long_name': 'Hydrometeor classification + HSDA',
                  'standard_name': 'Hydrometeor_ID_HSDA', 'comments': the_comments}
    radar.add_field(fieldn['hca_hsda'], hsda_meta, replace_existing=True) 
    
    ###########################################################
    # Gridded Processing and Output
    ###########################################################
    
    #genreate grid object for cts fields
    grid_cts = pyart.map.grid_from_radars(
        radar,
        grid_shape = grid_shape_cts,
        grid_limits = grid_limits_cts,
        weighting_function = 'Barnes',
        gridding_algo = 'map_gates_to_grid',
        roi_func='constant', constant_roi = grid_roi_cts,
        fields=[fieldn['dbzh_corr'], fieldn['zdr_corr'], 
               fieldn['kdp_bringi'], fieldn['rhv_corr'],
               fieldn['hdr'], fieldn['temp'], fieldn['alt']])
    
    #generate grid object for discrete fields (using nearest neighbour)
    grid_dis = pyart.map.grid_from_radars(
        radar,
        grid_shape = grid_shape_dis,
        grid_limits = grid_limits_dis,
        weighting_function = 'Nearest',
        gridding_algo = 'map_to_grid',
        roi_func='constant', constant_roi = grid_roi_dis,
        fields=fieldn['hca_hsda'])
    
    #MESH, adding to cts grid object
    grid_cts = mesh.main(grid_cts, fieldn, out_ffn, sonde_ffn, [], ref_name=fieldn['dbzh_corr'], False)
    
    return radar, grid_dis, grid_cts
    
def grid_to_geotif(grid_dis, grid_cts, hsda_geotif_ffn, hdr_geotif_ffn, mesh_geotif_ffn)    
    ###########################################################
    # Geotiff output Processing and Output
    ###########################################################    
    
    #calculate index of nearest z grid index to out_alt_m for HDR
    grid_z = grid.z['data']
    z_idx  = (np.abs(grid_z - out_alt_m)).argmin()
    
    pyart.io.output_to_geotiff.write_grid_geotiff(grid_dis, hsda_geotif_ffn, fieldn['hca_hsda'],
                                                  level = 0) #dis grid only computed on one level

    pyart.io.output_to_geotiff.write_grid_geotiff(grid_cts, hdr_geotif_ffn, fieldn['hdr'],
                                                  level = z_idx)
        
    pyart.io.output_to_geotiff.write_grid_geotiff(grid_cts, mesh_geotif_ffn, fieldn['mesh'],
                                                  level = 0) #zeroth level is valid mesh
    
    #remove sounding
    os.remove(sonde_ffn)