In [1]:
### Importing all the neccessary packages ###
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
import pandas as pd
import cartopy.crs as ccrs
import cartopy
from matplotlib.offsetbox import AnchoredText
import cartopy.feature as cfeature
import scipy.fft as sf
from scipy import signal
from scipy.stats import circmean
from scipy import optimize
from mpl_toolkits.axes_grid1 import make_axes_locatable
import time
import intake


In [None]:
### Bombardi et al., 2019 has put the code on github under the MIT license, which allows us to utilize their code in any way we see fit. 
### We will of course credit Bombardi et al., 2019 in the eventual manuscript.

In [2]:
### Allows us to use dask to speed up some calculations ###
from dask.distributed import Client
client = Client()


In [3]:
client

0,1
Connection method: Cluster object,Cluster type: distributed.LocalCluster
Dashboard: http://127.0.0.1:8787/status,

0,1
Dashboard: http://127.0.0.1:8787/status,Workers: 4
Total threads: 8,Total memory: 16.00 GiB
Status: running,Using processes: True

0,1
Comm: tcp://127.0.0.1:60187,Workers: 4
Dashboard: http://127.0.0.1:8787/status,Total threads: 8
Started: Just now,Total memory: 16.00 GiB

0,1
Comm: tcp://127.0.0.1:60200,Total threads: 2
Dashboard: http://127.0.0.1:60204/status,Memory: 4.00 GiB
Nanny: tcp://127.0.0.1:60190,
Local directory: /var/folders/m4/3slgbrxj3z3dm65l82539j6w0000gq/T/dask-scratch-space/worker-yeb0t4qp,Local directory: /var/folders/m4/3slgbrxj3z3dm65l82539j6w0000gq/T/dask-scratch-space/worker-yeb0t4qp

0,1
Comm: tcp://127.0.0.1:60199,Total threads: 2
Dashboard: http://127.0.0.1:60203/status,Memory: 4.00 GiB
Nanny: tcp://127.0.0.1:60191,
Local directory: /var/folders/m4/3slgbrxj3z3dm65l82539j6w0000gq/T/dask-scratch-space/worker-cb_zf4vo,Local directory: /var/folders/m4/3slgbrxj3z3dm65l82539j6w0000gq/T/dask-scratch-space/worker-cb_zf4vo

0,1
Comm: tcp://127.0.0.1:60201,Total threads: 2
Dashboard: http://127.0.0.1:60205/status,Memory: 4.00 GiB
Nanny: tcp://127.0.0.1:60192,
Local directory: /var/folders/m4/3slgbrxj3z3dm65l82539j6w0000gq/T/dask-scratch-space/worker-ramu097l,Local directory: /var/folders/m4/3slgbrxj3z3dm65l82539j6w0000gq/T/dask-scratch-space/worker-ramu097l

0,1
Comm: tcp://127.0.0.1:60198,Total threads: 2
Dashboard: http://127.0.0.1:60202/status,Memory: 4.00 GiB
Nanny: tcp://127.0.0.1:60193,
Local directory: /var/folders/m4/3slgbrxj3z3dm65l82539j6w0000gq/T/dask-scratch-space/worker-hhlk_uc1,Local directory: /var/folders/m4/3slgbrxj3z3dm65l82539j6w0000gq/T/dask-scratch-space/worker-hhlk_uc1


In [4]:
catalog_url = 'https://ncar-cesm-lens.s3-us-west-2.amazonaws.com/catalogs/aws-cesm1-le.json'
col = intake.open_esm_datastore(catalog_url)
col

Unnamed: 0,unique
variable,78
long_name,75
component,5
experiment,4
frequency,6
vertical_levels,3
spatial_domain,5
units,25
start_time,12
end_time,13


In [5]:
col_subset = col.search(frequency=["daily"], component="atm", variable="PRECT",
                        experiment=["20C", "RCP85", "HIST"])

In [6]:
dsets = col_subset.to_dataset_dict(zarr_kwargs={"consolidated": True}, storage_options={"anon": True})
print(f"\nDataset dictionary keys:\n {dsets.keys()}")


--> The keys in the returned dictionary of datasets are constructed as follows:
	'component.experiment.frequency'


  dsets = col_subset.to_dataset_dict(zarr_kwargs={"consolidated": True}, storage_options={"anon": True})



