KDE-based prediction: estimating transition temperatures
===

Global module imports
---

In [1]:
from netCDF4 import Dataset
import pandas as pd
import numpy as np
from scipy import stats
import os
# import seaborn as sns
# import matplotlib.pyplot as plt
# import matplotlib.ticker as ticker
# from matplotlib.lines import Line2D
# from matplotlib.text import Text

%load_ext autoreload
%autoreload 2
from model import Model
from GCNet import GCNet
# from plotUtils import PlotUtils
from itertools import izip

pd.options.mode.chained_assignment = None

Function: Remove masked elements
---

In [2]:
def compress(data, selectors):
    # compress('ABCDEF', [1,0,1,0,1,1]) --> A C E F
    return (d for d, s in izip(data, selectors) if s)

Function: Load and prepare a data file
---

In [3]:
def loadData( var, model, period, mms, yrs, suff, mask, site = None ):
    # Load data, convert to 1-d
    
    if model == "melt":
        modelDir = "/Users/dbr/Documents/gismelt/mote_sfc_melt"
        if site is None:
            modelFN = "Mote_"+yrs+"_"+mms+".nc"
        else:
            modelFN = "wrf_mote_melt_closest_"+yrs+"_"+mms+".nc"
    if model == "erai":
        if site is None:
            modelDir = "/Volumes/sbp1/model/pwrf/gis_erai/"+period+"/wrf/postproc/tas"
            modelFN = var+"_wrf_erai_"+yrs+"_"+mms+"_d.nc"
        else:
            modelDir = "/Users/dbr/Documents/gismelt/models_vs_obs/gis_erai"
            modelFN = "wrf_erai_"+var+"_closest_"+yrs+"_"+mms+".nc"
    if model == "cesmle":
        if site is None:
            modelDir = "/Volumes/sbp1/model/pwrf/gis_cesmle/"+period+"/wrf/postproc/tas/ens"
            if suff is None:
                modelFN = var+"_wrf_cesmle_ens_"+yrs+"_"+mms+"_d.nc"
            else:
                modelFN = var+"_wrf_cesmle_ens_"+yrs+"_"+mms+"_d_"+suff+".nc"
        else:
            modelDir = "/Users/dbr/Documents/gismelt/models_vs_obs/gis_cesmle"
            if suff is None:
                modelFN = "wrf_cesmle_"+var+"_closest_"+yrs+"_"+mms+".nc"
            else:
                modelFN = "wrf_cesmle_"+var+"_closest_"+yrs+"_"+mms+"_"+suff+".nc"

    if model == "cesmlw":
        if site is None:
            modelDir = "/Volumes/sbp1/model/pwrf/gis_cesmlw/"+period+"/wrf/postproc/tas/ens"
            if suff is None:
                modelFN = var+"_wrf_cesmlw_ens_"+yrs+"_"+mms+"_d.nc"
            else:
                modelFN = var+"_wrf_cesmlw_ens_"+yrs+"_"+mms+"_d_"+suff+".nc"
        else:
            modelDir = "/Users/dbr/Documents/gismelt/models_vs_obs/gis_cesmlw"
            if suff is None:
                modelFN = "wrf_cesmlw_"+var+"_closest_"+yrs+"_"+mms+".nc"
            else:
                modelFN = "wrf_cesmlw_"+var+"_closest_"+yrs+"_"+mms+"_"+suff+".nc"
        
    if site is None:
        M = Model("wrf_geog.nc", modelDir+"/"+modelFN)
        X = M.loadData( var, True, -2 )
    else:
#         print modelFN
        A = GCNet("site_info.nc", modelDir+"/"+modelFN)
        A.loadMeta()
        M = A.loadData( var, True )
        if site == "All":
            X = M
        else:
            X = M[site]

    X1d = X.values.reshape(-1,).tolist()
#     print "X1d: "+str(len(X1d))+" points, min = "+str(min(X1d))
    
    if mask is None:
        X1dMask = X1d
    else:
        maskVar = mask[0]
        if maskVar == "lsm":  # masking against landsea mask
            X1dMask = list( compress( X1d, lsm ) )
        else:
            if maskVar == "orog":  # masking against orography
                X1dLSM = np.array( list( compress( X1d, lsm )))
                orogLSM = np.array( list( compress( orog, lsm )))
                maskElev = mask[1]
                if maskElev > 0:
                    msk = np.where( orogLSM > abs(maskElev) )[0]
                    X1dMask = X1dLSM[ msk ].tolist()
