In [2]:
import xarray as xr

In [26]:
class xarray_IO:
    """
    Basic xarray interface for netcdf I/O
    """
    def __init__(self, dfile=None, engine='netcdf4', FV=1.e20):
        if dfile is not None:
            self.ds = xr.open_dataset(dfile)
        else:
            self.ds = xr.Dataset()

        self.engine = engine
        self._FV = FV
            
            
    def copy_variable(self, dsin, field):
        self.ds[field] = dsin[field]

        
    def create_dimension(self, var, field, **attributes):
        self.ds[field] = xr.DataArray(var, coords=[(field, var)])

        try:
            self.ds[field].encoding['_FillValue'] = self._FV
        except Exception:
            pass
        
        try:
            self.add_attributes(var, field, **attributes)
        except Exception:
            pass
        
        
    def create_variable(self, var, field, dims, **attributes):
        
        self.ds[field] = xr.DataArray(var, dims=dims) 
        
        try:
            self.ds[field].encoding['_FillValue'] = self._FV
        except Exception:
            pass

        try:
            self.add_attributes(var, field, **attributes)
        except Exception:
            pass

    
    
    def add_attributes(self, var, field, **attributes):    
        try:
            for attr, value in attributes.items():
                self.ds[field].attrs[attr] = value
        except Exception: 
            print('-- ERROR: attribute {} is not defined'.format(field))
        
        
    def write_netcdf(self, ofile):
        self.ds.to_netcdf(path=ofile, mode='w', engine=self.engine)
        

In [30]:
import numpy as np

# dfile = '../data/ERAInt.surf_geopot.0.75x0.75.nc'
dfile = '../data/ERAInt.t2m.ltm.0.75x0.75.nc'

dsin = xarray_IO(dfile)
ds = xarray_IO()


ds.copy_variable(dsin.ds, 't2m')

ds.create_variable(dsin.ds.t2m-273.15, 't2m_degC',
                  ('time', 'latitude', 'longitude'),
                  units='degC')


"""
ny = 50
nx = 100

lat = np.linspace(-90.,90.,ny)
lon = np.linspace(0.,360.,nx)


y = np.linspace(-2.*np.pi,2*np.pi,ny)
x = np.linspace(-2.*np.pi,2*np.pi,nx)
X,Y = np.meshgrid(x,y)

Z = np.sin(np.sqrt(X**2 + Y**2))


ds.create_dimension(lat, 'lat', 
                    units='degrees_north',
                    long_name='latitude')

ds.create_dimension(lon, 'lon', 
                    units='degrees_east',
                    long_name='longitude')


ds.create_variable(Z, 'Z', dims=('lat','lon'),
                   units='m',
                   long_name='wave perturbation')

#print(ds.ds['Z'])

"""

ofile = './test.nc'
ds.write_netcdf(ofile)