# Calculate the feedbacks in single model abrupt Solar change experiments by using the radiative kernel 

# Reference: 50 year running mean

## GFDL model version

>ref: [Soden, et.al., (2008)](https://doi.org/10.1175/2007JCLI2110.1)

Original code written by Chenggong Wang 
> Available on [Github](https://github.com/ChenggongWang/Radiative_Response_with_Radiative_Kernel)

Modified by Maya V. Chung Sept 11 2024

In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
import subprocess 
import time
from scipy import stats
sstart_time0 = time.time()
import netCDF4

In [2]:
# # download r3k functions on the fly from github
# # comment this line if you manually downloaded it
# ! wget https://raw.githubusercontent.com/ChenggongWang/Radiative_Response_with_Radiative_Kernel/main/Radiative_Repsonse_with_Raditive_kernel.py -O Radiative_Repsonse_with_Raditive_kernel.py

In [3]:
# import the function as r3k 
import Radiative_Repsonse_with_Raditive_kernel as r3k
r3k.compile_njit_functions()

@njit Functions compiling ...  | finished!


# Data required and preprocess

- 10 Monthly mean variables [CMIP6 standard output](https://pcmdi.llnl.gov/mips/cmip3/variableList.html): 
`ta, hus, ts, rlut, rsdt, rsut, rsutcs, rlutcs, rsus, rsds`

- Control (12 month climatology) and Perturbation (nx12 monthly data) data are need.
If you have already computed anomaly data, create a control data filled with zeros
- `ta` and `hus` are 4d `(time, plev, lat, lon)` variables, while others are 3d `(time, lat, lon)` variables.
`time` should be nx12 month (n years >= 1 year). 

- __Since data and kernel are discrete, we need regrid them to the same resolution.__
My practice: bring(regrid) the data to kernels.

### __Download example data (4GB)__

> from: https://drive.google.com/drive/folders/1wtsZ4-kRXmNe2MXb2-czs4osTKASGa0x?usp=sharing
>
> or: https://tigress-web.princeton.edu/~cw55/share_data/r3k_example_data.tar (extract files: tar -xf xxx.tar)
>
> make sure data and kernel in path: ./data/

>1. GFDL kernel file
>2. GFDL-CM4 piControl
>3. 4xCO2 experiment


# functions

In [4]:
def get_years(expname):
    import re
    import os
    
    # Directory containing your files
    directory = f"./{model}_regridded/"
    
    # Substring to search for
    partial_name = expname
    
    # Loop through files in the directory
    for file_name in os.listdir(directory):
        if partial_name in file_name:
            # Search for the '0101-1187' pattern
            match = re.search(r"\d{4}-\d{4}", file_name)
            if match:
                year_range = match.group()
            else:
                print(f"File: {file_name}, Year range not found.")

    return year_range

### get the years for loading in files later ###


In [5]:
def load_4D_var(var, model, exp, CTRLexp):

    pert_files = f"./temp_4D_{model}/{var}.{exp}.mon.*.nc.2x2.5"
    pert_ds = xr.open_mfdataset(pert_files, combine="by_coords", parallel=True)
    
    CTRL_files = f"./temp_4D_{model}/{var}.{CTRLexp}.mon.*.nc.2x2.5"
    CTRL_ds = xr.open_mfdataset(CTRL_files, combine="by_coords", parallel=True)

    pert_da = pert_ds["__xarray_dataarray_variable__"].rename(var) # rename, there was some error with the loading variable name
    CTRL_da = CTRL_ds["__xarray_dataarray_variable__"].rename(var)

    return pert_da, CTRL_da

In [6]:
def exp_to_pctstring(exp):
    # Dictionary to map patterns to their corresponding percent strings
    pattern_map = {
        'p6': '+6', 'm6': '-6',
        'p4': '+4', 'm4': '-4',
        'p2': '+2', 'm2': '-2',
        'p1': '+1', 'm1': '-1'
    }
    
    # Check each pattern in the dictionary and return the corresponding percent string
    for pattern, pctstring in pattern_map.items():
        if pattern in exp:
            return pctstring
    
    return None  # Return None if no match is found

### function for reading in control and experiment data

In [7]:
def read_postprocessed_data(model,exp,CTRLexp):
    """
    This function reads and processes the data.
    The required variables are: ta hus ts rlut rsdt rsut rlutcs rsutcs rsus rsds, which in CMIP6 standard name (ref: https://pcmdi.llnl.gov/mips/cmip3/variableList.html#overview)

    Return:
        var_cont: dict of variables (monthly data of the control experiment. dim: [month,plev,lat,lon], time.size=yearx12)
        var_pert: dict of variables (monthly data of the perturbation experiment. dim: [time,plev,lat,lon], time.size=yearx12) 
    """
    # CMIP6 standard name, 
    # ref: https://pcmdi.llnl.gov/mips/cmip3/variableList.html#overview
    
    #var_list = 'ta hus ts rlut rsdt rsut rlutcs rsutcs rsus rsds'.split()
    var_list        = "temp sphum t_surf olr  swdn_toa swup_toa olr_clr swup_toa_clr swup_sfc swdn_sfc ".split()
    var_list_rename = "ta   hus   ts     rlut rsdt     rsut     rlutcs  rsutcs       rsus     rsds     ".split()
        
    var_cont = {} # control data in a dictionary {var1: DataArray,...} 
    var_pert = {} # perturbation data in a dictionary {var1: DataArray,...} 
    
    for i, var in enumerate(var_list): 

        if var == 'temp' or var == 'sphum':
            # load the chunked interpolated data
            da_pert, da_cont = load_4D_var(var, model, exp, CTRLexp)
            
            # rename vertical coordinate to plev
            da_pert = da_pert.rename({'pfull': 'plev'})
            da_pert.coords['plev'].attrs['units'] = 'mb'
            
            da_cont = da_cont.rename({'pfull': 'plev'})
            da_cont.coords['plev'].attrs['units'] = 'mb'
            
        else:
            pertyears = get_years(exp)
            CTRLyears = get_years(CTRLexp)
            
            filepath = f"./{model}_regridded/{var}.{CTRLexp}.mon.{CTRLyears}.nc.2x2.5" # for all time control reference
            da_cont =  xr.open_dataarray(filepath, engine='netcdf4')

            filepath = f"./{model}_regridded/{var}.{exp}.mon.{pertyears}.nc.2x2.5"
            da_pert = xr.open_dataarray(filepath, engine='netcdf4')

        # subset the control to the perturbation experiment years
        if da_cont.time.size > da_pert.time.size:
            da_cont = da_cont.isel(time=slice(0,da_pert.time.size))
            
        var_cont[var_list_rename[i]] =  da_cont #ds.rename({var: var_list_rename[i]}).load()
        var_pert[var_list_rename[i]] =  da_pert #.rename({var: var_list_rename[i]}).load()
    
    # return processed data 
    return var_cont, var_pert

# read kernel data
def read_kernel_file(rk_source='GFDL'):
    rkpath = "./radiative_kernels/kernels_TOA_"+rk_source+"_CMIP6-standard.nc"
    f_RK =  xr.open_dataset(rkpath,decode_times=False) 
    if rk_source =='GFDL':
        f_RK =  f_RK.rename({'time': 'month'})
        f_RK.coords['month'] = np.arange(1,13,1)
    return f_RK

### Function for decomposing the radiative response with radiative kernels
* Relative to Control climatology calculated over 50 years in a running window updated every 5 years
* Saves to netcdf in 50-year chunks (~1 GB, 2-3 minutes each)

In [8]:
def decomp_rk_window_clim_save(new_RK,var_pert,var_cont,model,exp,CTRLexp):

    years = np.unique(var_pert['ta'].time.dt.year)  # ex. array([ 101,  102,  103, ..., 1998, 1999, 2000])
    pctstring = exp_to_pctstring(exp)
    chunk_size = 50

    print(years)
    
    # for progress tracker
    import time
    #from IPython.display import clear_output
    
    # Iterate over the years
    for i in range(0, len(years), chunk_size):     #for i in range(len(years)):
    
        appended_dict = {} # reset the list that the data is saved in
        start_time = time.time()
        
        start_year = years[i] # get start year of chunk (101, 201, etc.)
        end_year = years[i] + chunk_size - 1 # get end year of chunk (150, 200, 250, etc.)
        print(f'Start year: {start_year}')
    
        import os
        import glob
        
        # Define the file path pattern with a wildcard for the unknown end_year
        file_pattern = f'./rk_decomp/rk.decomp.{model}.{pctstring}%solar.50yrref.{start_year}-*.nc'
        
        # Use glob to search for matching files
        matching_files = glob.glob(file_pattern)
        
        if matching_files:  # If there are any matches
            print(f'Already exists: {matching_files}')
            continue
        else:
            print('computing...')

        # check if the file exists
        
        for iyear in range(i,i+chunk_size):
            #print(iyear) # debugging
            
            # stop if end of run does not align with chunk size
            if iyear >= np.size(years):
                continue
            
            current_year = years[iyear]
            print(f'Current year: {current_year}')
            
            #var_pert_temp = {key: da.isel(time=slice(i*12,i*12+12)) for key, da in var_pert.items()} # get that year's perturbation data
            var_pert_temp = {key: da.isel(time=slice(iyear*12,iyear*12+12)) for key, da in var_pert.items()} # get that year's perturbation data
    
            # Define the 50-year running window for the control experiment
            if current_year >= years[0] + 25 and current_year <= years[-1] - 25:
                # compute and update every 5 years
                if iyear % 5 == 0:
                    var_cont_mean = {key: da.isel(time=slice((iyear-25)*12, (iyear+25)*12)).groupby('time.month').mean(dim='time') for key, da in var_cont.items()}
                else:
                    pass # don't re-calculate mean
            
            elif current_year < years[0] + 25:
                # Use the average of the first 50 years
                if current_year == years[0]:
                    var_cont_mean = {key: da.isel(time=slice(0,50*12)).groupby('time.month').mean(dim='time') for key, da in var_cont.items()}
                else:
                    pass # don't re-calculate mean
            
            elif current_year >= years[-1] - 25 + 1:
                # Use the average of the last 50 years
                #if current_year == years[-1] - 25 + 1:
                var_cont_mean = {key: da.isel(time=slice(-50*12,None)).groupby('time.month').mean(dim='time') for key, da in var_cont.items()}
                #else:
                #    pass # don't re-calculate mean
            
            # Compute the difference between the perturbation and the running mean
            ds_rk =  r3k.decompose_dR_rk_toa_core(var_pert_temp,
                                                       var_cont_mean, # 12-month climatology calculated over window
                                                       new_RK) # use the interpolated kernel
            
            # Append the result to the list
            if not appended_dict:
                appended_dict = ds_rk  # Initialize with the first result
            else:
                appended_dict = {key: xr.concat([appended_dict[key], ds_rk[key]], dim='time') for key in appended_dict}
            
        # save the dict
        yr1, yr2 = int(appended_dict['dR_rh_lw'].time.dt.year.min()), int(appended_dict['dR_rh_lw'].time.dt.year.max())
        fileo = f'./rk_decomp/rk.decomp.{model}.{pctstring}%solar.50yrref.{yr1}-{yr2}.nc'
        xr.Dataset(appended_dict).to_netcdf(fileo)
        print(f'saved {fileo}')
    
        elapsed_time = (time.time() - start_time) / 60
        print(f"Elapsed time: {elapsed_time:.2f} minutes")
        
    # Combine the list into a DataArray
    #running_diff = xr.concat(running_diff, dim='time')

# main section

In [9]:
def get_exps(model):
    if model == 'FLOR':
        exp = [
            'p6p0sol_CTL1860_tigercpu_intelmpi_18_576PE',
            'p4p0sol_CTL1860_tigercpu_intelmpi_18_576PE',
            'p2p0sol_CTL1860_tigercpu_intelmpi_18_576PE',
            'p1p0sol_CTL1860_tigercpu_intelmpi_18_576PE',
            'm1p0sol_CTL1860_tigercpu_intelmpi_18_576PE',
            'm2p0sol_CTL1860_tigercpu_intelmpi_18_576PE',
            'm4p0sol_CTL1860_tigercpu_intelmpi_18_576PE',
            'm6p0sol_CTL1860_tigercpu_intelmpi_18_576PE'
        ]
        CTRLexp = 'CTL1860_newdiag_tigercpu_intelmpi_18_576PE'
        
    elif model == 'CM2.1p1':
        exp = [
            'CTL1860_p6pctSolar_tigercpu_intelmpi_18_80PE',
            'CTL1860_p4pctSolar_tigercpu_intelmpi_18_80PE',
            'CTL1860_p2pctSolar_tigercpu_intelmpi_18_80PE',
            'CTL1860_p1pctSolar_tigercpu_intelmpi_18_80PE',
            'CTL1860_m1pctSolar_tigercpu_intelmpi_18_80PE',
            'CTL1860_m2pctSolar_tigercpu_intelmpi_18_80PE',
            'CTL1860_m4pctSolar_tigercpu_intelmpi_18_80PE',
            'CTL1860_m6pctSolar_tigercpu_intelmpi_18_80PE',
        ]
        CTRLexp = 'CTL1860_tigercpu_intelmpi_18_80PE'

    return exp, CTRLexp

In [10]:
# each 50-year chunk ~2.5 min each

models = ['CM2.1p1','FLOR'] 

f_RK =  read_kernel_file('GFDL') # read in kernel file

for model in models:

    experiments, CTRLexp = get_exps(model)

    for exp in experiments:

        print(model, exp)
        
        ##### read data in; control, perturbation #####
        var_cont, var_pert = read_postprocessed_data(model, exp, CTRLexp)
        
        #print(var_cont['ta'].shape)
        
        ##### interpolate vertical levels (kernel --> data) #####
        plev_ori = f_RK.coords['plev']
        plev_solar = var_pert['ta'].plev.values
        plev_solar = np.asanyarray(plev_solar)
        #print(plev_solar,plev_ori)
        
        # interpolate data onto the plev grid
        new_RK=f_RK.interp(plev=plev_solar)
        
        # # check with plots 
        # new_RK.lw_ta.mean(dim=['month','lon']).plot(cmap=plt.cm.jet, levels=np.arange(-0.6,0,0.05),figsize=(3,3))
        # f_RK.lw_ta.mean(dim=['month','lon']).plot(cmap=plt.cm.jet, levels=np.arange(-0.6,0,0.05),figsize=(3,3))
        
        decomp_rk_window_clim_save(new_RK,var_pert,var_cont,model,exp,CTRLexp)

CM2.1p1 CTL1860_p6pctSolar_tigercpu_intelmpi_18_80PE
[101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154
 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172
 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190
 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208
 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226
 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244
 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262
 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280
 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298
 299 300 301 302 303 304 305 306 307 308 309 310]
Start year: 101
Already exists: ['./rk_decomp/rk.decomp.CM2.1p1.+6%solar.50yrref.101-150.nc']


### Continue to 2_compute_plot_total_feedback and 2_plot_feedback_decomp_GFDL.