#                     X1dMask = X1dLSM[ np.where( orogLSM > maskElev ) ].tolist()
                else:
                    msk = np.where( orogLSM < abs(maskElev) )[0]
                    X1dMask = X1dLSM[ msk ].tolist()
#                     X1dMask = X1dLSM[ np.where( orogLSM < abs(maskElev) ) ].tolist()
            else:
                print "Mask variable not recognized"
                return None
    
    return X1dMask

Function: Make Kernel Density Estimate plot
---

In [4]:
def plotData( X, bw = None ):
    if bw is None:
        bw = 0.2
    sns.kdeplot( X, cut = 0, bw = bw, linewidth = linewidth )

Function: Mask elements
---

In [5]:
def maskData( X, Y, mask ):
    mskix = np.where(np.array(Y) == mask )[0]
    msklist = np.zeros(len(Y))
    if len(mskix) > 0:
        msklist[mskix] = 1
    msklist = msklist.tolist()
    X2 = list( compress( X, msklist ))
    Y2 = list( compress( Y, msklist ))
    return ( X2, Y2 )

Function: Drop NaN elements
---

In [6]:
def dropNaN( X, Y, theNaN = None ):
    if theNaN is None:
        nanix = np.where(np.isnan(np.array(Y)))[0]
    else:
        if theNaN > 0:
            nanix = np.where(np.array(Y) > theNaN )[0]
        else:
            nanix = np.where(np.array(Y) < theNaN )[0]
    nanlist = np.ones(len(Y))
    nanlist[nanix] = 0
    nanlist = nanlist.tolist()
    X2 = list( compress( X, nanlist ))
    Y2 = list( compress( Y, nanlist ))
    return ( X2, Y2 )

Function: Mask elements below a certain value
---

In [7]:
def maskDataBelow( X, mask ):
    X2 = np.copy( X )
    mskix = np.where(X < mask)[0]
    if len(mskix) > 0:
        X2[mskix] = np.nan
#         X2[mskix] = 0
    return ( X2, mskix )

Function: Clean up time series
---

In [8]:
def cleanTimeSeries( X, Y, debug = None ):
    # clean time series by testing against values in either X or Y
    
    # drop days where X data is _FillValue
    Y2, X2 = dropNaN( Y, X, 1.e3 )
    if not debug is None:
        print "Orig len: "+str(len(X))+", X FillValue-filtered len: "+str(len(X2))

    # drop days where X data is "nan"
    Y3, X3 = dropNaN( Y2, X2 )
    if not debug is None:
        print "Orig len: "+str(len(X2))+", X NaN-filtered len: "+str(len(X3))

    # drop days where Y data is "NaN"
    X4, Y4 = dropNaN( X3, Y3, -1. )
    if not debug is None:
        print "Orig len: "+str(len(X3))+", Y NaN-filtered len: "+str(len(X4))

    return (X4, Y4)

Function: convert KDE to transition function
---

In [9]:
def transitionFN( xs, Xm, Xnm, smthIX = 20 ):
    
    # copy input vars
    Xm2 = np.copy(Xm)
    Xnm2 = np.copy(Xnm)
    
    # find maximum values
    ixXm = np.argmax(Xm)
    ixXnm = np.argmax(Xnm)
#     print "Values:  Max no-melt", Xnm[ixXnm], "Max melt", Xm[ixXm]
#     print "Indices: Max no-melt", ixXnm, "Max melt", ixXm

    # set pieces of each var outside overlap to nan
    Xm2[0:ixXnm] = np.nan
    Xm2[(ixXm+1):] = np.nan
    Xnm2[(ixXm+1):] = np.nan
    Xnm2[0:ixXnm] = np.nan

    # calculate P(melt)
    Pm = Xm2 / ( Xm2 + Xnm2 )
    minPm = Pm[ixXnm]
    maxPm = Pm[ixXm]
    
    # adjusted indices
    ixXmAdj  = ixXm + smthIX
#     print "Unadjusted ixXmAdj", ixXmAdj
    if ixXmAdj >= len( Xm ):
        ixXmAdj = len( Xm ) - 1
