# Train four different ML models with PyTorch to test linearity and independence

1. Train four different ML algorithums on all CMIP6 Global Climate Model large ensembles (LEs) with >15 members
2. Train four different ML algorithums on the multi-model CMIP6 ensemble (n=41) with 1st/2nd/3rd members as training/validation/test members
3. Remove one variable at a time from the linear model, trained on the first 75% of the LE members
4. Remove one variable at a time from the linear model, trained on the CMIP6 1st members  





In [1]:
import torch
import torch.nn as nn
import numpy as np
import scipy.stats as stats
import xarray as xr
import matplotlib.pyplot as plt
import datetime
import dask
import glob
import os
import re
import pickle
print(datetime.datetime.now())

2023-10-17 09:54:05.834879


In [2]:
torch.manual_seed(0); #for reproducability

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cpu


## Load data and define useful functions and constants

In [40]:
#load useful data such as the GCM names and their DOIs
CMIP6_info = xr.open_dataset(
    '/glade/work/cwpowell/low-frequency-variability/raw_data/CMIP6_info/'\
    +'CMIP6_modeling_center_members_doi.nc'
)

#loop through all GCMs and make list of all realizations common to regional
#SIC data and all CVDP variables
good_GCM_mem = {}
for GCM in CMIP6_info['model'].values:
    try:
        #open the regional SIC file and list the members, all members in file have
        #already gone through quality control to check no nan or 0 values
        SIC = xr.open_dataset(
            '/glade/work/cwpowell/low-frequency-variability/input_data/'\
            +f'Regional_SIC_detrended_lowpass_filter_{GCM}_1920_2014.nc')
        # SIC = xr.open_dataset(
        #     '/glade/work/cwpowell/low-frequency-variability/input_data/'\
        #     +f'Regional_SIC_bandpass_2_20_year_{GCM}_1920_2014.nc')
        SIC_mems = SIC['member'].values

        #open the CVDP data and list all of the members which do not have nan values
        #for the AMO - if they do have nan values do not use list that member.
        CVDP = xr.open_dataset(
            '/glade/work/cwpowell/low-frequency-variability/input_data/'\
            +f'CVDP_standardized_linear_detrended_1920_2014_historical_{GCM}.nc')
        CVDP_mems = CVDP['member'].where(
            ~xr.ufuncs.isnan(CVDP['AMO']).max('time'), drop=True)

        #now list the members with both good SIC and CVDP data, if at least 1
        if len(np.intersect1d(CVDP_mems, SIC_mems)) > 0:
            good_GCM_mem[GCM] = np.sort(np.intersect1d(CVDP_mems, SIC_mems))
        else:
            print(GCM, 'No members with good SIC and CVDP data')
    
    except FileNotFoundError:
        print(GCM, 'Either SIC or CVDP data missing')
    
#remove CESM2 from the dictionary as replace with CESM2-LENS data
good_GCM_mem.pop('CESM2', None);

#save the member data
# with open('/glade/work/cwpowell/low-frequency-variability/raw_data/CMIP6_info/'\
#           +'CMIP6_members_CVDP_and_SIC.pickle', 'wb') as handle:
#     pickle.dump(good_GCM_mem, handle, protocol=pickle.HIGHEST_PROTOCOL)

AWI-CM-1-1-MR Either SIC or CVDP data missing
AWI-ESM-1-1-LR Either SIC or CVDP data missing
CAS-ESM2-0 Either SIC or CVDP data missing
CanESM5-1 No members with good SIC and CVDP data
E3SM-2-0 No members with good SIC and CVDP data
FGOALS-f3-L Either SIC or CVDP data missing
FGOALS-g3 Either SIC or CVDP data missing
GISS-E2-1-G-CC Either SIC or CVDP data missing
GISS-E3-G Either SIC or CVDP data missing
IITM-ESM Either SIC or CVDP data missing
KACE-1-0-G Either SIC or CVDP data missing
MCM-UA-1-0 Either SIC or CVDP data missing


In [41]:
#define the train/validation/test split for the large ensembles with at least 15
#members
train_valid_test = [0.75, 0.15, 0.10]

LE_train_mem = {}
LE_valid_mem = {}
LE_test_mem  = {}
LE_GCM_list = []
for GCM in np.sort(list(good_GCM_mem.keys())):
    
    n_mem = len(good_GCM_mem[GCM])
    
    if n_mem > 15:
        LE_GCM_list.append(GCM)
        train_n = int(np.ceil(n_mem*train_valid_test[0]))
        test_n  = int(np.floor(n_mem*train_valid_test[2]))

        LE_train_mem[GCM] = good_GCM_mem[GCM][:train_n]
        LE_valid_mem[GCM] = good_GCM_mem[GCM][train_n:-test_n]
        LE_test_mem[GCM]  = good_GCM_mem[GCM][-test_n:]
        
LE_GCM_list = np.sort(LE_GCM_list)

#add in the CMIP6 member number codes
CVDP_CMIP6_xr = xr.open_dataset(
    '/glade/work/cwpowell/low-frequency-variability/input_data/'\
    +'CVDP_standardized_linear_detrended_1970_2014_historical_CMIP6.nc')

SIC_CMIP6_xr = xr.open_dataset(
    '/glade/work/cwpowell/low-frequency-variability/input_data/'\
    +'Regional_SIC_detrended_lowpass_filter_CMIP6_1970_2014.nc')

# CVDP_CMIP6_xr = xr.open_dataset(
#     '/glade/work/cwpowell/low-frequency-variability/input_data/'\
#     'CVDP_standardized_highpass_filt_40_yr_1920_2014_historical_CMIP6.nc')

# SIC_CMIP6_xr = xr.open_dataset(
#     '/glade/work/cwpowell/low-frequency-variability/input_data/'\
#     'Regional_SIC_bandpass_2_40_year_CMIP6_1920_2014.nc')

LE_train_mem['CMIP6'] = CVDP_CMIP6_xr['member'].sel(
    member=slice(1000,1999)).values
LE_valid_mem['CMIP6'] = CVDP_CMIP6_xr['member'].sel(
    member=slice(2000,2999)).values
LE_test_mem['CMIP6']  = CVDP_CMIP6_xr['member'].sel(
    member=slice(10000,None)).values

#add in the CMIP6 member number codes for 30+ members
CVDP_CMIP6_30_xr = xr.open_dataset(
    '/glade/work/cwpowell/low-frequency-variability/input_data/'\
    +'CVDP_standardized_linear_detrended_1920_2014_historical_CMIP6_30.nc')

SIC_CMIP6_30_xr = xr.open_dataset(
    '/glade/work/cwpowell/low-frequency-variability/input_data/'\
    +'Regional_SIC_detrended_lowpass_filter_CMIP6_30_1920_2014.nc')

# CVDP_CMIP6_30_xr = xr.open_dataset(
#     '/glade/work/cwpowell/low-frequency-variability/input_data/'\
#     +'CVDP_standardized_highpass_filt_40_yr_1920_2014_historical_CMIP6_30.nc')

# SIC_CMIP6_30_xr = xr.open_dataset(
#     '/glade/work/cwpowell/low-frequency-variability/input_data/'\
#     +'Regional_SIC_bandpass_2_40_year_CMIP6_30_1920_2014.nc')

LE_train_mem['CMIP6_30'] = CVDP_CMIP6_30_xr['member'].sel(
    member=slice(10000,19999)).values
LE_valid_mem['CMIP6_30'] = CVDP_CMIP6_30_xr['member'].sel(
    member=slice(20000,29999)).values
LE_test_mem['CMIP6_30']  = CVDP_CMIP6_30_xr['member'].sel(
    member=slice(30000,None)).values

#add in the PI Control member number codes
CVDP_PI_xr = xr.open_dataset(
    '/glade/work/cwpowell/low-frequency-variability/input_data/'\
    +'CVDP_standardized_PI_Control_MMLE_500_first_3_train.nc')

SIC_PI_xr = xr.open_dataset(
    '/glade/work/cwpowell/low-frequency-variability/input_data/'\
    'Regional_SIC_lowpass_filter_PI_Control_MMLE_500_first_3_train.nc')

LE_train_mem['PI_500'] = CVDP_PI_xr['member'].sel(
    member=slice(1000,9999)).values
LE_valid_mem['PI_500'] = CVDP_PI_xr['member'].sel(
    member=slice(20000,29999)).values
LE_test_mem['PI_500']  = CVDP_PI_xr['member'].sel(
    member=slice(30000,None)).values

In [5]:
#generate a list of variable and season names
CVDP_sample = xr.open_dataset('/glade/work/cwpowell/low-frequency-variability/'\
    +'input_data/CVDP_standardized_linear_detrended_1920_'\
    +'2014_historical_CanESM5.nc')

var_month_list = []
for i in CVDP_sample.to_array()['variable'].drop_sel(
    variable=['AMOC','NINO12','NINO3','NINO4']).values:
    for month_num in [1,4,7,10]:
        var_month_list.append(str(i)+'_'+str(month_num))
        
var_month_list.append('RAND_1')
var_month_list.append('RAND_4')
var_month_list.append('RAND_7')
var_month_list.append('RAND_10')

In [14]:
# for i, var_month_name in enumerate(var_month_list):
#     print(i,var_month_name)

# Define functions for loading feature and target data, as well as training the 4 ML models

In [52]:
def load_CVDP(model_name, month_, lag_, start_end_yr, extra_drop=None, 
              white_noise=None):
    '''
    Load the training, validation and testing data of the climate modes of
    varaibility (features), corresponding to the time period of the sea ice
    data (target) and the lag time. Additionally, remove certain climate modes
    and/or include a white noise variable.
    
    Parameters
    ----------
    model_name: str,
        The name of the GCM which provides a sufficiently large ensemble.
    month_: int,
        The month of the target sea ice concentration, this will determine
        how many years lagged each season is, e.g. if sea ice concentration
        is for October the seasonal CVDP data will be lagged 2 years for DJF,
        MAM and JJA, but 3 years for SON.
    lag_: int,
        The number of years the CVDP data is offset before the sea ice data 
    start_end_yr: list, length 2 with integers,
        The start and end years (insclusive) for the sea ice concentration data.
    extra_drop: none or list of strings,
        If False, no additonal variables are dropped. If a list of a string or
        strings, those variables listed will be removed from the CVDP data. 
    white_noise: bool,
        If True, add a variable of normalized random values in a gaussian 
        distribution for the 4 seasons, all members and all years.
    
    Returns
    ----------
    CVDP_train: PyTorch tensor,
        The stacked CVDP for the training members, with shape 
        ([member x year],[variable x month]) e.g for the first 75% of the 65
        CanESM5 members used for training: ([49x74],[17x4]) = (3626,68).
    CVDP_valid: PyTorch tensor,
        The stacked CVDP for the validation members, with shape
        ([member x year],[variable x month]).
    CVDP_test: PyTorch tensor,
        The stacked CVDP for the testing members, with shape
        ([member x year],[variable x month]).    
    ''' 
    
    #load CVDP features and convert to seasonal data
    #select the NetCDF file with lowpass filtering or linear detrending
    # CVDP_year_month = xr.open_dataset(
    #     '/glade/work/cwpowell/low-frequency-variability/input_data/'\
    #     +'CVDP_standardized_linear_detrended_1920_2014_historical_'\
    #     +f'{model_name}.nc'
    # )
    ############ !!!!!!!! COMMENT ABOVE/BELOW FOR PI OR HIST !!!!!!! ###########
    CVDP_year_month = xr.open_dataset(
        '/glade/work/cwpowell/low-frequency-variability/input_data/'\
        +'CVDP_standardized_PI_Control_MMLE_500_first_3_train.nc'
    )

    # CVDP_year_month = xr.open_dataset(
    #     '/glade/work/cwpowell/low-frequency-variability/input_data/'\
    #     +'CVDP_standardized_highpass_filt_40_yr_1920_2014_historical_'\
    #     +f'{model_name}.nc'
    # )    
        
        
    CVDP_year_month = CVDP_year_month.to_array('variable').sortby('time')

    month_seperate = []
    for i in [1,4,7,10]:
        temp_data = CVDP_year_month.sel(
            time=CVDP_year_month['time.month']==i)
        temp_data['time'] = np.arange(1920,2015)
        # temp_data['time'] = np.arange(1970,2015)
        month_seperate.append(temp_data)

    CVDP_data = xr.concat((month_seperate), dim='month')
    CVDP_data['month'] = [1,4,7,10]
    CVDP_data = CVDP_data.rename({'time':'year'})

    ######### !!!!!!!!!!!! UNCOMMENT FOR NON PI CONTROL !!!!!!!!!!!! ##########
#     CVDP_data = CVDP_data.drop_sel(variable=['AMOC','NINO12','NINO3','NINO4'])
    
