# Bias correct atmospheric conditions

- spatially varying field that is constant in time
- Calculated from the difference between ERA5 1979-2015 time-mean and PACE ensemble-mean and time-mean

In [1]:
import xarray as xr
import numpy as np
import cmocean
import sys
sys.path.append('/home/users/birgal/')
import pickle
from nemo_python_git.utils import fix_lon_range
from nemo_python_git.interpolation import regrid_array_cf, regrid_operator_cf
from nemo_python_git.forcing import find_cesm2_file
from nemo_python_git.ics_obcs import fill_ocean

In [2]:
base_folder = '/gws/nopw/j04/anthrofail/birgal/NEMO_AIS/'

Start with thermodynamic variables (TREFHT, QREFHT, FLDS, FSDS, PRECT, PRECS, PSL)

In [None]:
def atm_bias_correct(source, variable, expt='LE2', year_start=1979, year_end=2015, 
                     ensemble_mean_file=None, era5_mean_file=None):

    # process_forcing_for_correction(source, variable)
    if source=='CESM2':
        # Read in ensemble time mean (or calculate it)
        if ensemble_mean_file:
            CESM2_time_mean = xr.open_dataset(ensemble_mean_file)
        else:
            CESM2_time_mean = cesm2_ensemble_time_mean_forcing(expt, variable, year_start=year_start, year_end=year_end)

        # Read in time mean of ERA5 files (or calculate it)
        if era5_mean_file:
            ERA5_time_mean = xr.open_dataset(era5_mean_file)
        else:
            ERA5_time_mean = era5_time_mean_forcing(expt, variable, year_start=year_start, year_end=year_end)
        
        # Interpolate time means to eANT025 grids since they need to be on the same grid to do the correction
        
        # thermodynamic correction
        if variable in ['TREFHT','QREFHT','FLDS','FSDS']:
            print('Correcting thermodynamics')
            thermo_correction(variable, CESM2_time_mean, ERA5_time_mean, out_file)
            
        # wind correction
        elif variable in ['UBOT','VBOT']:
            print('Correcting katabatic winds')
            katabatic_correction(variable, CESM2_time_mean, ERA5_time_mean, out_file)
        else:
            print(f'Variable {variable} does not need bias correction. Check that this is true.')
    else:
        raise Exception("Bias correction currently only set up to correct CESM2, sorry you'll need to write some more code")

    return

# Function calculates the time-mean over specified year range for mean of all CESM2 ensemble members in the specified experiment
# Input:
# - expt : string of CESM2 experiment name (e.g. 'LE2')
# - variable : string of forcing variable name
# - (optional) year_start : start year for time averaging
# - (optional) end_year   : end year for time averaging
# - (optinoal) out_file   : path to file to write time mean to NetCDF in case you want to store it
def cesm2_ensemble_time_mean_forcing(expt, variable, year_start=1979, year_end=2015, out_file=None):

    if expt =='LE2':
        ensemble_members = ['1001.001','1011.001','1021.002','1031.002','1041.003','1051.003','1061.004', \
                            '1071.004','1081.005','1091.005','1101.006'] # consider adding this to a central python file to read in

    # calculate ensemble mean for each year
    year_mean = xr.Dataset()
    for year in range(year_start, year_end+1):
        files_to_open = []
        for ens in ensemble_members:
            file_path     = find_cesm2_file(expt, variable, 'atm', '1d', ens, year)
            files_to_open = files_to_open.append(file_path)
        # calculate ensemble mean    
        files    = xr.open_mfdataset(files_to_open, concat_dim='ens', combine='nested')
        ens_mean = files[variable].mean(dim='time') # dimensions should be x,y
        # save ensemble mean to xarray dataset
        if year == year_start:
            year_mean = ens_mean
        else:
            year_mean = xr.concat([year_mean, ens_mean], dim='year')
            
    # and then calculate time-mean of all ensemble means:
    time_mean = year_mean.mean(dim='year')
    if out_file:
        time_mean.to_netcdf(out_file)
    
    return time_mean

    
def time_mean_forcing(source, variable):

    # daily forcing files
    ERA5_ds = xr.open_mfdataset(f'{base_folder}ERA5-forcing/files/era5_{variable}*') 
    ERA5_mean = ERA5_ds[variable].mean(dim='time')

    return


def thermo_correction():

    # name remapping for variables
    ERA5_to_CESM2_varnames = {'TREFHT':'t2m','FSDS':'','FLDS':'','QREFHT':''}
                                         
    return