#         print "Adjusted ixXmAdj", ixXmAdj
    ixXnmAdj = ixXnm - smthIX
    if ixXnmAdj < 0:
        ixXnmAdj = 0
    
    # smooth the part between min/max and 0/1
    
    # lhs
    Pm[0:ixXnmAdj] = 0
    infill = np.linspace(0, minPm, num = smthIX)
    try:
        Pm[ixXnmAdj:ixXnm] = infill
    except Exception, error:
        print "transitionFN: LHS infill assign failed", error
         
    # rhs
    Pm[ixXmAdj:] = 1
    rhsIX = range( ixXm, ixXmAdj )
    rhsSmthIX = smthIX
    if len(rhsIX) < rhsSmthIX:
        rhsSmthIX = len(rhsIX)
    infill = np.linspace(maxPm, 1, num = rhsSmthIX)
    try:
        Pm[rhsIX] = infill
    except Exception, error:
        print "transitionFN: RHS infill assign failed", error

    # make in-between linear
    lhsIX = range( ixXnmAdj, ixXmAdj )
#     print "lhsix",lhsIX
#     npts = (ixXm + smthIX) - (ixXnm - smthIX)
    npts = len( lhsIX )
    infill = np.linspace(0, 1, num = npts)
#     print "infill", len(infill)
    try:
        Pm[lhsIX] = infill
    except Exception, error:
        print "transitionFN: center infill failed", error
    
    # likelihood of no-melt = 1 - P(melt)
    Pnm = 1 - Pm
    
    # transition point values
#     print "Original values: Lower", xs[ixXnmAdj], "Upper", xs[ixXmAdj] 
    xLwr = round( xs[ixXnmAdj],1 )
    xUpr = round( xs[ixXmAdj],1 )

    # calculate linear regression parameters
    m = (1. - 0.) / (xUpr - xLwr)
    b = 1. - m*xUpr
    
    return Pm, Pnm, xLwr, xUpr, m, b

Function: calculate lower/upper boundary temperatures
===

In [10]:
def calcLwrUpr( var, period, mms, yrsHist, suff, mask, site, minMeltPts = 1, smthIX = 10 ):
    # default return values
    xLwr = xUpr = None
    nXMelt = nXNoMelt = 0
    
    # load melt
    melt = loadData( "greenland_surface_melt", "melt", period, mms, yrsHist, suff, mask, site )
    
    # load tas
    X = loadData( var, "erai", period, mms, yrsHist, suff, mask, site )
    nX = len(X)

    # drop FillValues, NaNs
    X, melt = cleanTimeSeries( X, melt, None )
    
    if len(X) == 0:
        # a site with no melt measurements will clear tas in the cleanTimeSeries function
        pass
    else:
        # stratify into melt/no-melt using melt data
        Xm, meltmsk = maskData( X, melt, 1 )
        nXMelt = len(Xm)
        if nXMelt < minMeltPts:
            # Warning: short data series!
            pass
        else:
            Xnm, nomeltmsk = maskData( X, melt, 0 )
            nXNoMelt = len(Xnm)

            # generate KDE's directly
            try:
                # range for KDE x-axis (should be robust...)
                xs = np.arange(min(X), max(X), 0.1)

                # KDE for melt
                try:
                    Km = stats.gaussian_kde( Xm, bw_method=bw )
                    Ym = np.array(Km(xs))

                    # KDE for non-melt
                    try:
                        Knm = stats.gaussian_kde( Xnm, bw_method=bw )
                        Ynm = np.array(Knm(xs))

                        # calculate transition function
                        try:
                            Pm, Pnm, xLwr, xUpr, m, b = transitionFN( xs, Ym, Ynm, smthIX )
                        except Exception, error:  # transition FN error
                            print "*** Transition fn failed ***",error

                    except Exception, error:  # non-melt KDE error
                        print "*** Non-melt KDE failed ***",error

                except Exception, error:  # melt KDE error
                    # a site where it didn't melt (or not very much!) during time period
                    if len( tasm ) > minMeltPts:  # did not fail because of too few melt points
                        print "*** Melt KDE failed ***",error
                    else:
                        pass  # ignore error due to too few melt point

            except Exception, error:  # range error
                # should never get this far since I'm checking for empty tas after cleanTimeSeries(),
                # but keep the check anyway
                print "*** KDE x-axis range failed ***",error

    return xLwr, xUpr, nXMelt, nXNoMelt

Process variables, collect statistics
===