#     if type(extra_drop) != type(None): #drop extra variables from the CVDP data
#         CVDP_data = CVDP_data.drop_sel(variable=extra_drop)
    #########  PI CONTROL DOES NOT NEED THIS AS ALREADY NOT INCLUDED  ##########
    
    #now stack the CVDP data into X members and years, Y features
    CVDP_train = []
    CVDP_test  = []
    CVDP_valid = []
    for lag_season in [1,4,7,10]: 
        if lag_season >= month_:
            extra_year = 1
        else:
            extra_year = 0

        CVDP_month_data = CVDP_data.sortby('member')
            
        CVDP_month_data = CVDP_month_data.sel(
            month=lag_season).sel(
            year=slice(str(start_end_yr[0]-lag_-extra_year), 
                       str(start_end_yr[1]-lag_-extra_year)))
        CVDP_month_data['year'] = np.arange(
            0,start_end_yr[1]-start_end_yr[0]+1)
        
        #now, optionally add in white noise as 4 seasons of a new variable
        if white_noise:
            white_noise_month = (
                CVDP_month_data.copy().isel(variable=0) * 0 + np.random.normal(
                    loc=0, scale=1, size=(len(CVDP_month_data['member']),
                                          len(CVDP_month_data['year'])))
            )

            white_noise_month['variable'] = 'RAND'
            
            CVDP_month_data = xr.concat(
                (CVDP_month_data, white_noise_month), dim='variable'
            ) 
        
        CVDP_train.append(CVDP_month_data.sel(member=LE_train_mem[model_name]))
        CVDP_valid.append(CVDP_month_data.sel(member=LE_valid_mem[model_name]))
        CVDP_test.append(CVDP_month_data.sel(member=LE_test_mem[model_name]))
        
    CVDP_train_stacked = xr.concat((CVDP_train),'month').stack(
        member_time=('member','year')).stack(
        var_month=('variable','month'))
    CVDP_train = torch.Tensor(CVDP_train_stacked.values)
    
    CVDP_valid_stacked = xr.concat((CVDP_valid),'month').stack(
        member_time=('member','year')).stack(
        var_month=('variable','month'))
    CVDP_valid = torch.Tensor(CVDP_valid_stacked.values)

    CVDP_test_stacked = xr.concat((CVDP_test),'month').stack(
        member_time=('member','year')).stack(
        var_month=('variable','month'))
    CVDP_test = torch.Tensor(CVDP_test_stacked.values)
    
    return(CVDP_train, CVDP_valid, CVDP_test)

In [51]:
def load_SIC(model_name, month_, region_, start_end_yr, 
             ensemble_detrend=None
            ):
    '''
    Load the sea ice concentration anomalies (targets) for a specific GCM,
    month and time period.
    
    Parameters
    ----------
    model_name: str,
        The name of the GCM which provides a sufficiently large ensemble.
    month_: int,
        The month of the target sea ice concentration anomalies. 
    region_: int,
        The region number for the sea ice concentration anomalies.
    start_end_yr: list, length 2 with integers,
        The start and end years (inclusive) for the sea ice concentration 
        anomalies.
    ensemble_detrend: None, bool,
        Whether to use the lowpass filtered data (when keyword is None or False)
        or the ensemble mean for when the keyword is True.
    
    Returns
    ----------
    target_train: PyTorch tensor,
        The stacked sea ice concentration anomalies for the training members,
        with shape (member, year) e.g for the first 75% of the 65 CanESM5
        members used for training: (49,74). 
    target_valid: PyTorch tensor,
        The stacked sea ice concentration anomalies for the validation members,
        with shape (member, year).
    target_test: PyTorch tensor,
        The stacked sea ice concentration anomalies for the test members, with
        shape (member, year).
    
    '''
    #load SIC targets with the desired type of filtering/dretrending
    if ensemble_detrend:
        SIC_data = xr.open_dataset(
            '/glade/work/cwpowell/low-frequency-variability/input_data/'\
            'Regional_SIC_bandpass_2_40_year_'\
            f'{model_name}_1920_2014.nc'
        )
    else:
        SIC_data = xr.open_dataset(
            '/glade/work/cwpowell/low-frequency-variability/input_data/'\
            'Regional_SIC_lowpass_filter_PI_Control_MMLE_500_first_3_train.nc')
        # SIC_data = xr.open_dataset(
        #     '/glade/work/cwpowell/low-frequency-variability/input_data/'\
        #     +f'Regional_SIC_detrended_lowpass_filter_{model_name}_1920_2014.nc'
        # )
        
    #select the years for this analysis period and sort by member
    SIC_data = SIC_data['SIC'].sortby('member').sel(
        year=slice(str(start_end_yr[0]), str(start_end_yr[1])))
   
    #convert 2D xr.DataArray to 1D xr.DataArray to np.ndarray to torch.Tensor
    target_train = torch.from_numpy(
        SIC_data.sel(member=LE_train_mem[model_name]).sel(month=month_).sel(
            region=region_).stack(member_time=('member','year')).values
    )
    target_valid = torch.from_numpy(
        SIC_data.sel(member=LE_valid_mem[model_name]).sel(month=month_).sel(
            region=region_).stack(member_time=('member','year')).values
    )
    target_test = torch.from_numpy(
        SIC_data.sel(member=LE_test_mem[model_name]).sel(month=month_).sel(
            region=region_).stack(member_time=('member','year')).values
    )
    
    return(target_train, target_valid, target_test)

In [47]:
def train_4_ML_for_LE(model_name, month_, region_list, lag_list, start_end_yr, 
                      n_epoch, learn_rates, white_noise_=None,
                      ens_bool=None):
    '''
    Train the 4 machine learning models for a given large ensemble on a 
    specificed of sea ice concentration data for a specified set of months,
    regions, and lags.
    
    Parameters
    ----------
    model_name: str,
        The name of the GCM which provides a sufficiently large ensemble.
    month_: int,
        The months of sea ice concentration anomalies on which to train the 
        machine learning model. 
    region_list: list of ints,
        A list of the regional sea ice anomalies to train the model on 
        separately.
    lag_list: list of ints,
        The range of lagged year on which to separately train the machine 
        learning model 
    start_end_yr: list, length 2 with integers,
        The start and end years (inclusive) for the sea ice concentration data.
    n_epoch: int, 
        number of epochs with which to train all of the machine learning 
        models.
    learn_rates: list of ints,
        List of 4 integers corresponding to the learning rate for each of the 4
        machine learning algorithums.     
    white_noise: bool,
        If True, add a variable of normalized random values in a gaussian 
        distribution for the 4 seasons, all members and all years.
    ens_bool: None, bool,
        Whether to use the lowpass filtered sea ice concentration data (when 
        keyword is None or False) or the ensemble mean when the keyword is True.
                      
    Returns
    ----------
    r_values_xr: xarray.Dataset,
        The pearson correlation coefficients for the 4 machine learning 
        algorithms from the validation data.
    all_1_grads_xr: xarray.Dataset,
        The gradients from the machine learning model using a simple multiple
        linear regression model.    
    '''

    if white_noise_:
        n_features = 15*4
    else:
        # n_features = 14*4 #for using 4 seasons for all variables
        # n_features = 50 #for using annual value of NPI and IPO
        n_features = 36#33 #for removing SAM, IOD, NPI, NAM, AMM
        
    if model_name in ['CMIP6', 'CMIP6_30', 'PI_500']:
        doi_model = '10.5194/gmd-9-1937-2016'
    else:
        doi_model = CMIP6_info['doi'].sel(model=model_name).values
    
    all_1_weights = np.empty(
        [len(region_list), len(lag_list), n_features], dtype=float)

    all_r_values = np.empty(
        [len(region_list), len(lag_list), 4], dtype=float)

    for region_i, region_ in enumerate(region_list):

        for lag_i, lag_ in enumerate(lag_list):
            #load the feature and target data for the correct model, month,
            #region and lag
            CVDP_train, CVDP_valid, CVDP_test = load_CVDP(
                model_name, month_, lag_, start_end_yr, 
                extra_drop=None, white_noise=white_noise_
            )
            # '''This block removes the MAM, JJA, SON seasons of NPI and IPO'''
            # CVDP_train = torch.cat(
            #     (CVDP_train[:,0:5], CVDP_train[:,8:29], CVDP_train[:,32:]),
            #     dim=1)
            # CVDP_valid = torch.cat(
            #     (CVDP_valid[:,0:5], CVDP_valid[:,8:29], CVDP_valid[:,32:]),
            #     dim=1)
            # '''This block removes the MAM, JJA, SON seasons of NPI and IPO'''
            
            '''Remove NAM, SAM, IOD, NPI, AMM and non DJF IPO values'''
            #NOTE CVDP_train[:,0:5], CVDP_train[:,8:16] for just DJF IPO
            #OR just CVDP_train[:,0:16] for all IPO seasons
            #Below is for the non PI-Control data
#             CVDP_train = torch.cat(
#                 (CVDP_train[:,0:16], CVDP_train[:,20:24], 
#                  CVDP_train[:,36:48], CVDP_train[:,52:]), dim=1
#             )
            
#             CVDP_valid = torch.cat(
#                 (CVDP_valid[:,0:16], CVDP_valid[:,20:24],
#                  CVDP_valid[:,36:48], CVDP_valid[:,52:]), dim=1
#             )
            
            #these 2 lines are for the PI control data which was reduced already
            # CVDP_train = torch.cat((CVDP_train[:,0:5], CVDP_train[:,8:]), dim=1)
            # CVDP_valid = torch.cat((CVDP_valid[:,0:5], CVDP_valid[:,8:]), dim=1)

            '''Remove NAM, SAM, IOD, NPI, AMM and non DJF IPO values'''
            
#             CVDP_train = torch.index_select(
#                 CVDP_train,1, torch.LongTensor(
#                     [3,7,10,15,19,20,24,28,34,38,42,47,48,55]))
            
#             CVDP_valid = torch.index_select(
#                 CVDP_valid,1, torch.LongTensor(
#                     [3,7,10,15,19,20,24,28,34,38,42,47,48,55]))            
            

            target_train, target_valid, target_test = load_SIC(
                model_name, month_, region_, start_end_yr, 
                ensemble_detrend=ens_bool,
            )

            ML_model_computed = []
            for ML_i in range(4):
                if ML_i == 0:
                    ML_model_use = nn.Sequential(
                        nn.Linear(n_features,1,bias=False)
                    )
                elif ML_i == 1:
                    ML_model_use = nn.Sequential(
                        nn.Linear(n_features,1,bias=False), nn.ReLU()
                    )
                elif ML_i == 2:
                    ML_model_use = nn.Sequential(
                        nn.Linear(n_features,n_neurons), 
                        nn.Linear(n_neurons,n_neurons),
                        nn.Linear(n_neurons,1)
                    )
                    
                elif ML_i == 3:
                    ML_model_use = nn.Sequential(
                        nn.Linear(n_features,n_neurons), nn.ReLU(),
                        nn.Linear(n_neurons,n_neurons), nn.ReLU(),
                        nn.Linear(n_neurons,1)
                    )

                optimizer = torch.optim.Adam(
                    params=ML_model_use.parameters(), 
                    lr=learn_rates[ML_i]
                )

                ##### train the model #####
                train_r, valid_r, valid_loss = [],[],[]
                for epoch in range(n_epoch):
                    # TRAIN
                    prediction = ML_model_use(CVDP_train)
                    optimizer.zero_grad() #reset the gradients to zeros
                    loss = loss_fcn(prediction[:,0].double(), target_train)
                    loss.backward()
                    optimizer.step()
                    train_r.append(
                        np.corrcoef(prediction[:,0].detach().numpy(),
                                    target_train.detach().numpy()
                                   )[1][0]
                    )

                    if ML_i == 0:
                        ML_model_computed.append(ML_model_use)

                    ##### validate the model #####
                    with torch.no_grad():
                        p_val = ML_model_use(CVDP_valid)
                        loss_val = loss_fcn(p_val[:,0].double(), 
                                            target_valid
                                           )
                        valid_loss.append(loss_val)
                        
                        valid_r.append(
                            np.corrcoef(p_val[:,0].detach().numpy(),
                                        target_valid.detach().numpy()
                                       )[1][0]
                        )

                #select highest validation r value from all epochs
                # all_r_values[region_i][lag_i][ML_i] = np.max(valid_r)
                all_r_values[region_i][lag_i][ML_i] = valid_r[np.argmin(
                    valid_loss)]

                if ML_i == 0: #record the weights of the best epoch
                    # all_1_weights[region_i][lag_i] = np.ravel(
                    #     ML_model_computed[np.argmax(valid_r)][
                    #         0].weight[0,:].detach().numpy())
                    all_1_weights[region_i][lag_i] = np.ravel(
                        ML_model_computed[np.argmin(valid_loss)][
                            0].weight[0,:].detach().numpy())

    if model_name == 'CMIP6':
        model_name = 'CMIP6 multi-model large ensemble'
    
    #save the r values to NetCDF every model and month
    r_values_xr = xr.Dataset(
        data_vars = {
            'r_value':(['region', 'lag', 'ML_model'], all_r_values)},
        coords = {
            'region':region_list, 'lag':lag_list, 'ML_model':range(4)},
    )

    r_values_xr_attrs = {
        'Description': 'Pearson correlation coefficient for validation '\
            +'data (15%) of the availible members of the global climate '\
            +f'model {model_name}. 4 different machine learning models '\
            +'fit the features of seasonal climate modes computed by '\
            +'the Climate Variability Diagnostics Package with 2 year '\
            +'lowpass filtered sea ice concentration. These climate modes are '\
            +f'as follows: AMO, IPO, NINO34, PDO, ATN, '\
            +'NPO, PNA, NAO, TAS. The ML models are '\
            +'trained at different lag times of 1-20 years and for each '\
            +'region (regions are defined by NSIDC MASIE-NH Version 1 '\
            +'(doi:10.7265/N5GT5K3K). The ML_model dimension refers to '\
            +'the 4 different ML architectures with PyTorch all using L1 '\
            +f'loss function and Adam optimizer and {n_epoch} '\
            +'epochs:'+'\n'+'1 - nn.Sequential'\
            +f'(nn.Linear({n_features},1,bias=False)), learning rate = '\
            +f'{learn_rates[0]}.'+'\n'+'2 - nn.Sequential(nn.Linear('\
            +f'{n_features},1,bias=False), nn.ReLU()), learning rate = '\
            +f'{learn_rates[1]}.'+'\n3 - nn.Sequential(nn.Linear('\
            +f'{n_features},{n_neurons}), nn.Linear({n_neurons},'\
            +f'{n_neurons}), nn.Linear({n_neurons},1)), learning rate = '\
            +f'{learn_rates[2]}.'+'\n'+'4 - nn.Sequential(nn.Linear('\
            +f'{n_features},{n_neurons}), nn.ReLU(),nn.Linear({n_neurons}'\
            +f',{n_neurons}), nn.ReLU(), nn.Linear({n_neurons},1)), '\
            +f'learning rate {learn_rates[3]}.',
        'Timestamp'  : str(datetime.datetime.utcnow().strftime(
            "%H:%M UTC %a %Y-%m-%d")),
        'Data source': f'CMIP6 global climate model {model_name}, '\
            +f'doi:{doi_model} sea ice concentration and climate modes '\
            +'calculated by the Climate Variability Diagnostics Package '\
            +'(doi:10.1002/2014EO490002)',
        'Analysis'   : 'https://github.com/chrisrwp/low-frequency-'\
            +'variability/blob/main/neural_network/4_PyTorch_model_'\
            +'configurations.ipynb',
    }

    r_values_xr.attrs = r_values_xr_attrs

    #save the linear 1 layer neural network weights to NetCDF
    if white_noise_:
        var_month_use = np.array(var_month_list).copy()
    else:
        # var_month_use = np.array(var_month_list[:-4]).copy()
        '''REMOVE THIS BLOCK IF WANT MORE THAN 1 VALUE NPI, IPO'''
        # var_month_use = np.delete(np.array(var_month_list[:-4]).copy(),[5,6,7,29,30,31])
        # var_month_use = ['AMO','IPO','NINO34','PDO','AMM','ATN','IOD','NPI',
        #                  'NAM','NPO','PNA','NAO','SAM','TAS']
        '''REMOVE THIS BLOCK IF WANT MORE THAN 1 VALUE NPI, IPO'''
        #use line below only if want 9 vars with single DJF IPO value
        # var_month_use = np.delete(np.array(var_month_list[:-4]).copy(),
        #     [5,6,7,16,17,18,19,24,25,26,27,28,29,30,31,32,33,34,35,48,49,50,51])
        
        #use line below if using 9 vars with ALL IPO values
        var_month_use = np.delete(np.array(var_month_list[:-4]).copy(),
            [16,17,18,19,24,25,26,27,28,29,30,31,32,33,34,35,48,49,50,51])
        
    all_1_weights_xr = xr.Dataset(
        data_vars = {
            'weights':(['region', 'lag', 'mode_month'], all_1_weights),
        },
        coords = {'region':region_list, 'lag':lag_list,
                  'mode_month':var_month_use,
        },
    )

    all_1_weights_xr_attrs = r_values_xr_attrs.copy()
    all_1_weights_xr_attrs['Description'] = 'Weights of the '\
        +'linear model fit for the validation data (15%) of the availible '\
        +f'members of the global climate model {model_name}. {n_features} '\
        +'features of seasonal climate modes computed by the Climate '\
        +'Variability Diagnostics Package with 2 year lowpass filtered sea '\
        +'ice concentration. These climate modes are as follows: AMO, IPO, '\
        +'NINO34, PDO, ATN, NPO, PNA, NAO, TAS. '\
        +'The model is trained at different lag times '\
        +'of 1-20 years and for each region (regions are defined by NSIDC '\
        +'MASIE-NH Version 1 (doi:10.7265/N5GT5K3K). The model uses '\
        +f'PyTorch with a L1 loss function, {n_epoch} epochs, Adam'\
        +f' optimizer and is defined by nn.Linear({n_features},1,'\
        +f'bias=False)), learning rate = {learn_rates[0]}.'

    all_1_weights_xr.attrs = all_1_weights_xr_attrs
        
    return(r_values_xr, all_1_weights_xr)

