In [1]:
import time
import os
import gc
import glob

import numpy as np
import pandas as pd
import xarray as xr

import dask
import dask.array as da
from dask import delayed, compute

import cftime
import pop_tools

In [2]:
start_time = time.time()

### INITIALISATION ###

# Load data and group csv file by member
filename = f"change_point_indices_1.0_40_20.csv"
path = os.path.join(os.environ['HOME'], 'phase1_CONDA/publishable_code')
file = os.path.join(path, filename)
df = pd.read_csv(file)
#####              #####
##### INTERVENTION #####
#####              #####
df_first_three_rows = df.iloc[:3]
df_first_three_rows = df_first_three_rows.groupby('Member')
#####              #####
##### INTERVENTION #####
#####     END      #####
#####              #####
#grouped = df.groupby('Member')

# Extract the variables from the filename
filename = os.path.basename(file)
parts = filename.replace('change_point_indices_', '').replace('.csv', '').split('_')
threshold_multiple = float(parts[0])
P1_len = int(parts[1])
P2_len = int(parts[2])

# set up mask
grid_name = 'POP_gx1v7'
region_defs = {
    'subzero_Atlantic':[
        {'match': {'REGION_MASK': [6]}, 'bounds': {'TLAT': [10.0, 70.0], 'TLONG': [260.0, 360.0]}}
    ],
    'superzero_Atlantic':[
        {'match': {'REGION_MASK': [6]}, 'bounds': {'TLAT': [10.0, 70.0], 'TLONG': [0, 20.0]}}
    ],
    'Mediterranean': [
        {'match': {'REGION_MASK': [7]}}
    ],
    'LabradorSea': [
        {'match': {'REGION_MASK': [8]}, 'bounds': {'TLAT': [10.0, 70.0]}}
    ],
        'NordicSea': [
        {'match': {'REGION_MASK': [9]}, 'bounds': {'TLAT': [10.0, 70.0]}}
    ]
}
NA_mask = pop_tools.region_mask_3d(grid_name, region_defs=region_defs, mask_name='North Atlantic Mask')
NA_mask = NA_mask.sum('region')
NA_mask = NA_mask.roll(nlon=-100)

# set up paths
base_path = '/Data/gfi/share/ModData/CESM2_LENS2/ocean/monthly/'
temporary_path = '/Data/skd/scratch/innag3580/comp/temporary/'
final_path = '/Data/skd/scratch/innag3580/comp/composites/'

# easy variables
variables = ['TEMP', 
             'SALT', 'VVEL', 'SHF', 'HMXL', 'TAUX', 'TAUY']#, 'SIGMA_2']
base_name = ['temp', 
             'salt', 'vvel', 'shf', 'hmxl', 'taux', 'tauy']#, 'dens']
decrease_save_name = [f"decrease_{name}_{threshold_multiple}_{P1_len}_{P2_len}.nc" for name in base_name]
increase_save_name = [f"increase_{name}_{threshold_multiple}_{P1_len}_{P2_len}.nc" for name in base_name]

OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.


In [6]:
### COMPUTATION ###

#def standardise_time(ds):
#####              #####
##### INTERVENTION #####
#####              #####
def standardise_time(ds, time):
    print('standardise_time start')
    print(time.time())
    print('')
#####              #####
##### INTERVENTION #####
#####     END      #####
#####              #####
    ds['time'] = xr.decode_cf(ds, use_cftime=True).time
    if isinstance(ds.time.values[0], cftime._cftime.DatetimeNoLeap):
        time_as_datetime64 = np.array([pd.Timestamp(str(dt)).to_datetime64() for dt in ds.time.values])
        ds['time'] = xr.DataArray(time_as_datetime64, dims='time')
    return ds

#def DJFM_average(ds):
#####              #####
##### INTERVENTION #####
#####              #####
def DJFM_average(ds, time):
    print('DJFM_average start')
    print(time.time())
    print('')
#####              #####
##### INTERVENTION #####
#####     END      #####
#####              #####
    ds_first_FM  = ds.isel(time=slice(0,2)).coarsen(time=2, boundary='trim').mean()
    ds_DJFM = ds.isel(time=slice(2, None)).coarsen(time=4, boundary='trim').mean()
    ds_combined = xr.concat([ds_first_FM, ds_DJFM], dim='time')
    return ds_combined

