In [16]:
def state_from_gfse(statevars=['t2'], bounds=None):
    """ Build an xray state vector from GEFS forecasts """
    from efa_xray.ensemble_class import Xray_Ensemble_State
    from netCDF4 import Dataset, num2date
    from datetime import datetime, timedelta
    from numpy import reshape, zeros
    import os
    from siphon_gfse import get_gefs_ensemble
    
    vardict = {'t2' : 'Temperature_height_above_ground_ens',
               'prmsl' : 'Pressure_reduced_to_MSL_msl_ens',
               'psfc' : 'Pressure_surface_ens',
               }
    
    # Get the data from Unidata/Siphon
    raw_ensemble = get_gefs_ensemble([vardict[x] for x in statevars], bounds=bounds,writeout=False)
    
    
    # Build the state array
    meta = {}
    #print raw_ensemble
    nmems = len(raw_ensemble.dimensions['ens'])
    ntimes = len(raw_ensemble.dimensions['time2'])
    ny = len(raw_ensemble.dimensions['lat'])
    nx = len(raw_ensemble.dimensions['lon'])
    nvars = len(statevars)
    
    # Total number of geographic locations
    nlocs = ny*nx
    
    print("Allocating the state vector array...")
    state = zeros((nvars, ntimes, nlocs, nmems))
    
    # For metadata need list of times and locations
    times = raw_ensemble.variables['time2']
    valid_times = num2date(times[:], times.units)
    
   
    # Location is every latitude and longitude in the flattened array
    latl = np.array(raw_ensemble.variables['lat'][:])[:,None]*np.ones((ny,nx))
    lats = list(latl.flatten())
            
    lonl = np.array(raw_ensemble.variables['lon'][:])[None,:]*np.ones((ny,nx))
    lons = list(lonl.flatten())
    # Convert to negative lons
    lons = [l-360.0 if l > 180.0 else l for l in lons]
    # Zip these
    locations = zip(lats,lons)
    locations = ['{:3.4f},{:3.4f}'.format(l[0],l[1]) for l in locations]

    # Compose the metadata
    print("Building state metadata...")
    meta[(0,'var')] = statevars
    meta[(1,'time')] = valid_times
    meta[(2,'location')] = locations
    meta[(3,'mem')] = [n+1 for n in xrange(nmems)]

    
    
     # Now we can populate the state array
    for var in statevars:
        field = np.squeeze(raw_ensemble.variables[vardict[var]][:])
        # For surface pressure fields, put it in hPa
        if var in ['prmsl','psfc']:
            field = field / 100.
        # Reshape this to be flattened in space
        field = np.reshape(field,(ntimes,nmems,ny*nx))
        # Swap the last two axes
        field = np.swapaxes(field,1,2)
        # Populate its component of the state array
        state[statevars.index(var),:,:,:] = field
    
    """
    for mnum, mem in enumerate(memfiles):
        # Point to the file
        data = Dataset('/'.join((filedir,mem)),'r')
        # If this is the first member, calculate how large the state
        # array needs to be and allocate.  Also set up the metadata.
        if mnum == 0:
            nmems = len(memfiles)
            ntimes = len(data.dimensions['time'])
            nvars = len(statevars)
            ny = len(data.dimensions['latitude'])
            nx = len(data.dimensions['longitude'])
            nlocs = ny * nx
            # Allocate the state array
            print("Allocating the state vector array...")
            state = zeros((nvars,ntimes,nlocs,nmems))
            
            # For the metadata, need a list of times and locations
            ftimes = [datetime(1970,1,1) + timedelta(seconds=d) for d in data.variables['time'][:]]
            # Location is every latitude and longitude in the flattened array
            latl = np.array(data.variables['latitude'][:])[:,None]*np.ones((ny,nx))
            lats = list(latl.flatten())
            
            lonl = np.array(data.variables['longitude'][:])[None,:]*np.ones((ny,nx))
            lons = list(lonl.flatten())
            # Convert to negative lons
            lons = [l-360.0 if l > 180.0 else l for l in lons]
            # Zip these
            locations = zip(lats,lons)
            locations = ['{:3.4f},{:3.4f}'.format(l[0],l[1]) for l in locations]
            
            # Compose the metadata
            print("Building state metadata...")
            meta[(0,'var')] = statevars
            meta[(1,'time')] = ftimes
            meta[(2,'location')] = locations
            meta[(3,'mem')] = [n+1 for n in xrange(nmems)]
        
        # Now we can populate the state array
        for var in statevars:
            field = data.variables[vardict[var]][:]
            if var in ['prmsl','psfc']:
                field = field / 100.
            # Reshape this to be flattened in space
            field = reshape(field,(ntimes,ny*nx))
            # Populate its component of the state array
            state[statevars.index(var),:,:,mnum] = field
        data.close()
    """
    print(state.shape)
    for name,dat in meta.items():
        print(name,len(dat))
    # Make an Xray ensemble state
    statecls = Xray_Ensemble_State(state=state, meta=meta)
    return statecls, ny, nx


