In [None]:
def sq(a):
    import numpy as np
    a = np.squeeze(a)
    masked_array=np.ma.masked_where(a==0., a)
    return masked_array

def mosaic_llc(field):
    return np.vstack([np.hstack([np.vstack([np.rot90(field[i]) for i in [9,8,7]]),
                                 np.vstack([np.rot90(field[i]) for i in [12,11,10]]),
                                 np.vstack([field[i] for i in [0,1,2]]),
                                 np.vstack([field[i] for i in [3,4,5]])]),
                      np.hstack([np.rot90(field[6])*np.tri(90)[::-1,:],np.triu(np.rot90(field[6],k=2))*np.tri(90)[::-1,:],
                                 np.triu(np.rot90(field[6],k=-1)),np.zeros(field[6].shape)])])[30:315,:]

def llc13to5faces(field):
    """
    fld = llc13to5faces(field) returns a list of 5 faces constructed from 
    the input of a 13-faces field as returned from 
    xmitgcm.open_mdsdataset(...,geometry='llc')
    """
    return [np.vstack((field[0,...],field[1,...],field[2,...])),
            np.vstack((field[3,...],field[4,...],field[5,...])),
            field[6,...],
            np.hstack((field[7,...],field[8,...],field[9,...])),
            np.hstack((field[10,...],field[11,...],field[12,...]))]

def symNorm(vmax):
    import matplotlib.colors as mcolors
    return mcolors.Normalize(vmin=-vmax,vmax=vmax)

face_connections = {'face':
                    {0: {'X':  ((12, 'Y', False), (3, 'X', False)),
                         'Y':  (None,             (1, 'Y', False))},
                     1: {'X':  ((11, 'Y', False), (4, 'X', False)),
                         'Y':  ((0, 'Y', False),  (2, 'Y', False))},
                     2: {'X':  ((10, 'Y', False), (5, 'X', False)),
                         'Y':  ((1, 'Y', False),  (6, 'X', False))},
                     3: {'X':  ((0, 'X', False),  (9, 'Y', False)),
                         'Y':  (None,             (4, 'Y', False))},
                     4: {'X':  ((1, 'X', False),  (8, 'Y', False)),
                         'Y':  ((3, 'Y', False),  (5, 'Y', False))},
                     5: {'X':  ((2, 'X', False),  (7, 'Y', False)),
                         'Y':  ((4, 'Y', False),  (6, 'Y', False))},
                     6: {'X':  ((2, 'Y', False),  (7, 'X', False)),
                         'Y':  ((5, 'Y', False),  (10, 'X', False))},
                     7: {'X':  ((6, 'X', False),  (8, 'X', False)),
                         'Y':  ((5, 'X', False),  (10, 'Y', False))},
                     8: {'X':  ((7, 'X', False),  (9, 'X', False)),
                         'Y':  ((4, 'X', False),  (11, 'Y', False))},
                     9: {'X':  ((8, 'X', False),  None),
                         'Y':  ((3, 'X', False),  (12, 'Y', False))},
                     10: {'X': ((6, 'Y', False),  (11, 'X', False)),
                          'Y': ((7, 'Y', False),  (2, 'X', False))},
                     11: {'X': ((10, 'X', False), (12, 'X', False)),
                          'Y': ((8, 'Y', False),  (1, 'X', False))},
                     12: {'X': ((11, 'X', False), None),
                          'Y': ((9, 'Y', False),  (0, 'X', False))}}}

def plot2dmap(ax,fld,levs,tstr='dummy',cmap=None):
    
    ax.set_global()
    # for iface in [0,1,2,3,4,5,6,7,8,9,10,11,12]:
    # for iface in [0,1,2,3,4,5,10,11,12]:
    for iface in range(13):
        if iface==12: 
            clrbr=True
            # cbarargs={"orientation": "horizontal"}
            cbarargs={"extend": "both", "orientation": "horizontal"}
            if levs.vmin==0.: 
                cbarargs["extend"] = "max"
        else:
            clrbr=False
            cbarargs=None

        fld.isel(face=iface).plot.pcolormesh(ax=ax, transform=cart.crs.PlateCarree(), x="XC", y="YC", norm=levs,
                                             add_colorbar=clrbr, cmap=cmap, cbar_kwargs=cbarargs)
    
    ax.set_title('%s'%(tstr))
    ax.coastlines()
    ax.add_feature(cart.feature.LAND, zorder=100, edgecolor='k')
    ax.gridlines()
    