# 1. Train the 4 ML models with the first 75% of LE members, tested on next 15%

In [11]:
#use this cell for running each of the individual GCM LEs
loss_fcn  = torch.nn.L1Loss() #keep all models sparse
n_neurons = 6 

for month__ in [1,2,3,4,5,6,7,8,10,11,12]:
    print(datetime.datetime.now(), month__)

    for model_name in LE_GCM_list[:-1]:
        print(datetime.datetime.now(), model_name)
        
        r_values_xr, all_1_weights_xr = train_4_ML_for_LE(
            model_name, month_=month__, region_list=[1,2,3,4,5,6,11], 
            lag_list=np.arange(1,21), start_end_yr=[1941,2014], n_epoch=2000, 
            learn_rates=[5e-4, 5e-4, 1e-4, 2e-4], white_noise_=False,
            ens_bool=False,
        )    

        r_values_xr.to_netcdf(
            '/glade/work/cwpowell/low-frequency-variability/PyTorch_models/'\
            +f'validation_r_values_4ML_{model_name}_month_'\
            +f'{str(month__).zfill(2)}_var_9_all_IPO_lowpass_filt.nc')
        all_1_weights_xr.to_netcdf(
            '/glade/work/cwpowell/low-frequency-variability/PyTorch_models/'\
            +f'weights_linear_{model_name}_month_{str(month__).zfill(2)}_'\
            +'var_9_all_IPO_lowpass_filt.nc')
    

2023-10-17 10:10:19.898648 9
2023-10-17 10:10:19.898708 ACCESS-ESM1-5
2023-10-17 10:33:58.374795 CESM2-LENS
2023-10-17 10:58:37.787520 CNRM-CM6-1
2023-10-17 11:21:01.013167 CanESM5
2023-10-17 11:47:02.387114 EC-Earth3
2023-10-17 12:09:38.576934 GISS-E2-1-G
2023-10-17 12:33:38.798835 GISS-E2-1-H
2023-10-17 12:56:16.382415 IPSL-CM6A-LR
2023-10-17 13:19:33.358826 MIROC-ES2L
2023-10-17 14:07:05.955468 MPI-ESM1-2-LR
2023-10-17 14:30:02.273150 NorCPM1


In [None]:
# loss_fcn  = torch.nn.L1Loss() #keep all models sparse
# n_neurons = 8

# for month__ in [8,9]:
#     print(datetime.datetime.now(), month__)

#     for model_name in LE_GCM_list:
#         print(datetime.datetime.now(), model_name)
#         #N.B. changed learn rates, should invetigate further to find best values
#         #for now the learn rates are doubled and so is the epoch
#         r_values_xr, all_1_weights_xr = train_4_ML_for_LE(
#             model_name, month_=month__, region_list=[2,11], 
#             lag_list=np.arange(1,21), start_end_yr=[1941,2014], n_epoch=4000, 
#             learn_rates=[1e-3, 1e-3, 2e-4, 4e-4], white_noise_=False,
#             ens_bool=True,
#         )    

#         r_values_xr.to_netcdf(
#             '/glade/work/cwpowell/low-frequency-variability/PyTorch_models/'\
#             +f'validation_r_values_4ML_{model_name}_month_'\
#             +f'{str(month__).zfill(2)}_var_14_annual_IPO_NPI_ens_detrend.nc')
#         all_1_weights_xr.to_netcdf(
#             '/glade/work/cwpowell/low-frequency-variability/PyTorch_models/'\
#             +f'weights_linear_{model_name}_month_{str(month__).zfill(2)}_'\
#             +'var_14_annual_IPO_NPI_ens_detrend.nc')
    

# 2. Train the 4 ML models with all CMIP6 GCMs, 1st/2nd/3rd+ or 1-7/8-9/10+ members training/validation/testing

### Make the CMIP6 CVDP and SIC files as if 'CMIP6' was a GCM name and each GCM member 1 was a different member 

In [25]:
#obtain all of the train/validate/test GCM members
CMIP6_GCM_list = []
for GCM in np.sort(list(good_GCM_mem.keys())):    
    n_mem = len(good_GCM_mem[GCM])
    if n_mem > 2:
        CMIP6_GCM_list.append(GCM)
        
#gather all of the members together and save to NetCDF
CVDP_CMIP6 = []
SIC_CMIP6 = []

train_mem_i = 1000
valid_mem_i = 2000
test_mem_i  = 30000

for GCM in CMIP6_GCM_list:
    #loop through and append the correct members to the train, validation 
    #and testing groups with the 
    CVDP_data = xr.open_dataset(
        '/glade/work/cwpowell/low-frequency-variability/input_data/'\
        +'CVDP_standardized_linear_detrended_1920_2014_historical_'\
        +f'{GCM}.nc'
    )
    # CVDP_data = xr.open_dataset(
    #     '/glade/work/cwpowell/low-frequency-variability/input_data/'\
    #     +'CVDP_standardized_highpass_filt_40_yr_1920_2014_historical_'\
    #     +f'{GCM}.nc'
    # )
    
    CVDP_data = CVDP_data.sel(member=good_GCM_mem[GCM])
    
    new_mem_list = np.arange(test_mem_i, test_mem_i+len(CVDP_data['member'])-2)
    new_mem_list = np.insert(new_mem_list, 0, [train_mem_i, valid_mem_i])
    
    CVDP_data['member'] = new_mem_list
    CVDP_CMIP6.append(CVDP_data)
    
    #now do the same for the SIC data
    SIC_data = xr.open_dataset(
        '/glade/work/cwpowell/low-frequency-variability/input_data/'\
        +f'Regional_SIC_detrended_lowpass_filter_{GCM}_1920_2014.nc'
    )
    # SIC_data = xr.open_dataset(
    #     '/glade/work/cwpowell/low-frequency-variability/input_data/'\
    #     +f'Regional_SIC_bandpass_2_40_year_{GCM}_1920_2014.nc'
    # )
    
    SIC_data = SIC_data.sel(member=good_GCM_mem[GCM])
    
    SIC_data['member'] = new_mem_list   
    SIC_CMIP6.append(SIC_data)
    
    #now increase the initial value of the training, validation, and test member
    #element numbers
    train_mem_i += 1
    valid_mem_i += 1
    test_mem_i = test_mem_i + 1000

In [29]:
#save this CMIP6 data to NetCDF and include metadata
CVDP_CMIP6_xr = xr.concat((CVDP_CMIP6),dim='member').sortby('member')

CVDP_CMIP6_xr.attrs = {
    'Description' : '40-year highpass filtered and standardized variables '\
        +'from the CVDP (Climate Variability Diagnostics Package) for all '\
        +'CMIP6 global climate models with at least 3 available members.  '\
        +'Seasonal data for 1920-2014. The members can be decoded as follows: '\
        +'1000-1999 are the training members, 2000-2999 are the validation '\
        +'members, and 10000+ are the test members. The GCM is encoded as '\
        +'the last 2 digits for the training and validation members, and the '\
        +'first 2 digits +10 for the testing members. Note there is always '\
        +'1 member from each GCM for training and validation, and all '\
        +'remaining members are with the testing dataset. The GCM numbering '\
        +'referrs to the following: 0 ACCESS-CM2, 1 ACCESS-ESM1-5, '\
        +'2 BCC-CSM2-MR, 3 BCC-ESM1, 4 CAMS-CSM1-0, 5 CESM2-FV2, '\
        +'6 CESM2-LENS, 7 CESM2-WACCM, 8 CESM2-WACCM-FV2, 9 CIESM, '\
        +'10 CMCC-CM2-SR5, 11 CNRM-CM6-1, 12 CNRM-ESM2-1, 13 CanESM5, '\
        +'14 CanESM5-CanOE, 15 E3SM-1-0, 16 EC-Earth3, 17 EC-Earth3-CC, '\
        +'18 EC-Earth3-Veg, 19 EC-Earth3-Veg-LR, 20 FIO-ESM-2-0, '\
        +'21 GFDL-ESM4, 22 GISS-E2-1-G, 23 GISS-E2-1-H, 24 GISS-E2-2-G, '\
        +'25 GISS-E2-2-H, 26 HadGEM3-GC31-LL, 27 HadGEM3-GC31-MM, '\
        +'28 INM-CM5-0, 29 IPSL-CM6A-LR, 30 MIROC-ES2H, 31 MIROC-ES2L, '\
        +'32 MIROC6, 33 MPI-ESM-1-2-HAM, 34 MPI-ESM1-2-HR, 35 MPI-ESM1-2-LR, '\
        +'36 MRI-ESM2-0, 37 NESM3, 38 NorCPM1, 39 NorESM2-LM, 40 NorESM2-MM, '\
        +'41 UKESM1-0-LL. All members are sorted alphabetically before '\
        +'being divided into the three groups',
    'Units' :'standardized values',
    'Timestamp' : str(datetime.datetime.utcnow().strftime(
        "%H:%M UTC %a %Y-%m-%d")),
    'Data source': 'CMIP6 historical simulations, computed by CVDP '\
        +'doi: 10.1002/2014EO490002.',
    'Analysis'   : 'https://github.com/chrisrwp/low-fequency-variability/'\
            +'neural_network/Train_4_ML_Models.ipynb',
}
    
CVDP_CMIP6_xr.to_netcdf(
    '/glade/work/cwpowell/low-frequency-variability/input_data/'\
    +'CVDP_standardized_linear_detrended_1920_2014_historical_CMIP6.nc')

