In [1]:
from matplotlib import pyplot as plt

import xarray as xr
import netCDF4 as nc
import numpy as np

import os

import datetime as dt
import pickle
import h5py

In [2]:
# Selecting low DIV for train and high DIV for test
models = ['MIROC6', 'CESM2', 'CanESM5', 'MIROC-ES2L', 'MPI-ESM1-2-LR']
ref_period = ('1850-01-01', '1900-01-01')
coarsen_factor = 12
time_scale = 'month' # Only yearly or montly

In [3]:
def load_data(models, var='pr', time_scale='month', ref_period=None, path='../../AnchorMultivariateAnalysis/data/ForceSMIP/Training/Amon/psl/ForceSMIP/', coarsen_factor=None, cp_anomalies=True) :
    ensemble = {}
    flag = True
    for model in models:
        print('## Model {}'.format(model))
        # Loop through each file
        directory = path + model
        # Define the file path
        listdir = os.listdir(directory)

        data = None
        ensemble[model] = {}
        for i, file in enumerate(listdir, start=1):
            if i > 25:
                break
            print('File {}/{}'.format(i,len(listdir)), end='\r')
            # Reading temperature file
            ## Open the NetCDF file using xarray
            file_path = os.path.join(directory, file)
            ds = xr.open_dataset(file_path)
            # Compute anomalies
            if time_scale == 'month':
                if cp_anomalies:
                    climatology = ds.groupby('time.month').mean(dim='time')
                    anomalies = ds.groupby('time.month') - climatology
                else :
                    anomalies = ds
                
            elif time_scale == 'year' :
                # Extracting yearly avergaes
                ds_yearly = ds.resample(time='1Y').mean()
                # Calculate the mean over the reference period for each grid point
                mean_ref_period = ds_yearly.sel(time=slice(ref_period[0], ref_period[1])).mean(dim='time')
                anomalies = ds_yearly - mean_ref_period
                            
            if coarsen_factor is not None:
                anomalies = anomalies.coarsen(lat=coarsen_factor, lon=coarsen_factor, boundary='trim').mean()
            
            if flag:
                if time_scale=='year':
                    ensemble['time'] = np.unique(anomalies['time'])
                elif time_scale == 'month':
                     ensemble['time'] = anomalies['time']
                ensemble['lat'] = anomalies['lat'].values
                ensemble['lon'] = anomalies['lon'].values  
                flag = False

            if data is None:
                data = [anomalies[var].values]
            else :
                data.append(anomalies[var].values)

            ds.close()
        ensemble[model][var] = np.array(data)
        print()
    return ensemble

def save_data(data, model='CanESM5', var='tas', data_path='../data/', name_adder='', pkl=False):
    n_members = data[model][var].shape[0]
    # Create a NetCDF file
    if not pkl :
        with nc.Dataset(data_path + '{}_{}'.format(model, var) + name_adder + '.nc', 'w') as f:
            # Define dimensions
            f.createDimension('n_members', n_members)
            f.createDimension('time', len(data['time']))
            f.createDimension('lat', len(data['lat']))
            f.createDimension('lon', len(data['lon']))

            # Create variables
            members_var = f.createVariable('n_members', 'i4', ('n_members',))
            time_var = f.createVariable('time', 'f8', ('time',))
            lat_var = f.createVariable('lat', 'f4', ('lat',))
            lon_var = f.createVariable('lon', 'f4', ('lon',))
            tas_var = f.createVariable(var, 'f4', ('n_members', 'time', 'lat', 'lon'))

            # Assign data to variables
            members_var[:] = np.arange(n_members)
            time_var[:] = datetime_array = range(1716)#nc.date2num(np.array([dt.datetime(d['time.year'].values, d['time.month'].values, d['time.day'].values) for d in data['time']]), units='days since 1850-01-01 00:00:00', calendar='noleap')
            lat_var[:] = data['lat']
            lon_var[:] = data['lon']
            tas_var[:] = data[model][var]

            # Add attributes if necessary
            members_var.units = 'member index'
            time_var.units = 'time units'
            lat_var.units = 'latitude units'
            lon_var.units = 'longitude units'
            tas_var.units = '{} units'.format(var)
    else :
        hf = h5py.File(data_path + '{}_tas'.format(model) + name_adder + '.h5', 'w')
        hf.create_dataset('data', data=data[model][var])
        hf.close()

In [4]:
for model in models :
    data = load_data([model], var='psl')
    save_data(data, model=model, var='psl')

## Model MIROC6
File 25/50
## Model CESM2
File 25/50
## Model CanESM5
File 25/25
## Model MIROC-ES2L
File 25/30
## Model MPI-ESM1-2-LR
File 24/24