#def prepare_ds_member(var, member_id):
#####              #####
##### INTERVENTION #####
#####              #####
def prepare_ds_member(var, member_id, time):
    print('prepare_ds_member start')
    print(time.time())
    print('')
#####              #####
##### INTERVENTION #####
#####     END      #####
#####              #####
    file_pattern = os.path.join(base_path, var, f'*BHIST*LE2-{member_id}*.nc')
    file_paths = sorted(glob.glob(file_pattern))
    datasets = []
    for file in file_paths:
        ds_member = xr.open_dataset(file, chunks={'time': 12})
        #ds_member = standardise_time(ds_member)
        #####              #####
        ##### INTERVENTION #####
        #####              #####
        ds_member = standardise_time(ds_member, time)
        #####              #####
        ##### INTERVENTION #####
        #####     END      #####
        #####              #####
        
        ds_member = ds_member.sel(time=ds_member['time.month'].isin([12, 1, 2, 3]))
        datasets.append(ds_member)
    #with dask.config.set(**{'array.slicing.split_large_chunks': False}):
    #    ds_member = xr.merge(datasets)
    #    #ds_member = DJFM_average(ds_member)
    #    #####              #####
    #    ##### INTERVENTION #####
    #    #####              #####
    #    ds_member = DJFM_average(ds_member, time)
    #    #####              #####
    #    ##### INTERVENTION #####
    #    #####     END      #####
    #    #####              #####
    
    
    #####              #####
    ##### INTERVENTION #####
    #####              #####
    ds_member = xr.merge(datasets)
    #ds_member = DJFM_average(ds_member)
    #####              #####
    ##### INTERVENTION #####
    #####              #####
    ds_member = DJFM_average(ds_member, time)
    #####              #####
    ##### INTERVENTION #####
    #####     END      #####
    #####              #####

    #####              #####
    ##### INTERVENTION #####
    #####     END      #####
    #####              #####
    ds_member = ds_member.roll(nlon=-100).where(NA_mask == 1)
    return ds_member   

In [41]:
var = 'TEMP'
member_id = '1001.001'

In [64]:
def standardise_time(ds):
    ds['time'] = xr.decode_cf(ds, use_cftime=True).time
    if isinstance(ds.time.values[0], cftime._cftime.DatetimeNoLeap):
        time_as_datetime64 = np.array([pd.Timestamp(str(dt)).to_datetime64() for dt in ds.time.values])
        ds['time'] = xr.DataArray(time_as_datetime64, dims='time')
    return ds

def DJFM_average(ds):
    numeric_vars = {k: v for k, v in ds.data_vars.items() if np.issubdtype(v.dtype, np.number)}
    ds_numeric = xr.Dataset(numeric_vars, coords=ds.coords)
    
    ds_first_FM  = ds_numeric.isel(time=slice(0,2)).coarsen(time=2, boundary='trim').mean()
    ds_DJFM = ds_numeric.isel(time=slice(2, None)).coarsen(time=4, boundary='trim').mean()
    ds_combined = xr.concat([ds_first_FM, ds_DJFM], dim='time')
    return ds_combined

def prepare_ds_member(var, member_id):
    print('prepare_ds_member start')
    print('')
    file_pattern = os.path.join(base_path, var, f'*BHIST*LE2-{member_id}*.nc')
    file_paths = sorted(glob.glob(file_pattern))
    
    ds_member = xr.open_mfdataset(file_paths, chunks={'time': 120}, preprocess=standardise_time)

    ds_member = ds_member.sel(time=ds_member['time.month'].isin([12, 1, 2, 3]))
    ds_member = DJFM_average(ds_member)
    ds_member = ds_member.roll(nlon=-100).where(NA_mask == 1)
    
    return ds_member 

In [65]:
ds_member = prepare_ds_member(var, member_id)

prepare_ds_member start



    >>> with dask.config.set(**{'array.slicing.split_large_chunks': False}):
    ...     array[indexer]

To avoid creating the large chunks, set the option
    >>> with dask.config.set(**{'array.slicing.split_large_chunks': True}):
    ...     array[indexer]
  return self.array[key]
    >>> with dask.config.set(**{'array.slicing.split_large_chunks': False}):
    ...     array.reshape(shape)

To avoid creating the large chunks, set the option
    >>> with dask.config.set(**{'array.slicing.split_large_chunks': True}):
    >>> array.reshape(shape, limit='128 MiB')
  ds_DJFM = ds_numeric.isel(time=slice(2, None)).coarsen(time=4, boundary='trim').mean()