SIC_CMIP6_xr = xr.concat((SIC_CMIP6),dim='member').sortby('member')
SIC_CMIP6_xr.attrs = {
    'Description' : '2 to 40-year bandpass filter of regional '\
        +'average sea ice concentration (SIC) in % for the climate model for '\
        +'all CMIP6 global climate models with at least 3 available members. '\
        +'Seasonal data for 1920-2014. The members can be decoded as follows: '\
        +'1000-1999 are the training members, 2000-2999 are the validation '\
        +'members, and 10000+ are the test members. The GCM is encoded as '\
        +'the last 2 digits for the training and validation members, and the '\
        +'first 2 digits +10 for the testing members. Note there is always '\
        +'1 member from each GCM for training and validation, and all '\
        +'remaining members are with the testing dataset. The GCM numbering '\
        +'referrs to the following: 0 ACCESS-CM2, 1 ACCESS-ESM1-5, '\
        +'2 BCC-CSM2-MR, 3 BCC-ESM1, 4 CAMS-CSM1-0, 5 CESM2-FV2, '\
        +'6 CESM2-LENS, 7 CESM2-WACCM, 8 CESM2-WACCM-FV2, 9 CIESM, '\
        +'10 CMCC-CM2-SR5, 11 CNRM-CM6-1, 12 CNRM-ESM2-1, 13 CanESM5, '\
        +'14 CanESM5-CanOE, 15 E3SM-1-0, 16 EC-Earth3, 17 EC-Earth3-CC, '\
        +'18 EC-Earth3-Veg, 19 EC-Earth3-Veg-LR, 20 FIO-ESM-2-0, '\
        +'21 GFDL-ESM4, 22 GISS-E2-1-G, 23 GISS-E2-1-H, 24 GISS-E2-2-G,'\
        +'25 GISS-E2-2-H, 26 HadGEM3-GC31-LL, 27 HadGEM3-GC31-MM, '\
        +'28 INM-CM5-0, 29 IPSL-CM6A-LR, 30 MIROC-ES2H, 31 MIROC-ES2L, '\
        +'32 MIROC6, 33 MPI-ESM-1-2-HAM, 34 MPI-ESM1-2-HR, 35 MPI-ESM1-2-LR, '\
        +'36 MRI-ESM2-0, 37 NESM3, 38 NorCPM1, 39 NorESM2-LM, 40 NorESM2-MM, '\
        +'41 UKESM1-0-LL. All members are sorted alphabetically before '\
        +'being divided into the three groups',
    'Timestamp' : str(datetime.datetime.utcnow().strftime(
        "%H:%M UTC %a %Y-%m-%d")),
    'Data source': 'CMIP6 historical simulations',
    'Analysis'   : 'https://github.com/chrisrwp/low-fequency-variability/'\
            +'neural_network/Train_4_ML_Models.ipynb',
}
    
SIC_CMIP6_xr.to_netcdf(
    '/glade/work/cwpowell/low-frequency-variability/input_data/'\
    +'Regional_SIC_detrended_lowpass_filter_CMIP6_1920_2014.nc')

### Now do the same for GCMs with 30+ members, but exclude NorCPM1 as it has poor validation scores

In [34]:
#obtain all of the train/validate/test GCM members
CMIP6_GCM_30_list = [
    'ACCESS-ESM1-5', 'CESM2-LENS', 'CanESM5', 'GISS-E2-1-G', 'IPSL-CM6A-LR', 
    'MIROC-ES2L', 'MIROC6', 'MPI-ESM1-2-LR', 'NorCPM1',
]
        
#gather all of the members together and save to NetCDF
CVDP_CMIP6_30 = []
SIC_CMIP6_30 = []

train_mem_i = 10000
valid_mem_i = 20000
test_mem_i  = 30000

for GCM in CMIP6_GCM_30_list:
    #loop through and append the correct members to the train, validation 
    #and testing groups 
    CVDP_data = xr.open_dataset(
        '/glade/work/cwpowell/low-frequency-variability/input_data/'\
        +'CVDP_standardized_linear_detrended_1920_2014_historical_'\
        +f'{GCM}.nc'
    )
    # CVDP_data = xr.open_dataset(
    #     '/glade/work/cwpowell/low-frequency-variability/input_data/'\
    #     +'CVDP_standardized_highpass_filt_40_yr_1920_2014_historical_'\
    #     +f'{GCM}.nc'
    # )
    
    CVDP_data = CVDP_data.sel(member=good_GCM_mem[GCM])
    
    new_mem_list = np.arange(test_mem_i, test_mem_i+len(CVDP_data['member'])-28)
    new_mem_list = np.insert(new_mem_list, 0,
                             np.arange(valid_mem_i, valid_mem_i+5))
    new_mem_list = np.insert(new_mem_list, 0,
                             np.arange(train_mem_i, train_mem_i+23))
    
    CVDP_data['member'] = new_mem_list
    CVDP_CMIP6_30.append(CVDP_data)
    
    #now do the same for the SIC data
    SIC_data = xr.open_dataset(
        '/glade/work/cwpowell/low-frequency-variability/input_data/'\
        +f'Regional_SIC_detrended_lowpass_filter_{GCM}_1920_2014.nc'
    )
#     SIC_data = xr.open_dataset(
#         '/glade/work/cwpowell/low-frequency-variability/input_data/'\
#         +f'Regional_SIC_bandpass_2_40_year_{GCM}_1920_2014.nc'
#     )
    
    SIC_data = SIC_data.sel(member=good_GCM_mem[GCM])
    
    SIC_data['member'] = new_mem_list   
    SIC_CMIP6_30.append(SIC_data)
    
    #now increase the initial value of the training, validation, and test member
    #element numbers
    train_mem_i += 1000
    valid_mem_i += 1000
    test_mem_i  += 1000

In [35]:
#save this CMIP6 data to NetCDF and include metadata
CVDP_CMIP6_30_xr = xr.concat((CVDP_CMIP6_30),dim='member').sortby('member')

CVDP_CMIP6_30_xr.attrs = {
    'Description' : '40-year highpass filtered and standardized variables '\
        +'from the CVDP (Climate Variability Diagnostics Package) for all '\
        +'CMIP6 global climate models with at least 30 available members. '\
        +'Seasonal data for 1920-2014. The members can be decoded as follows: '\
        +'10000-19999 are the training members, 20000-29999 are the '\
        +'validation members, and 30000+ are the test members. The member '\
        +'number is encoded as the last 2 digits, with the second digit '\
        +'indicating the GCM as follows, 0:ACCESS-ESM1-5, 1:CESM2-LENS, '\
        +'2:CanESM5, 3:GISS-E2-1-G, 4:IPSL-CM6A-LR, 5:MIROC-ES2L, 6:MIROC6, '\
        +'7:MPI-ESM1-2-LR, 8:NorCPM1. All members are sorted alphabetically '\
        +'before being divided into the three groups.',
    'Units' :'standardized values',
    'Timestamp' : str(datetime.datetime.utcnow().strftime(
        "%H:%M UTC %a %Y-%m-%d")),
    'Data source': 'CMIP6 historical simulations, computed by CVDP '\
        +'doi: 10.1002/2014EO490002.',
    'Analysis'   : 'https://github.com/chrisrwp/low-fequency-variability/'\
            +'neural_network/Train_4_ML_Models.ipynb',
}
    
CVDP_CMIP6_30_xr.to_netcdf(
    '/glade/work/cwpowell/low-frequency-variability/input_data/'\
    +'CVDP_standardized_highpass_filt_40_yr_1920_2014_historical_CMIP6'\
    +'_30.nc')

SIC_CMIP6_30_xr = xr.concat((SIC_CMIP6_30),dim='member').sortby('member')
SIC_CMIP6_30_xr.attrs = {
    'Description' : '2 to 40  year bandpass of regional '\
        +'average sea ice concentration (SIC) in % for the climate model for '\
        +'all CMIP6 global climate models with at least 30 available members. '\
        +'Seasonal data for 1920-2014. The members can be decoded as follows: '\
        +'10000-19999 are the training members, 20000-29999 are the '\
        +'validation members, and 30000+ are the test members. The member '\
        +'number is encoded as the last 2 digits, with the second digit '\
        +'indicating the GCM as follows, 0:ACCESS-ESM1-5, 1:CESM2-LENS, '\
        +'2:CanESM5, 3:GISS-E2-1-G, 4:IPSL-CM6A-LR, 5:MIROC-ES2L, 6:MIROC6, '\
        +'7:MPI-ESM1-2-LR, 8:NorCPM1. All members are sorted alphabetically '\
        +'before being divided into the three groups.',
    'Timestamp' : str(datetime.datetime.utcnow().strftime(
        "%H:%M UTC %a %Y-%m-%d")),
    'Data source': 'CMIP6 historical simulations',
    'Analysis'   : 'https://github.com/chrisrwp/low-fequency-variability/'\
            +'neural_network/Train_4_ML_Models.ipynb',
}
    
SIC_CMIP6_30_xr.to_netcdf(
    '/glade/work/cwpowell/low-frequency-variability/input_data/'\
    +'Regional_SIC_bandpass_2_40_year_CMIP6_30_1920_2014.nc')

## Now train the 4 ML models on the CMIP6 data

In [53]:
#firstly using all 1st members for all GCMs with 3+ members
#N.B. need to rerun LE_train/valid/test_mem values near top for 1970-2014
loss_fcn  = torch.nn.L1Loss() #keep all models sparse
n_neurons = 6

for month__ in [9]:#np.arange(1,13):
    print(datetime.datetime.now(), month__, 'MMLE 3+')

    r_values_xr, all_1_weights_xr = train_4_ML_for_LE(
        'PI_500', month_=month__, region_list=[1,2,3,4,5,6,11], 
        lag_list=np.arange(1,21), start_end_yr=[1991,2014], n_epoch=2000, 
        learn_rates=[5e-4, 5e-4, 1e-4, 2e-4], white_noise_=False, 
        ens_bool=False,
    )    

    r_values_xr.to_netcdf(
        '/glade/work/cwpowell/low-frequency-variability/PyTorch_models/'\
        'validation_r_values_4ML_PI_500_first_3_mem_train_month_'\
        f'{str(month__).zfill(2)}_var_9_all_IPO_lowpass_filt_74_year_mem.nc')
    all_1_weights_xr.to_netcdf(
        '/glade/work/cwpowell/low-frequency-variability/PyTorch_models/'\
        'weights_linear_PI_500_first_3_mem_train_month_'\
        f'{str(month__).zfill(2)}_var_9_all_IPO_lowpass_filt_74_year_mem.nc')
    
    # r_values_xr.to_netcdf(
    #     '/glade/work/cwpowell/low-frequency-variability/PyTorch_models/'\
    #     'validation_r_values_4ML_CMIP6_month_'\
    #     f'{str(month__).zfill(2)}_var_9_all_IPO_lowpass_filt_1970_2014.nc')
    # all_1_weights_xr.to_netcdf(
    #     '/glade/work/cwpowell/low-frequency-variability/PyTorch_models/'\
    #     'weights_linear_CMIP6_month_'\
    #     f'{str(month__).zfill(2)}_var_9_all_IPO_lowpass_filt_1970_2014.nc')

2023-10-18 12:21:29.372173 9 MMLE 3+


In [None]:
#secondly for all 1st-23rd members for all GCMs with 30+ members
#note increased the number of epochs to 4000 from 2000 for the smaller dataset
loss_fcn  = torch.nn.L1Loss() #keep all models sparse
n_neurons = 6

for month__ in [8,9,10,11,12,1,2,3,4,5,6,7]:
    print(datetime.datetime.now(), month__)

    r_values_xr, all_1_weights_xr = train_4_ML_for_LE(
        'CMIP6_30', month_=month__, region_list=[1,2,3,4,5,6,11], 
        lag_list=np.arange(1,21), start_end_yr=[1941,2014], n_epoch=2000, 
        learn_rates=[5e-4, 5e-4, 1e-4, 2e-4], white_noise_=False, ens_bool=False,
    )    

    r_values_xr.to_netcdf(
        '/glade/work/cwpowell/low-frequency-variability/PyTorch_models/'\
        +f'validation_r_values_4ML_CMIP6_30_month_{str(month__).zfill(2)}_'\
        +f'var_9_all_IPO_lowpass_filt.nc')
    all_1_weights_xr.to_netcdf(
        '/glade/work/cwpowell/low-frequency-variability/PyTorch_models/'\
        +f'weights_linear_CMIP6_30_month_{str(month__).zfill(2)}_'\
        +'var_9_all_IPO_lowpass_filt.nc')


## Test the CMIP6-trained linear model with 3rd members

In [15]:
#load weight data 
CMIP6_weights = []
for month_ in np.arange(1,13):
    all_1_weights_xr = xr.open_dataset(
        '/glade/work/cwpowell/low-frequency-variability/PyTorch_models/'\
        +f'weights_linear_CMIP6_month_{str(month_).zfill(2)}_'\
        +'var_9_all_IPO_lowpass_filt.nc'
    )
    CMIP6_weights.append(all_1_weights_xr['weights'])

CMIP6_weights = xr.concat((CMIP6_weights), dim='month')
CMIP6_weights['month'] = np.arange(1,13)

