In [1]:
import xarray as xr
import matplotlib.pyplot as plt
import scipy.stats
from scipy import stats
from scipy.stats import ttest_ind 
from cartopy import config
import cartopy
import cartopy.crs as ccrs
import cmocean
import cartopy.feature as cfeature
import seaborn as sns
import datetime
import calendar
import iris
import pandas as pd
from pathlib import Path
import glob
import xesmf as xe

In [2]:
ds = xr.open_dataset('/g/data/e14/cp3790/Charuni/ModelRun/umnsaa_pa000.nc')

mask = ds.lmask

In [3]:
START_YEAR = 1981
END_YEAR = 2018

## Alternate runs 

In [18]:
def get_processed_var_xr(xr, var_name):
    if var_name == "ts_0":
        return xr - 273.15
    elif var_name == "air_pressure_at_sea_level":
        return xr/100
    return xr

In [12]:
#def get_alt_year_xr_map(var_name, filepath_format):
#    year_xr_map = {}
#    for year in range(START_YEAR, END_YEAR+1):
#        filename = filepath_format.format(year)
#        if len(glob.glob(filename)) == 0:
#            continue
#        ds = None
#        if '*' in filename:
#            ds = xr.open_mfdataset(filename)
#        else:
#            ds = xr.open_dataset(filename)
#        var_xr = ds[var_name]
#        # Special processing for the variables
#        var_xr = get_processed_var_xr(var_xr, var_name)
#        year_xr_map[year] = var_xr
#
#    return year_xr_map  

In [1]:
##chai

def get_alt_year_xr_map(var_name, filepath_format):
    year_xr_map = {}
    for year in range(START_YEAR, END_YEAR+1):
        filename = filepath_format.format(year)
        if len(glob.glob(filename)) == 0:
            continue
        ds = None
        if '*' in filename:
            ds = xr.open_mfdataset(filename)
        else:
            ds = xr.open_dataset(filename)
        var_xr = None
        try:
            var_xr = ds[var_name]
        except KeyError:
            print("Variable {} was not found. Skipping year {}".format(var_name, year))
            continue
        
        # Special processing for the variables
        var_xr = get_processed_var_xr(var_xr, var_name)
        next_year = year + 1
        # Slicing
        time_slice = slice(
            "{}-06-01".format(year),
            "{}-06-30".format(next_year),
        )
        var_xr_slice = var_xr.sel(time=time_slice)
        year_xr_map[year] = var_xr_slice

    return year_xr_map 

## New control runs

In [69]:
# con_bundles = {
#     '1980': {
#         'xr': bundle_1980,
#         'start': 1981,
#         'end': 1998,
#     },
#     '1999': {
#         'xr': bundle_1999,
#         'start': 2000,
#         'end': 2018,
#     }
# }

In [70]:
def get_year_bundle_xr_map(con_bundle_info):
    year_bundle_xr_map = {}
    for bundle in con_bundle_info:
        val = con_bundle_info[bundle]
        for year in range(val['start'], val['end']+1):
            year_bundle_xr_map[year] = val['xr']
    return year_bundle_xr_map

In [71]:
def get_con_year_xr_map(alt_year_xr_map, con_bundle_info, var_name):
    year_bundle_xr_map = get_year_bundle_xr_map(con_bundle_info)
    years = alt_year_xr_map.keys()
    year_xr_map = {}
    for year in years:
        next_year = year + 1
        var_xr = year_bundle_xr_map[year][var_name]
        # Special processing for the variables
        var_xr = get_processed_var_xr(var_xr, var_name)
        # Slicing
        time_slice = slice(
            "{}-06-01".format(year),
            "{}-06-30".format(next_year),
        )
        var_xr_slice = var_xr.sel(time=time_slice)
        year_xr_map[year] = var_xr_slice
    return year_xr_map

## Difference

In [3]:
def get_diff_year_xr_map(alt_year_xr_map, con_year_xr_map):
    alt_years = set(alt_year_xr_map.keys())
    con_years = set(con_year_xr_map.keys())
    alt_diff_con = alt_years - con_years
    con_diff_alt = con_years - alt_years
    if len(alt_diff_con) > 0:
        print("Following years in alt_xr_map is not in con_xr_map: {}".format(alt_diff_con))
    if len(con_diff_alt) > 0:
        print("Following years in con_xr_map is not in alt_xr_map: {}".format(con_diff_alt))
    
    years = list(alt_years & con_years)
    years.sort()
    year_xr_map = {}
    for year in years:
        year_xr = \
            ( \
              alt_year_xr_map[year].resample(time='1M').mean() \
              - con_year_xr_map[year].resample(time='1M').mean() \
            ) \
            .rename({'time': 'event_month'})
        em_coord_size = year_xr["event_month"].size
        if em_coord_size != 13:
            print("{} has {} event_month coords after resampling. Skipping this year!".format(year, em_coord_size))
            continue
        year_xr_map[year] = year_xr
    return year_xr_map