Dataset dictionary keys:
 dict_keys(['atm.HIST.daily', 'atm.20C.daily', 'atm.RCP85.daily'])


In [15]:
ds_HIST = dsets['atm.HIST.daily']['PRECT']
ds_20C = dsets['atm.20C.daily']['PRECT']
ds_RCP85 = dsets['atm.RCP85.daily']['PRECT']


In [16]:
ds_20C

Unnamed: 0,Array,Chunk
Bytes,258.65 GiB,121.50 MiB
Shape,"(40, 31390, 192, 288)","(1, 576, 192, 288)"
Dask graph,2200 chunks in 2 graph layers,2200 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 258.65 GiB 121.50 MiB Shape (40, 31390, 192, 288) (1, 576, 192, 288) Dask graph 2200 chunks in 2 graph layers Data type float32 numpy.ndarray",40  1  288  192  31390,

Unnamed: 0,Array,Chunk
Bytes,258.65 GiB,121.50 MiB
Shape,"(40, 31390, 192, 288)","(1, 576, 192, 288)"
Dask graph,2200 chunks in 2 graph layers,2200 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [23]:
### Peru Domain ###
min_lon = -83+360
min_lat = -18.0
max_lon = -67+360
max_lat = 0.0


subset20C = ds_20C.sel(lat=slice(min_lat,max_lat), lon=slice(min_lon,max_lon))
subsetRCP85 = ds_RCP85.sel(lat=slice(min_lat,max_lat), lon=slice(min_lon,max_lon))
subsetHIST = ds_HIST.sel(lat=slice(min_lat,max_lat), lon=slice(min_lon,max_lon))


In [24]:
cropped_ds

Unnamed: 0,Array,Chunk
Bytes,1.16 GiB,555.75 kiB
Shape,"(40, 31390, 19, 13)","(1, 576, 19, 13)"
Dask graph,2200 chunks in 3 graph layers,2200 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 1.16 GiB 555.75 kiB Shape (40, 31390, 19, 13) (1, 576, 19, 13) Dask graph 2200 chunks in 3 graph layers Data type float32 numpy.ndarray",40  1  13  19  31390,

Unnamed: 0,Array,Chunk
Bytes,1.16 GiB,555.75 kiB
Shape,"(40, 31390, 19, 13)","(1, 576, 19, 13)"
Dask graph,2200 chunks in 3 graph layers,2200 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [None]:
### Data Directory For Narwhal ###
data_dir = '/data/deluge/reanalysis/REANALYSIS/ERA5/2D/daily/precip/'

In [None]:
### Opening all of the ERA5 daily precipitation data. The 'time':-1 section loads the data such that each 'chunk' has the entire time series, but is only 100x100 lat, lon.
ds = xr.open_mfdataset(data_dir+'*.nc',parallel=True, chunks={'latitude': 25, 'longitude': 25, 'time': -1})


In [None]:
### subsetting the data so we grab only north america.
subset=ds.sel(latitude=slice(50,30), longitude=slice((360.0-125),(360.0-65.0)))


In [None]:
### Showing the advantage of using Dask to perform operations ###
### Time without dask:  319 seconds.  
### Time with dask: 155 seconds.


start = time.time()



ds_subset = subset.load()
end = time.time()
print(end - start)


In [None]:
ds_subset

In [None]:
### Changing data to mm and removing leap days ###
ds_subset['precip'].values = ds_subset['precip'].values*1000.0
precip = ds_subset['precip']
precip = precip.sel(time=~((precip.time.dt.month == 2) & (precip.time.dt.day == 29)))

### Calculating the mean precipitation for each grid point
annual_mean_precip = precip.mean(dim='time')

### Calculating the mean annual cycle ###
annual_precip_cycle= precip.groupby('time.dayofyear').mean(dim='time')

annual_precip_cycle_var = precip.groupby('time.dayofyear').std(dim='time')

### Create Daily precip from hourly ###
##daily_precip = precip.resample(time='1D').sum()

### Create The annual daily precip anomalies ###
annual_means = precip.groupby('time.year').mean(dim='time')