In [16]:
month_r_vals = []
for month_ in np.arange(1,13):
    print(datetime.datetime.now(), month_)
    lag_r_vals = []
    for lag_ in np.arange(1,21):

        CVDP_train, CVDP_valid, CVDP_test = load_CVDP(
            'CMIP6', month_, lag_, [1941,2014], 
            extra_drop=None, white_noise=False,
        )
        
        # CVDP_test = torch.cat(
        #         (CVDP_test[:,0:5], CVDP_test[:,8:29], CVDP_test[:,32:]),
        #         dim=1)
        
        CVDP_test = torch.cat(
                (CVDP_test[:,0:16], CVDP_test[:,20:24], 
                 CVDP_test[:,36:48], CVDP_test[:,52:]), dim=1
            )
            
        
        region_r_vals = []
        for region_ in [1,2,3,4,5,6,11]:
            target_train, target_valid, target_test = load_SIC(
                 'CMIP6', month_, region_, [1941,2014], 
            )
            
            prediction = np.sum(
                np.array(CVDP_test) 
                * np.tile(CMIP6_weights.sel(month=month_).sel(
                    region=region_).sel(lag=lag_), 
                [38850,1]), axis=1
            )
            
            all_test_mem = []
            #loop through the 525 ensemble members to test
            for i in range(525):
                #increment by 74 years for (1941-2014)
                all_test_mem.append(
                    np.corrcoef(prediction[i*74:(i+1)*74], 
                                target_test[i*74:(i+1)*74])[0][1])
                
            region_r_vals.append(all_test_mem)
            
        lag_r_vals.append(region_r_vals)
        
    month_r_vals.append(lag_r_vals)
    
month_r_vals_xr = xr.Dataset(
    data_vars = {'r_value':(['month','lag','region','member'], month_r_vals)},
    coords = {'month':np.arange(1,13), 'lag':np.arange(1,21), 
              'region':[1,2,3,4,5,6,11], 
              'member':SIC_CMIP6_xr['member'][-525:].values,
             },
    )

2023-10-18 09:52:28.814195 1
2023-10-18 09:52:44.578743 2
2023-10-18 09:53:00.811190 3
2023-10-18 09:53:17.174374 4
2023-10-18 09:53:33.406937 5
2023-10-18 09:53:49.687625 6
2023-10-18 09:54:05.995534 7
2023-10-18 09:54:22.277564 8
2023-10-18 09:54:38.816618 9
2023-10-18 09:54:55.091383 10
2023-10-18 09:55:11.279469 11
2023-10-18 09:55:27.298995 12


In [18]:
#loop through the r values and group by GCM
CMIP6_GCM_list = []
for GCM in np.sort(list(good_GCM_mem.keys())):    
    n_mem = len(good_GCM_mem[GCM])
    if n_mem > 2:
        CMIP6_GCM_list.append(GCM)


r_val_by_GCM = []
for GCM_i, GCM in enumerate(CMIP6_GCM_list):
    begin_mem = str(10000+(GCM_i*1000))
    end_mem   = str(10999+(GCM_i*1000))
    
    data = month_r_vals_xr.sel(member=slice(begin_mem,end_mem))
    data['member'] = np.arange(0,len(data['member']))
    r_val_by_GCM.append(data)

r_val_by_GCM = xr.concat((r_val_by_GCM), dim='model_name')
r_val_by_GCM['model_name'] = CMIP6_GCM_list

r_val_by_GCM.attrs = {
    'Description': 'Pearson correlation coefficient for test data (3rd and '\
        +'later members of 42 global climate models as follows: '\
        +f'{CMIP6_GCM_list}. Model train and validated on the 1st and 2nd '\
        +'members of the same 42 GCMs. The model features were the following '\
        +'modes of variability: AMO, IPO, NINO34, PDO, ATN,'\
        +'NPO, PNA, NAO, TAS. The targets were 2 year lowpass '\
        +'filtered regional Arctic sea ice concentration anomalies for lag '\
        +'times of 1-20 years. Regions are defined by NSIDC MASIE-NH Version'\
        +' 1 (doi:10.7265/N5GT5K3K). The climate modes are computed by the '\
        +'Climate Variability Diagnostics Package (CVDP). The model was '\
        +'trained using an L1 loss function and Adam optimizer with 2000 '\
        +'epochs: nn.Sequential(nn.Linear(36,1,bias=False)), with a '\
        +'learning rate of 5e-4.',
    'Timestamp'  : str(datetime.datetime.utcnow().strftime(
        "%H:%M UTC %a %Y-%m-%d")),
    'Data source': f'CMIP6 global climate models for historical simulations '\
        +'with sea ice concentration output, climate modes calculated by CVDP '\
        +'(doi:10.1002/2014EO490002)',
    'Analysis'   : 'https://github.com/chrisrwp/low-frequency-'\
        +'variability/blob/main/neural_network/Train_4_ML_Models.ipynb',
}

r_val_by_GCM.to_netcdf(
    '/glade/work/cwpowell/low-frequency-variability/PyTorch_models/'\
    +f'test_r_values_linear_CMIP6_var_9_all_IPO.nc'
)

## Now do the same for the test members for the LEs. 'Perfect model' test

In [20]:
#load weight data 
LE_weights = {}
for GCM in LE_GCM_list[:-1]:
    all_months_LE = []
    for month_ in np.arange(1,13):
        all_1_weights_xr = xr.open_dataset(
            '/glade/work/cwpowell/low-frequency-variability/PyTorch_models/'\
            +f'weights_linear_{GCM}_month_{str(month_).zfill(2)}_'\
            +'var_9_all_IPO_lowpass_filt.nc'
        )
        all_months_LE.append(all_1_weights_xr['weights'])

    all_months_LE = xr.concat((all_months_LE), dim='month')
    all_months_LE['month'] = np.arange(1,13)
    
    LE_weights[GCM] = all_months_LE

In [21]:
LE_test = {}
for GCM in LE_GCM_list[:-1]:
    print(datetime.datetime.now(), GCM)
    month_r_vals = []
    for month_ in np.arange(1,13):
        
        lag_r_vals = []
        for lag_ in np.arange(1,21):

            CVDP_train, CVDP_valid, CVDP_test = load_CVDP(
                GCM, month_, lag_, [1941,2014], 
                extra_drop=None, white_noise=False,
            )

            # CVDP_test = torch.cat(
            #     (CVDP_test[:,0:5], CVDP_test[:,8:29], CVDP_test[:,32:]), dim=1
            # )
            CVDP_test = torch.cat(
                (CVDP_test[:,0:16], CVDP_test[:,20:24], 
                 CVDP_test[:,36:48], CVDP_test[:,52:]), dim=1
            )

            region_r_vals = []
            for region_ in [1,2,3,4,5,6,11]:
                target_train, target_valid, target_test = load_SIC(
                     GCM, month_, region_, [1941,2014], 
                )

                prediction = np.sum(
                    np.array(CVDP_test) 
                    * np.tile(LE_weights[GCM].sel(month=month_).sel(
                        region=region_).sel(lag=lag_), 
                    [CVDP_test.shape[0],1]), axis=1
                )

                all_test_mem = []
                for i in range(int(CVDP_test.shape[0]/74)):
                    all_test_mem.append(
                        np.corrcoef(prediction[i*74:(i+1)*74], 
                                    target_test[i*74:(i+1)*74])[0][1])

                region_r_vals.append(all_test_mem)

            lag_r_vals.append(region_r_vals)

        month_r_vals.append(lag_r_vals)

    month_r_vals_xr = xr.DataArray(
        data = month_r_vals,
        dims = ['month', 'lag', 'region', 'member'],
        coords = {
            'month':np.arange(1,13), 
            'lag':np.arange(1,21), 
            'region':[1,2,3,4,5,6,11],
            'member':np.arange(CVDP_test.shape[0]/74),
        },
    )
    
    LE_test[GCM] = month_r_vals_xr

2023-10-18 10:04:59.979893 ACCESS-ESM1-5
2023-10-18 10:05:38.250376 CESM2-LENS
2023-10-18 10:06:16.970812 CNRM-CM6-1
2023-10-18 10:06:53.347515 CanESM5
2023-10-18 10:07:30.345063 EC-Earth3
2023-10-18 10:08:06.183012 GISS-E2-1-G
2023-10-18 10:08:45.140309 GISS-E2-1-H
2023-10-18 10:09:20.158155 IPSL-CM6A-LR
2023-10-18 10:09:56.412165 MIROC-ES2L
2023-10-18 10:10:32.358459 MIROC6
2023-10-18 10:11:09.897982 MPI-ESM1-2-LR
2023-10-18 10:11:46.024587 NorCPM1


In [27]:
LE_test = xr.Dataset(LE_test)

LE_test.attrs = {
    'Description': 'Pearson correlation coefficient for test data (final 10% '\
        +'of members from 12 GCM large ensembles as per the data variables. '\
        +'The linear model was train and validated on the first 75% and 15% '\
        +'of members of the GCM. The model features were the following '\
        +'modes of variability: AMO, IPO, NINO34, PDO, ATN, '\
        +'NPO, PNA, NAO, TAS. The targets were 2 year lowpass '\
        +'filtered regional Arctic sea ice concentration anomalies for lag '\
        +'times of 1-20 years. Regions are defined by NSIDC MASIE-NH Version'\
        +' 1 (doi:10.7265/N5GT5K3K). The climate modes are computed by the '\
        +'Climate Variability Diagnostics Package (CVDP). The model was '\
        +'trained using an L1 loss function and Adam optimizer with 2000 '\
        +'epochs: nn.Sequential(nn.Linear(36,1,bias=False)), with a '\
        +'learning rate of 5e-4.',
    'Timestamp'  : str(datetime.datetime.utcnow().strftime(
        "%H:%M UTC %a %Y-%m-%d")),
    'Data source': f'CMIP6 global climate models for historical simulations '\
        +'with sea ice concentration output, climate modes calculated by CVDP '\
        +'(doi:10.1002/2014EO490002)',
    'Analysis'   : 'https://github.com/chrisrwp/low-frequency-'\
        +'variability/blob/main/neural_network/Train_4_ML_Models.ipynb',
}

LE_test.to_netcdf(
    '/glade/work/cwpowell/low-frequency-variability/PyTorch_models/'\
    +f'test_r_values_linear_LEs_var_9_all_IPO.nc'
)

# 3. Remove one variable at a time from the LE linear models, retrain on the useful variables

In [274]:
def train_linear_model_remove(model_name, month_, region_list, lag_list, 
                              start_end_yr, n_epoch, learn_rate, 
                              extra_drop_=None, white_noise_=None):
    '''
    Train the 4 machine learning models for a given large ensemble on a 
    specificed of sea ice concentration data for a specified set of months,
    regions, and lags.
    
    Parameters
    ----------
    model_name: str,
        The name of the GCM which provides a sufficiently large ensemble.
    month_: int,
        The months of sea ice concentration anomalies on which to train the 
        machine learning model. 
    region_list: list of ints,
        A list of the regional sea ice anomalies to train the model on 
        separately.
    lag_list: list of ints,
        The range of lagged year on which to separately train the machine 
        learning model 
    start_end_yr: list, length 2 with integers,
        The start and end years (inclusive) for the sea ice concentration data.
    extra_drop: none or list of strings,
        If False, no additonal variables are dropped. If a list of a string or
        strings, those variables listed will be removed from the CVDP data. 
    white_noise: bool,
        If True, add a variable of normalized random values in a gaussian 
        distribution for the 4 seasons, all members and all years.
    n_epoch: int, 
        number of epochs with which to train all of the machine learning 
        models.
    learn_rate: int,
        Integer corresponding to the learning rate.         
                      
    Returns
    ----------
    r_values_xr: xarray.Dataset,
        The pearson correlation coefficients for the linear model
        for the validation data.
    all_1_grads_xr: xarray.Dataset,
        The gradients from the machine learning model using a simple multiple
        linear regression model.    
    '''
    if white_noise_:
        fewer_var = 0
    else:
        fewer_var = 1
    
    if type(extra_drop_) != type(None):
        n_features = int((51-len(extra_drop_)-fewer_var))
    else:
        n_features = int(51-fewer_var)
    
    if model_name in ['CMIP6','CMIP6_30']:
        doi_model = '10.5194/gmd-9-1937-2016'
    else:
        doi_model = CMIP6_info['doi'].sel(model=model_name).values
    
    all_1_weights = np.empty([len(region_list), len(lag_list), n_features], 
                             dtype=float)
    
    all_r_values = np.empty([len(region_list), len(lag_list)], dtype=float)

    for region_i, region_ in enumerate(region_list):
        
        for lag_i, lag_ in enumerate(lag_list):
            
            #load the feature and target data for the correct model, month,
            #region and lag
            CVDP_train, CVDP_valid, CVDP_test = load_CVDP(
                model_name, month_, lag_, start_end_yr, 
                extra_drop=extra_drop_, white_noise=white_noise_
            )
            
            CVDP_train = torch.cat(
                (CVDP_train[:,0:5], CVDP_train[:,8:29], CVDP_train[:,32:]),
                dim=1)
            CVDP_valid = torch.cat(
                (CVDP_valid[:,0:5], CVDP_valid[:,8:29], CVDP_valid[:,32:]),
                dim=1)
            
            if extra_drop_: #remove seasonal var exactly matching extra_drop
                remove_i = np.where(
                    np.char.find(var_list_use, extra_drop_)==0)[0][0]
            
                CVDP_train = torch.cat(
                    (CVDP_train[:,0:remove_i], CVDP_train[:,remove_i+1:]), dim=1)
                CVDP_valid = torch.cat(
                    (CVDP_valid[:,0:remove_i], CVDP_valid[:,remove_i+1:]), dim=1)
                        
            if white_noise_:
                CVDP_train = CVDP_train[:,:-3]
                CVDP_valid = CVDP_valid[:,:-3]
                        
            
            target_train, target_valid, target_test = load_SIC(
                model_name, month_, region_, start_end_yr, 
            )

            ML_model_computed = []
            ML_model_use = nn.Sequential(nn.Linear(n_features,1,bias=False))

            optimizer = torch.optim.Adam(
                params=ML_model_use.parameters(), 
                lr=learn_rate
            )

            ##### train the model #####
            train_r, valid_r = [],[]
            for epoch in range(n_epoch):
                # TRAIN
                prediction = ML_model_use(CVDP_train)
                optimizer.zero_grad() #reset the gradients to zeros
                loss = loss_fcn(prediction[:,0].double(), target_train)
                loss.backward()
                optimizer.step()
                train_r.append(
                    np.corrcoef(prediction[:,0].detach().numpy(),
                                target_train.detach().numpy()
                               )[1][0]
                )

                ML_model_computed.append(ML_model_use)

                ##### validate the model #####
                with torch.no_grad():
                    p_val = ML_model_use(CVDP_valid)
                    loss_val = loss_fcn(p_val[:,0].double(), target_valid)
                    valid_r.append(np.corrcoef(p_val[:,0].detach().numpy(),
                                    target_valid.detach().numpy())[1][0])

                    #select highest validation r value from all epochs
                    all_r_values[region_i][lag_i] = np.max(valid_r)

                    all_1_weights[region_i][lag_i] = np.ravel(
                        ML_model_computed[np.argmax(valid_r)][                               
                            0].weight[0,:].detach().numpy())

    if model_name in ['CMIP6','CMIP6_30']:
        model_name = 'CMIP6 multi-model large ensemble'
        
    if (white_noise_ == None) and (extra_drop_ == None):
        # mode_month_list = [
        #     item for item in var_month_list if 'RAND' not in item]
        mode_month_list = var_list_use[:-1]
    elif (len(extra_drop_) == 1) and (white_noise_ == True):
        # mode_month_list = [
        #     item for item in var_month_list if extra_drop_[0] not in item]
        mode_month_list = np.delete(np.array(var_list_use).copy(),remove_i)
    else:
        extra_drop_inc_rand = np.append(extra_drop_, 'RAND')
        mode_month_list = var_month_list.copy()
        for i in extra_drop_inc_rand:
            mode_month_list = [
                item for item in mode_month_list if i not in item]

    
    #save the r values to NetCDF every model and month
    r_values_xr = xr.Dataset(
        data_vars = {
            'r_value':(['region', 'lag'], all_r_values)},
        coords = {
            'region':region_list, 'lag':lag_list},
    )

    #save the linear 1 layer neural network weights to NetCDF
    all_1_weights_xr = xr.Dataset(
        data_vars = {
            'weights':(['region', 'lag', 'mode_month'], all_1_weights),
        },
        coords = {'region':region_list, 'lag':lag_list, 
                  'mode_month':mode_month_list,
        },
    )
        
    return(r_values_xr, all_1_weights_xr)