In [4]:
def coord_to_dim(coord_xr, coord="event_month", dim=None):
    if coord == dim:
        print("Coord and dim cannot be the same name!")
        return
    if dim is None:
        dim = "{}_dim".format(coord)
    coord_segments_xr = xr.concat(
        [coord_xr.isel({ coord: i}) for i in range(0, coord_xr[coord].size)],
        dim=dim
    )
    return coord_segments_xr

def detect_and_rename_coords(
    xr_in,
    detect_coord="longitude",
    rename_map={"longitude_0": "longiude", "latitude_0": "latitude"}
):
    try:
        print ("xr has '{}'. Not renaming.".format(detect_coord, len(xr_in[detect_coord])))
        return xr_in
    except:
        print("Can't find '{}'. Renaming coords...".format(detect_coord))
        xr_out = xr_in.rename(rename_map)
        return xr_out  

In [9]:
d1 = {
    'a': 1,
    'b': 2,
    'c': 3
}

d2 = {
    'b': 2,
    'c': 3,
    'd': 4
}

i = list(set(d1.keys()) & set(d2.keys()))
i.sort()
len(set(d1.keys()) & set(d2.keys()))

2

In [1]:
def get_year_daily_xr(ref_year_xr_map, year_xr_map, phase_dim='event_month'):
    phase_ranges = {
        "buildup": {
            "start": "09",
            "end": "02",
        },
        "peak": {
            "start": "03",
            "end": "06",
        }
    }

    year_phases = []
    years = ref_year_xr_map.keys()
    for year in years:
        year_xr = year_xr_map[year]

        next_year = year + 1
        buildup_slice = slice("{}-{}".format(year, phase_ranges["buildup"]["start"]),
                              "{}-{}".format(next_year, phase_ranges["buildup"]["end"]))
        peak_slice = slice("{}-{}".format(next_year, phase_ranges["peak"]["start"]),
                              "{}-{}".format(next_year, phase_ranges["peak"]["end"]))

        buildup = year_xr.sel({phase_dim: buildup_slice})
        peak = year_xr.sel({phase_dim: peak_slice})

        year_phase_xr = xr.concat([buildup, peak], dim="phase")
        
        print ("Checking coords of {} - ".format(year), end="")
        try:
            print ("has 'longtitude'".format(len(year_phase_xr.longitude)), end ="")
        except:
            # Has 'longitude_0' and 'latitude_0'
            print("has 'longitude_0' -> renaming coords...", end="")
            year_phase_xr = year_phase_xr.rename(longitude_0='longitude', latitude_0='latitude')
        finally:
            print("")
        year_phases.append(year_phase_xr) ## I'm adding this now
    year_phases_xr = xr.concat(year_phases, dim="time")
    return year_phases_xr

## Phases

New approach: 2 phases 
Sep to Feb (high SSTA and strong SAT response) and March to June (when SSTA is high but the SAT response is relatively weak)

In [2]:
def get_year_2phases_xr(ref_year_xr_map, year_xr_map, phase_dim='event_month'):
    phase_ranges = {
        "buildup": {
            "start": "09",
            "end": "02",
        },
        "peak": {
            "start": "03",
            "end": "06",
        }
    }

    year_phases = []
    years = ref_year_xr_map.keys()
    for year in years:
        year_xr = year_xr_map[year]

        next_year = year + 1
        buildup_slice = slice("{}-{}".format(year, phase_ranges["buildup"]["start"]),
                              "{}-{}".format(next_year, phase_ranges["buildup"]["end"]))
        peak_slice = slice("{}-{}".format(next_year, phase_ranges["peak"]["start"]),
                              "{}-{}".format(next_year, phase_ranges["peak"]["end"]))

        buildup = year_xr.sel({phase_dim: buildup_slice}).mean(dim=phase_dim)
        peak = year_xr.sel({phase_dim: peak_slice}).mean(dim=phase_dim)

        year_phase_xr = xr.concat([buildup, peak], dim="phase")
        
        print ("Checking coords of {} - ".format(year), end="")
        try:
            print ("has 'longtitude'".format(len(year_phase_xr.longitude)), end ="")
        except:
            # Has 'longitude_0' and 'latitude_0'
            print("has 'longitude_0' -> renaming coords...", end="")
            year_phase_xr = year_phase_xr.rename(longitude_0='longitude', latitude_0='latitude')
        finally:
            print("")
        year_phases.append(year_phase_xr) ## I'm adding this now
    year_phases_xr = xr.concat(year_phases, dim="time")
    return year_phases_xr