def mdsTo13faces(field):
    """fld = llcTo13faces(field) returns an array where axis=-2 has been
    split into 13 to for 13 llc-faces

    """
    nn,nx = field.shape[-2:]
    if nn/nx == 13:
        n = nn//nx//4
        dims = field.shape
        fld13 = field.reshape( ( *dims[:-2], 13, nx, nx ) )
        # move last 3 dimensions to the front for easier manipulation
        fld13 = np.moveaxis(np.moveaxis(
            np.moveaxis(fld13, -1, 0), -1, 0), -1, 0)
        tmp = np.zeros( (6, nx, nx, *dims[:-2]) )
        # re-arrange
        for k in range(n):
            tt = fld13[7:,k::n,:,...].reshape(2*nx,nx,*dims[:-2])
            tmp[  k,...] = tt[:nx,...]
            tmp[3+k,...] = tt[nx:,...]

        fld13[7:,...] = tmp
        # move dimensions back
        fld13 = np.moveaxis(np.moveaxis(
            np.moveaxis(fld13,0,-1),0,-1),0,-1)
    else:
        fld13=np.NaN
        raise ValueError(
            "%s %i,%i with %i/%i = %.1f but not = 13"%(
                "unexpected horizontal llc-dimensions",nn,nx,nn,nx,nn/nx))

    return fld13

def flat2d(x):
    if type(x) is np.ndarray:
        x0 = np.concatenate( [np.concatenate([x[:,0,:,:],x[:,1,:,:],x[:,2,:,:]], axis=-2),
                              np.concatenate([x[:,3,:,:],x[:,4,:,:],x[:,5,:,:]], axis=-2)], axis=-1 )
        y0 = np.concatenate( [np.concatenate([x[:,7,:,:],x[:,8,:,:],x[:,9,:,:]], axis=-1),
                              np.concatenate([x[:,10,:,:],x[:,11,:,:],x[:,12,:,:]], axis=-1)], axis=-2 )
    else:
        x0 = xr.concat( [xr.concat( [x.isel(face=0),x.isel(face=1),x.isel(face=2)], dim = 'j' ),
                         xr.concat( [x.isel(face=3),x.isel(face=4),x.isel(face=5)], dim = 'j' )], dim='i' )
        y0 = xr.concat( [xr.concat( [x.isel(face=7),x.isel(face=8),x.isel(face=9)], dim = 'i' ),
                         xr.concat( [x.isel(face=10),x.isel(face=11),x.isel(face=12)], dim = 'i' )], dim='j' )
    return np.concatenate((x0,np.rot90(y0,k=1,axes=(-2,-1))), axis=-1)

def calc_drake_passage_transport(ds):
    vtrans = (ds.VVELMASS*ds.drF).sum('k')*ds.dxG * 1e-6
    return (vtrans.sel(face=11,i=range(87,90),j_g=62).sum('i') + vtrans.sel(face=12,i=range(20),j_g=62).sum('i'))

def make_masks(coords, withoutArctic=True):
    global_mask = coords.hFacC.isel(k=0)
    # global_mask[6,:,:]=0. # delete Arctic face
    global_mask[2,80:,60:]=0.
    global_mask[7,:,:13]=0.
    global_mask[10,:43,:11]=0.
    # remove Hudson
    global_mask[10,30:54,5:39] = 0.
    global_mask[10,30:62,10:39] = 0.
    #
    atlantic_mask = global_mask.where(coords.YC>-35).where( # Southern Ocean
        np.logical_and(coords.XC<20,coords.XC>-98)).where( # most of the non-Atlantic Ocean
        np.logical_or(coords.XC<0,np.logical_or(coords.YC<30,coords.YC>47))).where(
        np.logical_or(coords.XC<-9,np.logical_or(coords.YC<34,coords.YC>38))).where( # Strait of Gibraltar
        np.logical_or(coords.XC>-70,coords.YC>9)).where( # East Pacific
        np.logical_or(coords.XC>-84,coords.YC>14)).where( # Isthmus of Panama etc.
        np.logical_or(coords.XC>-90,coords.YC>18)).where(
        np.logical_or(coords.XC>-70,coords.YC<50)).fillna(0)
    indopacific_mask = (global_mask-atlantic_mask).where(
        np.logical_and(coords.YC>-35,coords.YC<70)).fillna(0)
    # remove Hudson
    indopacific_mask[10,10:,:39] = 0.
    # remove Med and parts of Arctic
    indopacific_mask[ 2,20:,29:84] = 0.
    # remove Bering strait and Chukchy Sea
    indopacific_mask[ 7,:,:14] = 0.
    if withoutArctic:
        global_mask[6,:,:]=0. # delete Arctic face
        atlantic_mask[6,:,:]=0. # delete Arctic face
        indopacific_mask[6,:,:]=0. # delete Arctic face
    return global_mask.values, atlantic_mask.values, indopacific_mask.values