## Remove one variable at a time from the LE models

In [7]:
#use the following dictionaries to determine which months and regions to 
#retrain the model on. Based on best 5 year lag r2 above persistence. 

#use the months from the best CMIP6 values
good_GCM_months = {
    '1': [ 9,  9,  9,  9,  9,  9,],
    '2': [10, 10, 10, 10, 10, 10,],
    '3': [ 9,  9,  9,  9,  9,],
    '4': [10, 10,],
    '5': [10, 10, 10, 10, 10,],
    '6': [ 1,  1,],
    '11':[ 8,  8,  8, 8,],
}

good_GCM_names = {
    '1': ['CanESM5', 'CESM2-LENS', 'MIROC6', 'GISS-E2-1-G', 'IPSL-CM6A-LR', 
          'GISS-E2-1-H'],
    '2': ['CanESM5', 'CESM2-LENS', 'MIROC6', 'GISS-E2-1-G', 'IPSL-CM6A-LR',
          'MIROC-ES2L'],
    '3': ['CanESM5', 'MIROC6', 'ACCESS-ESM1-5', 'IPSL-CM6A-LR', 'MIROC-ES2L'],
    '4': ['CanESM5', 'IPSL-CM6A-LR'],
    '5': ['CanESM5', 'CESM2-LENS', 'MIROC6', 'MPI-ESM1-2-LR', 'GISS-E2-1-H'],
    '6': ['CanESM5', 'GISS-E2-1-H'],
   '11': ['CanESM5', 'MIROC6', 'GISS-E2-1-G', 'ACCESS-ESM1-5']}


In [278]:
loss_fcn  = torch.nn.L1Loss() #keep all models sparse

var_list_use = np.delete(np.array(var_month_list[:-3]).copy(),[5,6,7,29,30,31])

for region_key in [1,2,3,4,5,6,11]:
    print(datetime.datetime.now(), region_key)
    
    for GCM_i, GCM in enumerate(good_GCM_names[str(region_key)]):
        print(datetime.datetime.now(), GCM)
        
        doi_model = CMIP6_info['doi'].sel(model=GCM).values

        month__ = good_GCM_months[str(region_key)][GCM_i]
        
        #skip this one if already run this analysis
        if len(glob.glob(
            '/glade/work/cwpowell/low-frequency-variability/PyTorch_models/'\
            +f'validation_r_values_linear_{GCM}_month_{str(month__).zfill(2)}_'\
            +f'var_15_drop_1_season_region_{region_key}.nc')):
            continue

        r_values_list = []
        weights_list  = []
        for drop_var in var_list_use:
            print(datetime.datetime.now(), drop_var)

            if drop_var == 'RAND_1':
                r_values_xr, all_weights_xr = train_linear_model_remove(
                    model_name=GCM, month_=month__, region_list=[region_key], 
                    lag_list=np.arange(1,21), start_end_yr=[1941,2014], 
                    n_epoch=2000, learn_rate=5e-4, extra_drop_=None,
                    white_noise_=None,
                )

            else:
                r_values_xr, all_weights_xr = train_linear_model_remove(
                    model_name=GCM, month_=month__, region_list=[region_key], 
                    lag_list=np.arange(1,21), start_end_yr=[1941,2014], 
                    n_epoch=2000, learn_rate=5e-4, extra_drop_=[drop_var],
                    white_noise_=True,
                )  

            r_values_list.append(r_values_xr)
            weights_list.append(all_weights_xr)

        r_values_save = xr.concat((r_values_list), dim='drop_var')
        r_values_save['drop_var'] = var_list_use

        weights_save = xr.concat((weights_list), dim='drop_var')
        weights_save['drop_var'] = var_list_use

        r_values_attrs = {
            'Description': 'Pearson correlation coefficient for validation '\
                +f'data (15%) of members from the global climate model {GCM}. '\
                +'14 of the 15 seasonal (except NPI and IPO) modes '\
                +'of variability as computed by the Climate Variability '\
                +'Diagnostics Package (CVDP) with 2 year lowpass filtered '\
                +'sea ice concentration. These climate modes are as follows: '\
                +'AMO, IPO, NINO34, PDO, AMM, ATN, IOD, NPI, NAM, NPO, PNA, '\
                +'NAO, SAM, TAS, RAND. The linear model is trained at '\
                +'different lag times of 1-20 years and for each '\
                +'region (regions are defined by NSIDC MASIE-NH Version 1 '\
                +'(doi:10.7265/N5GT5K3K). Using an L1 loss function and Adam '\
                +'optimizer with 2000 epochs: nn.Sequential(nn.Linear(50,1,'\
                +'bias=False)), learning rate = 5e-4.',
            'Timestamp'  : str(datetime.datetime.utcnow().strftime(
                "%H:%M UTC %a %Y-%m-%d")),
            'Data source': f'CMIP6 global climate model {GCM}, doi:'\
                +f'{doi_model} for historical sea ice concentration '\
                +'simulation, climate modes calculated by CVDP '\
                +'(doi:10.1002/2014EO490002)',
            'Analysis'   : 'https://github.com/chrisrwp/low-frequency-'\
                +'variability/blob/main/neural_network/Train_4_ML_Models.ipynb',
        }

        r_values_save.attrs = r_values_attrs
        r_values_save.to_netcdf(
            '/glade/work/cwpowell/low-frequency-variability/PyTorch_models/'\
            +f'validation_r_values_linear_{GCM}_month_{str(month__).zfill(2)}_'\
            +f'var_15_drop_1_season_region_{region_key}.nc'
        )


        weights_attrs = r_values_attrs.copy()
        weights_attrs['Description'] = 'Linear weights including 14 of the '\
            +'15 seasonal (except NPI and IPO) modes of variability as '\
            +'computed by the Climate Variability '\
            +'Diagnostics Package (CVDP) with 2 year lowpass filtered sea ice '\
            +'concentration. The linear model is trained/validated on the '\
            +'first 75/15% of the members from the global climate model '\
            +f'{GCM}. The climate modes are as follows: AMO, IPO, NINO34, '\
            +'PDO, AMM, ATN, IOD, NPI, NAM, NPO, PNA, NAO, SAM, TAS, RAND. '\
            +'The linear model is trained at different lag times of 1-20 '\
            +'years and for each region (regions are defined by NSIDC '\
            +'MASIE-NH Version 1 (doi:10.7265/N5GT5K3K). Using an L1 loss '\
            +'function and Adam optimizer with 2000 epochs: '\
            +'nn.Sequential(nn.Linear(50,1,bias=False)), learning rate = 5e-4.'

        weights_save.attrs = weights_attrs
        weights_save.to_netcdf(
            '/glade/work/cwpowell/low-frequency-variability/PyTorch_models/'\
            +f'weights_linear_{GCM}_month_{str(month__).zfill(2)}_'\
            +f'var_15_drop_1_season_region_{region_key}.nc'
        )

2023-03-09 13:03:10.320037 1
2023-03-09 13:03:10.320097 CanESM5
2023-03-09 13:03:10.321794 CESM2-LENS
2023-03-09 13:03:10.322632 MIROC6
2023-03-09 13:03:10.323275 GISS-E2-1-G
2023-03-09 13:03:10.323908 IPSL-CM6A-LR
2023-03-09 13:03:10.324530 GISS-E2-1-H
2023-03-09 13:03:10.325166 2
2023-03-09 13:03:10.325207 CanESM5
2023-03-09 13:03:10.325816 CESM2-LENS
2023-03-09 13:03:10.326445 MIROC6
2023-03-09 13:03:10.327074 GISS-E2-1-G
2023-03-09 13:03:10.327697 IPSL-CM6A-LR
2023-03-09 13:03:10.328353 MIROC-ES2L
2023-03-09 13:03:10.329039 3
2023-03-09 13:03:10.329082 CanESM5
2023-03-09 13:03:10.329746 MIROC6
2023-03-09 13:03:10.330453 ACCESS-ESM1-5
2023-03-09 13:03:10.331127 IPSL-CM6A-LR
2023-03-09 13:03:10.331775 MIROC-ES2L
2023-03-09 13:03:10.337396 4
2023-03-09 13:03:10.337441 CanESM5
2023-03-09 13:03:10.338098 IPSL-CM6A-LR
2023-03-09 13:03:10.338724 5
2023-03-09 13:03:10.338765 CanESM5
2023-03-09 13:03:10.339385 CESM2-LENS
2023-03-09 13:03:10.340020 MIROC6
2023-03-09 13:03:10.340645 MPI-ESM1-

### Remove all variables with a worse effect on skill than a random variable for the LEs

In [8]:
with open('/glade/work/cwpowell/low-frequency-variability/PyTorch_models/'\
          +'best_season_MMLE_3_lags_2_20_using_drop_1.pickle', 'rb') as handle:
    best_seasons_dict = pickle.load(handle)