tas
---

In [11]:
# ---------------------------------
# Variable to process
# ---------------------------------
var = "tas"

# ---------------------------------
# Time period
# ---------------------------------
yr1a = "1986"
yr2a = "2015"

# yr1a = "1996"
# yr2a = "2005"

yrsHist = yr1a+"-"+yr2a

period = "historical"

# ---------------------------------
# Possible masking datasets
# ---------------------------------
# Mlsm = Model("wrf_geog.nc", "landmask.nc")
# lsm2d = Mlsm.loadData( "LANDMASK", True )
# lsm = lsm2d.values.reshape(-1,).tolist()

# Morog = Model("wrf_geog.nc", "lat_lon_orog.nc")
# orog2d = Morog.loadData( "orog", True )
# orog = orog2d.values.reshape(-1,).tolist()

# ---------------------------------
# KDE, plotting parms
# ---------------------------------
linewidth = 4
bw = 0.4

# ---------------------------------
# Alternate input datasets
# ---------------------------------
mask = None
suff = None

Aglobal = GCNet("site_info.nc", None)
Aglobal.loadMeta()
awsNames = Aglobal.getNames()

# ---------------------------------
# Sites to process
# ---------------------------------
# sites = None   # "None" will process all GRID points
# sites = "All"  # "All" will process all AWS SITES
# sites = ( "Swiss Camp", )
# sites = ( "Saddle", )
sites = list( awsNames )  # copy the list of AWS names

# ---------------------------------
# Special site "names"
# ---------------------------------
sites.insert(0,"All")
# sites.insert(0,None)  

# ---------------------------------
# Do the processing
# ---------------------------------
for mm in ( 6, 7, 8, "JJA" ):
# for mm in ( 6,  ):
    try:
        mms = "%02d" % mm
    except TypeError:
        mms = mm
    for site in sites:
        if site is None:
            site2 = "Grid"
        else:
            site2 = site
        xLwr, xUpr, ntasM, ntasNM = calcLwrUpr( var, period, mms, yrsHist, suff, mask, site, 10, 0 )
         # print results
        if xLwr is None:
            print "%s,%s,%.1f,%.1f,%d,%d,**" % ( site2, mms, 0., 0., ntasM, ntasNM )
        else:
            print "%s,%s,%.1f,%.1f,%d,%d" % ( site2, mms, xLwr, xUpr, ntasM, ntasNM )
    print " "

Grid,06,-7.7,-0.7,1124011,5568483
All,06,-7.8,-0.1,3493,14480
Swiss Camp,06,-4.2,-0.1,598,258
Crawford Point1,06,-5.4,-1.3,116,736
NASA-U,06,-7.3,-1.4,16,844
GITS,06,-5.9,-2.2,18,843
Humboldt,06,-6.4,-0.9,11,850
Summit,06,0.0,0.0,0,0,**
Tunu-N,06,0.0,0.0,8,0,**
DYE-2,06,-7.9,-3.0,187,661
JAR1,06,-3.1,0.8,689,164
Saddle,06,-9.5,-3.6,71,777
South Dome,06,-11.2,-4.9,24,827
NASA-E,06,0.0,0.0,2,0,**
Crawford Point2,06,-5.5,-1.7,129,726
NGRIP,06,0.0,0.0,1,0,**
NASA-SE,06,-9.8,-6.6,41,809
KAR,06,-9.8,-4.7,41,813
JAR2,06,0.0,0.0,0,0,**
KULU,06,-1.9,0.1,296,551
JAR3,06,0.0,0.0,0,0,**
Aurora,06,-6.3,-2.3,379,473
Petermann GL,06,-0.7,0.5,555,305
Petermann ELA,06,-3.6,-0.5,310,551
NEEM,06,0.0,0.0,1,0,**
 
Grid,07,-6.9,-0.4,1769914,5204941
All,07,-7.2,0.3,5333,13388
Swiss Camp,07,-1.6,-0.0,691,197
Crawford Point1,07,-4.8,-1.2,177,716
NASA-U,07,-6.5,-1.6,53,842
GITS,07,-4.5,-2.0,120,776
Humboldt,07,-5.1,-1.2,71,825
Summit,07,0.0,0.0,5,0,**
Tunu-N,07,-6.8,-2.0,34,863
DYE-2,07,-7.7,-2.1,366,518
JAR1,0