In [13]:
def get_year_phases_xr(ref_year_xr_map, year_xr_map, phase_dim='event_month'):
    phase_ranges = {
        "buildup": {
            "start": "06",
            "end": "01",
        },
        "peak": {
            "start": "02",
            "end": "04",
        },
        "diedown": {
            "start": "05",
            "end": "06",
        },
    }

    year_phases = []
    years = ref_year_xr_map.keys()
    for year in years:
        year_xr = year_xr_map[year]

        next_year = year + 1
        buildup_slice = slice("{}-{}".format(year, phase_ranges["buildup"]["start"]),
                              "{}-{}".format(year, phase_ranges["buildup"]["end"]))
        peak_slice = slice("{}-{}".format(year, phase_ranges["peak"]["start"]),
                              "{}-{}".format(next_year, phase_ranges["peak"]["end"]))
        diedown_slice = slice("{}-{}".format(next_year, phase_ranges["diedown"]["start"]),
                              "{}-{}".format(next_year, phase_ranges["diedown"]["end"]))

        buildup = year_xr.sel({phase_dim: buildup_slice}).mean(dim=phase_dim)
        peak = year_xr.sel({phase_dim: peak_slice}).mean(dim=phase_dim)
        diedown = year_xr.sel({phase_dim: diedown_slice}).mean(dim=phase_dim)

        year_phase_xr = xr.concat([buildup, peak, diedown], dim="phase")
        
        print ("Checking coords of {} - ".format(year), end="")
        try:
            print ("has 'longtitude'".format(len(year_phase_xr.longitude)), end ="")
        except:
            # Has 'longitude_0' and 'latitude_0'
            print("has 'longitude_0' -> renaming coords...", end="")
            year_phase_xr = year_phase_xr.rename(longitude_0='longitude', latitude_0='latitude')
        finally:
            print("")
        year_phases.append(year_phase_xr) ## I'm adding this now
    year_phases_xr = xr.concat(year_phases, dim="time")
    return year_phases_xr

## Months

In [7]:
# time (13 months) x longitude x latitude
def get_year_months_xr(ref_year_xr_map, year_xr_map):
    year_months = []
    for m in range(0,13):
        index_list = []
        years = ref_year_xr_map.keys()
        for year in years:
            year_xr = year_xr_map[year].resample(time='1M').mean()
            index_list.append(year_xr.isel(time=m))
        year_months.append(xr.concat(index_list, dim="time").mean(dim="time"))
    year_months_xr = xr.concat(year_months, dim="time")
    return year_months_xr

In [2]:
# year (36 years) x month (13) x longitude x latitude
def get_year_and_months_xr(ref_year_xr_map, year_xr_map):
    year_month_segments_list = []
    years = ref_year_xr_map.keys()
    for year in years:
        # has time:coordinate x 13
        year_month_segments_time_coord_xr = year_xr_map[year].resample(time='1M').mean() 
        time_coord_size = year_month_segments_time_coord_xr["time"].size
        if time_coord_size != 13:
            print("{} has {} time coords after resampling. Skipping this year!".format(year, time_coord_size))
            continue
        # has month: dimension x 13
        year_month_segments_xr = xr.concat(
            [year_month_segments_time_coord_xr.isel(time=m) for m in range(0,13)],
            dim='month'
        )
        print ("Checking coords of {} - ".format(year), end="")
        try:
            print ("has 'longtitude'".format(len(year_month_segments_xr.longitude)), end ="")
        except:
            # Has 'longitude_0' and 'latitude_0'
            print("has 'longitude_0' -> renaming coords...", end="")
            year_month_segments_xr = year_month_segments_xr.rename(longitude_0='longitude', latitude_0='latitude')
        finally:
            print("")
        year_month_segments_list.append(year_month_segments_xr)

    year_and_months_xr = xr.concat(year_month_segments_list, dim="year")
    return year_and_months_xr

In [5]:
# chai trying to get year and month 

def get_year_months_xr(ref_year_xr_map, year_xr_map):
    year_months = []
    years = ref_year_xr_map.keys()
    for year in years:
        month_xr = year_xr_map[year].resample(time='1M').mean()
        year_month_xr = xr.concat(month_xr, dim='time')
        year_months.append(year_month_xr)
    year_months_xr = xr.concat(year_months, dim="time")
    return year_months_xr