In [9]:
def train_linear_useful_season(
    model_name, month_, region_, best_seasons, lag_list, start_end_yr,
    n_epoch, learn_rate):
    
    var_list_to_use = list(best_seasons_dict[str(region_)].keys())
    
    n_features = 14
    
    if model_name in ['CMIP6','CMIP6_30']:
        doi_model = '10.5194/gmd-9-1937-2016'
    else:
        doi_model = CMIP6_info['doi'].sel(model=model_name).values
    
    all_1_weights = np.empty([len(lag_list), n_features], 
                             dtype=float)
    
    all_r_values = np.empty([len(lag_list)], dtype=float)
        
    for lag_i, lag_ in enumerate(lag_list):

        #load the feature and target data for the correct model, month,
        #region and lag
        CVDP_year_month = xr.open_dataset(
            '/glade/work/cwpowell/low-frequency-variability/input_data/'\
            +'CVDP_standardized_linear_detrended_1920_2014_historical_'\
            +f'{model_name}.nc'
        )    

        #select only the CVDP varaibles that are useful
        CVDP_year_month = CVDP_year_month.to_array('variable').sortby(
            'time')

        month_seperate = []
        for i in [1,4,7,10]:
            temp_data = CVDP_year_month.sel(
                time=CVDP_year_month['time.month']==i)
            temp_data['time'] = np.arange(1920,2015)
            month_seperate.append(temp_data)

        CVDP_data = xr.concat((month_seperate), dim='month')
        CVDP_data['month'] = [1,4,7,10]
        CVDP_data = CVDP_data.rename({'time':'year'})

        #now stack the CVDP data into X members and years, Y features
        CVDP_train = []
        CVDP_valid = []
        for var_ in var_list_to_use:
            lag_season = best_seasons_dict[str(region_)][var_]

            if lag_season >= month_:
                extra_year = 1
            else:
                extra_year = 0

            CVDP_month_data = CVDP_data.sortby('member')

            CVDP_month_data = CVDP_month_data.sel(
                month=lag_season).sel(variable=var_).sel(
                year=slice(str(start_end_yr[0]-lag_-extra_year), 
                           str(start_end_yr[1]-lag_-extra_year)))
            CVDP_month_data['year'] = np.arange(
                0,start_end_yr[1]-start_end_yr[0]+1)
            
            CVDP_train.append(CVDP_month_data.sel(
                member=LE_train_mem[model_name]))
            CVDP_valid.append(CVDP_month_data.sel(
                member=LE_valid_mem[model_name]))

        CVDP_train_stacked = xr.concat((CVDP_train),'variable').stack(
            member_time=('member','year'))   
        CVDP_train = torch.Tensor(CVDP_train_stacked.transpose().values)
        
        CVDP_valid_stacked = xr.concat((CVDP_valid),'variable').stack(
            member_time=('member','year'))
        CVDP_valid = torch.Tensor(CVDP_valid_stacked.transpose().values)


        #load the SIC data
        target_train, target_valid, target_test = load_SIC(
            model_name, month_, region_, start_end_yr, 
        )
        
        #prepare to train the linear model
        ML_model_computed = []
        ML_model_use = nn.Sequential(nn.Linear(n_features,1,bias=False))

        optimizer = torch.optim.Adam(
            params=ML_model_use.parameters(), lr=learn_rate
        )

        ##### train the model #####
        valid_r, valid_loss = [],[]
        for epoch in range(n_epoch):
            # TRAIN
            prediction = ML_model_use(CVDP_train)
            optimizer.zero_grad() #reset the gradients to zeros
            loss = loss_fcn(prediction[:,0].double(), target_train)
            loss.backward()
            optimizer.step()
            ML_model_computed.append(ML_model_use)

            ##### validate the model #####
            with torch.no_grad():
                p_val = ML_model_use(CVDP_valid)
                loss_val = loss_fcn(p_val[:,0].double(), target_valid)
                
                valid_loss.append(loss_val)
                valid_r.append(np.corrcoef(p_val[:,0].detach().numpy(),
                                target_valid.detach().numpy())[1][0])

                #select the lowest validation loss function from all epochs
                all_r_values[lag_i] = valid_r[np.argmin(valid_loss)]

                all_1_weights[lag_i] = np.ravel(
                    ML_model_computed[np.argmin(valid_loss)][0].weight[
                        0,:].detach().numpy())

    #save the r values
    r_values_xr = xr.Dataset(
        data_vars = {'r_value':(['lag'], all_r_values)},
        coords = {'lag':lag_list},
    )

    #save the linear model weights to NetCDF
    all_1_weights_xr = xr.Dataset(
        data_vars = {'weights':(['lag', 'mode_month'], all_1_weights), },
        coords = {'lag':lag_list, 'mode_month':var_list_to_use,},
    )
        
    return(r_values_xr, all_1_weights_xr)

In [12]:
loss_fcn  = torch.nn.L1Loss() #keep all models sparse

month_names = ['January', 'February', 'March', 'April', 'May', 'June', 'July', 
               'August', 'September', 'October', 'November', 'December']

for region_key in [2,3,4,5,6,11]:
    print(datetime.datetime.now(), region_key)
    
    month__ = good_GCM_months[str(region_key)][0]
    
    GCM_r_vals = []
    GCM_weights = []
    for GCM_i, GCM in enumerate(good_GCM_names[str(region_key)]):
        print(datetime.datetime.now(), GCM)
        
        doi_model = CMIP6_info['doi'].sel(model=GCM).values
    
        r_values_xr, all_1_weights_xr = train_linear_useful_season(
            model_name=GCM, month_=month__, region_=region_key, 
            best_seasons=best_seasons_dict, lag_list=np.arange(1,21), 
            start_end_yr=[1941,2014], n_epoch=2000, learn_rate=2e-4,
        )
        
        GCM_r_vals.append(r_values_xr)
        GCM_weights.append(all_1_weights_xr)
        
    region_r_vals = xr.concat((GCM_r_vals), dim='model_name')
    region_r_vals['model_name'] = good_GCM_names[str(region_key)]
    
    region_weights = xr.concat((GCM_weights), dim='model_name')
    region_weights['model_name'] = good_GCM_names[str(region_key)]

    r_values_attrs = {
        'Description': 'Pearson correlation coefficient for validation '\
            +'data (15%) of members from the global climate models as '\
            +f'follows {good_GCM_names[str(region_key)]} for '\
            +f'{month_names[month__-1]}, with the highest predictability '\
            +'season of each mode of variability chosen: '\
            +f'{best_seasons_dict[str(region_key)]}. Modes computed by the '\
            +'Climate Variability Diagnostics Package (CVDP) with 2 year '\
            +'lowpass filtered sea ice concentration. The linear model is '\
            +'trained at  different lag times of 1-20 years, the region is '\
            +f'{region_key} as defined by NSIDC MASIE-NH Version 1 '\
            +'(doi:10.7265/N5GT5K3K). Training uses an L1 loss function and '\
            +'epochs: nn.Sequential(nn.Linear(14,1,bias=False)), '\
            +'Adam optimizer with 2000 learning rate = 2e-4.',
        'Timestamp'  : str(datetime.datetime.utcnow().strftime(
            "%H:%M UTC %a %Y-%m-%d")),
        'Data source': f'CMIP6 global climate model historical simulations '\
            +'for historical sea ice concentration, with climate modes '\
            +'calculated by CVDP (doi:10.1002/2014EO490002)',
        'Analysis'   : 'https://github.com/chrisrwp/low-frequency-'\
            +'variability/blob/main/neural_network/Train_4_ML_Models.ipynb',
    }

    region_r_vals.attrs = r_values_attrs
    region_r_vals.to_netcdf(
        '/glade/work/cwpowell/low-frequency-variability/PyTorch_models/'\
        +f'validation_r_values_linear_LEs_region_{region_key}_month_{month__}'\
        +f'best_season_only.nc'
    )


    weights_attrs = r_values_attrs.copy()
    weights_attrs['Description'] = 'Linear weights of the useful seasons of '\
        +'modes of variability as computed by the Climate Variability '\
        +'Diagnoistics Package (CVDP) as follows: '\
        +f'{best_seasons_dict[str(region_key)]}. The target is 2 year lowpass '\
        +f'filtered sea ice concentrations for {month_names[month__-1]}. '\
        +'The linear model is trained/validated on the first 75/15% of the '\
        +'members from the global climate models '\
        +f'{good_GCM_names[str(region_key)]}. The linear model is trained at '\
        +f'different lag times of 1-20 years and for region {region_key} as'\
        +'defined by NSIDC MASIE-NH Version 1 (doi:10.7265/N5GT5K3K). '\
        +'Training uses an L1 loss function and Adam optimizer with 2000 '\
        +'epochs: nn.Sequential(nn.Linear(14,1,bias=False)), '\
        +'learning rate = 5e-4.'

    region_weights.attrs = weights_attrs
    region_weights.to_netcdf(
        '/glade/work/cwpowell/low-frequency-variability/PyTorch_models/'\
        +f'weights_linear_LEs_region_{region_key}_month_{month__}_'\
        +f'best_season_only.nc'
    )


2023-03-12 08:34:44.098146 2
2023-03-12 08:34:44.098260 CanESM5
2023-03-12 08:40:07.686749 CESM2-LENS
2023-03-12 08:45:32.221758 MIROC6
2023-03-12 08:51:00.016699 GISS-E2-1-G
2023-03-12 08:56:26.582114 IPSL-CM6A-LR
2023-03-12 09:01:49.879249 MIROC-ES2L
2023-03-12 09:07:13.078281 3
2023-03-12 09:07:13.078495 CanESM5
2023-03-12 09:12:37.781233 MIROC6
2023-03-12 09:18:02.065041 ACCESS-ESM1-5
2023-03-12 09:23:25.024493 IPSL-CM6A-LR
2023-03-12 09:28:48.349601 MIROC-ES2L
2023-03-12 09:34:09.434481 4
2023-03-12 09:34:09.434700 CanESM5
2023-03-12 09:39:32.117619 IPSL-CM6A-LR
2023-03-12 09:44:52.727347 5
2023-03-12 09:44:52.727524 CanESM5
2023-03-12 09:50:16.965656 CESM2-LENS
2023-03-12 09:55:38.767928 MIROC6
2023-03-12 10:01:01.103115 MPI-ESM1-2-LR
2023-03-12 10:06:21.618735 GISS-E2-1-H
2023-03-12 10:11:42.766342 6
2023-03-12 10:11:42.766537 CanESM5
2023-03-12 10:17:06.473921 GISS-E2-1-H
2023-03-12 10:22:28.301052 11
2023-03-12 10:22:28.301229 CanESM5
2023-03-12 10:27:50.396612 MIROC6
2023-03-

In [18]:
# loss_fcn  = torch.nn.L1Loss() #keep all models sparse
# lag_range = np.arange(5,11)

# month_names = ['January', 'February', 'March', 'April', 'May', 'June', 'July', 
#                'August', 'September', 'October', 'November', 'December']

# for region_key in [1,2,3,4,5,6,11]:
#     print(datetime.datetime.now(), region_key)
    
#     GCM_r_vals = []
#     GCM_weights = []
#     drop_dict = {}
#     for GCM_i, GCM in enumerate(good_GCM_names[str(region_key)]):
        
#         doi_model = CMIP6_info['doi'].sel(model=GCM).values

#         month__ = good_GCM_months[str(region_key)][GCM_i]
        
#         r_diff = r_vals_LE_drop1.sel(region=region_key).sel(lag=lag_range).mean(
#             'lag').sel(model_name=GCM)**2 - r_vals_LE_drop1.sel(
#             region=region_key).sel(lag=lag_range).mean('lag').sel(
#             model_name=GCM).sel(drop_var='RAND')**2

#         drop_list = r_diff['drop_var'].where(r_diff>0, drop=True).values
#         drop_dict[GCM] = drop_list
#         print(datetime.datetime.now(), GCM, drop_list)
        
#         r_values_xr, all_weights_xr = train_linear_model_remove(
#             model_name=GCM, month_=month__, region_list=[region_key], 
#             lag_list=np.arange(1,21), start_end_yr=[1941,2014], 
#             n_epoch=2000, learn_rate=5e-4, extra_drop_=drop_list,
#             white_noise_=False,
#         )  
        
#         GCM_r_vals.append(r_values_xr)
#         GCM_weights.append(all_weights_xr)
        
#     region_r_vals = xr.concat((GCM_r_vals), dim='model_name')
#     region_r_vals['model_name'] = good_GCM_names[str(region_key)]
    
#     region_weights = xr.concat((GCM_weights), dim='model_name')
#     region_weights['model_name'] = good_GCM_names[str(region_key)]

#     r_values_attrs = {
#         'Description': 'Pearson correlation coefficient for validation '\
#             +'data (15%) of members from the global climate models as '\
#             +f'follows {good_GCM_names[str(region_key)]} for '\
#             +f'{month_names[month__-1]}, with modes of '\
#             +f'variability yielding lower predictive skill than a random '\
#             +f'variable dropped as follows: {drop_dict}. The full list of 15 '\
#             +'variables is: AMO, IPO, NINO34, PDO, AMM, ATN, IOD, NPI, NAM, '\
#             +'NPO, PNA, NAO, SAM, TAS, as computed by the Climate Variability '\
#             +'Diagnostics Package (CVDP) with 2 year lowpass filtered '\
#             +'sea ice concentration. The linear model is trained at '\
#             +f'different lag times of 1-20 years, the region is {region_key} '\
#             +'as defined by NSIDC MASIE-NH Version 1 (doi:10.7265/N5GT5K3K). '\
#             +'Training uses an L1 loss function and Adam optimizer with 2000 '\
#             +'epochs: nn.Sequential(nn.Linear(n_features,1,bias=False)), '\
#             +'learning rate = 5e-4.',
#         'Timestamp'  : str(datetime.datetime.utcnow().strftime(
#             "%H:%M UTC %a %Y-%m-%d")),
#         'Data source': f'CMIP6 global climate model historical simulations '\
#             +'for historical sea ice concentration, with climate modes '\
#             +'calculated by CVDP (doi:10.1002/2014EO490002)',
#         'Analysis'   : 'https://github.com/chrisrwp/low-frequency-'\
#             +'variability/blob/main/neural_network/Train_4_ML_Models.ipynb',
#     }

#     region_r_vals.attrs = r_values_attrs
#     region_r_vals.to_netcdf(
#         '/glade/work/cwpowell/low-frequency-variability/PyTorch_models/'\
#         +f'validation_r_values_linear_LEs_region_{region_key}_month_{month__}'\
#         +f'dropped_useless_vars.nc'
#     )


#     weights_attrs = r_values_attrs.copy()
#     weights_attrs['Description'] = 'Linear weights of the useful modes of '\
#         +'variability as computed by the Climate Variability Diagnoistics'\
#         +'Package (CVDP) with 2 year lowpass filtered sea ice concentration '\
#         +f'for {month_names[month__-1]}. The linear '\
#         +f'model is trained/validated on the first 75/15% of the members '\
#         +f'from the global climate models {good_GCM_names[str(region_key)]}. '\
#         +'Climate modes which do not yield a higher predictive skill than a '\
#         +f'random variable are removed as follows {drop_dict}. The full list '\
#         +'of variables is as follows: AMO, IPO, NINO34, PDO, AMM, ATN, IOD, '\
#         +'NPI, NAM, NPO, PNA, NAO, SAM, TAS. The linear model is trained at '\
#         +f'different lag times of 1-20 years and for region {region_key} as'\
#         +'defined by NSIDC MASIE-NH Version 1 (doi:10.7265/N5GT5K3K). '\
#         +'Training uses an L1 loss function and Adam optimizer with 2000 '\
#         +'epochs: nn.Sequential(nn.Linear(n_features,1,bias=False)), '\
#         +'learning rate = 5e-4.'