In [17]:
import numpy as np
from datetime import datetime, timedelta
statevars = ['t2']
# Only focus on CONUS
bounds=(-124.9,-66.8,24.3,49.4)
obtypes = statevars

# Download the latest GFSE ensemble
state_vect, ny, nx = state_from_gfse(statevars=statevars, bounds=bounds)
ftimes = state_vect.ensemble_times()
print(ftimes)


<type 'netCDF4._netCDF4.Dataset'>
root group (NETCDF3_CLASSIC data model, file format NETCDF3):
    Originating_or_generating_Center: US National Weather Service, National Centres for Environmental Prediction (NCEP)
    Originating_or_generating_Subcenter: NCEP Ensemble Products
    GRIB_table_version: 2,1
    Type_of_generating_process: Ensemble forecast
    Analysis_or_forecast_generating_process_identifier_defined_by_originating_centre: Global Ensemble Forecast System (GEFS)
    Conventions: CF-1.6
    history: Read using CDM IOSP GribCollection v3
    featureType: GRID
    History: Translated to CF-1.0 Conventions by Netcdf-Java CDM (CFGridWriter2)
Original Dataset = /data/ldm/pub/native/grid/NCEP/GEFS/Global_1p0deg_Ensemble/member/GEFS_Global_1p0deg_Ensemble_20160120_1800.grib2.ncx3#LatLon_181X360-p5S-180p0E; Translation Date = 2016-01-21T01:39:54.824Z
    geospatial_lat_min: 24.0
    geospatial_lat_max: 49.0
    geospatial_lon_min: -125.0
    geospatial_lon_max: -67.0
    dimensi

In [18]:
def get_local_observations(siteids = None, obtypes = None, use_times = None):
    """ Get observations and return a list of obs objects """
    from surface_parse_bufkit import obs_parser
    from efa_xray.observation_class import Observation
    observations = []
    obdict = {'t2' : 'tmpc',
              'prmsl' : 'press'}
    oberrs = {'t2' : 0.2,
              'prmsl' : 1.0,
             }
              
    for siteid in siteids:
        cur_obs, xnlist = obs_parser(siteid)
        obtimes = cur_obs.keys()
        obtimes.sort()
        #print obtimes
        #print use_times
        overlap = [x for x in obtimes if x in use_times]
        for t in overlap:
            for obtype in obtypes:
                if obtype in ['t2']:
                    addval = 273.
                else:
                    addval = 0.0
                curob = Observation(value=getattr(cur_obs[t],obdict[obtype])+addval,\
                        obtype=obtype, error=oberrs[obtype], time=t,\
                                    location='{:3.7f},{:3.7f}'.format(cur_obs[t].lat,cur_obs[t].lon),\
                                    description=siteid, localize_radius=1000.)
                observations.append(curob)
                
    return observations
    
    