anomalies = precip.groupby('time.year') - annual_means
#anomalies = anomalies_withleap.sel(time=~((anomalies_withleap.time.dt.month == 2) & (anomalies_withleap.time.dt.day == 29)))

climDailyMeanAnomaly = annual_precip_cycle - annual_mean_precip

In [None]:
### Code acquired from 

In [None]:
"""
Funtion that calculates the Fourier coefficients and the explained variance of the Nth
first harmonics of a time series
Input:
   tseries: input time series
   nmodes : number of harmonics to retain (N)
   coefa  : Array with N (or 'nmodes') elements
   coefb  : Array with N (or 'nmodes') elements
   hvar   : Array with N (or 'nmodes') elements
   missval: Falg value for missing data
Output:
   coefa: Array of A coefficients of the Nth first harmonics
   coefb: Array of B coefficients of the Nth first harmonics
   hvar : Array of explained variance of the Nth first harmonics
"""



In [None]:
### Use dask to run the Harmonics finction over the data ### 
### Slightly faster than a loop but still slow ###
start_wet2 = xr.apply_ufunc(
fourier1,
annual_precip_cycle.load(),
input_core_dims=[["dayofyear"]],
exclude_dims=set(["dayofyear"]),
vectorize=True,
)


In [None]:
start_wet2.plot()

In [None]:
array_jday = anomalies.time.dt.dayofyear
input_data = anomalies.assign_coords(jday=("time",array_jday.data))

In [None]:
start_wet2.to_netcdf('algorithm_start.nc')

In [None]:
### Used for testing individual areas ###
lat = 43.5
lon = 270.5
data = input_data.sel(latitude=lat, longitude=lon).values
data_time = input_data.sel(latitude=lat, longitude=lon).time.values
data_jday = input_data.sel(latitude=lat, longitude=lon).jday.values

start_test = start_wet2.sel(latitude=lat, longitude=lon).values

In [None]:

#========================================================================
#  Subroutine that calculates the beginning date of the rainy season
# for a time series of precipitation
# nyrs   --> integer for the number of years in the input dataset
# tot    --> total number of points for one year of data (365)
# mtot   --> total number of points in the whole precipitation dataset
# jday   --> an array of Julian days
# day    --> an array of days
# month  --> an array of months
# year   --> an array of years
# jstart --> Julian day of the climatological date when the calculation
#            should start
# precip --> a time series of precipitation anomalies (against mean annual daily)
# npass  --> integer for the number of "passes" for the smoothing of the
#            time series of accumulated precipitation anomalies
#========================================================================#def rainyseason_onset(nyrs,ytot,jday,day,month,year,jstart,precip,sjday,sday,smonth,syear,curve):
#def rainyseason_onset(anomaly_ds,start_wet):
#
### From the apply_ufunc for testing ###
data = input_data.sel(latitude=lat, longitude=lon).values
data_time = input_data.sel(latitude=lat, longitude=lon).time.values
data_jday = input_data.sel(latitude=lat, longitude=lon).jday.values
start_test = start_wet2.sel(latitude=lat, longitude=lon).values




In [None]:
input_data.sel(latitude=lat, longitude=lon, time=slice('1959','1960')).plot(figsize=(28,10))
input_data.sel(latitude=lat, longitude=lon, time=slice('1959','1960')).cumsum().plot()

In [None]:
test_onset = onset_LM01(data,data_time,start_test)

In [None]:
test_onset

In [None]:
start_wet2.sel(latitude=lat,longitude=lon)

In [None]:
test_demise[1:]- test_onset[:-1]

In [None]:
test_onset[:-1]

In [None]:
test_demise[1:]

In [None]:
n=0
input_data.sel(latitude=lat, longitude=lon, time=slice(test_onset[:-1][n]-50,test_demise[1:][n]+50)).plot(figsize=(28,10))
plt.axvline(test_onset[:-1][n], color='green')
plt.axvline(test_demise[1:][n],color='brown')

In [None]:
def onset_dunning(data, data_time, start_test):
    

In [None]:
data_time[0] - (data_time[0] - 40)

In [None]:
0 <= 0