def zonal_mean(ds,fld,msk):
    # mask the Med
    dvol = flat2d(ds.hFacC*ds.rA*ds.drF*msk)
    # mask the Med
    dvol[:,200:217,33:80]=0
    dvol[:,217:222,40:60]=0
    ra = dvol.sum(axis=-1)
    ra[ra==0]=np.Inf
    fldz = (flat2d(fld)*dvol).sum(axis=-1)/ra
    return np.ma.masked_array(fldz,fldz==0)

def zonal_sum(fld):
    # zonal integral of scalar field
    return flat2d(fld).sum(axis=-1)

def zonal_lat_bin(clat,res=1.):
    lat_group = np.round(clat/res)*res
    latg = np.unique(lat_group.values.ravel())
    return latg - 0.5*res
    
def zonal_sum_bin(data,clat,res=1.):
    lat_group = np.round(clat/res)*res
    return data.where(data>0.).groupby(lat_group).sum()
    
def calc_flux_divergence(dl):
    # this is done separately
    # grd = xgcm.Grid(dl, periodic=False, face_connections=face_connections)
    # layer flux in two directions
    flxx = dl.LaUH1RHO*dl.dyG
    flxy = dl.LaVH1RHO*dl.dxG
    # difference in the x and y directions
    diff_flx = grd.diff_2d_vector({'X': flxx, 'Y': flxy}, boundary='fill')
    # divergence
    return diff_flx['X'] + diff_flx['Y']

def calc_wflux_dia(dl):
    flx_div = calc_flux_divergence(dl)
    # determine vertical coordinate
    try: z = dl.Zl.values
    except:
        try: z = dl.Z.values
        except: 
            print('no z-coordintate')

    if z[0]>0: # p-coords
        # compute wflux at w-points (below c-points),
        # integrating (from the bottom up) cumulatively,
        wflx = flx_div.cumsum(dim='l1_c')
        wflx[-1,:] = 0.
    else:
        # compute wflux at w-points (above c-points) by reversing the k-axis,
        # integrating (now from the bottom up) cumulatively,
        # assuming wflx=0 at n+1
        wflux = -flx_div.reindex(l1_c=flx_div.l1_c[::-1]).cumsum(dim='l1_c')
        # and reverse the k-axis again
        wflx = wflux.reindex(l1_c=wflux.l1_c[::-1])
    return wflx

def calc_std_xr(fld2,fld):
    # xarray version
    var = fld2 - fld**2
    return np.sqrt(var.where(var>0.,0.))

def calc_std_np(fld2,fld):
    # numpy version
    var = fld2-fld**2
    return np.where(var<0.,0,np.sqrt(var))

In [None]:
def get_next_line_value(f):
    return float(next(f).strip().split()[-1].replace('E','e'))

def get_parms (fname):
    with open(fname) as f:
        for line in f:
            if '/* Monitor output interval ( s ). */' in line:
                mondt = get_next_line_value(f)
                # ll = next(f).strip().split()
                # mondt = float(ll[-1].replace('E','e'))
            elif 'mass2rUnit' in line:
                m2rUnit = get_next_line_value(f)
            elif '/* Reference density (Boussinesq)  ( kg/m^3 ) */' in line:
                rhoConst = get_next_line_value(f)
            elif '/* density of sea ice (kg/m3) */' in line:
                rhoIce = get_next_line_value(f)
            elif 'startDate_1' in line:
                startDate = line.split('=')[-1]
            elif '/* Gravitational acceleration ( m/s^2 ) */' in line:
                gravity = get_next_line_value(f)
            elif 'gravity orientation relative to vertical coordinate' in line:
                gravitySign = get_next_line_value(f)

    return mondt, m2rUnit, rhoConst, rhoIce, gravity*gravitySign, startDate

def get_output (fnames, mystring):
    """parse fname and get some numbers out"""
    timev = []
    myvar = []
    if mystring[:3]=='exf':
        timename = 'exf_time_sec'
    elif mystring[:3]=='sea':
        timename = 'seaice_time_sec'
    else:
        timename = 'time_secondsf'
    for fname in fnames:
        try:
            f = open(fname)
        except:
            print(fname + " does not exist, continuing")
        else:
            for line in f:
                if timename in line:
                    ll = line.split()
                    timev.append(float(ll[-1].replace('D','e')))
                elif mystring in line:
                    ll = line.split()
                    myvar.append(float(ll[-1].replace('D','e')))

            f.close()

    # reverse order
    timevs=np.asarray(timev[::-1])
    myvars=np.asarray(myvar[::-1])
    # timevs=np.asarray(timev)
    # myvars=np.asarray(myvar)
    # This sorts again in ascending order and returns the index of
    # the first occurrence of duplicates. Because we have reverted the order
    # before, in this way we use the values at the beginning of a pickup run
    # rather than the overlapping values of the previous (potentially crashed)
    # run
    timevs, isort = np.unique(timevs,return_index=True)
    myvars=myvars[isort]

    return timevs, myvars
# done

