In [3]:
import numpy as np
import netCDF4 as nc
#import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.colors as colors
#import data_cleaning as dc
#import pandas as pd
import calendar as cal 


In [1]:
def rmse(arr,true_arr):
    ## Given 1d-array and its theoretical/true value calculated rmse.
    rmse = np.zeros(arr.size)
    rmse = np.sqrt(np.mean(np.square(arr-true_arr)))
    return rmse

In [2]:
def two_moment_adjust(otemp,omean,std_ratio, mmd, dtratio, mean_dtratio):
    ##### Sratio-scaling #############
    adjusted_mean = omean - mean_dtratio*mmd
    numerator    = otemp - dtratio*(mmd+(1-std_ratio)*adjusted_mean)
    denominator  = 1 - dtratio*(1-std_ratio)
    #print('madj,num,denom=',adjusted_mean,numerator,denominator)
    ###### No Sratio -Scaling #################
    # numerator = otemp - dtratio*mmd
    # denominator = 1
    return numerator,denominator,numerator/denominator

# def two_moment_adjust(observed_temp, observed_mean, std_ratio, mean_displacement, delta_temp_ratio, mean_delta_temp_ratio):
#     adjusted_mean = observed_mean - mean_delta_temp_ratio*mean_displacement
#     numerator = observed_temp - delta_temp_ratio*((1-std_ratio)*adjusted_mean+mean_displacement)
#     denominator = 1 - delta_temp_ratio*(1-std_ratio)
#     return numerator/denominator

def find_nearest(array, value):
    """Finds index holding number closest to value"""
    array = np.asarray(array)
    idx = (np.abs(array - value)).argmin()
    return idx


########################################
def get_jan_1st_indices_obs(start_year,end_year):
    """Given start_year and end_year, returns indices of every January 1st"""
    jan_1st_index = np.zeros(end_year-start_year+1,dtype=int)
    days =0
    for year in range(start_year,end_year+1):
        jan_1st_index[year-start_year] = days
        if(cal.isleap(year)):
           days_in_year = 366
        else:
           days_in_year = 365   
        days+=days_in_year
    return jan_1st_index

def get_yearday_indcs(start_year,end_year,year_day):
    "Gets the indices corr to a year day. Drops Feb 29 indices"
    indcs      = np.zeros(end_year-start_year+1,dtype=int)
    jan1_indcs = np.zeros(end_year-start_year+1,dtype=int) 
    jan1_indcs = get_jan_1st_indices_obs(start_year,end_year)
   
    for year in range(start_year,end_year+1):
        time       = year -start_year
        if (cal.isleap(year)==False):
            indcs[time] = jan1_indcs[time] + year_day 
        else:       
            if (year_day<59): 
               #print('-year-')
               #print(year)
               indcs[time]  =  jan1_indcs[time] + year_day 
            else:
               indcs[time]  =  jan1_indcs[time] +year_day +1
    return indcs

def get_feb29_indices(start_year,end_year):
    """Get the feb 29 indices of a dataset."""
    feb29_indx = np.zeros(end_year-start_year+1,dtype=int)
    jan1_indx  = get_jan_1st_indices_obs(start_year, end_year)
    for year in range(start_year,end_year+1):
        if(cal.isleap(year)):
           feb29_indx[year-start_year] = jan1_indx[year-start_year] + 59
        else:
            feb29_indx[year-start_year]     = 0
    return feb29_indx[feb29_indx != 0]

def calc_annual_mean(data,start_year,end_year):
    """Given 1d data with index by days (0 to last_day), calculates the annual mean
    for each year from start_year to end_year. Handles leap years."""
    jan1_indx  =  get_jan_1st_indices_obs(start_year, end_year)
    dec31_indx =  get_yearday_indcs(start_year,end_year,365) 
    amean      =  np.zeros(end_year-start_year+1)
    for year in range(start_year,end_year+1):
        amean[year-start_year] = data[jan1_indx[year-start_year]:dec31_indx[year-start_year]].mean()
    return amean
    
def calc_monthly_mean(data,start_year,end_year):
    """Given monthly data, calculates the monthly mean (over years)
    for each year from start_year to end_year. Handles leap years."""
    jan1_indx  =  get_jan_1st_indices_obs(start_year, end_year)
    dec31_indx =  get_yearday_indcs(start_year,end_year,365) 
    amean      =  np.zeros(end_year-start_year+1)
    for year in range(start_year,end_year+1):
        amean[year-start_year] = data[jan1_indx[year-start_year]:dec31_indx[year-start_year]].mean()
    return amean