In [None]:
# Build a correction file for a thermodynamic variable, which will add a spatially-varying offset to UKESM/PACE data so that
# it matches ERA5 data in the time-mean.
def thermo_correction (grid_dir, var_name, cmip_file, era5_file, out_file, prec=64):

    grid = Grid(grid_dir)
    data = []
    for fname in [cmip_file, era5_file]:
        data.append(read_netcdf(fname, var_name))
    data_diff = data[1] - data[0]
    if len(data_diff.shape) == 2:
        latlon_plot(data_diff, grid, ctype='plusminus', figsize=(10,6))
    else:
        titles = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'June', 'July', 'Aug', 'Sept', 'Oct', 'Nov', 'Dec']
        fig, gs, cax = set_panels('3x4+1C1')
        cmap, vmin, vmax = set_colours(data_diff, ctype='plusminus')
        for n in range(12+1):
            if n == 12:
                ax = plt.subplot(gs[0,3])
                img = ax.pcolormesh(np.mean(data_diff,axis=0), cmap=cmap, vmin=vmin, vmax=vmax)
                ax.set_title('Annual')
            else:
                ax = plt.subplot(gs[n//4+1, n%4])
                img = ax.pcolormesh(data_diff[n,:], cmap=cmap, vmin=vmin, vmax=vmax)
                ax.set_title(titles[n])
            ax.set_xticks([])
            ax.set_yticks([])
            ax.axis('tight')
        plt.colorbar(img, cax=cax, orientation='horizontal')
        plt.text(0.05, 0.95, var_name+' correction', transform=fig.transFigure, fontsize=20, ha='left', va='top')
        finished_plot(fig, fig_name=var_name+'_correction.png')
    write_binary(data_diff, out_file, prec=prec)

In [None]:
# Read forcing (var='wind' or 'thermo') from a given atmospheric dataset (source='ERA5', 'UKESM', or 'PACE'). 
# Time-average, ensemble-average (if PACE) and interpolate to the MITgcm grid. Save the otuput to a NetCDF file. 
# This will be used to create spatially-varying, time-constant bias correction files in the functions katabatic_correction
# and thermo_correction. Can also set monthly_clim=True to get monthly climatology instead of constant in time.
def process_forcing_for_correction (source, mit_grid_dir, out_file, in_dir=None, start_year=1979, end_year=None, monthly_clim=False):

    # Set parameters based on source dataset
    if source == 'ERA5':
        if in_dir is None:
            # Path on BAS servers
            in_dir = '/data/oceans_input/processed_input_data/ERA5/'
        file_head = 'ERA5_'
        gtype = ['t', 't', 't', 't', 't']
        per_day = 4
    elif source == 'PACE':
        if in_dir is None:
            # Path on BAS servers
            in_dir = '/data/oceans_input/processed_input_data/CESM/PACE_new/'
        file_head = 'PACE_ens'
        num_ens = 20
        missing_ens = 13
        var_names_in = ['TREFHT', 'QBOT', 'PRECT', 'FSDS', 'FLDS']
        monthly = [False, False, False, True, True]
        gtype = ['t', 't', 't', 't', 't']
    else:
        print(('Error (process_forcing_for_correction): invalid source ' + source))
        sys.exit()
    
    # Set parameters based on variable type
    var_names = ['atemp', 'aqh', 'precip', 'swdown', 'lwdown']
    units = ['degC', '1', 'm/s', 'W/m^2', 'W/m^2']
    # Check end_year is defined
    if end_year is None:
        print('Error (process_forcing_for_correction): must set end_year. Typically use 2014 for WSFRIS and 2013 for PACE.')
        sys.exit()

    mit_grid_dir = real_dir(mit_grid_dir)
    in_dir = real_dir(in_dir)

    print('Building grids')
    if source == 'ERA5':
        forcing_grid = ERA5Grid()
    elif source == 'PACE':
        forcing_grid = CAMGrid()
    mit_grid = Grid(mit_grid_dir)

    if monthly_clim:
        dim_code = 'xyt'
    else:
        dim_code = 'xy'
    ncfile = NCfile(out_file, mit_grid, dim_code)

    # Loop over variables
    for n in range(len(var_names)):
        print(('Processing variable ' + var_names[n]))
        # Read the data, time-integrating as we go
        data = None
        num_time = 0

        if source == 'ERA5':
            # Loop over years
            for year in range(start_year, end_year+1):
                file_path = in_dir + file_head + var_names[n] + '_' + str(year)
                data_tmp = read_binary(file_path, [forcing_grid.nx, forcing_grid.ny], 'xyt')
                if monthly_clim:
                    # Average over each month
                    data_sum = np.zeros([12, data_tmp.shape[1], data_tmp.shape[2]])
                    t = 0
                    for m in range(12):
                        nt = days_per_month(m+1, year)*per_day
                        data_sum[m,:] = np.mean(data_tmp[t:t+nt,:], axis=0)
                        t += nt
                    num_time += 1  # in years
                else:
                    # Integrate over entire year
                    data_sum = np.sum(data_tmp, axis=0)
                    num_time += data_tmp.shape[0]  # in timesteps
                if data is None:
                    data = data_sum
                else:
                    data += data_sum
                    
        elif source == 'PACE':
            # Loop over years
            for year in range(start_year, end_year+1):
                # Loop over ensemble members
                data_tmp = None
                num_ens_tmp = 0
                for ens in range(1, num_ens+1):
                    file_path = in_dir + file_head + str(ens).zfill(2) + '_' + var_names_in[n] + '_' + str(year)
                    data_tmp_ens = read_binary(file_path, [forcing_grid.nx, forcing_grid.ny], 'xyt')
                    if data_tmp is None:
                        data_tmp = data_tmp_ens
                    else:
                        data_tmp += data_tmp_ens
                    num_ens_tmp += 1
                # Ensemble mean for this year
                data_tmp /= num_ens_tmp
                # Now accumulate time integral                    
                if monthly_clim:
                    data_sum = np.zeros([12, data_tmp.shape[1], data_tmp.shape[2]])
                    t = 0
                    for m in range(12):
                        if monthly[n]:
                            # Already have monthly averages
                            data_sum[m,:] = data_tmp[m,:]
                        else:
                            ndays = days_per_month(m+1, year, allow_leap=False)
                            data_sum[m,:] = np.mean(data_tmp[t:t+ndays,:], axis=0)
                            t += ndays
                    num_time += 1
                else:
                    if monthly[n]:
                        # Have to weight monthly averages
                        for m in range(12):
                            ndays = days_per_month(m+1, year, allow_leap=False)
                            data_tmp[m,:] *= ndays
                            num_time += ndays
                    else:
                        num_time += data_tmp.shape[0]
                    data_sum = np.sum(data_tmp, axis=0)
                if data is None:
                    data = data_sum
                else:
                    data += data_sum

        # Now convert from time-integral to time-average
        data /= num_time

        forcing_lon, forcing_lat = forcing_grid.get_lon_lat(gtype=gtype[n], dim=1)
        # Get longitude in the range -180 to 180, then split and rearrange so it's monotonically increasing        
        forcing_lon = fix_lon_range(forcing_lon)
        i_split = np.nonzero(forcing_lon < 0)[0][0]
        forcing_lon = split_longitude(forcing_lon, i_split)
        data = split_longitude(data, i_split)
        # Now interpolate to MITgcm tracer grid        
        mit_lon, mit_lat = mit_grid.get_lon_lat(gtype='t', dim=1)
        print('Interpolating')
        if monthly_clim:
            data_interp = np.empty([12, mit_grid.ny, mit_grid.nx])
            for m in range(12):
                print(('...month ' + str(m+1)))
                data_interp[m,:] = interp_reg_xy(forcing_lon, forcing_lat, data[m,:], mit_lon, mit_lat)
        else:
            data_interp = interp_reg_xy(forcing_lon, forcing_lat, data, mit_lon, mit_lat)
        print(('Saving to ' + out_file))
        ncfile.add_variable(var_names[n], data_interp, dim_code, units=units[n])

    ncfile.close()

In [None]:
# Build katabatic correction files which scale and rotate the winds in a band around the coast. The arguments cmip_file and
# era5_file are the outputs of process_forcing_for_correction, for UKESM/PACE and ERA5 respectively.
# Update 13 March 2020: Can set bounds on region in domain to apply this correction to. For example, in PAS 
# can set xmin=-90 to only correct in the eastern part of the domain. 
def katabatic_correction (grid_dir, cmip_file, era5_file, out_file_scale, out_file_rotate, scale_dist=150., scale_cap=3, xmin=None, xmax=None, ymin=None, ymax=None, prec=64):

    var_names = ['uwind', 'vwind']
    # Radius for smoothing
    sigma = 2

    print('Building grid')
    grid = Grid(grid_dir)
    print('Selecting coastal points')
    coast_mask = grid.get_coast_mask(ignore_iceberg=True)
    lon_coast = grid.lon_2d[coast_mask].ravel()
    lat_coast = grid.lat_2d[coast_mask].ravel()
    if xmin is None:
        xmin = np.amin(grid.lon_2d)
    if xmax is None:
        xmax = np.amax(grid.lon_2d)
    if ymin is None:
        ymin = np.amin(grid.lat_2d)
    if ymax is None:
        ymax = np.amax(grid.lat_2d)

    print('Calculating winds in polar coordinates')
    magnitudes = []
    angles = []
    for fname in [cmip_file, era5_file]:
        u = read_netcdf(fname, var_names[0])
        v = read_netcdf(fname, var_names[1])
        magnitudes.append(np.sqrt(u**2 + v**2))
        angle = np.arctan2(v,u)
        angles.append(angle)

    print('Calculating corrections')
    # Take minimum of the ratio of ERA5 to CMIP wind magnitude, and the scale cap
    scale = np.minimum(magnitudes[1]/magnitudes[0], scale_cap)
    # Smooth and mask the land and ice shelf
    scale = mask_land_ice(smooth_xy(scale, sigma=sigma), grid)
    # Take difference in angles
    rotate = angles[1] - angles[0]
    # Take mod 2pi when necessary
    index = rotate < -np.pi
    rotate[index] += 2*np.pi
    index = rotate > np.pi
    rotate[index] -= 2*np.pi
    # Smoothing would be weird with the periodic angle, so just mask
    rotate = mask_land_ice(rotate, grid)

    print('Calculating distance from the coast')
    min_dist = None
    # Loop over all the coastal points
    for i in range(lon_coast.size):
        # Skip over any points that are out of bounds
        if lon_coast[i] < xmin or lon_coast[i] > xmax or lat_coast[i] < ymin or lat_coast[i] > ymax:
            continue
        # Calculate distance of every point in the model grid to this specific coastal point, in km
        dist_to_pt = dist_btw_points([lon_coast[i], lat_coast[i]], [grid.lon_2d, grid.lat_2d])*1e-3
        if min_dist is None:
            # Initialise the array
            min_dist = dist_to_pt
        else:
            # Figure out which cells have this coastal point as the closest one yet, and update the array
            index = dist_to_pt < min_dist
            min_dist[index] = dist_to_pt[index]

    print('Tapering function offshore')
    # Cosine function moving from scaling factor to 1 over distance of scale_dist km offshore
    scale_tapered = (min_dist < scale_dist)*(scale - 1)*np.cos(np.pi/2*min_dist/scale_dist) + 1
    # For the rotation, move from scaling factor to 0
    rotate_tapered = (min_dist < scale_dist)*rotate*np.cos(np.pi/2*min_dist/scale_dist)    

    print('Plotting')
    data_to_plot = [min_dist, scale_tapered, rotate_tapered]
    titles = ['Distance to coast (km)', 'Scaling factor', 'Rotation factor']
    ctype = ['basic', 'ratio', 'plusminus']
    fig_names = ['min_dist.png', 'scale.png', 'rotate.png']
    for i in range(len(data_to_plot)):
        for fig_name in [None, fig_names[i]]:
            latlon_plot(data_to_plot[i], grid, ctype=ctype[i], include_shelf=False, title=titles[i], figsize=(10,6), fig_name=fig_name)

    print('Writing to file')
    fields = [scale_tapered, rotate_tapered]
    out_files = [out_file_scale, out_file_rotate]
    for n in range(len(fields)):
        # Replace mask with zeros
        mask = fields[n].mask
        data = fields[n].data
        data[mask] = 0
        write_binary(data, out_files[n], prec=prec)

In [None]:
# Build a correction file for a thermodynamic variable, which will add a spatially-varying offset to UKESM/PACE data so that
# it matches ERA5 data in the time-mean.
def thermo_correction (grid_dir, var_name, cmip_file, era5_file, out_file, prec=64):

    grid = Grid(grid_dir)
    data = []
    for fname in [cmip_file, era5_file]:
        data.append(read_netcdf(fname, var_name))
    data_diff = data[1] - data[0]
    if len(data_diff.shape) == 2:
        latlon_plot(data_diff, grid, ctype='plusminus', figsize=(10,6))
    else:
        titles = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'June', 'July', 'Aug', 'Sept', 'Oct', 'Nov', 'Dec']
        fig, gs, cax = set_panels('3x4+1C1')
        cmap, vmin, vmax = set_colours(data_diff, ctype='plusminus')
        for n in range(12+1):
            if n == 12:
                ax = plt.subplot(gs[0,3])
                img = ax.pcolormesh(np.mean(data_diff,axis=0), cmap=cmap, vmin=vmin, vmax=vmax)
                ax.set_title('Annual')
            else:
                ax = plt.subplot(gs[n//4+1, n%4])
                img = ax.pcolormesh(data_diff[n,:], cmap=cmap, vmin=vmin, vmax=vmax)
                ax.set_title(titles[n])
            ax.set_xticks([])
            ax.set_yticks([])
            ax.axis('tight')
        plt.colorbar(img, cax=cax, orientation='horizontal')
        plt.text(0.05, 0.95, var_name+' correction', transform=fig.transFigure, fontsize=20, ha='left', va='top')
        finished_plot(fig, fig_name=var_name+'_correction.png')
    write_binary(data_diff, out_file, prec=prec)