def correct_jumps(x):
    return x
#     ii = np.where(np.abs(np.diff(x)) > 0.1*x.std())[0]
#     y = np.copy(x)
#     for i in ii:
# #        print(i,y[i],y[i+1],2*y[i+1]-y[i])
#         y[i+1:] = y[i+1:]-y[i+1]+y[i]

#     return y

def readstats(fname):
    '''
    locals,totals,itrs = readstats(fname)

    Read a diagstats text file into record arrays (or dictionaries).

    Parameters
    ----------
    fname : string
        name of diagstats file to read

    Returns
    -------
    locals : record array or dict of arrays
        local statistics, shape (len(itrs), Nr, 5)
    totals : record array or dict of arrays
        column integrals, shape (len(itrs), 5)
    itrs : list of int
        iteration numbers found in the file

    Notes
    -----
    - The 5 columns of the resulting arrays are average, std.dev, min, max and total volume.
    - There is a record (or dictionary key) for each field found in the file.

    '''
    nstats = 5
    flds = []
    with open(fname) as f:
        for line in f:
            if line.startswith('# end of header'):
                break

            m = re.match(r'^# ([^:]*) *: *(.*)$', line.rstrip())
            if m:
                var,val = m.groups()
                if var.startswith('Fields'):
                    flds = val.split()
                if var.startswith('Regions'):
                    regs = val.split()

        if len(regs) > 1:
            res = dict((fld,dict((reg,[]) for reg in regs)) for fld in flds)
        else:
            res = dict((fld,[]) for fld in flds)
        itrs = dict((fld,[]) for fld in flds)

        line = f.readline()
        while not line.startswith('# records'):

            m = re.match(r' field : *([^ ]*) *; Iter = *([0-9]*) *; region # *([0-9]*) ; nb\.Lev = *([0-9]*)', line)
            if m:
                fld,itr,reg,nlev = m.groups()
                itrs[fld].append(int(itr))
                nlevs = int(nlev)
                if nlevs > 1: nlevs=nlevs+1
                tmp = np.zeros((nlevs,nstats))
                line = f.readline()
                while not (line.strip() == '' or line.startswith(' field')):
                    if not line.startswith(' k'):
                        cols = line.strip().split()
                        k = int(cols[0])
                        tmp[k] = [float(s) for s in cols[1:]]

                    line = f.readline()

                if len(regs) > 1: res[fld][reg].append(tmp)
                else: res[fld].append(tmp)

            # else:
            #     raise ValueError('readstats: parse error: ' + line)

            else:
                line = f.readline()

    if len(regs)>1:
        totals = dict((fld,np.squeeze(np.array(res[fld]['0']))) for fld in flds)
        locals = dict((fld,[]) for fld in flds)
        for fld in flds:
            for reg in regs:
                if reg!='0':
                    locals[fld].append(np.squeeze(np.array(res[fld][reg])))

        locals = dict((fld,np.array(locals[fld])) for fld in flds)
        return locals, totals, itrs
    else:
        try:
            all = np.rec.fromarrays(
                [np.array(res[fld]) for fld in flds], names=flds)
            return all[:,1:,...],all[:,0,...],itrs
        except:
            totals = dict((fld,np.array(res[fld])[:,0,...]) for fld in flds)
            locals = dict((fld,np.array(res[fld])[:,1:,]) for fld in flds)
            return locals,totals,itrs


In [None]:
import warnings
def compute_moc_layers(dl,msk):
    wflux = calc_wflux_dia(dl)
    return compute_moc(wflux*msk)

def compute_moc(wflux):
    # zonal integral
    wflx = zonal_sum(wflux)
    # order of integration: from north to south because of Atlantic MOC, requires sign change
    mocstrf = -np.flip(np.flip(wflx,axis=-1).cumsum(axis=-1),axis=-1)
    mocstrf[wflx==0]=0.
    return mocstrf

def compute_layers(dl,msk):
    # determine vertical coordinate
    try: z = dl.Zl.values
    except:
        try: z = dl.Z.values
        except: 
            print('no z-coordintate')

    pCoords = False
    if z[0]>0: pCoords=True
    
    grd = xgcm.Grid(dl, periodic=False, face_connections=face_connections)
    lahc = grd.interp_2d_vector({'X': dl.LaHw1RHO, 
                                 'Y': dl.LaHs1RHO},
                                 to = 'center', boundary='fill')
    lath = 0.5*(lahc['X']+lahc['Y'])*msk
    # not sure if this is better or worse, but does not make much of a difference
    #lath = np.maximum(lahc['X'],lahc['Y'])*msk
    zzz = flat2d(lath)
    # first do the zonal average
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        zz = zzz.mean(axis=-1,where=zzz!=0)
        #zz = zzz.where(zzz!=0).mean(axis=-1)
    zz[np.isnan(zz)]=0 # need to get rid of NaNs etc.
    # then do the vertical integral
    z = zz.cumsum(axis=0)
    #z[zz==0]=0 # this does not work
    # The top level coordinate should always be zero (surface),
    # but with the cumulative sum, the first value of z is the 
    # depth of the first interface; here push down all values to k+1
    z = np.roll(z,1,axis=0)
    # and make sure that the surface layer is zero
    z[0,:]=0
    # if pCoords: 
    #     z = np.roll(z,1,axis=0)
    #     z[0,:]=0
    # else:
      
    return z