tasmax
---

In [12]:
# ---------------------------------
# Variable to process
# ---------------------------------
var = "tasmax"

# ---------------------------------
# Time period
# ---------------------------------
yr1a = "1986"
yr2a = "2015"

# yr1a = "1996"
# yr2a = "2005"

yrsHist = yr1a+"-"+yr2a

period = "historical"

# ---------------------------------
# Possible masking datasets
# ---------------------------------
# Mlsm = Model("wrf_geog.nc", "landmask.nc")
# lsm2d = Mlsm.loadData( "LANDMASK", True )
# lsm = lsm2d.values.reshape(-1,).tolist()

# Morog = Model("wrf_geog.nc", "lat_lon_orog.nc")
# orog2d = Morog.loadData( "orog", True )
# orog = orog2d.values.reshape(-1,).tolist()

# ---------------------------------
# KDE, plotting parms
# ---------------------------------
linewidth = 4
bw = 0.4

# ---------------------------------
# Alternate input datasets
# ---------------------------------
mask = None
suff = None

Aglobal = GCNet("site_info.nc", None)
Aglobal.loadMeta()
awsNames = Aglobal.getNames()

# ---------------------------------
# Sites to process
# ---------------------------------
# sites = None   # "None" will process all GRID points
# sites = "All"  # "All" will process all AWS SITES
# sites = ( "Swiss Camp", )
# sites = ( "Saddle", )
sites = list( awsNames )  # copy the list of AWS names

# ---------------------------------
# Special site "names"
# ---------------------------------
sites.insert(0,"All")
# sites.insert(0,None)  

# ---------------------------------
# Do the processing
# ---------------------------------
for mm in ( 6, 7, 8, "JJA" ):
# for mm in ( 6,  ):
    try:
        mms = "%02d" % mm
    except TypeError:
        mms = mm
    for site in sites:
        if site is None:
            site2 = "Grid"
        else:
            site2 = site
        xLwr, xUpr, ntasM, ntasNM = calcLwrUpr( var, period, mms, yrsHist, suff, mask, site, 10, 0 )
         # print results
        if xLwr is None:
            print "%s,%s,%.1f,%.1f,%d,%d,**" % ( site2, mms, 0., 0., ntasM, ntasNM )
        else:
            print "%s,%s,%.1f,%.1f,%d,%d" % ( site2, mms, xLwr, xUpr, ntasM, ntasNM )
    print " "

Grid,06,-3.7,0.9,1124011,5568483
All,06,-2.9,1.2,3493,14480
Swiss Camp,06,0.3,1.8,598,258
Crawford Point1,06,-0.5,0.8,116,736
NASA-U,06,-3.3,0.2,16,844
GITS,06,-3.4,-0.1,18,843
Humboldt,06,-2.4,0.3,11,850
Summit,06,0.0,0.0,0,0,**
Tunu-N,06,0.0,0.0,8,0,**
DYE-2,06,-1.3,0.6,187,661
JAR1,06,1.0,2.2,689,164
Saddle,06,-4.2,-0.0,71,777
South Dome,06,-6.4,-0.9,24,827
NASA-E,06,0.0,0.0,2,0,**
Crawford Point2,06,-0.7,0.7,129,726
NGRIP,06,0.0,0.0,1,0,**
NASA-SE,06,-4.5,-0.0,41,809
KAR,06,-4.7,0.1,41,813
JAR2,06,0.0,0.0,0,0,**
KULU,06,1.0,1.3,296,551
JAR3,06,0.0,0.0,0,0,**
Aurora,06,-0.2,0.9,379,473
Petermann GL,06,0.9,2.4,555,305
Petermann ELA,06,0.2,1.3,310,551
NEEM,06,0.0,0.0,1,0,**
 
Grid,07,-2.5,1.0,1769914,5204941
All,07,-2.1,1.4,5333,13388
Swiss Camp,07,0.8,1.6,691,197
Crawford Point1,07,0.1,0.7,177,716
NASA-U,07,-2.7,0.1,53,842
GITS,07,-1.9,0.2,120,776
Humboldt,07,-1.1,0.5,71,825
Summit,07,0.0,0.0,5,0,**
Tunu-N,07,-1.8,0.6,34,863
DYE-2,07,0.1,0.6,366,518
JAR1,07,1.1,2.1,806,80
Saddle,07,-