In [None]:
def onset_LM01(data, time, startWet):
    ### Requires anomaly data ###
    dataLength = len(data)
    if (startWet <= 0):
        startWet=1
    data_time = pd.DatetimeIndex(time)
    
    ### dayofyear is 1 indexed ###
    startDOY= np.where(data_time.dayofyear == startWet)[0]
    
    nYears = len(np.unique(data_time.year))
    
    onsetDOY=np.empty((nYears))
    onsetDOY[:] = np.nan
    onsetDate=np.empty((nYears),dtype='datetime64[D]')
    onsetDate[:] = 'nat'
    #print(data_time[::-1])
    ### looping through start dates ###
    for i, start in enumerate(startDOY):
        
        ### Make sure we dont exceed our data ###
        if start < dataLength:
            analysisBegin = start
            analysisEnd = start + int(180) ### end of analysis is 180 days later
            ### Make sure that we have enough data to compute ###
            if (analysisEnd > dataLength - 180):
                ### arrays initialized as Nans ###
                pass
            analysisPeriod = data_time[analysisBegin:analysisEnd]
            sumSeries = np.cumsum(data[analysisBegin:analysisEnd])
            
            onset = np.argmin(sumSeries)
            #print(onset)
            onsetDOY[i] = onset
            onsetDate[i] = analysisPeriod[onset]
            if (analysisBegin < onset < analysisEnd):
                #print(onset)
                onsetDOY[i] = onset
                onsetDate[i] = analysisPeriod[onset]
            #if (i==2):
                #break
    return onsetDate


In [None]:
def demise_LM01(data, time, startWet):
    ### Requires anomaly data ###
    dataLength = len(data)
    data_time = pd.DatetimeIndex(time)
    ### dayofyear is 1 indexed.
    if (startWet <= 0):
        startWet=1
    
    startDOY= np.where(data_time.dayofyear == startWet)[0]
    nYears = len(np.unique(data_time.year))
    
    demiseDOY=np.empty((nYears))
    demiseDOY[:] = np.nan
    demiseDate=np.empty((nYears),dtype='datetime64[D]')
    demiseDate[:] = 'nat'
    #print(startWet)
    ### looping through start dates ###
    for i, start in enumerate(startDOY):
        analysisBegin = start
        analysisEnd = start - int(180)
        ### Make sure we dont exceed our data ###
        if analysisBegin < dataLength and analysisEnd > 0:
             ### end of analysis is 180 days later
            ### Make sure that we have enough data to compute ###
            analysisPeriod = data_time[analysisBegin:analysisEnd:-1]
            ### Cumulative Sum ###
            sumSeries = np.cumsum(data[analysisBegin:analysisEnd:-1])
            
            
            demise = np.argmin(sumSeries)
            #print(demise)
            #print(analysisPeriod[demise])
            #print(onset)
            demiseDOY[i] = demise
            demiseDate[i] = analysisPeriod[demise]
            if (analysisBegin > demise > analysisEnd):
                #print(onset)
                demiseDOY[i] = demise
                demiseDate[i] = analysisPeriod[demise]
            #if (i==2):
                #break
    return demiseDate

In [None]:
def onset_bombardi(data, data_time, startWet):
    #print(f"data: {data.shape} | time: {data_time.shape} | start_test: {start_test.shape}")
    ### Precalculations ###
    if (startWet <= 0):
        startWet=1
    sseries=np.zeros((int(366/2)))
    dataLength = len(data)
    data_time = pd.DatetimeIndex(data_time)
    startDOY= np.where(data_time.dayofyear == startWet)[0]
    #print(len(np.unique(startDOY)))
    nyrs = len(np.unique(data_time.year))
    ytot=365
    ### Data Structures to hold results ###
    sjday=np.empty((nyrs))
    sjday[:] = np.nan
    sdate=np.empty((nyrs),dtype='datetime64[D]')
    sdate[:] = 'nat'
    #smonth=np.zeros((nyrs))
    #syear=np.zeros((nyrs))
    ### Run through entire time series for one grid point ###
    yt = -1
    for i, start in enumerate(startDOY):
        ### Loop through the list of start days ###
        if start < (dataLength):         # -5 to avoid calcualtion with short time series for last year
            
            beg = start
            end = beg+int(365/2)
            if end <= dataLength-1:  # it is not the last year
                end2=int(ytot/2)
            if end > dataLength-1:
                end=dataLength-1
                end2=end-beg
            sseries[:]=0
            sseries[0:end2]=np.cumsum(data[beg:beg+end2])
            #curve[yt,:]=sseries[:]
            #-------------------------------------------------------------------------
            # Calculating onset and demise of the rainy season
            #-------------------------------------------------------------------------
            beg=0
            try:
                ons=np.where(sseries[0:end2] == sseries[0:end2].min())
            except ValueError:
                pass
                #print(beg)
            if len(ons[0]) > 0:
                
                beg=ons[0][0]+start+1
                #print(beg)
            if beg > 0 and beg < end:

                sjday[yt]= data_time[beg].dayofyear
                sdate[yt]= data_time[beg]
                #smonth[yt]= data_time[beg].month
                #syear[yt]= data_time[beg].day
    return sdate