In [1]:
import pyresample

class LLCMapper:

    def __init__(self, ds, dx=0.25, dy=0.25):
        # Extract LLC 2D coordinates
        lons_1d = ds.XC.values.ravel()
        lats_1d = ds.YC.values.ravel()

        # Define original grid
        self.orig_grid = pyresample.geometry.SwathDefinition(lons=lons_1d, lats=lats_1d)

        # Longitudes latitudes to which we will we interpolate
        lon_tmp = np.arange(-180, 180, dx) + dx/2
        lat_tmp = np.arange(-90, 90, dy) + dy/2

        # Define the lat lon points of the two parts.
        self.new_grid_lon, self.new_grid_lat = np.meshgrid(lon_tmp, lat_tmp)
        self.new_grid  = pyresample.geometry.GridDefinition(lons=self.new_grid_lon,
                                                            lats=self.new_grid_lat)

    def __call__(self, da, ax=None, lon_0=-60, tstr=None, **plt_kwargs):

        assert set(da.dims) == set(['face', 'j', 'i']), "da must have dimensions ['face', 'j', 'i']"

        field = pyresample.kd_tree.resample_nearest(self.orig_grid, da.values,
                                                    self.new_grid,
                                                    radius_of_influence=100000,
                                                    fill_value=None)

        x,y = self.new_grid_lon, self.new_grid_lat

        p = ax.pcolormesh(x, y, field, transform=cart.crs.PlateCarree(), **plt_kwargs)
        #p = ax.contourf(x, y, field, transform=cart.crs.PlateCarree(), **plt_kwargs)

        ax.coastlines()
        # ax.add_feature(cart.feature.LAND, zorder=100, edgecolor='k')
        ax.add_feature(cart.feature.LAND, facecolor=landcolor, edgecolor='k',  zorder=3, linewidth=0.3)
        gl = ax.gridlines(zorder=4)

        if tstr is not None:
            ax.set_title('%s'%(tstr))
            
        label = None
        # if da.name is not None:
        #     label = da.name
        # if 'units' in da.attrs:
        #     label += ' (%s)' % da.attrs['units']
        # cbarextend='both'
        # try:
        #     mynorm = plt_kwargs.pop('norm')
        #     if mynorm.vmin == 0.: cbarextend='max'
        # except:
        #     cbarextend='both'

        # shrinkfac=1.
        # cb = plt.colorbar(p, ax=ax, shrink=shrinkfac, label=label, extend=cbarextend, orientation='horizontal')

        return ax, p, gl
    
class LLCinterp:

    def __init__(self, ds, dx=0.25, dy=0.25):
        # Extract LLC 2D coordinates
        lons_1d = ds.XC.values.ravel()
        lats_1d = ds.YC.values.ravel()

        # Define original grid
        self.orig_grid = pyresample.geometry.SwathDefinition(
            lons=lons_1d, lats=lats_1d)

        # Longitudes latitudes to which we will we interpolate
        lon_tmp = np.arange(-180, 180, dx) + dx/2
        lat_tmp = np.arange(-90, 90, dy) + dy/2

        # Define the lat lon points of the two parts.
        self.new_grid_lon, self.new_grid_lat = np.meshgrid(lon_tmp, lat_tmp)
        self.new_grid  = pyresample.geometry.GridDefinition(
            lons=self.new_grid_lon, lats=self.new_grid_lat)

    def __call__(self, da):

        assert set(da.dims) == set(
            ['face', 'j', 'i']), "da must have dimensions ['face', 'j', 'i']"

        field = pyresample.kd_tree.resample_nearest(self.orig_grid, da.values,
                                                    self.new_grid,
                                                    radius_of_influence=100e3,
                                                    fill_value=None)
        return field

In [None]:
import pyresample