#     region_weights.attrs = weights_attrs
#     region_weights.to_netcdf(
#         '/glade/work/cwpowell/low-frequency-variability/PyTorch_models/'\
#         +f'weights_linear_LEs_region_{region_key}_month_{month__}_'\
#         +f'dropped_useless_vars.nc'
#     )

# 4. Remove one variable at a time from the MMLE 3+ models, and retrain on the useful variables

In [284]:
loss_fcn  = torch.nn.L1Loss() #keep all models sparse

var_list_use = np.delete(np.array(var_month_list[:-3]).copy(),[5,6,7,29,30,31])

for region_key in [1,2,3,4,5,6,11]:
    print(datetime.datetime.now(), region_key)
    
    month__ = good_GCM_months[str(region_key)][0]
    
    #skip this one if already run this analysis
    if len(glob.glob(
        '/glade/work/cwpowell/low-frequency-variability/PyTorch_models/'\
        +f'validation_r_values_linear_CMIP6_month_{str(month__).zfill(2)}_'\
        +f'var_15_drop_1_seasonal_region_{region_key}.nc')):
        continue

    r_values_list = []
    weights_list  = []

    for drop_var in var_list_use:
        print(datetime.datetime.now(), drop_var)
        
        if drop_var == 'RAND_1':
            r_values_xr, all_weights_xr = train_linear_model_remove(
                'CMIP6', month_=month__, region_list=[region_key], 
                lag_list=np.arange(1,21), start_end_yr=[1941,2014], 
                n_epoch=2000, learn_rate=5e-4, extra_drop_=None,
                white_noise_=None,
            )
            
        else:
            r_values_xr, all_weights_xr = train_linear_model_remove(
                'CMIP6', month_=month__, region_list=[region_key], 
                lag_list=np.arange(1,21), start_end_yr=[1941,2014], 
                n_epoch=2000, learn_rate=5e-4, extra_drop_=[drop_var],
                white_noise_=True,
            )    

        r_values_list.append(r_values_xr)
        weights_list.append(all_weights_xr)
        
    r_values_save = xr.concat((r_values_list), dim='drop_var')
    r_values_save['drop_var'] = var_list_use
    
    weights_save = xr.concat((weights_list), dim='drop_var')
    weights_save['drop_var'] = var_list_use

    r_values_attrs = {
        'Description': 'Pearson correlation coefficient for validation '\
            +'data of the second members from 42 CMIP6 GCMs. Training data '\
            +'was the first member. 14 of the 15 modes of variability as '\
            +'computed by the Climate Variability Diagnostics Package (CVDP) '\
            +'with 2 year lowpass filtered sea ice concentration. These '\
            +'climate modes are as follows: AMO, IPO, NINO34, PDO, AMM, ATN, '\
            +'IOD, NPI, NAM, NPO, PNA, NAO, SAM, TAS, RAND. The linear model '\
            +'is trained at different lag times of 1-20 years and for each '\
            +'region (regions are defined by NSIDC MASIE-NH Version 1 '\
            +'(doi:10.7265/N5GT5K3K). Using an L1 loss function and Adam '\
            +'optimizer with 2000 epochs: nn.Sequential(nn.Linear(51,1,'\
            +'bias=False)), learning rate = 5e-4.',
        'Timestamp'  : str(datetime.datetime.utcnow().strftime(
            "%H:%M UTC %a %Y-%m-%d")),
        'Data source': f'CMIP6 global climate model multi-model ensemble'\
            +f'doi:10.5194/gmd-9-1937-2016 for historical sea ice '\
            +'concentration simulation, climate modes calculated by CVDP '\
            +'(doi:10.1002/2014EO490002)',
        'Analysis'   : 'https://github.com/chrisrwp/low-frequency-'\
            +'variability/blob/main/neural_network/Train_4_ML_Models.ipynb',
    }

    r_values_save.attrs = r_values_attrs
    r_values_save.to_netcdf(
        '/glade/work/cwpowell/low-frequency-variability/PyTorch_models/'\
        +f'validation_r_values_linear_CMIP6_month_{str(month__).zfill(2)}_'\
        +f'var_15_drop_1_seasonal_region_{region_key}.nc'
    )


    weights_attrs = r_values_attrs.copy()
    weights_attrs['Description'] = 'Linear weights including 14 of the 15 '\
        +'modes of variability as computed by the Climate Variability '\
        +'Diagnostics Package (CVDP) with 2 year lowpass filtered sea ice '\
        +'concentration. The linear model is trained on the 1st member of 42 '\
        +'CMIP6 GCMs and validated on the second members. The climate modes '\
        +'are as follows: AMO, IPO, NINO34, PDO, AMM, ATN, IOD, NPI, NAM, '\
        +'NPO, PNA, NAO, SAM, TAS, RAND. The linear model is trained at '\
        +'different lag times of 1-20 years and for each region (regions are '\
        +'defined by NSIDC MASIE-NH Version 1 (doi:10.7265/N5GT5K3K). Using '\
        +'an L1 loss function and Adam optimizer with 2000 epochs: '\
        +'nn.Sequential(nn.Linear(51,1,bias=False)), learning rate = 5e-4.'

    weights_save.attrs = weights_attrs
    weights_save.to_netcdf(
        '/glade/work/cwpowell/low-frequency-variability/PyTorch_models/'\
        +f'weights_linear_CMIP6_month_{str(month__).zfill(2)}_'\
        +f'var_15_drop_1_seasonal_region_{region_key}.nc'
    )

2023-03-10 09:54:57.682755 1
2023-03-10 09:54:57.683428 2
2023-03-10 09:54:57.683508 3
2023-03-10 09:54:57.683550 4
2023-03-10 09:54:57.683619 AMO_1
2023-03-10 09:55:51.048695 AMO_4
2023-03-10 09:56:44.936684 AMO_7
2023-03-10 09:57:38.676356 AMO_10
2023-03-10 09:58:32.428096 IPO_1
2023-03-10 09:59:25.946433 NINO34_1
2023-03-10 10:00:22.085388 NINO34_4
2023-03-10 10:01:18.912533 NINO34_7
2023-03-10 10:02:11.911358 NINO34_10
2023-03-10 10:03:04.731529 PDO_1
2023-03-10 10:03:57.569299 PDO_4
2023-03-10 10:04:50.442948 PDO_7
2023-03-10 10:05:43.267633 PDO_10
2023-03-10 10:06:35.969264 AMM_1
2023-03-10 10:07:28.636652 AMM_4
2023-03-10 10:08:21.662317 AMM_7
2023-03-10 10:09:15.293885 AMM_10
2023-03-10 10:10:10.025805 ATN_1
2023-03-10 10:11:03.525362 ATN_4
2023-03-10 10:11:56.319964 ATN_7
2023-03-10 10:12:50.402928 ATN_10
2023-03-10 10:13:43.677340 IOD_1
2023-03-10 10:14:40.468639 IOD_4
2023-03-10 10:15:39.367538 IOD_7
2023-03-10 10:16:38.130180 IOD_10
2023-03-10 10:17:36.977863 NPI_1
2023-03-

## Retrain the CMIP6 linear model with only the useful variables

In [28]:
#load all of the drop1 CMIP6 data
r_vals_CMIP6_drop1 = []
for month_ in np.arange(1,13):

    r_values = xr.open_dataset(
        '/glade/work/cwpowell/low-frequency-variability/PyTorch_models/'\
        +f'validation_r_values_linear_CMIP6_month_{str(month_).zfill(2)}'\
        +f'_var_15_drop_1.nc'
    )

    r_vals_CMIP6_drop1.append(r_values['r_value'])
    
r_vals_CMIP6_drop1 = xr.concat((r_vals_CMIP6_drop1), dim='month')
r_vals_CMIP6_drop1['month'] = np.arange(1,13)

In [37]:
#obtain all of the train/validate/test GCM members
CMIP6_GCM_list = []
for GCM in np.sort(list(good_GCM_mem.keys())):    
    n_mem = len(good_GCM_mem[GCM])
    if n_mem > 2:
        CMIP6_GCM_list.append(GCM)

In [42]:
loss_fcn  = torch.nn.L1Loss() #keep all models sparse
lag_range = np.arange(5,11)

month_names = ['January', 'February', 'March', 'April', 'May', 'June', 'July', 
               'August', 'September', 'October', 'November', 'December']

#select the months where multi-model ensemble has highest r2 above persistence
good_CMIP6_months = [9, 8, 10, 10, 10, 9, 8]

CMIP6_r_vals  = []
CMIP6_weights = []
for region_i, region_key in enumerate([1,2,3,4,5,6,11]):
    print(datetime.datetime.now(), region_key)

    month__ = good_CMIP6_months[region_i]

    r_diff = r_vals_CMIP6_drop1.sel(region=region_key).sel(month=month__).sel(
        lag=lag_range).mean('lag')**2 - r_vals_CMIP6_drop1.sel(
        region=region_key).sel(month=month__).sel(lag=lag_range).mean(
        'lag').sel(drop_var='RAND')**2

    drop_list = r_diff['drop_var'].where(r_diff>0, drop=True).values
    drop_dict[str(region_)] = drop_list
    print(datetime.datetime.now(), drop_list)

    r_values_xr, all_weights_xr = train_linear_model_remove(
        model_name='CMIP6', month_=month__, region_list=[region_key], 
        lag_list=np.arange(1,21), start_end_yr=[1941,2014], 
        n_epoch=2000, learn_rate=5e-4, extra_drop_=drop_list,
        white_noise_=False,
    )  

    CMIP6_r_vals.append(r_values_xr)
    CMIP6_weights.append(all_weights_xr)
        
region_r_vals = xr.concat((CMIP6_r_vals), dim='region')
region_r_vals['region'] = [1,2,3,4,5,6,11]

region_weights = xr.concat((CMIP6_weights), dim='region')
region_weights['region'] = [1,2,3,4,5,6,11]

r_values_attrs = {
    'Description': 'Pearson correlation coefficient for validation '\
        +'data from all second ensemble members from 42 CMIP6 Global Climate '\
        +f'Models, train on first members. The full list of GMCs is as '\
        +f'follows: {CMIP6_GCM_list}, for regions [1,2,3,4,5,6,11] with '\
        +f'months corresponding to {good_GCM_months}. The modes of '\
        +f'variability yielding lower predictive skill than a random variable '\
        +f'are dropped as follows: {drop_dict}. The full list of 15 '\
        +'variables is: AMO, IPO, NINO34, PDO, AMM, ATN, IOD, NPI, NAM, '\
        +'NPO, PNA, NAO, SAM, TAS, as computed by the Climate Variability '\
        +'Diagnostics Package (CVDP) with 2 year lowpass filtered '\
        +'sea ice concentration. The linear model is trained at '\
        +f'different lag times of 1-20 years, the regions are defined '\
        +'by NSIDC MASIE-NH Version 1 (doi:10.7265/N5GT5K3K). '\
        +'Training uses an L1 loss function and Adam optimizer with 2000 '\
        +'epochs: nn.Sequential(nn.Linear(n_features,1,bias=False)), '\
        +'learning rate = 5e-4.',
    'Timestamp'  : str(datetime.datetime.utcnow().strftime(
        "%H:%M UTC %a %Y-%m-%d")),
    'Data source': f'CMIP6 global climate model historical simulations '\
        +'for historical sea ice concentration, with climate modes '\
        +'calculated by CVDP (doi:10.1002/2014EO490002)',
    'Analysis'   : 'https://github.com/chrisrwp/low-frequency-'\
        +'variability/blob/main/neural_network/Train_4_ML_Models.ipynb',
}

region_r_vals.attrs = r_values_attrs
region_r_vals.to_netcdf(
    '/glade/work/cwpowell/low-frequency-variability/PyTorch_models/'\
    +f'validation_r_values_linear_CMIP6_all_regions_best_months_'\
    +f'dropped_useless_vars.nc'
)


weights_attrs = r_values_attrs.copy()
weights_attrs['Description'] = 'Linear weights of the useful modes of '\
    +'variability as computed by the Climate Variability Diagnoistics'\
    +'Package (CVDP) with 2 year lowpass filtered sea ice concentration '\
    +f'for all regions [1,2,3,4,5,6,11] for months {good_GCM_months}. '\
    +f'The linear model is trained/validated on the first/second of all '\
    +f'members from the global climate models as follows {CMIP6_GCM_list}. '\
    +'Climate modes which do not yield a higher predictive skill than a '\
    +f'random variable are removed as follows {drop_dict}. The full list '\
    +'of variables is as follows: AMO, IPO, NINO34, PDO, AMM, ATN, IOD, '\
    +'NPI, NAM, NPO, PNA, NAO, SAM, TAS. The linear model is trained at '\
    +f'different lag times of 1-20 years and for region {region_key} as'\
    +'defined by NSIDC MASIE-NH Version 1 (doi:10.7265/N5GT5K3K). '\
    +'Training uses an L1 loss function and Adam optimizer with 2000 '\
    +'epochs: nn.Sequential(nn.Linear(n_features,1,bias=False)), '\
    +'learning rate = 5e-4.'

region_weights.attrs = weights_attrs
region_weights.to_netcdf(
    '/glade/work/cwpowell/low-frequency-variability/PyTorch_models/'\
    +f'weights_linear_CMIP6_all_regions_best_months_'\
    +f'dropped_useless_vars.nc'
)