def get_model_means():
    """Reads and fetches netcdf files with the model means"""
    pim = nc.Dataset(pim_path)
    #print(pim)
    smooth_pim = np.zeros((365,19,42))
    smooth_pim = pim.variables['smoothed_pimeans'][:,:,:]
    pim.close()

    icm  = nc.Dataset(icm_path)
    #print(icm)
    smooth_icm = np.zeros((365,19,42))
    smooth_icm = icm.variables['smoothed_icmeans'][:,:,:]
    icm.close()
    return {"ic": smooth_icm, "pi": smooth_pim}

def get_conus_latlon():
    pim2 = nc.Dataset(pim_path)
    #print(pim)
    lats = np.zeros(19)
    lons = np.zeros(42)
    lats = pim2.variables['lats'][:]
    lons = pim2.variables['lons'][:]
    pim2.close()
    return {"lats":lats,"lons":lons} 


def get_model_stds():
    """Reads and fetches netcdf files with the model means"""
    pis  = nc.Dataset(pis_path)
    #print(pis)
    smooth_pis = np.zeros((365,19,42))
    smooth_pis = pis.variables['smoothed_pistdevs'][:,:,:]
    pis.close()

    ics = nc.Dataset(ics_path)
    smooth_ics = np.zeros((365,19,42))
    smooth_ics = ics.variables['smoothed_icstdevs'][:,:,:]
    ics.close()

    return {"ic": smooth_ics, "pi": smooth_pis}

def get_tdratio(anomaly_name):
    ga = nc.Dataset(gta_path)
    na  = ga.variables[anomaly_name][:]
    ga.close()
    
    return na/(tot_icmean-tot_pimean)
    
def calc_ann_global_mean(data,spinup,nyears,var='tas'):
    """ Given monthly data, calculates, annual global mean"""
    lat_bnds = data['lat_bnds']
    lower_bounds = np.deg2rad(lat_bnds[:,0])
    upper_bounds = np.deg2rad(lat_bnds[:,1])
    lats         = np.deg2rad(data['lat'])
    weights = (np.cos(lats)*(upper_bounds-lower_bounds))
    print(np.sum(weights))
    weights = weights/np.sum(weights)
    print('weights=',np.sum(weights))
    temp = np.zeros(nyears*12)
    atemp = np.zeros(nyears)
    start = spinup*12
    stop = (spinup+nyears)*12
    
    for i in range(start, stop):
        tas = data[var][i,:,:]
        spatial_mean = np.average(tas, axis=0, weights=weights).mean()
        temp[i-start] = spatial_mean
    for j in range(nyears):
        atemp[j] = np.mean(temp[12*j:12*(j+1)])
    return atemp

    
def calc_global_mean(data,start_year,end_year,var='tasmax'):
    """ Given data, calculates, annual global mean"""
    lat_bnds = data['lat_bnds']
    lower_bounds = np.deg2rad(lat_bnds[:,0])
    upper_bounds = np.deg2rad(lat_bnds[:,1])
    lats         = np.deg2rad(data['lat'])
    weights = (np.cos(lats)*(upper_bounds-lower_bounds))
    print(np.sum(weights))
    weights = weights/np.sum(weights)
    print('sum of weights=',np.sum(weights))
    #atemp      = np.zeros(end_year-start_year+1)
    jan1_indx  =  get_jan_1st_indices_obs(start_year, end_year)
    dec31_indx =  get_yearday_indcs(start_year,end_year,365) 
    start      = jan1_indx[0]
    stop       = dec31_indx[-1]
    temp       = np.zeros(stop-start)
    
    for i in range(start,stop):
        tasmax2d = data[var][i,:,:]
        spatial_mean = np.average(tasmax2d, axis=0, weights=weights).mean()
        temp[i-start] = spatial_mean
    print(spatial_mean)   
     
    atemp = calc_annual_mean(temp,start_year,end_year)
    return atemp

def calc_regional_amean(data,start_year,end_year,lat0,lat1,lon0,lon1):
    """ Given data, calculates, annual regional mean, need to specify lat and lon bounds."""
    lat_bnds = data['lat_bnds']
    lats         = data['lat']
    lons         = data['lon']
    print(lats)
    #### Crop lats and lat_bnds to only include region of interest (lat0,lat1)
    lt0     = find_nearest(lats,lat0)
    lt1     = find_nearest(lats,lat1)
    ln0     = find_nearest(lons,lon0)
    ln1     = find_nearest(lons,lon1)
    print(lt0,lt1,ln0,ln1)
    print('Computing regional annual mean for the following box:')
    print('lats=',lats[lt0],lats[lt1],'lons=',lons[ln0],lons[ln1])
    lower_bounds = np.deg2rad(lat_bnds[lt0:lt1,0])
    upper_bounds = np.deg2rad(lat_bnds[lt0:lt1,1])
    lats         = np.deg2rad(lats[lt0:lt1])    
    ###########################
    weights = (np.cos(lats)*(upper_bounds-lower_bounds))
    print(np.sum(weights))
    weights = weights/np.sum(weights)
    print('sum of weights=',np.sum(weights))
    #atemp      = np.zeros(end_year-start_year+1)
    jan1_indx  =  get_jan_1st_indices_obs(start_year, end_year)
    dec31_indx =  get_yearday_indcs(start_year,end_year,365) 
    start      = jan1_indx[0]
    stop       = dec31_indx[-1]
    temp       = np.zeros(stop-start)
    
    for i in range(start,stop):
        tasmax2d = data['tasmax'][i,lt0:lt1,ln0:ln1]
        spatial_mean = np.average(tasmax2d, axis=0, weights=weights).mean()
        temp[i-start] = spatial_mean
    print(spatial_mean)   
     
    atemp = calc_annual_mean(temp,start_year,end_year)
    return atemp