## T test

In [3]:
def get_sig_da(year_phases_xr, dim1='latitude', dim2='longitude'):
    r_b = stats.ttest_1samp(year_phases_xr.isel(phase=0), 0, axis=0, nan_policy='omit')

    # Convert the p-values into a dataarray:
    da_b = xr.DataArray(r_b[1], dims=(dim1, dim2), coords={
        dim1: year_phases_xr[dim1],
        dim2: year_phases_xr[dim2]
    })

    r_p = stats.ttest_1samp(year_phases_xr.isel(phase=1), 0, axis=0, nan_policy='omit')


    # Convert the p-values into a dataarray:
    da_p = xr.DataArray(r_p[1], dims=(dim1, dim2), coords={
        dim1: year_phases_xr[dim1],
        dim2: year_phases_xr[dim2]
    })

    #r_d = stats.ttest_1samp(year_phases_xr.isel(phase=2), 0, axis=0, nan_policy='omit')


    # Convert the p-values into a dataarray:
    #da_d = xr.DataArray(r_d[1], dims=(dim1, dim2), coords={
    #    dim1: year_phases_xr[dim1],
    #    dim2: year_phases_xr[dim2]
    #})

    #sig_da = xr.concat([da_b, da_p, da_d], dim='phase')
    sig_da = xr.concat([da_b, da_p], dim='phase')
    return sig_da

## Difference Maps (Multiple plots)

In [6]:
## new diff plot (phases)

def multiple_plots_aus(year_phases_xr, sig_da, exclude_ocean=False):
    nrow = 1
    ncol = 2
    
    fig, ax = plt.subplots(nrows=nrow,ncols=ncol, figsize=(16,9), 
            subplot_kw={'projection': ccrs.PlateCarree()}) # Specifies the projection for the plots    
    
    data = year_phases_xr.mean(dim='time')
    if exclude_ocean:
        data = data.where(mask).mean(dim='time') 
        
    phases = ['Sep-Feb', 'March-June']
    contour = None    
    mask_mean = mask.mean(dim='time') 
    
    for col in range(ncol):
        cur_data = data[col]
        contour = cur_data.plot(
            ax=ax[col],
            cmap=plt.cm.get_cmap('bwr', 40),
            vmin=-1, vmax=1,
            add_colorbar=False,
            extend='both'
        )
        
        ax[col].contour(
            mask_mean.longitude, 
            mask_mean.latitude,
            mask_mean.values,
            levels=[0.5], 
            cmap='gray')
        
        phase_mask_mean = sig_da[col].where(mask).mean(dim='time')
        
        ax[col].contourf(
            phase_mask_mean.longitude,
            phase_mask_mean.latitude,
            phase_mask_mean.values,
            levels=[0.00, 0.05, 1.00],
            hatches=['.', None],
            colors='none',
            add_colorbar=False,
        )

        ax[col].set_extent([140, 155, -48, -33],crs=ccrs.PlateCarree()) 
        ax[col].set_xticks([140, 145, 150, 155],crs=ccrs.PlateCarree())
        ax[col].set_yticks([-48, -45, -40, -35, -33],crs=ccrs.PlateCarree())

        
        ax[col].set_title(phases[col], fontsize=20)

    cbar = fig.colorbar(contour, ax=ax.ravel().tolist(), orientation="horizontal",ticks=np.arange(-1,1.05,0.5),
                        fraction=0.035,aspect=25, extend='both')
    cbar.set_label(label=u'Surface air temperature anomaly (\u00B0C)', fontsize=24)  
    cbar.ax.tick_params(labelsize=20)

    #plt.xlabel('Longitude',fontsize=20)
    #plt.ylabel('Latitude', fontsize=20)
    plt.savefig('/g/data/e14/cp3790/Charuni/ModelExperiment/Revisions/thesis_fig3.5a_aus.png')

In [7]:
## new diff plot (phases)