In [22]:
# Get observations
import cPickle
infile = open('asos_lat_lon.pickle','r')
asos_sites = cPickle.load(infile)
infile.close()
siteids = [s for s in asos_sites if not s.startswith('P') and not s.startswith('T')]
siteids.sort()
siteids = siteids[::3]
#siteids = ['KJAX']
observations = get_local_observations(siteids=siteids,obtypes=obtypes,use_times=ftimes)
print('Number of obs found:', len(observations))

('Number of obs found:', 169)


In [25]:
from mpl_toolkits.basemap import Basemap
import matplotlib
import matplotlib.pyplot as plt
m = Basemap(projection='lcc', resolution = 'l',\
            width=185*40635, height=129*40635,\
            area_thresh=1000, lat_0 = 40.5, lon_0 = -95, lat_1=50.0)

In [26]:
observations[1].map_localization(state_vect, m, ny, nx)

In [30]:
import efa_xray.enkf_update
reload(efa_xray.enkf_update)
for ob in observations:
    ob.assimilate_this = False
    if (ob.time.hour <= 18) and (ob.time.hour >= 12):
        ob.assimilate_this = True

post_state, post_obs = efa_xray.enkf_update.update(state_vect, observations, loc='GC',nproc=1)

(15340,) 15340
(15340, 21)


In [31]:
for ob in observations[0:100:10]:
    prior_mean = np.mean(ob.H_Xb(state_vect))
    post_mean = np.mean(ob.H_Xb(post_state))
    prior_error = abs(prior_mean - ob.value)
    post_error = abs(post_mean - ob.value)
    print ob.description, ob.time, ob.obtype, ob.value, prior_mean, ob.assimilated, post_mean, post_error - prior_error

KAAF 2016-01-20 18:00:00 t2 288.0 288.467994202 True 288.638429579 0.170435377256
KANJ 2016-01-21 00:00:00 t2 263.0 263.063410034 False 263.026134554 -0.0372754803947
KBIS 2016-01-21 00:00:00 t2 263.555555556 264.720448365 False 265.611328988 0.890880622763
KBWI 2016-01-21 00:00:00 t2 271.333333333 269.832497944 False 269.701582012 0.130915932322
KCOU 2016-01-21 00:00:00 t2 269.111111111 267.803723594 False 268.711250435 -0.907526841482
KDFW 2016-01-21 00:00:00 t2 280.222222222 280.462510793 False 280.705543242 0.243032449198
KEYE 2016-01-21 00:00:00 t2 265.222222222 263.898248884 False 264.245845472 -0.34759658815
KGLD 2016-01-21 00:00:00 t2 273.0 274.567463031 False 274.26255923 -0.304903801028
KHUF 2016-01-20 18:00:00 t2 268.0 267.698965837 True 267.901309324 -0.202343487509
KJFK 2016-01-20 18:00:00 t2 276.333333333 274.974236639 True 275.425302971 -0.451066331968


In [None]:
# Get the increment for a particular time
timelist = ftimes[::2]
ntimes = len(timelist)
nrows = ntimes / 2 + 1
# Make the plot
plt.figure(figsize=(12,6))
for tnum, t in enumerate(timelist):
    plt.subplot(2,nrows,tnum+1)
    prior_mean = state_vect.ensemble_mean().loc[dict(time=t,var='t2')].values
    post_mean = post_state.ensemble_mean().loc[dict(time=t,var='t2')].values
    prior_mean = np.reshape(prior_mean,(ny,nx))
    post_mean = np.reshape(post_mean, (ny,nx))
    increment = np.subtract(post_mean, prior_mean)
    # Get the plotting coordinates
    x,y = state_vect.project_coordinates(m,ny,nx)

    incplot = plt.pcolormesh(x,y,increment,vmin=-2.0,vmax=2.0,cmap = matplotlib.cm.RdBu_r)
    m.drawcoastlines()
    m.drawcountries()
    m.drawstates()
    plt.title(t.strftime('%d/%HZ'))
    #if tnum == 0:
    #    plt.colorbar(incplot)
plt.suptitle('Increment of T2 (init 00Z, assim 12Z+18Z obs)')
plt.tight_layout()
plt.show()