class LLCMapper_noland:

    def __init__(self, ds, dx=0.25, dy=0.25):
        # Extract LLC 2D coordinates
        lons_1d = ds.XC.values.ravel()
        lats_1d = ds.YC.values.ravel()

        # Define original grid
        self.orig_grid = pyresample.geometry.SwathDefinition(lons=lons_1d, lats=lats_1d)

        # Longitudes latitudes to which we will we interpolate
        lon_tmp = np.arange(-180, 180, dx) + dx/2
        lat_tmp = np.arange(-90, 90, dy) + dy/2

        # Define the lat lon points of the two parts.
        self.new_grid_lon, self.new_grid_lat = np.meshgrid(lon_tmp, lat_tmp)
        self.new_grid  = pyresample.geometry.GridDefinition(lons=self.new_grid_lon,
                                                            lats=self.new_grid_lat)

    def __call__(self, da, ax=None, lon_0=-60, tstr=None, **plt_kwargs):

        assert set(da.dims) == set(['face', 'j', 'i']), "da must have dimensions ['face', 'j', 'i']"

        field = pyresample.kd_tree.resample_nearest(self.orig_grid, da.values,
                                                    self.new_grid,
                                                    radius_of_influence=100000,
                                                    fill_value=None)

        x,y = self.new_grid_lon, self.new_grid_lat

        p = ax.pcolormesh(x, y, field, transform=cart.crs.PlateCarree(), **plt_kwargs)
        #p = ax.contourf(x, y, field, transform=cart.crs.PlateCarree(), **plt_kwargs)

        #ax.coastlines()
        # ax.add_feature(cart.feature.LAND, zorder=100, edgecolor='k')
        #ax.add_feature(cart.feature.LAND, facecolor=landcolor, edgecolor='k',  zorder=3, linewidth=0.3)
        gl = ax.gridlines(zorder=4)

        if tstr is not None:
            ax.set_title('%s'%(tstr))
            
        label = None
        # if da.name is not None:
        #     label = da.name
        # if 'units' in da.attrs:
        #     label += ' (%s)' % da.attrs['units']
        # cbarextend='both'
        # try:
        #     mynorm = plt_kwargs.pop('norm')
        #     if mynorm.vmin == 0.: cbarextend='max'
        # except:
        #     cbarextend='both'

        # shrinkfac=1.
        # cb = plt.colorbar(p, ax=ax, shrink=shrinkfac, label=label, extend=cbarextend, orientation='horizontal')

        return ax, p, gl
    
class LLCinterp:

    def __init__(self, ds, dx=0.25, dy=0.25):
        # Extract LLC 2D coordinates
        lons_1d = ds.XC.values.ravel()
        lats_1d = ds.YC.values.ravel()

        # Define original grid
        self.orig_grid = pyresample.geometry.SwathDefinition(
            lons=lons_1d, lats=lats_1d)

        # Longitudes latitudes to which we will we interpolate
        lon_tmp = np.arange(-180, 180, dx) + dx/2
        lat_tmp = np.arange(-90, 90, dy) + dy/2

        # Define the lat lon points of the two parts.
        self.new_grid_lon, self.new_grid_lat = np.meshgrid(lon_tmp, lat_tmp)
        self.new_grid  = pyresample.geometry.GridDefinition(
            lons=self.new_grid_lon, lats=self.new_grid_lat)

    def __call__(self, da):

        assert set(da.dims) == set(
            ['face', 'j', 'i']), "da must have dimensions ['face', 'j', 'i']"

        field = pyresample.kd_tree.resample_nearest(self.orig_grid, da.values,
                                                    self.new_grid,
                                                    radius_of_influence=100e3,
                                                    fill_value=None)
        return field

In [None]:
import pyresample

class LLCMapper_wog:

    def __init__(self, ds, dx=0.25, dy=0.25):
        # Extract LLC 2D coordinates
        lons_1d = ds.XC.values.ravel()
        lats_1d = ds.YC.values.ravel()

        # Define original grid
        self.orig_grid = pyresample.geometry.SwathDefinition(lons=lons_1d, lats=lats_1d)

        # Longitudes latitudes to which we will we interpolate
        lon_tmp = np.arange(-180, 180, dx) + dx/2
        lat_tmp = np.arange(-90, 90, dy) + dy/2

        # Define the lat lon points of the two parts.
        self.new_grid_lon, self.new_grid_lat = np.meshgrid(lon_tmp, lat_tmp)
        self.new_grid  = pyresample.geometry.GridDefinition(lons=self.new_grid_lon,
                                                            lats=self.new_grid_lat)

    def __call__(self, da, ax=None, lon_0=-60, tstr=None, **plt_kwargs):

        assert set(da.dims) == set(['face', 'j', 'i']), "da must have dimensions ['face', 'j', 'i']"

        field = pyresample.kd_tree.resample_nearest(self.orig_grid, da.values,
                                                    self.new_grid,
                                                    radius_of_influence=100000,
                                                    fill_value=None)

        x,y = self.new_grid_lon, self.new_grid_lat

        p = ax.pcolormesh(x, y, field, transform=cart.crs.PlateCarree(), **plt_kwargs)
        #p = ax.contourf(x, y, field, transform=cart.crs.PlateCarree(), **plt_kwargs)

        ax.coastlines()
        # ax.add_feature(cart.feature.LAND, zorder=100, edgecolor='k')
        ax.add_feature(cart.feature.LAND, facecolor=landcolor, edgecolor='k',  zorder=3, linewidth=0.3)