def multiple_plots_nz(year_phases_xr, sig_da, exclude_ocean=False):
    nrow = 1
    ncol = 2
    
    fig, ax = plt.subplots(nrows=nrow,ncols=ncol, figsize=(16,9), 
            subplot_kw={'projection': ccrs.PlateCarree()}) # Specifies the projection for the plots    
    
    data = year_phases_xr.mean(dim='time')
    if exclude_ocean:
        data = data.where(mask).mean(dim='time') 
        
    phases = ['Sep-Feb', 'March-June']
    contour = None    
    mask_mean = mask.mean(dim='time') 
    
    for col in range(ncol):
        cur_data = data[col]
        contour = cur_data.plot(
            ax=ax[col],
            cmap=plt.cm.get_cmap('bwr', 40),
            vmin=-1, vmax=1,
            add_colorbar=False,
        )
        
        ax[col].contour(
            mask_mean.longitude, 
            mask_mean.latitude,
            mask_mean.values,
            levels=[0.5], 
            cmap='gray')
        
        phase_mask_mean = sig_da[col].where(mask).mean(dim='time')
        
        ax[col].contourf(
            phase_mask_mean.longitude,
            phase_mask_mean.latitude,
            phase_mask_mean.values,
            levels=[0.00, 0.05, 1.00],
            hatches=['.', None],
            colors='none',
            add_colorbar=False,
        )

        ax[col].set_extent([165, 180, -49, -34],crs=ccrs.PlateCarree()) 
        ax[col].set_xticks([165, 170, 175, 180],crs=ccrs.PlateCarree())
        ax[col].set_yticks([-49, -45, -40, -34],crs=ccrs.PlateCarree())

        
        ax[col].set_title(phases[col], fontsize=20)

    cbar = fig.colorbar(contour, ax=ax.ravel().tolist(), orientation="horizontal",ticks=np.arange(-1,1.5,0.5),
                        fraction=0.035,aspect=25)
    cbar.set_label(label=u'Surface air temperature anomaly (\u00B0C)', fontsize=24)  
    cbar.ax.tick_params(labelsize=20)

    #plt.xlabel('Longitude',fontsize=20)
    #plt.ylabel('Latitude', fontsize=20)
    plt.savefig('/g/data/e14/cp3790/Charuni/ModelExperiment/Revisions/thesis_fig3.5a_nz.png')

In [9]:
## Anom winds acting on anomalous temperatures AUS

def multiple_plots_sat_winds_mslp(year_phases_anom_xr, mslp_year_phases_anom_xr, uas_year_phases_anom_xr, 
                                  vas_year_phases_anom_xr):
    
    x = vas_year_phases_anom_xr['longitude']
    y = uas_year_phases_anom_xr['latitude']
    
    nrow = 1
    ncol = 2
    
    fig, ax = plt.subplots(nrows=nrow,ncols=ncol, figsize=(20,12), 
            subplot_kw={'projection': ccrs.PlateCarree()}) # Specifies the projection for the plots    
    
    data = year_phases_anom_xr.mean(dim='time')
    data2 = mslp_year_phases_anom_xr.mean(dim='time')
    data3 = uas_year_phases_anom_xr.mean(dim='time')
    data4 = vas_year_phases_anom_xr.mean(dim='time')
    
    phases = ['Sep-Feb', 'March-June']
    contour = None
    mask_mean = mask.mean(dim='time')
    
    for col in range(ncol):
        cur_data = data[col]
        contour = cur_data.plot(
            ax=ax[col],
            cmap=plt.cm.get_cmap('bwr', 25),
            vmin=-3, vmax=3,
            add_colorbar=False
        )
        ax[col].contour(
            mask_mean.longitude, 
            mask_mean.latitude,
            mask_mean.values,
            levels=[0.5], 
            cmap='gray')
        
        cur_data2 = data2[col]
        CS = ax[col].contour(
            cur_data2.longitude,
            cur_data2.latitude,
            cur_data2.values,
            levels=10,
            colors='k'
        )

        q = ax[col].quiver(
            x[::2], y[::2],
            data3[col][::2, ::2],
            data4[col][::2, ::2],
            scale=10,
            units='height', headwidth=9, headlength=8, headaxislength=6.5
        )
        
        
        ax[col].set_extent([143, 155, -45, -32],crs=ccrs.PlateCarree()) 
        ax[col].set_xticks([165, 170, 175, 180],crs=ccrs.PlateCarree())
        ax[col].set_yticks([-49, -45, -40, -34],crs=ccrs.PlateCarree())
        ax[col].set_title(phases[col], fontsize=18)
        ax[col].clabel(CS, CS.levels, inline=True, fontsize=16)
        
    cbar = fig.colorbar(contour, ax=ax.ravel().tolist(), orientation="horizontal",ticks=np.arange(-4,5.0,1),
                        fraction=0.015,aspect=45, extend='both')
    cbar.set_label(label=u'Surface air temperature anomaly (\u00B0C)', fontsize=18)  
    cbar.ax.tick_params(labelsize=14)
    
    plt.quiverkey(q, 0.1, 1.02, 1, '2 m/s', labelpos='E', fontproperties={'size':20})
    plt.savefig('/g/data/e14/cp3790/Charuni/ModelExperiment/anom_winds_anom_temp_AUS.png')