def get_colors_list(vmin0,vmax0,thresh0,thresh1,binsize):
    """Given range of data [vmin,vmax] and threshold [thresh0,thresh1], returns an array
    which can be used to get colorbar with threshold region white. 
    Also, returns binsize for colorbar"""
    vscale = vmax0-vmin0
    x0 = ((thresh0 -vmin0)/vscale) 
    x1 = ((thresh1-vmin0)/vscale)
    #We want to map intervals [vmin,threshold0], [threshold0,threshold1] and [threshold1,vmax] 
    #to the intervals [0,x0],[x0,x1] and [x1,1]. We will do this by creating an array with size 
    #ratios n0:100-n0-n1:n1.    
    n0 =   int(x0*100)  
    n1 = int((1-x1)*100)
    y = 0.5
    lower = plt.cm.seismic(np.linspace(0,y, n0))
    white = plt.cm.seismic(np.ones(100-n0-n1)*0.5)
    upper = plt.cm.seismic(np.linspace(1-y, 1, n1))
    colors_list = np.vstack((lower, white, upper))
    bins = np.arange(vmin0,vmax0+binsize,binsize)
    print('bins=',bins)
    return colors_list,bins

def get_colors_list0(vmin0,vmax0,thresh0,thresh1,binsize):
    """Given range of data [vmin,vmax] and threshold [thresh0,thresh1], returns an array
    which can be used to get colorbar with threshold region white. 
    Also, returns binsize for colorbar"""
    vscale = vmax0-vmin0
    x0 = ((thresh0 -vmin0)/vscale) 
    x1 = ((thresh1-vmin0)/vscale)
    #We want to map intervals [vmin,threshold0], [threshold0,threshold1] and [threshold1,vmax] 
    #to the intervals [0,x0],[x0,x1] and [x1,1]. We will do this by creating an array with size 
    #ratios n0:100-n0-n1:n1.    
    n0 =   int(x0*100)  
    n1 = int((1-x1)*100)
    y = 0.5
    lower = plt.cm.seismic(np.linspace(0,y, n0))
    white = plt.cm.seismic(np.ones(100-n0-n1)*0.5)
    upper = plt.cm.seismic(np.linspace(1-y, 1, n1))
    colors_list = np.vstack((lower, white, upper))
    bins = np.arange(vmin0,vmax0+binsize,binsize)
    print('bins=',bins)
    return colors_list,bins
    
def cmap_map(function, cmap):
    """ Applies function (which should operate on vectors of shape 3: [r, g, b]), on colormap cmap.
    This routine will break any discontinuous points in a colormap.
    """
    cdict = cmap._segmentdata
    step_dict = {}
    # Firt get the list of points where the segments start or end
    for key in ('red', 'green', 'blue'):
        step_dict[key] = list(map(lambda x: x[0], cdict[key]))
    step_list = sum(step_dict.values(), [])
    step_list = np.array(list(set(step_list)))
    # Then compute the LUT, and apply the function to the LUT
    reduced_cmap = lambda step : np.array(cmap(step)[0:3])
    old_LUT = np.array(list(map(reduced_cmap, step_list)))
    new_LUT = np.array(list(map(function, old_LUT)))
    # Now try to make a minimal segment definition of the new LUT
    cdict = {}
    for i, key in enumerate(['red','green','blue']):
        this_cdict = {}
        for j, step in enumerate(step_list):
            if step in step_dict[key]:
                this_cdict[step] = new_LUT[j, i]
            elif new_LUT[j,i] != old_LUT[j, i]:
                this_cdict[step] = new_LUT[j, i]
        colorvector = list(map(lambda x: x + (x[1], ), this_cdict.items()))
        colorvector.sort()
        cdict[key] = colorvector
    return plt.colors.LinearSegmentedColormap('colormap',cdict,1024)


    
    