#        gl = ax.gridlines(zorder=4)

        if tstr is not None:
            ax.set_title('%s'%(tstr))
            
        label = None
        # if da.name is not None:
        #     label = da.name
        # if 'units' in da.attrs:
        #     label += ' (%s)' % da.attrs['units']
        # cbarextend='both'
        # try:
        #     mynorm = plt_kwargs.pop('norm')
        #     if mynorm.vmin == 0.: cbarextend='max'
        # except:
        #     cbarextend='both'

        # shrinkfac=1.
        # cb = plt.colorbar(p, ax=ax, shrink=shrinkfac, label=label, extend=cbarextend, orientation='horizontal')

        return ax, p, gl
    
class LLCinterp:

    def __init__(self, ds, dx=0.25, dy=0.25):
        # Extract LLC 2D coordinates
        lons_1d = ds.XC.values.ravel()
        lats_1d = ds.YC.values.ravel()

        # Define original grid
        self.orig_grid = pyresample.geometry.SwathDefinition(
            lons=lons_1d, lats=lats_1d)

        # Longitudes latitudes to which we will we interpolate
        lon_tmp = np.arange(-180, 180, dx) + dx/2
        lat_tmp = np.arange(-90, 90, dy) + dy/2

        # Define the lat lon points of the two parts.
        self.new_grid_lon, self.new_grid_lat = np.meshgrid(lon_tmp, lat_tmp)
        self.new_grid  = pyresample.geometry.GridDefinition(
            lons=self.new_grid_lon, lats=self.new_grid_lat)

    def __call__(self, da):

        assert set(da.dims) == set(
            ['face', 'j', 'i']), "da must have dimensions ['face', 'j', 'i']"

        field = pyresample.kd_tree.resample_nearest(self.orig_grid, da.values,
                                                    self.new_grid,
                                                    radius_of_influence=100e3,
                                                    fill_value=None)
        return field

In [None]:
import pyresample

class LLCMapper_2:

    def __init__(self, ds, dx=0.25, dy=0.25):
        # Extract LLC 2D coordinates
        lons_1d = ds.XC.values.ravel()
        lats_1d = ds.YC.values.ravel()

        # Define original grid
        self.orig_grid = pyresample.geometry.SwathDefinition(lons=lons_1d, lats=lats_1d)

        # Longitudes latitudes to which we will we interpolate
        lon_tmp = np.arange(-180, 180, dx) + dx/2
        lat_tmp = np.arange(-90, 90, dy) + dy/2

        # Define the lat lon points of the two parts.
        self.new_grid_lon, self.new_grid_lat = np.meshgrid(lon_tmp, lat_tmp)
        self.new_grid  = pyresample.geometry.GridDefinition(lons=self.new_grid_lon,
                                                            lats=self.new_grid_lat)

    def __call__(self, da, ax=None, lon_0=-60, tstr=None, **plt_kwargs):

        assert set(da.dims) == set(['face', 'j', 'i']), "da must have dimensions ['face', 'j', 'i']"

        field = pyresample.kd_tree.resample_nearest(self.orig_grid, da.values,
                                                    self.new_grid,
                                                    radius_of_influence=100000,
                                                    fill_value=None)

        x,y = self.new_grid_lon, self.new_grid_lat

        #p = ax.pcolormesh(x, y, field, transform=cart.crs.PlateCarree(), **plt_kwargs)
        p = ax.contourf(x, y, field, transform=cart.crs.PlateCarree(), **plt_kwargs)

        ax.coastlines()
        # ax.add_feature(cart.feature.LAND, zorder=100, edgecolor='k')
        ax.add_feature(cart.feature.LAND, facecolor=landcolor, edgecolor='k',  zorder=3, linewidth=.3)
        gl = ax.gridlines(zorder=4)

        if tstr is not None:
            ax.set_title('%s'%(tstr))
            
        label = None
        # if da.name is not None:
        #     label = da.name
        # if 'units' in da.attrs:
        #     label += ' (%s)' % da.attrs['units']
        # cbarextend='both'
        # try:
        #     mynorm = plt_kwargs.pop('norm')
        #     if mynorm.vmin == 0.: cbarextend='max'
        # except:
        #     cbarextend='both'

        # shrinkfac=1.
        # cb = plt.colorbar(p, ax=ax, shrink=shrinkfac, label=label, extend=cbarextend, orientation='horizontal')

        return ax, p, gl
    