#========================================================================
#                             End of subroutine
#========================================================================

In [None]:
def demise_calculation(data, data_time, startWet):
 ### Reverse Date for Demise Calculation
    data = data[::-1]
    data_time = data_time[::-1]
    ### Precalculations ###
    if (startWet <= 0):
        startWet=1
    sseries=np.zeros((int(366/2)))
    dataLength= len(data)
    yt=-1
    data_time = pd.DatetimeIndex(data_time)
    start_list = np.where(data_time.dayofyear == startWet)[0]
    #print(start_test)
    nyrs = len(np.unique(data_time.year))
    ytot=365
    ### Data Structures to hold results ###
    sjday=np.empty((nyrs))
    sjday[:] = np.nan
    sdate=np.empty((nyrs),dtype='datetime64[D]')
    sdate[:] = 'nat'
    #smonth=np.zeros((nyrs))
    #syear=np.zeros((nyrs))
    ### Run through entire time series for one grid point ###
    for start in start_list:
        #print(start)
        #    for tt in range(0,dataLength-5): # -5 to avoid calcualtion with short time series for last year
        #------------------------------------------------------------------------
        # Starting the calculation of accumulated anomalies in the rainy season
        #------------------------------------------------------------------------                 !
        #        if jday[tt] == jstart:
        if start < (dataLength):         # -5 to avoid calcualtion with short time series for last year
            yt=yt+1
            beg= start
            end = beg+int(365/2)
            if end <= dataLength-1:  # it is not the last year
                end2=int(ytot/2)
            if end > dataLength-1:
                end=dataLength-1
                end2=end-beg
            sseries[:]=0
            sseries[0:end2]=np.cumsum(data[beg:beg+end2])
            #curve[yt,:]=sseries[:]
            #-------------------------------------------------------------------------
            # Calculating onset and demise of the rainy season
            #-------------------------------------------------------------------------
            beg=0
            try:
                ons=np.where(sseries[0:end2] == sseries[0:end2].min())
            except ValueError:
                #print(beg)
                pass
            if len(ons[0]) > 0:
                beg=ons[0][0]+start+1
                #print(beg)
            if beg > 0 and beg < end:

                sjday[yt]= data_time[beg].dayofyear
                sdate[yt]= data_time[beg]
                #smonth[yt]= data_time[beg].month
                #syear[yt]= data_time[beg].day
    return sdate[::-1]





In [None]:
### 500 seconds ###

### 423 without Dask... Need more optimization... ###
start = time.time()
"the code you want to test stays here"




demise_LM01_test = xr.apply_ufunc(
    demise_LM01,
    anomalies,
    anomalies.time,
    start_wet2,
    input_core_dims=[["time"],["time"],[]],
    exclude_dims=set(["time"]),
    output_core_dims=[["year"]],
    vectorize=True,
    dask = 'parallelized',
    #output_dtypes = 'datetime64[D]',
    #output_sizes={"data_jday": 71},
)
end = time.time()
print(end - start)

In [None]:
### 500 Seconds ###
start = time.time()
"the code you want to test stays here"




onset_LM01_test = xr.apply_ufunc(
    onset_LM01,
    anomalies,
    anomalies.time,
    start_wet2,
    input_core_dims=[["time"],["time"],[]],
    exclude_dims=set(["time"]),
    output_core_dims=[["year"]],
    vectorize=True,
    dask = 'parallelized',
    #output_dtypes = 'datetime64[D]',
    #output_sizes={"data_jday": 71},
)
end = time.time()
print(end - start)

In [None]:
### 500 Seconds ###
start = time.time()
"the code you want to test stays here"




onset_bombardi_test = xr.apply_ufunc(
    onset_bombardi,
    anomalies,
    anomalies.time,
    start_wet2,
    input_core_dims=[["time"],["time"],[]],
    exclude_dims=set(["time"]),
    output_core_dims=[["year"]],
    vectorize=True,
    dask = 'parallelized',
    #output_dtypes = 'datetime64[D]',
    #output_sizes={"data_jday": 71},
)
end = time.time()
print(end - start)

In [None]:
### 500 Seconds ###
start = time.time()
"the code you want to test stays here"




demise_bombardi_test = xr.apply_ufunc(
    demise_calculation,
    anomalies,
    anomalies.time,
    start_wet2,
    input_core_dims=[["time"],["time"],[]],
    exclude_dims=set(["time"]),
    output_core_dims=[["year"]],
    vectorize=True,
    dask = 'parallelized',
    #output_dtypes = 'datetime64[D]',
    #output_sizes={"data_jday": 71},
)
end = time.time()
print(end - start)

In [None]:
onset_data = onset_LM01_test
demise_data = demise_LM01_test
onset_data.name = 'onset_date'
demise_data.name = 'demise_date'
onset_data.coords['year'] = pd.date_range("1951", periods=70, freq='YS')
demise_data.coords['year'] = pd.date_range("1951", periods=70, freq='YS')

In [None]:
onset_data.isel(latitude = 20, longitude = 30)

In [None]:
test_onset = xr.merge([onset_data,demise_data])

In [None]:
onset_LM01_test.isel(latitude = 20, longitude = 30)[:-1] - demise_LM01_test.isel(latitude = 20, longitude = 30)[1:] 

In [None]:
onset_LM01_test = onset_LM01_test.isel(year=slice(0,70))

In [None]:
demise_LM01_test = demise_LM01_test.isel(year=slice(1,71))

In [None]:
test_onset['demise_doy'] = test_onset['demise_date'].dt.dayofyear
test_onset['onset_doy'] = test_onset['onset_date'].dt.dayofyear

In [None]:
test_onset.to_netcdf('OnsetDemise_ERA5.nc')

In [None]:
test_onset['onset_doy'].sel(year='2011').plot()

In [None]:
demise.to_netcdf('wetseason.demise.era5.nc')

In [None]:
onset.to_netcdf('wetseason.onset.era5.nc')

In [None]:
(test_mean - demise['demise_date'].mean(dim='year')).plot()

In [None]:
## TODO: Need to reverse the output array ###

demise['demise_date'].mean(dim='year').plot(figsize=(13,8))

In [None]:
test6['onset_date'].mean(dim='data_jday').plot(figsize=(13,8))

In [None]:
states_provinces = cfeature.NaturalEarthFeature(
    category='cultural',
    name='admin_1_states_provinces_lines',
    scale='10m',
    facecolor='none')
map_proj = ccrs.LambertConformal(central_longitude=-95, central_latitude=45)
#cmap = mpl.cm.RdBu_r


f, ax1 = plt.subplots(1, 1, figsize=(10, 13), dpi=600, subplot_kw={'projection': map_proj})
p = onset_data.dt.dayofyear.isel(year=50).plot.pcolormesh(ax=ax1,transform=ccrs.PlateCarree(), add_colorbar=False, cmap='viridis')


### Setting 1st plot parameters ###
ax1.coastlines(color='grey')
ax1.add_feature(cartopy.feature.BORDERS, color='black')
ax1.add_feature(cfeature.STATES, edgecolor='black')
#ax1.set_xticks(np.arange(-180,181, 40))
#ax1.set_yticks(np.arange(-90,91,15))
ax1.set_xlabel('Longitude')
ax1.set_ylabel('Latitude')
at = AnchoredText("a",
                      loc='upper left', prop=dict(size=8), frameon=True,)
at.patch.set_boxstyle("round,pad=0.,rounding_size=0.2")
ax1.add_artist(at)
divider = make_axes_locatable(ax1)
cax = divider.append_axes("right", size="5%", pad=0.05, axes_class=plt.Axes)
plt.colorbar(p, cax=cax)