class LLCinterp:

    def __init__(self, ds, dx=0.25, dy=0.25):
        # Extract LLC 2D coordinates
        lons_1d = ds.XC.values.ravel()
        lats_1d = ds.YC.values.ravel()

        # Define original grid
        self.orig_grid = pyresample.geometry.SwathDefinition(
            lons=lons_1d, lats=lats_1d)

        # Longitudes latitudes to which we will we interpolate
        lon_tmp = np.arange(-180, 180, dx) + dx/2
        lat_tmp = np.arange(-90, 90, dy) + dy/2

        # Define the lat lon points of the two parts.
        self.new_grid_lon, self.new_grid_lat = np.meshgrid(lon_tmp, lat_tmp)
        self.new_grid  = pyresample.geometry.GridDefinition(
            lons=self.new_grid_lon, lats=self.new_grid_lat)

    def __call__(self, da):

        assert set(da.dims) == set(
            ['face', 'j', 'i']), "da must have dimensions ['face', 'j', 'i']"

        field = pyresample.kd_tree.resample_nearest(self.orig_grid, da.values,
                                                    self.new_grid,
                                                    radius_of_influence=100e3,
                                                    fill_value=None)
        return field

In [None]:
import pyresample

class LLCMapper_3:

    def __init__(self, ds, dx=0.25, dy=0.25):
        # Extract LLC 2D coordinates
        lons_1d = ds.XC.values.ravel()
        lats_1d = ds.YC.values.ravel()

        # Define original grid
        self.orig_grid = pyresample.geometry.SwathDefinition(lons=lons_1d, lats=lats_1d)

        # Longitudes latitudes to which we will we interpolate
        lon_tmp = np.arange(-180, 180, dx) + dx/2
        lat_tmp = np.arange(-90, 90, dy) + dy/2

        # Define the lat lon points of the two parts.
        self.new_grid_lon, self.new_grid_lat = np.meshgrid(lon_tmp, lat_tmp)
        self.new_grid  = pyresample.geometry.GridDefinition(lons=self.new_grid_lon,
                                                            lats=self.new_grid_lat)

    def __call__(self, da, ax=None, lon_0=-60, tstr=None, **plt_kwargs):

        assert set(da.dims) == set(['face', 'j', 'i']), "da must have dimensions ['face', 'j', 'i']"

        field = pyresample.kd_tree.resample_nearest(self.orig_grid, da.values,
                                                    self.new_grid,
                                                    radius_of_influence=100000,
                                                    fill_value=None)

        x,y = self.new_grid_lon, self.new_grid_lat

        #p = ax.pcolormesh(x, y, field, transform=cart.crs.PlateCarree(), **plt_kwargs)
        p = ax.contour(x, y, field, transform=cart.crs.PlateCarree(), **plt_kwargs)

        ax.coastlines()
        # ax.add_feature(cart.feature.LAND, zorder=100, edgecolor='k')
        ax.add_feature(cart.feature.LAND, facecolor=landcolor, edgecolor='k',  zorder=3, linewidth=.3)
        gl = ax.gridlines(zorder=4)

        if tstr is not None:
            ax.set_title('%s'%(tstr))
            
        label = None
        # if da.name is not None:
        #     label = da.name
        # if 'units' in da.attrs:
        #     label += ' (%s)' % da.attrs['units']
        # cbarextend='both'
        # try:
        #     mynorm = plt_kwargs.pop('norm')
        #     if mynorm.vmin == 0.: cbarextend='max'
        # except:
        #     cbarextend='both'

        # shrinkfac=1.
        # cb = plt.colorbar(p, ax=ax, shrink=shrinkfac, label=label, extend=cbarextend, orientation='horizontal')

        return ax, p, gl
    
class LLCinterp:

    def __init__(self, ds, dx=0.25, dy=0.25):
        # Extract LLC 2D coordinates
        lons_1d = ds.XC.values.ravel()
        lats_1d = ds.YC.values.ravel()

        # Define original grid
        self.orig_grid = pyresample.geometry.SwathDefinition(
            lons=lons_1d, lats=lats_1d)

        # Longitudes latitudes to which we will we interpolate
        lon_tmp = np.arange(-180, 180, dx) + dx/2
        lat_tmp = np.arange(-90, 90, dy) + dy/2

        # Define the lat lon points of the two parts.
        self.new_grid_lon, self.new_grid_lat = np.meshgrid(lon_tmp, lat_tmp)
        self.new_grid  = pyresample.geometry.GridDefinition(
            lons=self.new_grid_lon, lats=self.new_grid_lat)

    def __call__(self, da):

        assert set(da.dims) == set(
            ['face', 'j', 'i']), "da must have dimensions ['face', 'j', 'i']"

        field = pyresample.kd_tree.resample_nearest(self.orig_grid, da.values,
                                                    self.new_grid,
                                                    radius_of_influence=100e3,
                                                    fill_value=None)
        return field