#install

In [None]:
!pip install pycwt

Collecting pycwt
  Downloading pycwt-0.4.0b0-py3-none-any.whl (753 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m753.5/753.5 kB[0m [31m8.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting numpy<2,>=1.24 (from pycwt)
  Downloading numpy-1.26.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m18.2/18.2 MB[0m [31m43.3 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: numpy, pycwt
  Attempting uninstall: numpy
    Found existing installation: numpy 1.23.5
    Uninstalling numpy-1.23.5:
      Successfully uninstalled numpy-1.23.5
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
lida 0.0.10 requires fastapi, which is not installed.
lida 0.0.10 requires kaleido, which is not installed.
lida 0.0.10 requires python-multipart, which is not installed.

#imports

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import glob
import os
import numpy as np
import string
ALPHA = string.ascii_letters
import math
import matplotlib.colors
import time
import matplotlib.dates as mdates
import re
import matplotlib.animation as animation

from IPython.display import HTML

from scipy import stats

from scipy.signal import find_peaks, peak_widths, peak_prominences
import seaborn as sns

from typing import List
import gc

from matplotlib.image import NonUniformImage
from matplotlib.ticker import ScalarFormatter, FuncFormatter

from tqdm import tqdm
import pycwt
import pycwt as wavelet
from pycwt.helpers import (ar1,get_cache_dir, rednoise)
from pycwt.mothers import DOG, MexicanHat, Morlet, Paul
from multiprocessing import Pool
from concurrent.futures import ThreadPoolExecutor
import matplotlib.patches as mpatches


#Functions

##Calculating Functions

###General

In [None]:
def create_SIF_dataframe(text_files: List[str], start_date: str = '2023-05-14', end_date: str = '2023', timezone: str = 'US/Central') -> pd.DataFrame:
    """
    Creates a DataFrame for the Solar-Induced Fluorescence (SIF) data as well as other indices computed with the SIF software GUI.
    Also adds some key environmental variables such as clearness index, Vapour-pressure deficit, and PRI scaled.

    Parameters:
        text_files (List[str]): List of text files containing SIF data.
        start_date (str): The start date for the data range. Defaults to '2023-05-14'.
        end_date (str): The end date for the data range. Defaults to '2023'.
        timezone (str): The timezone for the data. Defaults to 'US/Central'.

    Returns:
        pd.DataFrame: A DataFrame with SIF data.
    """
    # Read the SIF data from the text file
    SIF_data_df = pd.read_csv([file for file in text_files if 'ALL_INDEX' in file][0], sep=';', na_values=[r'#N/D'])

    # Convert the datetime column to datetime format and adjust for duplicates
    SIF_data_df['datetime [UTC]'] = pd.to_datetime(SIF_data_df['datetime [UTC]'], utc=True)
    SIF_data_df['datetime [UTC]'] += pd.to_timedelta(SIF_data_df.groupby('datetime [UTC]').cumcount(), unit='m')

    # Convert UTC to local timezone and set as index
    SIF_data_df.index = SIF_data_df['datetime [UTC]'].dt.tz_convert(timezone)
    SIF_data_df.index.name = 'datetime'

    # Filter the DataFrame based on the specified date range
    SIF_data_df = SIF_data_df.sort_index().loc[start_date: end_date]

    # Calculate additional columns
    SIF_data_df['PRI_s'] = ((SIF_data_df['PRI'] + 1 ) / 2)
    VPD = (1 - SIF_data_df['h2 [%]']) * .611 * np.exp((17.27 * SIF_data_df['temp4 [C]'])  / (237.3 + SIF_data_df['temp4 [C]']))
    SIF_data_df['VPD'] = VPD
    Rg_vis = SIF_data_df['PAR inc [W m-2]']
    R0_vis = 1367 * (1 + .033 * np.cos(2 * np.pi * SIF_data_df['doy.dayfract'] / 365)) * np.cos(SIF_data_df['SZA'] * (np.pi / 180)) * .389
    CI = (Rg_vis / R0_vis).to_frame('CI')
    SIF_data_df['CI'] = CI

    return SIF_data_df

In [None]:
def reshape_timeseries_to_day_col(dataframe: pd.DataFrame, column_names: List[str]) -> pd.DataFrame:
    """
    Reshapes a time series DataFrame into a DataFrame where time is the index and dates are the columns.
    This is useful for conducting analysis such as looking at the trend of every day at a specific time.

    Parameters:
        dataframe (pd.DataFrame): The input DataFrame.
        column_names (List[str]): The names of the columns to be processed.

    Returns:
        pd.DataFrame: The processed DataFrame with time as the index and dates as the columns.
    """
    # Copy the specified columns from the input DataFrame
    processed_df = dataframe[column_names].copy()

    # Change the index to time and date
    processed_df.index = [processed_df.index.time, processed_df.index.date]

    # Unstack the DataFrame (this pivots the DataFrame from long to wide format)
    processed_df = processed_df.unstack()

    # Interpolate the DataFrame
    processed_df = processed_df.interpolate()

    return processed_df

In [None]:
def calculate_multiscale_rolling_correlations(dataframe: pd.DataFrame, column1: str, column2: str, start: int, end: int, steps: int) -> pd.DataFrame:
    """
    Calculates rolling correlations between two signals in a DataFrame
    at different window sizes, similar to a wavelet transform. This allows for
    cross-temporal scale analysis.

    Parameters:
        dataframe (pd.DataFrame): The input DataFrame.
        column1, column2 (str): The names of the two columns to calculate correlations between.
        start, end (int): The start and end values for the window sizes in seconds.
        steps (int): The number of steps between start and end.

    Returns:
        pd.DataFrame: A DataFrame with the rolling correlations for each window size.
    """
    # Generate a list of window sizes in a geometric sequence
    window_sizes = np.char.add(np.geomspace(start, end, steps).astype(int).astype(str), 's')

    corr_dfs = []
    # For each window size...
    for window in window_sizes:
        # Calculate the rolling correlation between the two columns
        rolling_df = dataframe[[column1, column2]]
        corr_df = rolling_df[column1].rolling(window=window, center=True).corr(rolling_df[column2])
        corr_dfs.append(corr_df)

    # Combine the DataFrames for each window size into a single DataFrame
    corr_dfs_var_win_len = pd.concat(corr_dfs, axis=1)
    corr_dfs_var_win_len.columns = np.geomspace(start, end, steps).astype(int)

    return corr_dfs_var_win_len

In [None]:
def filter_text_files(text_files: List[str], *conditions: str) -> List[str]:
    """
    This function takes in a list of text files and any number of conditions,
    and returns the filtered text files based on the conditions.

    :param text_files: List of text files to filter
    :param conditions: Any number of conditions to filter the text files
    :return: A list of filtered text files
    """
    # Create a list to store the filtered text files
    filtered_text_files = []

    # Loop over each text file
    for f in text_files:
        # Assume that the file meets all conditions
        meets_all_conditions = True

        # Loop over each condition
        for condition in conditions:
            # Check if the condition is not blank
            if condition:
                # Check if the condition is not met by the file
                if condition not in f:
                    # Set the flag to False and break out of the loop
                    meets_all_conditions = False
                    break

        # Check if the file meets all conditions
        if meets_all_conditions:
            # Add the file to the list of filtered text files
            filtered_text_files.append(f)

    # Return the list of filtered text files
    return filtered_text_files


In [None]:
def process_raw_spectrometer_file(file_path: str) -> pd.DataFrame:
    """
    Processes a raw spectrometer file and returns a DataFrame.

    Parameters:
        file_path (str): The path to the raw spectrometer file.

    Returns:
        pd.DataFrame: A DataFrame containing the processed data.
    """
    # Read raw data from CSV file
    raw_data = pd.read_csv(file_path, header=None, low_memory=True)

    # Extract header data
    header_data = raw_data.loc[raw_data[0].str[0].str.isdigit(), :][0].str.split(';', expand=True)

    # Parse timestamp from header data
    date_time = pd.to_datetime(header_data[1].astype(str) + header_data[2].astype(str), format='%y%m%d%H%M%S').rename('date-time')
    date_time.index += 1

    # Split data arrays
    array_data = raw_data[0].str.split(';', expand=True)
    data_arrays = array_data[array_data[0].str.startswith(tuple(ALPHA))]

    # Combine timestamp and data arrays
    fluo_array_df = pd.concat([date_time, data_arrays], axis=1).sort_index()
    fluo_array_df.iloc[:, 0].fillna(method='ffill', inplace=True)

    # Reshape data arrays to create rows for each cycle
    fluo_array_df = fluo_array_df.set_index(['date-time', 0]).T
    fluo_array_df.drop(fluo_array_df.tail(1).index, inplace=True)

    # Set column names
    fluo_array_df.columns.rename({0: 'array'}, inplace=True)
    fluo_array_df = fluo_array_df.T
    fluo_array_df = fluo_array_df.apply(pd.to_numeric, errors='ignore', downcast='integer')

    return fluo_array_df


def process_raw_spectrometer_files(file_paths: List[str], output_directory: str):
    """
    Processes multiple raw spectrometer files and saves the results to CSV files.

    Parameters:
        file_paths (List[str]): List of file paths for the raw spectrometer files.
        output_directory (str): The directory to save the processed CSV files.

    Returns:
        None
    """
    for i, file_path in enumerate(file_paths):
        # Process the file
        result = process_raw_spectrometer_file(file_path)

        # Generate a unique file name
        output_file = f"{output_directory}/processed_file_{i}.csv"

        # Save the result to a CSV file
        result.to_csv(output_file)


In [None]:
def combine_processed_raw_spectrometer_files(file_paths):
    """
    Combines multiple processed raw spectrometer files into a single DataFrame.

    Parameters:
    file_paths (list): List of file paths for the spectrometer files to combine.

    Returns:
    pd.DataFrame: A DataFrame containing the combined data from all input files.
    """
    dataframe_list = []
    for file_path in file_paths:
        dataframe = pd.read_csv(file_path, index_col=[0, 1], header=[0])
        dataframe_list.append(dataframe)
    combined_dataframe = pd.concat(dataframe_list)
    return combined_dataframe


In [None]:
def red_noise_fill_gaps(series):
    # Check if the series is 'doy.dayfract'
    if series.name == 'doy.dayfract':
        # Resample the series for every minute
        series_resampled = series.resample('1T').mean()

        # Create a fractional day of the year series
        doy = series_resampled.index.dayofyear
        day_fraction = series_resampled.index.hour / 24 + series_resampled.index.minute / (24 * 60) + series_resampled.index.second / (24 * 60 * 60)
        fractional_doy = doy + day_fraction

        return fractional_doy
    else:
        # Calculate al1 on the original data
        signal = series.dropna().values
        al1 = ar1(signal)[0]

        # Resample the series for every minute
        series_resampled = series.resample('1T').mean()

        # Linearly interpolate gaps that are 1-2 minute wide
        series_interpolated = series_resampled.interpolate(method='linear', limit=2)

        # Generate red noise for NaN gaps
        N = len(series_interpolated)
        noise1 = rednoise(N, al1, 1)

        # Substitute red noise for NaN gaps in the resampled data
        series_filled = series_interpolated.where(~np.isnan(series_interpolated), noise1)

        return series_filled

In [None]:
def df_from_processed_spectrometer_files(text_files, condition1, condition2):
    """
    This function takes in a list of text files, and two conditions to filter the files.
    It reads the data from the filtered text files into a pandas DataFrame, and performs
    some processing on the data to create a final DataFrame.

    Used to process radiance and reflectance data.

    :param text_files: List of text files to read data from
    :param condition1 (str): First condition to filter the text files
    :param condition2 (str): Second condition to filter the text files
    :return: A pandas DataFrame containing the processed data
    """

    # Filter the text files based on the given conditions
    filtered_text_files = filter_text_files(text_files, condition1, condition2)

    # Read the data from the filtered text files into a DataFrame
    df = pd.concat((pd.read_csv(f, sep=';', header=[0], index_col=[0],na_values=[r'#N/D']).T
                   .assign(date=f.rsplit('/',2)[-2]) # add column with date info from filename
           for f in filtered_text_files),
                             axis=0)

    # Convert the index to a datetime object representing time
    df['time'] = pd.to_datetime(df.index,
               format='%H_%M_%S',
               utc=True,
               exact=False,
               infer_datetime_format=False,
               )
    # Create a new column 'date_time' by combining the 'date' and 'time' columns
    df['date'] = pd.to_datetime(df['date'], format='%y%m%d', utc=False)
    df['date_time'] = df['date'] + pd.to_timedelta(df['time'].dt.tz_convert('US/Central').dt.strftime('%H:%M:%S'))
    # Add time delta to 'date_time' column to handle duplicate values
    df['date_time'] += pd.to_timedelta(df.groupby('date_time').cumcount(), unit='m')

    # Set the index of the DataFrame to the 'date_time' column and rename it to 'datetime'
    df.index = df['date_time']
    df.index.name = 'datetime'

    # Drop unnecessary columns and sort the DataFrame by index
    df = df.drop(['date','time','date_time'],axis=1).sort_index()

    # Localize the index to the given timezone
    df.index= df.index.tz_localize('US/Central')

    return df


In [None]:
def check_for_skipped_proccessed_spectrometer_files(up_text_files, path):
    """
    This function takes in a list of text files and a path to a directory containing CSV files.
    It processes the text files and CSV files to create several lists of file paths.

    :param up_text_files: List of text files to process
    :param path: Path to a directory containing CSV files
    :return: A tuple containing several lists of file paths
    """

    # Filter the text files based on their size and whether they contain certain substrings
    files_over_100 = [f for f in up_text_files if 'nightlog' not in f and '/F' not in f and os.path.getsize(f) > (1024*100) if f]

    # Extract the date and filename from the filtered text files
    raw_files = [f.rsplit('/',2)[-2] + "/" + f.rsplit('/',1)[-1].rsplit('.')[0] for f in files_over_100]

    # Get a list of all CSV files in the given path
    text_files = glob.glob(path + "/**/*.csv", recursive = True)

    # Filter the CSV files based on whether they contain certain substrings
    processed_files = [f for f in text_files if 'Incoming' in f and 'FLUO' in f ]

    # Extract the date and filename from the filtered CSV files
    proccesed_files = [f.rsplit('/',2)[-2] + "/" + f.rsplit('_',1)[-1].split('.')[0] for f in processed_files]

    # Find the raw files that were not processed
    skipped_files = [f for f in raw_files if f not in proccesed_files]

    # Find the paths of the skipped files
    skipped_paths = [f for f in files_over_100 for f2 in skipped_files if f2 in f]
    print(len(skipped_paths),'skipped files')
    if len(skipped_paths) > 0:
      return skipped_paths
    else:
      return

In [None]:
drop_indexs_FLUO = [4, 6, 8, 10, 12, 14, 16, 18, 20, 23, 25, 27, 29, 31, 33, 35, 37, 39, 41, 43, 45, 47, 49, 51]
drop_indexs_FULL = [4, 6, 8, 10, 12, 15, 17, 19, 21, 23, 25, 27, 29, 31, 33]
drop_indexs_NL = [2, 4, 6, 8, 10, 12]

rename_dict_FLUO = {5: 'IT_WR[us]=',
 7: 'IT_VEG[us]=',
 9: 'cycle_duration[ms]=',
 11: 'QEpro_Frame[C]=',
 13: 'QEpro_CCD[C]=',
 15: 'chamber_temp[C]=',
 17: 'chamber_humidity=',
 19: 'mainboard_temp[C]=',
 21: 'mainboard_humidity=',
 24: 'GPS_TIME_UTC=',
 26: 'GPS_date=',
 28: 'GPS_lat=',
 30: 'GPS_lon=',
 32: 'GPS_alt=',
 34: 'GPS_prec=',
 36: 'voltage=',
 38: 'gps_CPU=',
 40: 'wr_CPU=',
 42: 'veg_CPU=',
 44: 'cooling_active=',
 46: 'heating_active=',
 48: 'Temp0',
 50: 'Temp1',
 52: 'Temp2',
 0: 'Cycle Number',
 1: 'Date',
 2: 'Time',
 3: 'Mode'}

rename_dict_FULL = {5: 'IT_WR[us]=',
 7: 'IT_VEG[us]=',
 9: 'cycle_duration[ms]=',
 11: 'mainboard_temp[C]=',
 13: 'mainboard_humidity=',
 16: 'GPS_TIME_UTC=',
 18: 'GPS_date=',
 20: 'GPS_lat=',
 22: 'GPS_lon=',
 24: 'GPS_alt=',
 26: 'GPS_prec=',
 28: 'voltage=',
 30: 'gps_CPU=',
 32: 'wr_CPU=',
 34: 'veg_CPU=',
 0: 'Cycle Number',
 1: 'Date',
 2: 'Time',
 3: 'Mode'}

rename_dict_NL = {3: 'mainboard_temp[C]=',
 5: 'chamber_temp[C]=',
 7: 'mainboard_humidity=',
 9: 'chamber_humidity=',
 11: 'voltage=',
 13: 'On at=',
 0: 'Date',
 1: 'Time',}

rename_dicts_dict = {'FLUO':rename_dict_FLUO,
                'FULL':rename_dict_FULL,
                'nightlog':rename_dict_NL}
drop_index_dict = {'FLUO':drop_indexs_FLUO,
                'FULL':drop_indexs_FULL,
                'nightlog':drop_indexs_NL}

In [None]:
def process_raw_spectrometer_sensor_data(text_files, file_type):
    """
    Processes raw sensor data from spectrometer files.

    :param text_files: List of text files containing raw sensor data
    :param file_type: Type of file to process. Can be 'FLUO', 'FULL', or 'nightlog'
    :return: DataFrame containing processed sensor data
    """
    # Filter files based on file type
    if file_type == 'FLUO':
        files = [f for f in text_files if 'nightlog' not in f and '/F' not in f]
    elif file_type == 'FULL':
        files = [f for f in text_files if 'nightlog' not in f and '/F' in f]
    elif file_type == 'nightlog':
        files = [f for f in text_files if 'nightlog' in f]

    # Read and concatenate data from all files
    df = pd.concat([
        pd.read_csv(f, header=None).assign(Filename=f.split('/')[-2] + '/' + f.split('/')[-1])
        for f in files
    ])

    # Filter rows and split data into columns
    df = df[~df.iloc[:, 0].str.startswith(('W', 'D', 'V'))]
    df.loc[:, list(range(54))] = df[0].str.split(';', expand=True)

    # Rename and drop columns based on file type
    df = df.rename(columns=rename_dicts_dict[file_type]).drop(columns=drop_index_dict[file_type])

    # Drop rows containing the pattern
    pattern = r'[=]'
    mask = df.stack().str.contains(pattern).groupby(level=0).any()
    df = df[~mask]

    # Convert time column to datetime object
    df['Time'] = pd.to_datetime(
        df.Time,
        format='%H%M%S',
        utc=False,
        exact=False,
        infer_datetime_format=False,
    )

    # Combine date and time columns into a single datetime column
    df['Date'] = pd.to_datetime(df['Date'], format='%y%m%d', utc=False)
    df['date_time'] = df['Date'] + pd.to_timedelta(
        df['Time'].dt.tz_localize('US/Central').dt.strftime('%H:%M:%S')
    )

    # Add time delta to handle duplicate values
    df['date_time'] += pd.to_timedelta(df.groupby('date_time').cumcount(), unit='m')

    # Set the index to the datetime column and localize to the given timezone
    df.index = df['date_time']
    df.index.name = 'datetime'
    df.index = df.index.tz_localize('US/Central')

    # Drop unnecessary columns and sort the DataFrame by index
    df = df.drop(['Date', 'Time', 'date_time'], axis=1).sort_index()

    return df

In [None]:
def convert_seconds_to_labels(seconds):
    """
    Converts a list of seconds into a list of time labels in seconds, minutes, hours, or days.

    Parameters:
    seconds (list): The list of seconds to convert.

    Returns:
    list: A list of time labels.
    """
    time_labels = []
    for time_in_seconds in seconds:
        if time_in_seconds < 60:
            time_labels.append(f"{time_in_seconds} s")
        elif 60 <= time_in_seconds < 3600:
            time_labels.append(f"{np.round(time_in_seconds / 60)} mins")
        elif 3600 <= time_in_seconds < 86400:
            time_labels.append(f"{np.round(time_in_seconds / 3600)} hours")
        elif time_in_seconds >= 86400:
            time_labels.append(f"{np.round(time_in_seconds / 86400)} days")
    return time_labels


In [None]:
def get_peak_info(row):
    """
    This function takes in a row of data and returns the Full Width at Half Maximum (FWHM), peak wavelength, and spectral resolution.

    Parameters:
        row (pd.Series): A row of data containing the values and index.

    Returns:
        list: A list containing the FWHM, peak wavelength, and spectral resolution.
    """
    # Flatten the values and negate them
    x = -row.values.flatten()

    # Get the index values
    y = row.index.values

    # Find the peaks
    peaks, _ = find_peaks(x)

    # Get the index of the maximum peak prominence
    max_i = peak_prominences(x, peaks)[0].argmax()

    # Calculate the peak widths
    widths = peak_widths(x, peaks)

    # Calculate the spectral resolution
    spec_res = np.diff(y).mean()

    # Calculate the Full Width at Half Maximum (FWHM)
    FWHM = np.round(widths[0][max_i] * spec_res, 3)

    # Calculate the peak wavelength
    peak_wl = np.round(y[peaks[max_i]], 3)

    return [FWHM, peak_wl, spec_res]

In [None]:
def reclassify_values_to_colors(values, colors):
    """
    Reclassifies values into categories based on their range and assigns a color to each category.
    For rolling-correlation function.

    Parameters:
    values (np.array): The array of values to reclassify.
    colors (list): The list of colors to assign to each category.

    Returns:
    np.array: An array of the same shape as values with the reclassified colors.
    """
    # Create an array of the same shape as values filled with the default color
    result = np.full(values.shape, '#ffffff')

    # Reclassify the values based on their range and assign the corresponding color
    result[values < -0.6] = colors[0]
    result[(values >= -0.6) & (values < -0.2)] = colors[1]
    result[(values >= -0.2) & (values < 0.2)] = colors[2]
    result[(values >= 0.2) & (values < 0.6)] = colors[3]

    return result


In [None]:
def fill_gaps_within_days(dataframe, timezone='America/Chicago'):
    """
    Fills gaps within days in a DataFrame.

    Parameters:
    dataframe (pd.DataFrame): The DataFrame to process.
    timezone (str): The timezone for the data. Defaults to 'America/Chicago'.

    Returns:
    pd.DataFrame: A DataFrame with gaps filled within days.
    """
    # Copy the DataFrame
    resampled_df = dataframe.copy()

    # Change the index to time and date
    resampled_df.index = [resampled_df.index.time, resampled_df.index.date]

    # Unstack the DataFrame, interpolate missing values, and stack it back
    unstacked_df = resampled_df.unstack().interpolate(limit_area='inside').stack(dropna=False)

    # Convert the multi-index back to a single datetime index
    unstacked_df.index = pd.to_datetime(unstacked_df.index.get_level_values(0).astype(str) + ' ' + unstacked_df.index.get_level_values(1).astype(str))

    # Make the datetime index timezone aware
    unstacked_df.index = unstacked_df.index.tz_localize(timezone)

    return unstacked_df.sort_index()


In [None]:
def split_timeseries_into_days(dataframe):
    """
    Splits a time series DataFrame into a DataFrame where time is the index and dates are the columns.

    Parameters:
    dataframe (pd.DataFrame): The input DataFrame.

    Returns:
    pd.DataFrame: The processed DataFrame with time as the index and dates as the columns.
    """
    # Copy the input DataFrame
    resampled_df = dataframe.copy()

    # Change the index to time and date
    resampled_df.index = [resampled_df.index.time, resampled_df.index.date]

    # Unstack the DataFrame (this pivots the DataFrame from long to wide format)
    unstacked_df = resampled_df.unstack()

    return unstacked_df


In [None]:
def analyze_regular_data_chunks(dataframe, min_length=40):
    """
    Analyzes regular data chunks in a DataFrame and prints information about them.

    Parameters:
    dataframe (pd.DataFrame): The DataFrame to analyze.
    min_length (int): The minimum length of a chunk to consider. Defaults to 40.

    Returns:
    None
    """
    # Sort the DataFrame by index
    dataframe = dataframe.sort_index()

    # Create a boolean series that is True where your DataFrame has NaN values
    isna = dataframe.isna().any(axis=1)

    # Create groups for each continuous chunk of data
    groups = isna.ne(isna.shift()).cumsum()

    chunks = []
    for _, chunk in dataframe.groupby(groups):
        if not chunk.isna().any().any() and len(chunk) >= min_length:
            # Process the chunk here and append it to the list
            chunks.append(chunk)

    # Calculate the average length of the chunks
    avg_length = sum(len(chunk) for chunk in chunks) / len(chunks)
    min_length = min(len(chunk) for chunk in chunks)
    max_length = max(len(chunk) for chunk in chunks)

    # Print the statistics
    print(f"Average length of the chunks: {avg_length}")
    print(f"Smallest chunk length: {min_length}")
    print(f"Largest chunk length: {max_length}")
    print(f"Total number of chunks: {len(chunks)}")


In [None]:
def get_clipped_values(series, lower_quantile=0.05, upper_quantile=0.95):
    """
    Returns the indices and values of the clipped data points for a given series.

    Parameters:
    series (pd.Series): The input data series.
    lower_quantile (float): The lower quantile for clipping. Defaults to 0.05.
    upper_quantile (float): The upper quantile for clipping. Defaults to 0.95.

    Returns:
    list: A list of two tuples, each containing the indices and values of the clipped data points below and above the quantiles.
    """
    # Get the lower and upper clipping thresholds
    lower_threshold = series.quantile(lower_quantile)
    upper_threshold = series.quantile(upper_quantile)

    # Get the indices and values of the clipped data points below and above the thresholds
    clipped_below = (series.index[series < lower_threshold], series[series < lower_threshold])
    clipped_above = (series.index[series > upper_threshold], series[series > upper_threshold])

    # Return a list of tuples
    return [clipped_below, clipped_above]

###Wavelet Coherence

In [None]:
def wct(
    y1,
    y2,
    dt,
    dj=1 / 12,
    s0=-1,
    J=-1,
    sig=True,
    significance_level=0.95,
    wavelet_type="morlet",
    normalize=True,
    partial = True,
    **kwargs,
):
    """Wavelet coherence transform (WCT).

    The WCT finds regions in time frequency space where the two time
    series co-vary, but do not necessarily have high power.

    Adapted from PYCWT pacakge.

    Parameters
    ----------
    y1, y2 : numpy.ndarray, list
        Input signals.
    dt : float
        Sample spacing.
    dj : float, optional
        Spacing between discrete scales. Default value is 1/12.
        Smaller values will result in better scale resolution, but
        slower calculation and plot.
    s0 : float, optional
        Smallest scale of the wavelet. Default value is 2*dt.
    J : float, optional
        Number of scales less one. Scales range from s0 up to
        s0 * 2**(J * dj), which gives a total of (J + 1) scales.
        Default is J = (log2(N*dt/so))/dj.
    sig : bool
        set to compute signficance, default is True
    significance_level (float, optional) :
        Significance level to use. Default is 0.95.
    normalize (boolean, optional) :
        If set to true, normalizes CWT by the standard deviation of
        the signals.

    Returns
    -------
    WCT : magnitude of coherence
    aWCT : phase angle of coherence
    coi (array like):
        Cone of influence, which is a vector of N points containing
        the maximum Fourier period of useful information at that
        particular time. Periods greater than those are subject to
        edge effects.
    freq (array like):
        Vector of Fourier equivalent frequencies (in 1 / time units)    coi :
    sig :  Significance levels as a function of scale
       if sig=True when called, otherwise zero.

    See also
    --------
    cwt, xwt

    """
    wavelet_type = _check_parameter_wavelet(wavelet_type)

    # Checking some input parameters
    if s0 == -1:
        # Number of scales
        s0 = 2 * dt / wavelet_type.flambda()
    if J == -1:
        # Number of scales
        J = int(np.round(np.log2(y1.size * dt / s0) / dj))

    # Makes sure input signals are numpy arrays.
    y1 = np.asarray(y1)
    y2 = np.asarray(y2)
    # Calculates the standard deviation of both input signals.
    std1 = y1.std()
    std2 = y2.std()
    # Normalizes both signals, if appropriate.
    if normalize:
        y1_normal = (y1 - y1.mean()) / std1
        y2_normal = (y2 - y2.mean()) / std2
    else:
        y1_normal = y1
        y2_normal = y2

    # Calculates the CWT of the time-series making sure the same parameters
    # are used in both calculations.
    _kwargs = dict(dj=dj, s0=s0, J=J, wavelet=wavelet_type)
    W1, sj, freq, coi, _, _ = wavelet.cwt(y1_normal, dt, **_kwargs)
    W2, sj, freq, coi, _, _ = wavelet.cwt(y2_normal, dt, **_kwargs)

    scales1 = np.ones([1, y1.size]) * sj[:, None]
    scales2 = np.ones([1, y2.size]) * sj[:, None]

    # Smooth the wavelet spectra before truncating.
    S1 = wavelet_type.smooth(np.abs(W1) ** 2 / scales1, dt, dj, sj)
    S2 = wavelet_type.smooth(np.abs(W2) ** 2 / scales2, dt, dj, sj)

    # Now the wavelet transform coherence
    W12 = W1 * W2.conj()
    scales = np.ones([1, y1.size]) * sj[:, None]
    S12 = wavelet_type.smooth(W12 / scales, dt, dj, sj)
    R = S12 / np.sqrt(S1 * S2)
    R2 = np.abs(R) ** 2
    WCT = R2
    aWCT = np.angle(W12)

    # Calculates the significance using Monte Carlo simulations with 95%
    # confidence as a function of scale.
    if sig:
        a1, b1, c1 = ar1(y1)
        a2, b2, c2 = ar1(y2)

        sig = wct_significance(
            a1,
            a2,
            dt=dt,
            dj=dj,
            s0=s0,
            J=J,
            significance_level=significance_level,
            wavelet_type=wavelet_type,
            **kwargs,
        )
    else:
        sig = np.asarray([0])

    return WCT, aWCT, coi, freq, sig,R

In [None]:
def calc_WCT_time_lag(frequencies: np.ndarray, cone_of_influence: np.ndarray, significance: np.ndarray, wavelet_coherence_angle: np.ndarray) -> Tuple[pd.DataFrame, pd.DataFrame]:
    """
    Calculate the time lag and phase angle based on the given parameters.

    Parameters:
    frequencies (np.ndarray): Frequency array.
    cone_of_influence (np.ndarray): Cone of influence array.
    significance (np.ndarray): Significance array.
    wavelet_coherence_angle (np.ndarray): Cross wavelet transform array.

    Returns:
    time_lag_df (pd.DataFrame): DataFrame containing various calculated values.
    phase_df (pd.DataFrame): DataFrame containing phase values.

    References:
    - https://sites.google.com/a/glaciology.net/grinsted/wavelet-coherence/faq
    - Application of the cross wavelet transform and wavelet coherence to geophysical time series
    """
    # Calculate the inside of the cone of influence and periods in hours
    inside_coi = np.outer(1/frequencies, 1./cone_of_influence) > 1
    periods_hours = 1/frequencies / (60*60)

    # Calculate the phase where significance is greater than or equal to 1
    phase = np.where(significance>=1,wavelet_coherence_angle,np.nan)
    phase_in_coi = np.where(~inside_coi,phase,np.nan)

    # Create a DataFrame for the phase and drop rows with all NaN values
    phase_df = pd.DataFrame(phase_in_coi)
    phase_df.index = periods_hours
    phase_df.index.name = 'period (hours)'
    phase_df.dropna(how='all',axis=0,inplace=True)

    # Calculate the circular mean and standard deviation
    X = np.cos(phase_df).sum(axis=1)
    Y = np.sin(phase_df).sum(axis=1)
    R = np.sqrt(X**2+Y**2)
    circular_mean = np.arctan2(Y,X)*180/np.pi
    count = phase_df.count(axis=1)
    circular_std = np.sqrt(-2 * np.log(R/count))*180/np.pi

    # Calculate the mean and standard deviation of the time lag
    time_lag_mean = (circular_mean/360 * phase_df.index).rename('mean_time_lag')
    time_lag_std = (circular_std/360 * phase_df.index).rename('std_time_lag')

    # Create a DataFrame for the time lag
    time_lag_df = pd.concat([time_lag_mean,
                             time_lag_std,
                             circular_mean.rename('mean_angle'),
                             circular_std.rename('std_angle'),
                             count.rename('num_data_points')],axis=1)

    # Adjust phase values in phase_df
    phase_df[phase_df < 0] += 2 * np.pi

    # Calculate the circular mean and standard deviation for all data points in phase_df
    X = np.cos(phase_df).sum().sum()
    Y = np.sin(phase_df).sum().sum()
    R = np.sqrt(X**2+Y**2)
    circular_mean_all = np.arctan2(Y,X)*180/np.pi
    count_all = count.sum()
    circular_std_all = np.sqrt(-2 * np.log(R/count_all))*180/np.pi

    print(f'average phase angle: {circular_mean_all} +- {circular_std_all}')

    return time_lag_df, phase_df




In [None]:
def calc_wave_coherence(wave1, wave2, dt, min_period=60, max_period=3600, sig=False,normalize=True, scales_per_octave=12):
    """
    Calculate wavelet coherence between wave1 and wave2 using pycwt.
    modified source:
    https://seankmartin.github.io/Claustrum_Experiment/html/bvmpc/lfp_coherence.html#bvmpc.lfp_coherence.calc_coherence
    Parameters
    ----------
    wave1 : np.ndarray
        The values of the first waveform.
    wave2 : np.ndarray
        The values of the second waveform.
    dt : float
        Time in seconds between data values.
    min_period : float
        minimum temporal scale to look at wave coherence, in seconds.
    max_period : float
        maximum temporal scale to look at wave coherence, in seconds.
    sig : bool, default False
        Optional Should significance of waveform coherence be calculated.
    scales_per_octave : int
        How many scale subdivions between each octave. Octave: each doubling of
        min_period

    Returns
    -------
    WCT, t, freq, coi, sig, aWCT
        WCT - 2D numpy array with coherence values
        t - 2D numpy array with sample_times
        freq - 1D numpy array with the frequencies wavelets were calculated at
        coi - 1D numpy array with a frequency value for each time
        sig - 2D numpy array indicating where data is significant by monte carlo
        aWCT - 2D numpy array with same shape as aWCT indicating phase angles
    """
    # Define the wavelet and sampling period
    wavelet_type = _check_parameter_wavelet("morlet")

    n = len(wave1)
    if min_period < (2 * dt):
        min_period = 2 * dt
    min_scale = min_period/wavelet_type.flambda()
    max_scale = max_period/wavelet_type.flambda()
    if max_period == -1:
      J = -1
    else:
      num_octaves = np.log2(max_scale / min_scale)
      num_scales = math.ceil(np.log2(max_scale / min_scale) * scales_per_octave)
      J = num_scales #- 1
      print('number of scales:',num_scales)
    dj = 1.0/scales_per_octave
    # Do the actual calculation
    print("Calculating coherence...")
    start_time = time.time()
    WCT, aWCT, coi, freq, sig,R = wct(
        wave1, wave2, dt,  # Fixed params
        dj=dj, s0=min_scale, J=J, sig=sig, normalize=True)
    if len(sig) != 1:
      sig = WCT/sig[:, None]
    print("Time Taken: %s s" % (time.time() - start_time))
    if np.max(WCT) > 1 or np.min(WCT) < 0:
        print('WCT was out of range: min {},max {}'.format(
            np.min(WCT), np.max(WCT)))
        WCT = np.clip(WCT, 0, 1)

    return WCT, freq, coi, sig, aWCT,R

In [None]:
def calc_partial_wave_coherence(x, y, z,dt, min_period=60,
                                max_period=3600, sig=False,normalize=True,
                                scales_per_octave=12,significance_level=.95):
    """
    Calculate wavelet coherence between wave1 and wave2 using pycwt.
    modified source:
    https://seankmartin.github.io/Claustrum_Experiment/html/bvmpc/lfp_coherence.html#bvmpc.lfp_coherence.calc_coherence
    Parameters
    ----------
    wave1 : np.ndarray
        The values of the first waveform.
    wave2 : np.ndarray
        The values of the second waveform.
    dt : float
        Time in seconds between data values.
    min_period : float
        minimum temporal scale to look at wave coherence, in seconds.
    max_period : float
        maximum temporal scale to look at wave coherence, in seconds.
    sig : bool, default False
        Optional Should significance of waveform coherence be calculated.
    scales_per_octave : int
        How many scale subdivions between each octave. Octave: each doubling of
        min_period

    Returns
    -------
    WCT, t, freq, coi, sig, aWCT
        WCT - 2D numpy array with coherence values
        t - 2D numpy array with sample_times
        freq - 1D numpy array with the frequencies wavelets were calculated at
        coi - 1D numpy array with a frequency value for each time
        sig - 2D numpy array indicating where data is significant by monte carlo
        aWCT - 2D numpy array with same shape as aWCT indicating phase angles
    """
    # Define the wavelet and sampling period
    wavelet_type = _check_parameter_wavelet("morlet")

    n = len(x)
    if min_period < (2 * dt):
        min_period = 2 * dt
    min_scale = min_period/wavelet_type.flambda()
    max_scale = max_period/wavelet_type.flambda()
    if max_period == -1:
      J = -1
    else:
      num_octaves = np.log2(max_scale / min_scale)
      num_scales = math.ceil(np.log2(max_scale / min_scale) * scales_per_octave)
      J = num_scales #- 1
      print('number of scales:',num_scales)
    dj = 1.0/scales_per_octave
    # Do the actual calculation
    print("Calculating partial coherence...")
    start_time = time.time()

    _, _, coi, freq, _,Rxy = wct(
        x, y, dt,  # Fixed params
        dj=dj, s0=min_scale, J=J, sig=False, normalize=True)
    print('Rxy done')
    _, _, _, _, _,Rxz = wct(
        x, z, dt,  # Fixed params
        dj=dj, s0=min_scale, J=J, sig=False, normalize=True)
    print('Rxz done')
    _, _, _, _, _,Ryz = wct(
        y, z, dt,  # Fixed params
        dj=dj, s0=min_scale, J=J, sig=False, normalize=True)
    print('Ryz done')

    RPxyz = ( Rxy - Rxz * Ryz.conj() ) / np.sqrt( 1 - abs(Rxz)**2) / np.sqrt( 1 - abs(Ryz)**2)
    pWCT = abs(RPxyz)**2
    apWCT = np.angle(RPxyz)



    if sig:
      alx,_,_ = ar1(x)
      aly,_,_ = ar1(y)

      sig_scales = wct_significance(alx,aly,
                                           dt=dt,dj=dj,s0=min_scale,J=J,
            significance_level=significance_level,
            wavelet_type=wavelet_type,)
      print('signifigance finished')
      sig_matrix = pWCT/sig_scales[:, None]
    else:
      sig_matrix = [0]
    print("Time Taken: %s s" % (time.time() - start_time))

    if np.max(pWCT) > 1 or np.min(pWCT) < 0:
        print('pWCT was out of range: min {},max {}'.format(
            np.min(pWCT), np.max(pWCT)))
        pWCT = np.clip(pWCT, 0, 1)

    return pWCT, freq, coi, sig_matrix, apWCT

In [None]:
def coherence_to_df(WCT, freq, coi, sig, aWCT, index):
    # Convert freq to period
    period = 1 / freq

    # Create a DataFrame with timestamp as index and period as columns for WCT and aWCT
    WCT_df = pd.DataFrame(WCT.T, index=index, columns=period)
    aWCT_df = pd.DataFrame(aWCT.T, index=index, columns=period)

    # For coi and sig, create a DataFrame with timestamp as index
    coi_df = pd.DataFrame(coi, index=index)

    # For sig, create a DataFrame with timestamp as index and period as columns if it's not [0]
    if (not (np.array(sig) == 0).all()) and (sig is not None):
        sig_df = pd.DataFrame(sig.T, index=index, columns=period)
    else:
        sig_df = None

    return WCT_df, aWCT_df, coi_df, sig_df

In [None]:
def process_scale_for_wc_mc(args):
    s, R2, nbins = args
    cd = np.floor(np.clip(R2[s,:], 0, 1 - np.finfo(float).eps) * nbins).astype(int)
    wlc_local_s = np.zeros(nbins)
    np.add.at(wlc_local_s[:], cd[~cd.mask], 1)
    return wlc_local_s

In [None]:
def wct_monte_carlo_simulation(args):
    N, al1, al2, dt, dj, s0, J, wavelet_type, nbins, maxscale, outsidecoi,scales = args

    # Generates two red-noise signals with lag-1 autoregressive
    # coefficients given by al1 and al2
    noise1 = wavelet.rednoise(N, al1, 1)
    noise2 = wavelet.rednoise(N, al2, 1)

    # Calculate the cross wavelet transform of both red-noise signals
    kwargs = dict(dt=dt, dj=dj, s0=s0, J=J, wavelet=wavelet_type)
    nW1, sj, freq, coi, _, _ = wavelet.cwt(noise1, **kwargs)
    nW2 = wavelet.cwt(noise2, **kwargs)[0]

    # Delete noise variables as they are no longer needed
    del noise1
    del noise2

    nW12 = nW1 * nW2.conj()

    # Smooth wavelet wavelet transforms and calculate wavelet coherence
    # between both signals.
    S1 = wavelet_type.smooth(np.abs(nW1) ** 2 / scales, dt, dj, sj)
    S2 = wavelet_type.smooth(np.abs(nW2) ** 2 / scales, dt, dj, sj)
    S12 = wavelet_type.smooth(nW12 / scales, dt, dj, sj)

    # Delete nW variables as they are no longer needed
    del nW1
    del nW2

    R2 = np.ma.array(np.abs(S12) ** 2 / (S1 * S2), mask=~outsidecoi)

    wlc_local = np.ma.zeros([J + 1,nbins])

    # Walks through each scale outside the cone of influence and builds a
    # coherence coefficient counter.

    # Run the outer loop in parallel using multiple processes
    with ThreadPoolExecutor() as executor:
        # Split the range into batches
        batch_size = int(maxscale/300)+1  # Adjust this value based on your memory constraints
        batches = [range(i,i+batch_size) for i in range(0,maxscale,batch_size)]

        for batch in batches:
            results = list(executor.map(process_scale_for_wc_mc,
                                        [(s,R2,nbins) for s in batch]))
            for s in range(len(batch)):
                wlc_local[batch[s],:] += results[s]

            # Delete the results list to free up memory
            del results

    gc.collect()
    return wlc_local

In [None]:
def _check_parameter_wavelet(wavelet):
    mothers = {"morlet": Morlet, "paul": Paul, "dog": DOG, "mexicanhat": MexicanHat}
    # Checks if input parameter is a string. For backwards
    # compatibility with Python 2 we check either if instance is a
    # `basestring` or a `str`.
    try:
        if isinstance(wavelet, basestring):
            return mothers[wavelet]()
    except NameError:
        if isinstance(wavelet, str):
            return mothers[wavelet]()
    # Otherwise, return itself.
    return wavelet

In [None]:
def wct_significance(
    al1,
    al2,
    dt,
    dj,
    s0,
    J,
    significance_level=0.95,
    wavelet_type="morlet",
    mc_count=300,
    progress=True,
    cache=True,
):
    """Wavelet coherence transform significance.

    Calculates WCT significance using Monte Carlo simulations with
    95% confidence.

    Parameters
    ----------
    al1, al2: float
        Lag-1 autoregressive coeficients of both time series.
    dt : float
        Sample spacing.
    dj : float, optional
        Spacing between discrete scales. Default value is 1/12.
        Smaller values will result in better scale resolution, but
        slower calculation and plot.
    s0 : float, optional
        Smallest scale of the wavelet. Default value is 2*dt.
    J : float, optional
        Number of scales less one. Scales range from s0 up to
        s0 * 2**(J * dj), which gives a total of (J + 1) scales.
        Default is J = (log2(N*dt/so))/dj.
    significance_level : float, optional
        Significance level to use. Default is 0.95.
    wavelet : instance of a wavelet class, optional
        Mother wavelet class. Default is Morlet wavelet.
    mc_count : integer, optional
        Number of Monte Carlo simulations. Default is 300.
    progress : bool, optional
        If `True` (default), shows progress bar on screen.
    cache : bool, optional
        If `True` (default) saves cache to file.

    Returns
    -------
    TODO

    """
    wavelet_type = _check_parameter_wavelet(wavelet_type)
    if cache:
        # Load cache if previously calculated. It is assumed that wavelet
        # analysis is performed using the wavelet's default parameters.
        aa = np.round(np.arctanh(np.array([al1, al2]) * 4))
        aa = np.abs(aa) + 0.5 * (aa < 0)
        cache_file = "wct_sig_{:0.5f}_{:0.5f}_{:0.5f}_{:0.5f}_{:d}_{}".format(
            aa[0], aa[1], dj, s0 / dt, J, wavelet_type.name
        )
        cache_dir = get_cache_dir()
        try:
            dat = np.loadtxt("{}/{}.gz".format(cache_dir, cache_file), unpack=True)
            print("NOTE: WCT significance loaded from cache.\n")
            return dat
        except IOError:
            pass

    # Some output to the screen
    print("Calculating wavelet coherence significance")

    # Choose N so that largest scale has at least some part outside the COI
    ms = s0 * (2 ** (J * dj)) / dt
    N = int(np.ceil(ms * 6))
    noise1 = rednoise(N, al1, 1)
    nW1, sj, freq, coi, _, _ = wavelet.cwt(noise1, dt=dt, dj=dj, s0=s0, J=J, wavelet=wavelet_type)

    period = np.ones([1, N]) / freq[:, None]
    coi = np.ones([J + 1, 1]) * coi[None, :]
    outsidecoi = period <= coi
    scales = np.ones([1, N]) * sj[:, None]
    sig95 = np.zeros(J + 1)
    maxscale = find(outsidecoi.any(axis=1))[-1]
    sig95[outsidecoi.any(axis=1)] = np.nan

    nbins = 1000
    wlc = np.ma.zeros([J + 1, nbins])
    # Displays progress bar with tqdm

    # Displays progress bar with tqdm
    with Pool(13,maxtasksperchild=1) as pool:
      args = (N, al1, al2, dt, dj, s0, J,wavelet_type,nbins,maxscale,outsidecoi,
              scales)
      results = list(tqdm(pool.imap(wct_monte_carlo_simulation,[args]*mc_count,
                                    chunksize=1),
                          total=mc_count))

    for result in results:
        wlc += result

    # After many, many, many Monte Carlo simulations, determine the
    # significance using the coherence coefficient counter percentile.
    wlc.mask = wlc.data == 0.0
    R2y = (np.arange(nbins) + 0.5) / nbins
    for s in range(maxscale):
        sel = ~wlc[s, :].mask
        P = wlc[s, sel].data.cumsum()
        P = (P - 0.5) / P[-1]
        sig95[s] = np.interp(significance_level, P, R2y[sel])

    if cache:
        # Save the results on cache to avoid to many computations in the future
        np.savetxt("{}/{}.gz".format(cache_dir, cache_file), sig95)

    # And returns the results
    return sig95


In [None]:
def wct_on_regular_chunks(df, wave1, wave2, dt=60, min_period=None, chunk_limit=40,length_limit=.25,calc_sig=False):
    # Create a boolean series that is True where your DataFrame has NaN values
    isna = df.isna().any(axis=1)

    # Create groups for each continuous chunk of data
    groups = isna.ne(isna.shift()).cumsum()

    WCT_results = []
    aWCT_results = []
    coi_results = []
    sig_results = []
    al1_list = []
    al2_list = []
    for _, chunk in df.groupby(groups):
        if not chunk.isna().any().any() and len(chunk) >= chunk_limit:
            # Calculate min_period and max_period based on dt and the length of the chunk
            min_period = min_period or 2 * dt
            max_period = len(chunk) * dt * length_limit

            # Run calc_wave_coherence on the chunk
            try:
              WCT, freq, coi, sig, aWCT, R = calc_wave_coherence(chunk[wave1], chunk[wave2], dt, min_period, max_period,sig=calc_sig)
            except:
              continue
            # Convert freq to period
            period = 1 / freq

            # Create a DataFrame with timestamp as index and period as columns for WCT and aWCT
            WCT_df = pd.DataFrame(WCT.T, index=chunk.index, columns=period)
            aWCT_df = pd.DataFrame(aWCT.T, index=chunk.index, columns=period)

            # For coi and sig, create a DataFrame with timestamp as index
            coi_df = pd.DataFrame(coi, index=chunk.index)
            try:
              al1_list.append(ar1(chunk[wave1])[0])
              al2_list.append(ar1(chunk[wave2])[0])
            except:
              pass

            # For sig, create a DataFrame with timestamp as index and period as columns if it's not [0]
            if not (np.array(sig) == 0).all():
                sig_df = pd.DataFrame(sig.T, index=chunk.index, columns=period)
            else:
                sig_df = None

            WCT_results.append(WCT_df)
            aWCT_results.append(aWCT_df)
            coi_results.append(coi_df)
            if sig_df is not None:
                sig_results.append(sig_df)

    # Combine all the results
    combined_WCT = pd.concat(WCT_results)
    combined_aWCT = pd.concat(aWCT_results)
    combined_coi = pd.concat(coi_results)
    combined_sig = pd.concat(sig_results) if sig_results else None
    print('wave one lag 1 coef avg:',np.mean(al1_list),' std dev:',np.std(al1_list))
    print('wave two lag 1 coef avg:',np.mean(al2_list),' std dev:',np.std(al2_list))
    return combined_WCT, combined_aWCT, combined_coi, combined_sig


##Plotting Functions

###General

In [None]:
def plot_SIF_vs_PRI(df, title, save_fig=None, filepath=None):
    """
    Plot SIF versus PRI scaled for a given dataframe.

    Args:
        df (pd.DataFrame): The input dataframe containing the SIF and PRI data to plot.
        title (str): The title of the plot.
        save_fig (str, optional): The file path to save the figure. If None, the figure is not saved. Defaults to None.
    """
    # Create a figure with multiple subplots
    fig, axes = plt.subplots(nrows=7, ncols=1, figsize=(10, 9), sharex=True,
                             gridspec_kw={'height_ratios': [7, 1, 1, 1, 1, 1, 1]},
                             squeeze=True)

    # Set the color cycle for the plots
    colors = plt.rcParams["axes.prop_cycle"]()

    # Plot the SIF_B_sfm data on the first subplot
    df['SIF_B_sfm [mW m-2nm-1sr-1]'].plot(
        legend=False,
        ylabel='Spectral Radiance [mW m-2 nm-1 sr-1]',
        title='SIF versus PRI scaled \n ' + title,
        ax=axes[0])

    # Plot the SIF_A_sfm data on the first subplot
    df['SIF_A_sfm [mW m-2nm-1sr-1]'].plot(
        legend=False,
        ylabel='Spectral Radiance [mW m-2 nm-1 sr-1]',
        ax=axes[0])

    # Plot the PRI_s data on the first subplot on the secondary y-axis
    df['PRI_s'].plot(
        secondary_y=True,
        c='purple',
        legend=False,
        mark_right=True,
        ylabel='PRI_s',
        ax=axes[0])

    # Add a legend to the first subplot
    lines = axes[0].get_lines() + axes[0].right_ax.get_lines()

    axes[0].legend(lines, [
                       'SIF_B_sfm',
                       'SIF_A_sfm',
                       'PRI_s (right)',
                       ],
                   loc='upper left')

    # Define the labels and y-axis limits for the remaining subplots
    var_labels_list = ['PAR inc', 'PAR ref', 'APAR ', 'temp', 'humidity', 'Clearness Index']
    y_labels_list = ['W m-2', 'W m-2', 'umol \n m-2 s-1', 'C', '%', '']
    tick_list = [[1, 450], [1, 55], [1, 1800], [10, 28], [1, 95], [0, .9]]

    # Plot the remaining subplots
    for i in range(len(axes[1:])):
        c = next(colors)["color"]

        df.iloc[:, i + 3].plot(ax=axes[i + 1],
                               c=c)

        axes[i + 1].set_ylabel(y_labels_list[i],
                               rotation=0)

        axes[i + 1].yaxis.set_label_coords(-.05, .5)

        axes[i + 1].annotate(var_labels_list[i],
                             xy=(0.5, 0.8),
                             xycoords='axes fraction',
                             fontsize=8,
                             ha="center")

        axes[i + 1].yaxis.tick_right()

        axes[i + 1].set(ylim=tick_list[i])

    # Adjust the spacing between the subplots
    plt.subplots_adjust(hspace=0.02)

    if save_fig is not None:
      if filepath is not None:
        fig.savefig(filepath+'SIF_vs_PRI_' + title.replace(" ", "_") + '.png')
      else:
        fig.savefig('SIF_vs_PRI_' + title.replace(" ", "_") + '.png')

In [None]:
def plot_rolling_corr(df, var1, var2, win_len, resample_len, method='quantile',q=.95,save_fig=None,filepath=None,ax=None):
    """
    Plot the rolling correlation between two variables in a dataframe.

    Args:
        df (pd.DataFrame): The input dataframe containing the two variables to plot.
        var1 (str): The name of the first variable to plot.
        var2 (str): The name of the second variable to plot.
        win_len (str): The window length for calculating the rolling correlation. Offset string.
        resample_len (str): The resampling frequency for calculating the 95th percentile. Offset string.
        save_fig (str, optional): The file path to save the figure. If None, the figure is not saved. Defaults to None.
    """

    if ax is None:
        fig, ax = plt.subplots(figsize=(11, 5))
    else:
        fig = None
    # Create a figure and axis

    # Extract the variable names and units
    try:
        var1_name, var1_units = re.split(r'(?=[[])', var1)
    except:
        var1_name = var1
        var1_units = var1
    try:
        var2_name, var2_units = re.split(r'(?=[[])', var2)
    except:
        var2_name = var2
        var2_units = var2


    # Resample the dataframe and calculate the 95th percentile
    # Calculate the rolling correlation between the two variables
    rolling_df = df[[var1, var2]].rolling(window=win_len, center=True)

    resampled_df = df[[var1, var2]].resample(resample_len)
    if method == 'mean':
      resampled_df = resampled_df.mean(numeric_only=True)
      rolling_df = rolling_df.mean(numeric_only=True)
    elif method == 'median':
      resampled_df = resampled_df.median(numeric_only=True)
      rolling_df = rolling_df.median(numeric_only=True)
    elif method == 'quantile':
      resampled_df = resampled_df.quantile(.95,numeric_only=True)
      rolling_df = rolling_df.quantile(.95,numeric_only=True)
      method = method + ' ' + str(q)
    corr_df = rolling_df[var2].rolling(window=win_len, center=True).corr(rolling_df[var1],numeric_only=True)
    resampled_df = rolling_df
    # Set the title of the plot
    title_str = var1_name.replace("_", " ") + ' vs. ' + var2_name.replace("_", " ") + ' (' + win_len + ' ' + method + ')'

    ax.set_title(title_str)

    # Plot the first line chart
    resampled_df[var1].plot(
        legend=False,
        ylabel=var1_units,
        secondary_y=False,
        ax=ax,
        c='#0F6003')

    # Plot the second line chart on the secondary y-axis
    resampled_df[var2].plot(
        c='#540360',
        legend=False,
        mark_right=True,
        ylabel=var2_units,
        secondary_y=True,
        ax=ax)
    ax.set_ylim(resampled_df[var1].min(), resampled_df[var1].max()) # set y-axis limits for left axis
    ax.right_ax.set_ylim(resampled_df[var2].min(), resampled_df[var2].max()) # set y-axis limits for right axis

    # Define the colors for each interval
    colors = ['#ca0020', '#f4a582', '#f7f7f7', '#92c5de', '#0571b0'][::-1]

    # Create a custom colormap from the list of colors
    cmap = matplotlib.colors.ListedColormap(colors)

    # Reclassify the correlation values into different color categories
    coor_colors = reclassify_values_to_colors(corr_df.values,colors)

    # Fill the area between the line and the x-axis with different colors based on the correlation values
    max_value = resampled_df.max().max()

    print("Filling in correlation background...")
    for color in colors:
      start_time = time.time()
      boolean_series = (coor_colors == color)

      # Use numpy's roll function to shift the array by one element
      shifted_series = np.roll(boolean_series, 1)

      # Replace the first element with False
      shifted_series[0] = False

      ax.fill_between(x=corr_df.index,
                    y1=0,
                    y2=max_value,
                    where=(boolean_series | shifted_series),
                    facecolor=color,
                    step='pre',
                    alpha=.5)
      print(color + " time Taken: %s s" % (time.time() - start_time))

    # Add a colorbar to the plot
    bounds = [-1, -0.6, -0.2, 0.2, 0.6, 1]

    norm = plt.Normalize(vmin=-1, vmax=1)

    sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm,)

    sm.set_array([])

    cbar = plt.colorbar(sm, boundaries=bounds, ticks=bounds, ax=plt.gca())

    cbar.set_label(win_len + ' Rolling Correlation')

    cbar.ax.set_position([0.9, 0.15, 0.05, 0.7])

    cbar.set_alpha(.5)

    cbar.draw_all()

    # Adjust the spacing of the plot
    lines = ax.get_lines() + ax.right_ax.get_lines()

    ax.legend(lines, [
                   var1_name.replace("_", " "),
                   var2_name.replace("_", " ")
                   ],
               loc='upper left')

    if save_fig is not None:
      if filepath is not None:
        fig.savefig(filepath+'rolling_corr_' + win_len + '-' + var1_name + '_' + var2_name + '_' + resample_len + '.png')
      else:
        fig.savefig('rolling_corr_' + win_len + '-' + var1_name + '_' + var2_name + '_' + resample_len + '.png')



In [None]:
def plot_absorption_spectrum(radiance, date, time, df_name, save_fig=False):
    """
    This function plots the absorption spectrum of the O2-A, O2-B and Ha bands for a given date and time.

    Parameters:
        radiance (pd.DataFrame): The dataframe containing the radiance values.
        date (str): The date to plot the spectrum for, in the format 'YYYY-MM-DD'.
        time (str): The time to plot the spectrum for, in the format 'HH:MM:SS'.
        df_name (str): The name of the input DataFrame.
        save_fig (bool): Whether to save the figure or not. Defaults to False.

    Returns:
        None
    """

    # Define the wavelength ranges and names for each band
    min_nm = ['653.','686.','754.']
    max_nm = ['659.','692.','770.']
    ticks = [[656.4],[687.0],[760.4]]
    absorbtion_names = ['Ha','O2-B','O2-A']

    # Create a figure and axes
    fig,ax = plt.subplots(1,3,sharey=True)
    fig.set_size_inches(8.5, 4.5)
    # Loop over each band
    for i in range(3):

        # Filter the radiance dataframe by wavelength and date
        radiance_df = radiance.T.loc[
            radiance.T.filter(like=min_nm[i],axis=0).index[0]: #starting wl
            radiance.T.filter(like=max_nm[i],axis=0).index[-1] #ending wl
        ]#.loc[:,date + ' ' + time:date + ' ' + time]          #date range
        date_radiance_df = radiance_df.loc[:,date + ' ' + time:date + ' ' + time]
        # Plot the radiance values
        date_radiance_df.plot(ax=ax[i],label=None,legend=False,color='darkgreen');

        # Negate and flatten the radiance values
        x = -date_radiance_df.values.flatten()

        # Get the wavelength values
        y = date_radiance_df.index.values

        # Find the peaks
        peaks, _ = find_peaks(x)

        # Get the index of the maximum peak prominence
        max_i = peak_prominences(x, peaks)[0].argmax()

        # Calculate the peak widths
        widths = peak_widths(x, peaks)

        # Calculate the spectral resolution
        spec_res = np.diff(y).mean()

        # Calculate the Full Width at Half Maximum (FWHM)
        FWHM = np.round(widths[0][max_i] * spec_res,3)

        # Calculate the peak wavelength
        peak_wl = np.round(y[peaks[max_i]],3)

        # Set the axes labels and title
        ax[i].set(xlabel='WL (nm)',ylabel='[W m-2 nm-1 sr-1]',
                  title=absorbtion_names[i] +
                  ', SSI: ' + str(np.round(spec_res,3)) +'nm',
                  xticks=[float(min_nm[i]),float(max_nm[i])],)

        # Add text annotations for FWHM and peak wavelength
        ax[i].text(0.5,0.02, s='FWHM: '+str(FWHM),fontsize=10,
                   transform=ax[i].transAxes)
        ax[i].text(0.45,0.08, s='peak: '+str(peak_wl),fontsize=10,
                   transform=ax[i].transAxes)

        # Plot a point at the peak wavelength and value
        ax[i].plot(peak_wl, -x[peaks[max_i]], 'ro', label='peak')

    # Add a suptitle with the name of the input DataFrame, date and time
    fig.suptitle(f'Absorption Spectrum for {df_name} on {date} at {time}')

    # Add a legend outside of all three plots
    handles, labels = ax[0].get_legend_handles_labels()
    fig.legend(handles=handles[1:], labels=labels[1:], loc='center right')

    plt.subplots_adjust(right=0.85)  # Adjust subplot to make room for legend

    # Save the figure if specified
    if save_fig:
        fig.savefig(f'absorption_spectrum_{df_name}_{date}_{time.replace(":", "-")}.png')

    return


In [None]:
def plot_daily_data_PAR(SIF_data_df, column_name, y_label, title, file_name):
    """
    Plots daily data for each day in a calendar-like format.

    Args:
        SIF_data_df (pd.DataFrame): The input data frame containing the data to plot.
        column_name (str): The name of the column in SIF_data_df to plot.
        y_label (str): The label for the y-axis.
        title (str): The title for the plot.
        file_name (str): The name of the file to save the plot to.
    """
    # Define a dictionary of colors for each rank of sunny days
    colors = {0: 'r', 1: 'b', 2: 'g',3:'purple',4:'orange'}

    # Extract the PAR daily data from the input data frame and interpolate missing values
    PAR_daily = SIF_data_df['PAR inc [W m-2]'].copy()
    PAR_daily.index = [PAR_daily.index.time, PAR_daily.index.date]
    PAR_daily = PAR_daily.unstack().interpolate()

    # Calculate the maximum PAR value and its string representation
    max_par = PAR_daily.sum().max()
    max_par_str = str(np.round(max_par/1000000,1))

    # Rank the sunny days based on the total incoming PAR
    sunny_rank = pd.cut(PAR_daily.sum(), bins=[500000,1500000,2500000,3500000,4500000,PAR_daily.sum().max()], labels=[4,3,2,1,0])

    # Extract the column data from the input data frame and interpolate missing values
    d1 = SIF_data_df[column_name].copy()
    d1.index = [d1.index.time, d1.index.date]
    d1 = d1.unstack().interpolate()

    # Determine whether to use the sharey option based on the variance between any day
    ranges = d1.apply(lambda col: col.max()-col.min())
    range_ratio = ranges.max() / ranges.min()
    sharey = range_ratio < 4

    # Find the index of the first Sunday before the first day of data
    first_date = pd.to_datetime(d1.columns[0])
    first_sunday_index = (first_date.weekday() - 6) % 7

    # Pad the data with blank days beginning on the first Sunday before the first day of data
    blank_days = pd.DataFrame(index=d1.index, columns=pd.date_range(first_date - pd.Timedelta(days=first_sunday_index), first_date - pd.Timedelta(days=1)))
    d1 = pd.concat([blank_days, d1], axis=1)

    # Calculate the number of rows needed for the subplots
    num_rows = int(np.ceil(len(d1.columns) / 7))

    # Plot the daily data in a calendar-like format using subplots
    axes = d1.plot(legend=0,
            subplots=True,
            layout=(num_rows,7),
            figsize=(10,(10/7)*num_rows),
            sharey=sharey,
            sharex=True,
            color=[colors[i] for i in sunny_rank],
                  );

    # Remove the x-ticks from all subplots
    plt.setp(axes,
              xticks=[],
              )

    # Define a dictionary of weekday names for each index
    weekday_map= {0:'SUN', 1:'MON', 2:'TUE', 3:'WED',
                  4:'THU', 5:'FRI', 6:'SAT'}

    # Set the x-labels for each subplot using the weekday names
    for i in range(7):
      for ax in axes[:,i]:
        ax.set_xlabel(weekday_map[i])

    # Create a list of day-of-month labels for each subplot
    day_of_month = pd.to_datetime(d1.columns).strftime('%b-%d').to_list() + [29,30,1,2,3,4,5]

    # Add text annotations with the day-of-month labels above each subplot
    for n, ax in enumerate(axes.flatten()):
        ax.text(.50, 1.02, str(day_of_month[n]), transform=ax.transAxes,
                size=9,
                )

        # Adjust the position of the y-tick labels to avoid overlapping with the subplot borders
        ax.tick_params(axis='y', direction="in",which='major', labelsize=8,
                       pad=-15,
                       )
        ax.tick_params(axis='y', direction="in",which='minor', labelsize=3,
                       )

        # Set the y-tick labels to use exponential notation if the range of y-values is large
        y_min, y_max = ax.get_ylim()
        if y_max - y_min > 100:
            ax.yaxis.set_major_formatter(ScalarFormatter(useMathText=True))
        else:
            # Define a custom formatter function to limit the number of characters in the y-tick labels
            def y_fmt(x, pos):
              s = f"{x:.2g}"
              if len(s) > 3:
                s = f"{x:.1g}"
              return s
            ax.yaxis.set_major_formatter(FuncFormatter(y_fmt))

    # Get the figure object from the last subplot
    fig = axes[-1,-1].get_figure()

    # Add a subtitle with the total incoming PAR for the day
    fig.suptitle('Total incoming PAR for the day (megawatt m-2)',y=0.05);

    # Add a title with the input title and the date range
    fig.text(x=0.50,y= 0.90,s=title+'\nMay 16th - June 28th', ha='center',color='black', fontsize=14)

    # Add a y-label for the whole figure
    fig.text(0.06, 0.5,y_label , ha='center', va='center', rotation='vertical')

    # Create a list of labels for the legend based on the sunny rank
    labels = [
        '.5 - 1.5',
        '1.5 - 2.5',
        '2.5 - 3.5',
        '3.5 - 4.5',
        '4.5 - ' + max_par_str,
              ]

    # Create a list of handles for the legend based on the colors
    handles = [plt.Rectangle((0,0),1,1,color=colors[i]) for i in [4,3,2,1,0]]

    # Add a legend below the figure with the labels and handles
    plt.figlegend(handles, labels,
                  loc = 'lower center',
                  ncol=5,
                  labelspacing=0.,
                 );

    # Adjust the position of the suptitle and legend to reduce spacing
    fig.subplots_adjust(bottom=0.1)

    # Save the figure to the input file name
    fig.savefig(file_name)


In [None]:
def plot_sunny_cloudy_pairs_rolling_corr(df, var1,var2,win_len,lower=None,upper=None):
    """
    This function plots the variable of interest for the day with the maximum gradient of 'CI'
    and the following day for each month.

    Args:
    df (pd.DataFrame): The input DataFrame. Must contain a 'CI' column and DateTimeIndex.
    var (str): The variable of interest to plot.

    Returns:
    None. The function will plot a grid of subplots.
    """

    # Calculate daily sum, gradient, and monthly gradients
    daily_median = df['CI'].resample('1D').median()
    gradient = daily_median.diff().abs()
    monthly_gradients = gradient.groupby(gradient.index.to_period('M'))

    # Find the day with the maximum gradient for each month
    max_gradient_day = monthly_gradients.idxmax()

    # Initialize an empty DataFrame to store the results
    result_df = pd.DataFrame()
    sunny_labels=[]
    for month, day in max_gradient_day.items():
      # The day with the maximum gradient
      day_data = df.loc[str(day.date())]
      result_df = pd.concat([result_df,day_data])

      # The next day
      next_day = day - pd.DateOffset(days=1)
      if str(next_day.date()) in df.index:  # Check if the next day exists in the original DataFrame
        next_day_data = df.loc[str(next_day.date())]
        result_df = pd.concat([result_df,next_day_data])
      label1 = 'Sunny day' if next_day_data['CI'].median() > day_data['CI'].median() else 'Cloudy day'
      label2 = 'Sunny day' if label1 == 'Cloudy day' else 'Cloudy day'
      sunny_labels.append(label1)
      sunny_labels.append(label2)

    d1 = result_df[[var1,var2,'PAR inc [W m-2]']].copy()
    d1.index = [d1.index.time, d1.index.date]
    d1 = d1.unstack().interpolate()
    num_plots = len(d1[var1].columns)# // 2
    fig, axs = plt.subplots(num_plots, figsize=(15, 30),sharex=True)
    day_of_month = pd.to_datetime(d1[var1].columns).strftime('%b-%d').to_list()# + [29,30,1,2,3,4,5]
    try:
        var1_name,_ = re.split(r'(?=[[])', var1)
    except:
        var1_name = var1
    try:
        var2_name, _ = re.split(r'(?=[[])', var2)
    except:
        var2_name = var2

    # Define the colors for each interval
    colors = ['#ca0020', '#f4a582', '#f7f7f7', '#92c5de', '#0571b0'][::-1]
    # Create a custom colormap from the list of colors
    cmap = matplotlib.colors.ListedColormap(colors)
    corr_df = df[var1].rolling(window=win_len, center=True).corr(df[var2],numeric_only=True,min_periods=30).copy()
    corr_df.index = [corr_df.index.time, corr_df.index.date]
    corr_df = corr_df.unstack().interpolate()

    for i in range(num_plots):
      var1_col = d1[var1].iloc[:,i]
      right_data = var1_col[(np.abs(stats.zscore(var1_col,nan_policy='omit')) <  3)]
      right_data =right_data.clip(lower=0)
      right_data.plot(
        legend=True,
        ylabel=var1_name,
        label=var1_name,
        ax=axs[i])
      var2_col = d1[var2].iloc[:,i]
      left_data = var2_col[(np.abs(stats.zscore(var2_col,nan_policy='omit')) <  3)]
      left_data =left_data.clip(lower=0)
      left_data.plot(
        secondary_y=True,
        c='purple',
        legend=True,
        mark_right=True,
        ylabel=var2_name,
        ax=axs[i],
        label=var2_name)
      axs[i].set_title(str(day_of_month[i])+' - '+sunny_labels[i])

      # Fill the area between the line and the x-axis with different colors based on the correlation values
      max_value = max(right_data.max(),left_data.max())

      # Reclassify the correlation values into different color categories
      corr_colors = reclassify_values_to_colors(corr_df.loc[:,var1_col.name].values,colors)
      for color in colors:
        boolean_series = (corr_colors == color)

        # Use numpy's roll function to shift the array by one element
        shifted_series = np.roll(boolean_series, 1)

        # Replace the first element with False
        shifted_series[0] = False
        axs[i].fill_between(x=corr_df.loc[:,var1_col.name].index,
                    y1=0,
                    y2=max_value,
                    where=(boolean_series | shifted_series),
                    facecolor=color,
                    step='post',
                    alpha=.5)



    plt.tight_layout()

    fig.suptitle('Comparing Sunny and Cloudy Dirnual Patterns Across the Corn Season \n2-minute average \n' + 'Variables: ' +var1_name + ' vs. ' + var2_name,
                 ha='left',x=0.02,y=1.02,fontsize=18,fontweight='bold')
    # Add a colorbar to the plot
    bounds = [-1, -0.6, -0.2, 0.2, 0.6, 1]

    norm = plt.Normalize(vmin=-1, vmax=1)

    sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm,)

    sm.set_array([])

    cbar = plt.colorbar(sm, boundaries=bounds, ticks=bounds, ax=axs[-1],
                        orientation='horizontal',
                        shrink=.01)

    cbar.set_label(win_len + ' Rolling Correlation',labelpad=-35,fontsize=14,fontweight='bold')

    cbar.ax.set_position([0.52, 1., 0.5, 0.01])  # Adjust these values as needed

    cbar.set_alpha(.5)

    cbar.draw_all()
    return

In [None]:
def plot_sunny_cloudy_pairs(df, var1,var2):
    """
    This function plots the variable of interest for the day with the maximum gradient of 'CI'
    and the following day for each month.

    Args:
    df (pd.DataFrame): The input DataFrame. Must contain a 'CI' column and DateTimeIndex.
    var (str): The variable of interest to plot.

    Returns:
    None. The function will plot a grid of subplots.
    """
    seasonal_df = df['PRI_s'].rolling('14D').quantile(.95,numeric_only=False)
    # Calculate daily sum, gradient, and monthly gradients
    daily_median = df['CI'].resample('1D').median()
    gradient = daily_median.diff().abs()
    max_gradient_day = gradient.resample('7W').agg(
    lambda x : np.nan if x.count() == 0 else x.idxmax()
    )

    # Initialize an empty DataFrame to store the results
    result_df = pd.DataFrame()
    sunny_labels=[]
    for month, day in max_gradient_day.items():
      # The day with the maximum gradient
      day_data = df.loc[str(day.date())]
      result_df = pd.concat([result_df,day_data])

      # The next day
      next_day = day - pd.DateOffset(days=1)
      if str(next_day.date()) in df.index:  # Check if the next day exists in the original DataFrame
        next_day_data = df.loc[str(next_day.date())]
        result_df = pd.concat([result_df,next_day_data])
      label1 = 'Sunny day' if next_day_data['CI'].median() > day_data['CI'].median() else 'Cloudy day'
      label2 = 'Sunny day' if label1 == 'Cloudy day' else 'Cloudy day'
      sunny_labels.append(label1)
      sunny_labels.append(label2)

    d1 = result_df[[var1,var2,'PAR inc [W m-2]']].copy()
    d1.index = [d1.index.time, d1.index.date]
    d1 = d1.unstack().interpolate()
    num_plots = len(d1[var1].columns) // 2

    fig = plt.figure(constrained_layout=True,figsize=(30,30))

    day_of_month = pd.to_datetime(d1[var1].columns).strftime('%b-%d').to_list()

    try:
        var1_name,_ = re.split(r'(?=[[])', var1)
    except:
        var1_name = var1
    try:
        var2_name, _ = re.split(r'(?=[[])', var2)
    except:
        var2_name = var2

    par_patch = mpatches.Patch(color='grey', alpha=0.5, label='PAR')


    row_titles = ['Beginning of Season', 'Growth', 'Senescence',  'End of season',]

    color_gradient = np.concatenate((np.linspace(0.6, 1, len(row_titles)//2), np.linspace(0.8, 0.6, len(row_titles)//2+1)))

    colors = [plt.cm.Greens(x) for x in color_gradient]
    subfigs = fig.subfigures(nrows=num_plots+1, ncols=1)

    for row, subfig in enumerate(subfigs[1:]):
      subfig.suptitle(row_titles[row], fontsize=26, fontweight='bold', color=colors[row])

      axs = subfig.subplots(nrows=1, ncols=2)
      for i in range(2):
        j = row*2+i
        left_data = d1[var1].iloc[:,j]
        left_data =left_data.clip(upper=left_data.quantile(.95),lower=0 if abs(left_data.quantile(.05)) < .05 else left_data.quantile(.05))
        left_data.plot(
          legend=False,
          ylabel=var1_name,
          label=var1_name,
          ax=axs[i],
          c='C0',
          fontsize=14)
        right_data = d1[var2].iloc[:,j]
        right_data =right_data.clip(upper=right_data.quantile(.95),lower=0 if abs(right_data.quantile(.05)) < .05 else right_data.quantile(.05))
        right_data.plot(
          secondary_y=True,
          c='purple',
          legend=False,
          mark_right=True,
          ylabel=var2_name,
          ax=axs[i],
          label=var2_name,
          fontsize=14)

        # Get the clipped data points for var1
        clipped_var1 = get_clipped_values(left_data)

        # Plot x markers on the left axis for var1
        plot_clipped_markers(axs[i], clipped_var1[0][0], clipped_var1[0][1],color='black') # Below threshold
        plot_clipped_markers(axs[i], clipped_var1[1][0], clipped_var1[1][1],color='black') # Above threshold

        # Get the clipped data points for var2
        clipped_var2 = get_clipped_values(right_data)

        # Plot x markers on the right axis for var2
        plot_clipped_markers(axs[i].right_ax, clipped_var2[0][0], clipped_var2[0][1],color='black') # Below threshold
        plot_clipped_markers(axs[i].right_ax, clipped_var2[1][0], clipped_var2[1][1],color='black',label=None) # Above threshold

        axs[i].set_title(str(day_of_month[j])+' - '+sunny_labels[j],fontsize=28,fontweight='bold')
        par_data = d1['PAR inc [W m-2]'].iloc[:,j]
        left_min = left_data.min()
        left_max = left_data.max()
        scaled_par_data = par_data / abs(par_data).max() * max(abs(left_min), abs(left_max))
        axs[i].fill_between(x=d1.index,
                    y1=0,
                    y2=scaled_par_data,
                    facecolor='grey',
                    step='post',
                    alpha=.5)
        axs[i].set_ylabel(var1_name,fontsize=28)
        axs[i].right_ax.set_ylabel(var2_name,fontsize=28)
        # Add a dashed line to indicate the clipping threshold

    ax_seasonal = subfigs[0].subplots(nrows=1, ncols=1)
    ax_seasonal.set_position([0.05, 0.15, 0.4, 0.8])
    seasonal_df.plot(ax=ax_seasonal,fontsize=20,c='black')

    ax_seasonal.set_xlabel('Date')
    ax_seasonal.set_ylabel('PRI - scaled and smoothed',fontsize=20)
    subfigs[0].suptitle('Season Overview', fontsize=26, fontweight='bold',
                        color='black',
                        ha='left',y=.95,x=.05)
    subfigs[0].patch.set_facecolor('none')

    for i, date in enumerate(d1[var1].columns[::2]):

      ax_seasonal.axvspan(date, date + pd.DateOffset(days=2), facecolor=colors[i], alpha=0.5)

      ax_seasonal.annotate(row_titles[i].replace(' ','\n'), (date, seasonal_df.min()),
                           fontsize=28,
                           fontweight='bold',
                           textcoords="offset points",
                           xytext=(50,150),
                           ha='left',
                           color=colors[i])

    #fig.legend(handles=[par_patch], labels=['Incoming PAR'],fontsize=28,)
    # Create an empty list to store the Line2D objects
    lines = []

    lines.extend(axs[-1].get_lines())
    lines.extend(axs[-1].right_ax.get_lines())
    # Create labels for your lines
    labels = [var1_name,var2_name]

   # Add the par_patch to your lines and labels
    lines.append(par_patch)
    labels.append('Incoming PAR')

    # Create a figure legend
    fig.legend(handles=lines, labels=labels, fontsize=48,loc='upper right',bbox_to_anchor=(.9,.95),)

    fig.suptitle('Comparing Sunny and Cloudy Diurnal Patterns Across the Corn Season\n2-minute Average, Data Clipped within 5-95 Quantile For Clarity\nVariables: ' + var1_name + ' vs. ' + var2_name,
                 ha='center',x=.5,y=1.055,fontsize=32,fontweight='bold')


In [None]:
def plot_clipped_markers(axis, x_coordinates, y_coordinates, color='red', marker='x', label='Clipped'):
    """
    Plots x markers on a given axis using the scatter method.

    Parameters:
    axis (matplotlib.axes.Axes): The axis to plot on.
    x_coordinates (array-like): The x coordinates of the markers.
    y_coordinates (array-like): The y coordinates of the markers.
    color (str): The color of the markers. Defaults to 'red'.
    marker (str): The marker style. Defaults to 'x'.
    label (str): The label for the legend. Defaults to 'Clipped'.

    Returns:
    None
    """
    # Plot the markers using scatter
    axis.scatter(x_coordinates, y_coordinates, c=color, marker=marker, label=label, s=4)


###Wavelet Coherence Plotting

In [None]:
def calc_and_plot_wc(wave1, wave2, min_period, max_period, df,
                     save_fig=False, scales_per_octave=12, quiv_x=2, quiv_y=20,
                     dt=60,sig=False,cache_file=None,partial=None,wave3=None):
    """
    Calculate and plot the wavelet coherence between two signals.

    Parameters
    ----------
    wave1 : str
        The name of the first signal in the df DataFrame.
    wave2 : str
        The name of the second signal in the df DataFrame.
    min_period : float
        The minimum period for the wavelet coherence calculation.
    max_period : float
        The maximum period for the wavelet coherence calculation.
    df : pd.DataFrame
        The DataFrame containing the signals.
    calc_wave_coherence : function
        The function to calculate the wavelet coherence between two signals.
    plot_wcohere : function
        The function to plot the wavelet coherence.
    plot_arrows : function
        The function to plot arrows on the wavelet coherence plot.
    save_fig : str or None, optional
        If not None, save the figure to a file with this name. Default is None.
    scales_per_octave : int, optional
        The number of voices per octave for the wavelet coherence calculation. Default is 12.
    quiv_x : int, optional
        The x-spacing of the arrows on the wavelet coherence plot. Default is 2.
    quiv_y : int, optional
        The y-spacing of the arrows on the wavelet coherence plot. Default is 20.

    Returns
    -------
    fig : matplotlib.figure.Figure
        The figure object containing the wavelet coherence plot.
    """
    try:
        var1_name,_ = re.split(r'(?=[[])', wave1)
    except:
        var1_name = wave1
    try:
        var2_name, _ = re.split(r'(?=[[])', wave2)
    except:
        var2_name = wave2
    try:
        var3_name, _ = re.split(r'(?=[[])', wave3)
    except:
        var3_name = wave3

    time_var = 'doy.dayfract'
    wave1 = df[wave1].values
    wave2 = df[wave2].values
    t= df[time_var].values
    if partial is not None:
      if wave3 is None:
        return print('no wave3')
      else:
        wave3 = df[wave3].values
      WCT, freq, coi, sig_array, aWCT = calc_partial_wave_coherence(wave1,wave2,wave3,
                                                    dt=dt,
                                                    min_period=min_period,
                                                    max_period=max_period,
                                                    sig=sig,
                                                    scales_per_octave=scales_per_octave,
                                                    normalize=True)

      title = "Partial Wavelet Coherence\n X: " + var1_name+ ' Y: ' + var2_name + ' Z: ' + var3_name
    else:
      WCT, freq, coi, sig_array, aWCT,R = calc_wave_coherence(wave1, wave2,
                                                    dt=dt,
                                                    min_period=min_period,
                                                    max_period=max_period,
                                                    sig=sig,
                                                    scales_per_octave=scales_per_octave,
                                                    normalize=True)
      title = "Wavelet Coherence\n X: " + var1_name+ ' Y: ' + var2_name
    #global_WCT = np.mean(WCT,axis=1)
    #sig_global_WCT = np.mean(WCT[sig>=1],axis=1)
    fig,ax = plt.subplots()
    if len(sig_array) == 1:
      sig_array = None
    if cache_file is not None:
      sig_array = np.loadtxt(cache_file, unpack=True)
      sig_array = WCT/sig_array[:, None]
    fig, [WCT, t, y_vals] = plot_wcohere(WCT,
                                         t,
                                         freq,
                                         coi=coi,
                                         sig=sig_array,
                                         plot_period=True,
                                         ax=ax,
                                         title=title,
                                         #dates=resampled_data.resample('D').mean(numeric_only=True).index.date,
                                         )

    plot_arrows(ax, [WCT, t, y_vals], aWCT=aWCT, quiv_x=quiv_x, quiv_y=quiv_y, all_arrows=False,sig=sig_array);

    if save_fig:
        plt.savefig(fig,bbox_inches='tight')

    return fig,WCT, freq, coi, sig_array, aWCT


In [None]:
def plot_wcohere(WCT, t, freq, coi=None, sig=None, plot_period=False, ax=None, title="Wavelet Coherence", block=None, mask=None, cax=None,dates=None):
    """
    Plot wavelet coherence using results from calc_wave_coherence.

    Parameters
    ----------
    *First 5 parameters can be obtained from from calc_wave_coherence
        WCT: 2D numpy array with coherence values
        t : 2D numpy array with sample_times
        freq : 1D numpy array with the frequencies wavelets were calculated at
        sig : 2D numpy array, default None
            Optional. Plots significance of waveform coherence contours.
        coi : 2D numpy array, default None
            Optional. Pass coi to plot cone of influence
    plot_period : bool
        Should the y-axis be in period or in frequency (Hz)
    ax : plt.axe, default None
        Optional ax object to plot into.
    title : str, default "Wavelet Coherence"
        Optional title for the graph
    block : [int, int]
        Plots only points between ints.

    Returns
    -------
    tuple : (fig, wcohere_pvals)
        Where fig is a matplotlib Figure
        and result is a tuple consisting of [WCT, t, y_vals]
    """
    dt = np.mean(np.diff(t))

    if plot_period:
        y_vals = np.log2(1 / freq)
    if not plot_period:
        y_vals = np.log2(freq)

    if ax is None:
        fig, ax = plt.subplots()
    else:
        fig = None

    if mask is not None:
        WCT = np.ma.array(WCT, mask=mask)

    # Set the x and y axes of the plot
    extent_corr = [t.min(), t.max(), 0, max(y_vals)]
    # Fill the plot with the magnitude squared coherence values
    im = NonUniformImage(ax, interpolation='bilinear', extent=extent_corr,cmap='jet')
    if plot_period:
        im.set_data(t, y_vals, WCT)
    else:
        im.set_data(t, y_vals[::-1], WCT[::-1, :])
    im.set_clim(0, 1)
    ax.add_artist(im)#ax.images.append(im)

    # Plot the cone of influence - Periods greater thanthose are subject to edge effects.
    if coi is not None:
        # Performed by plotting a polygon
        x_positions = np.zeros(shape=(len(t),))
        x_positions = t

        y_positions = np.zeros(shape=(len(t),))
        if plot_period:
            y_positions = np.log2(coi)
        else:
            y_positions = np.log2(1 / coi)

        ax.plot(x_positions, y_positions,
                'w--',
                linewidth=2,
                #c="w"
                )

    # Plot the significance level contour plot
    if sig is not None:
        min_sig = np.min(sig[~np.isnan(sig)])
        ax.contourf(t, y_vals,
            sig,
            levels=[min_sig, 1],
            colors='grey',
            extent=extent_corr,
            alpha=0.5,
            )
        ax.contour(t, y_vals,
                   sig,
                   [min_sig, 1],
                   colors='k',
                   linewidths=1,
                   extent=extent_corr,
                   alpha=1,
                   )


    # Add limits, titles, etc.
    ax.set_ylim(min(y_vals), max(y_vals))
    if block:
        ax.set_xlim(t[block[0]], t[int(block[1] * 1 / dt)])
    else:
        ax.set_xlim(t.min(), t.max())

    if plot_period:
        y_ticks = np.linspace(min(y_vals), max(y_vals), 8)
        y_labels = convert_seconds_to_labels(np.exp2(y_ticks))
        ax.set_ylabel("Period")
    else:
        y_ticks = np.linspace(min(y_vals), max(y_vals), 8)
        y_labels = [str(x) for x in (np.round(np.exp2(y_ticks), 3)) if x < 60 ]
        ax.set_ylabel("Frequency (Hz)")
    ax.set_yticks(y_ticks)
    ax.set_yticklabels(y_labels)
    ax.set_title(title)
    ax.set_xlabel("Date")
    if dates is not None:
      ax.set_xticks(np.append(np.unique(np.floor(t))[:-3:6],np.unique(np.floor(t))[-1]))
      ax.set_xticklabels(
          [date.strftime("%m-%d") for date in dates][:-3:6]+\
       [[date.strftime("%m-%d") for date in dates][-1]],
                     rotation=90,
          )

    if cax is not None:
        plt.colorbar(im, cax=cax, use_gridspec=False)
    else:
        if fig is not None:
            fig.colorbar(im)
        else:
            plt.colorbar(im, ax=ax, use_gridspec=True)

    return fig, [WCT, t, y_vals]

In [None]:
def plot_arrows(ax, wcohere_pvals, aWCT=None, u=None, v=None, magnitude=None, quiv_x=5, quiv_y=24, all_arrows=False,sig=None):
    """
    Plots phase arrows for wavelet coherence plot using results from plot_wcohere
    source:
    https://seankmartin.github.io/Claustrum_Experiment/html/bvmpc/lfp_coherence.html#bvmpc.lfp_coherence.plot_wave_coherence
    Parameters
    ----------
    wcohere_pvals:
        input structure, includes [WCT, t, y_vals]
        the first three parameters are out from plot_wcohere
    aWCT : 2D numpy
        array with same shape as aWCT indicating phase angles
        *Can be obtained from last value in calc_wave_coherence
    u : 2D numpy array of unit vector's cos angle
    v : 2D numpy array of unit vector's sin angle
    magnitude : 2D numpy array of vector magnitude at each freq and timepoint
    quiv_x : float
        sets quiver window in time domain in seconds
    quiv_y : float
        sets number of quivers evenly distributed across freq limits
    all_arrows : bool
        Should phase arrows be plotted uniformly or only at high coherence

    """
    WCT, t, y_vals = wcohere_pvals

    if aWCT is not None:
        angle = aWCT #0.5 * np.pi - aWCT  # To set zero pointing up for arrow
        u, v = np.cos(angle), np.sin(angle)
    elif u is None or v is None:
        raise ValueError("Must pass aWCT or [u, v]")

    dt = np.mean(np.diff(t))

    x_res = int(1 / dt * quiv_x)
    y_res = int(np.ceil(len(y_vals) / quiv_y))
    if all_arrows:
        ax.quiver(t[::x_res], y_vals[::y_res],
                  u[::y_res, ::x_res], v[::y_res, ::x_res], units='height',
                  angles='uv', pivot='mid', linewidth=1, edgecolor='k', scale=30,
                  headwidth=10, headlength=10, headaxislength=5, minshaft=2,
                  )
    else:
        # t[::x_res], y_vals[::y_res],
        # u[::y_res, ::x_res], v[::y_res, ::x_res]
        if magnitude is not None:
            f_mean = np.empty_like(magnitude)
            for i, f in enumerate(magnitude):
                # Plot arrows if magnitude > mean of particular frequency
                f_mean[i, :] = np.mean(f) + np.std(f)
            high_points = np.nonzero(
                magnitude[::y_res, ::x_res] > f_mean[::y_res, ::x_res])

        elif sig is not None:
          high_points = np.nonzero(sig[::y_res, ::x_res] > 1)
        else:
          high_points = np.nonzero(WCT[::y_res, ::x_res] > 0.5)
        sub_t = t[::x_res][high_points[1]]
        sub_y = y_vals[::y_res][high_points[0]]
        sub_u = u[::y_res, ::x_res][np.array(
            high_points[0]), np.array(high_points[1])]
        sub_v = v[::y_res, ::x_res][high_points[0], high_points[1]]
        res = 1
        ax.quiver(sub_t[::res], sub_y[::res],
                  sub_u[::res], sub_v[::res], units='height',
                  angles='uv', pivot='mid', linewidth=1, edgecolor='k', scale=30,
                  headwidth=10, headlength=10, headaxislength=5, minshaft=2,
                  )

    return ax

In [None]:
def plot_heatmap(df, title='Heatmap',coi=None,sig=None):
    df = df.copy()
    if coi is not None:
        coi = coi.copy()

    plt.figure(figsize=(10, 6))

    try:
      df.index = df.index.strftime('%Y-%m-%d %H:%M:%S')
    except:
      print()

    # Transpose the DataFrame and create the heatmap
    ax = sns.heatmap(df.T, cmap='jet',
                     vmin=0,vmax=1)

    # Set the y-label to 'Period (minutes)'
    ax.set_ylabel('Period (minutes)')

    # Convert the y-tick labels from seconds to minutes and round to 0 decimals
    try:
      yticklabels = [round(float(label.get_text()) / 60, 0) for label in ax.get_yticklabels()]
      ax.set_yticklabels(yticklabels)
    except:
      ax.set_ylabel('Time of Day')

    # Flip the y-axis
    ax.invert_yaxis()

    # Plot the cone of influence
    if coi is not None:
      coi.index = coi.index.strftime('%Y-%m-%d %H:%M:%S')
      #ax.plot(coi.index, coi.values, 'w--', linewidth=.5)
      coi_df.plot(color='w',style='--',ax=ax)
    if sig is not None:
      ax.contour(sig.T>1,
            colors='black',
            #extent=extent_corr,
            alpha=1,
                 origin='lower'
             )
    # Rotate x-axis labels if necessary
    #plt.xticks(rotation=45)

    # Set the title
    plt.title(title)

    plt.show()

In [None]:
def plot_corr_heatmap(df, title='Heatmap',coi=None,sig=None):
    df = df.copy()
    if coi is not None:
        coi = coi.copy()

    plt.figure(figsize=(10, 6))
    # Define the colors for each interval
    colors = ['#ca0020', '#f4a582', '#f7f7f7', '#92c5de', '#0571b0'][::-1]

    # Create a custom colormap from the list of colors
    cmap = matplotlib.colors.ListedColormap(colors)

    try:
      df.index = df.index.strftime('%Y-%m-%d')
    except:
      print()

    # Transpose the DataFrame and create the heatmap
    ax = sns.heatmap(df.T, cmap=cmap,
                     vmin=-1,vmax=1)

    # Set the y-label to 'Period (minutes)'
    ax.set_ylabel('Rolling Correlation Window Length (hours)')

    # Convert the y-tick labels from seconds to minutes and round to 0 decimals
    try:
      yticklabels = [round(float(label.get_text()) / 60/60, 1) for label in ax.get_yticklabels()]
      ax.set_yticklabels(yticklabels)
    except:
      ax.set_ylabel('Time of Day')

    # Flip the y-axis
    ax.invert_yaxis()


    # Rotate x-axis labels if necessary
    #plt.xticks(rotation=45)

    # Set the title
    plt.title(title)

    plt.show()

In [None]:
def plot_WCT_w_timeseries(df, raw_df=None, waves=None, title="Wavelet Coherence", coi=None, sig=None, plot_raw_data=True):

  # Sort df by its index
  df = df.copy()
  df.sort_index(inplace=True)

  WCT = df.values.T
  periods = df.columns.values

  y_vals = np.log2(periods)

  if plot_raw_data and waves is not None:
    fig, axs = plt.subplots(3, sharex=True, figsize=(15,15), gridspec_kw={'hspace': 0})
    ax2 = axs[1].twinx()
    for i, wave in enumerate(waves):
        try:
            wave_name, wave_units = re.split(r'(?=[[])', wave)
        except:
            wave_name = wave
            wave_units = wave
        if i == 0:
            axs[1].plot(raw_df.index, raw_df[wave], label=wave_name.replace("_", " "))
            axs[1].legend(loc="upper left")
        else:
            ax2.plot(raw_df.index, raw_df[wave], label=wave_name.replace("_", " "), color='tab:orange')
            ax2.legend(loc="upper right")

  else:
    fig, axs = plt.subplots(1, figsize=(15,5))
    axs = [axs]

  # Set the x and y axes of the plot
  extent_corr = [mdates.date2num(df.index.min()), mdates.date2num(df.index.max()), min(y_vals), max(y_vals)]

  # Fill the plot with the magnitude squared coherence values
  im = NonUniformImage(axs[0], interpolation='bilinear', extent=extent_corr,cmap='jet')

  im.set_data(mdates.date2num(df.index), y_vals, WCT)

  im.set_clim(0, 1)

  axs[0].add_artist(im)

  # Plot the cone of influence - Periods greater than those are subject to edge effects.
  if coi is not None:
    # Performed by plotting a polygon
    x_positions = mdates.date2num(coi.index)

    y_positions = np.log2(coi.values.flatten())

    axs[0].plot(x_positions, y_positions,
            'w--',
            linewidth=2,
            )

  # Plot the significance level contour plot
  if sig is not None:
    min_sig = np.min(sig.values[~np.isnan(sig.values)])

    axs[0].contourf(mdates.date2num(df.index), y_vals,
        sig.values,
        levels=[min_sig, 1],
        colors='grey',
        extent=extent_corr,
        alpha=0.5,
        )

    axs[0].contour(mdates.date2num(df.index), y_vals,
               sig.values,
               [min_sig, 1],
               colors='k',
               linewidths=1,
               extent=extent_corr,
               alpha=1,
               )

  # Add limits, titles, etc.
  axs[0].set_ylim(min(y_vals), max(y_vals))
  axs[0].set_xlim(mdates.date2num(df.index.min()), mdates.date2num(df.index.max()))

  y_ticks = np.linspace(min(y_vals), max(y_vals), 8)
  y_labels = convert_seconds_to_labels(np.exp2(y_ticks))
  axs[0].set_yticks(y_ticks)
  axs[0].set_yticklabels(y_labels)

# Format x-ticks as datetime
  axs[0].xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
  axs[0].xaxis.set_major_locator(mdates.DayLocator(interval=1)) # set x-ticks to be every day

# Rotate x-ticks
  fig.autofmt_xdate()

  plt.show()


In [None]:
def plot_phase_angle_histogram(phase_data: pd.DataFrame, time_lag_data: pd.DataFrame) -> None:
    """
    Plot the phase angle of the wavelet coherence.

    Parameters:
    phase_data (pd.DataFrame): DataFrame containing phase data.
    time_lag_data (pd.DataFrame): DataFrame containing time lag data with columns 'mean_angle' and 'std_angle'.

    Returns:
    None
    """
    # Calculate the number of rows for the plot layout
    num_rows = int(np.ceil(len((phase_data.T * 180/np.pi).columns) / 7))

    # Create subplots for each column in the DataFrame
    axes = (phase_data.T * 180/np.pi).plot(kind='hist',subplots=True,
                                         layout=(num_rows,7),
                                         figsize=(10,(10/7)*num_rows),
                                         legend=0,
                                         sharey=True,
                                         sharex=True,
                                         title='Wavelet Coherence Phase Angle \n Period in Hours'
                                        )

    # Create a list of mean and standard deviation angles
    angle_list = [f"{round(row['mean_angle'], 1)} +- {round(row['std_angle'], 1)}" for _, row in time_lag_data.iterrows()]

    # Create a list of period values
    period_list = np.round(time_lag_data.index.values,2)

    # Add text annotations with the day-of-month labels above each subplot
    for n, ax in enumerate(axes.flatten()):
        if n >= len(angle_list):
            continue
        ax.text(.0, 1.02, str(angle_list[n]), transform=ax.transAxes,
                    size=9,
                    )
        ax.text(.0, .87, str(period_list[n]), transform=ax.transAxes,
                    size=9,
                    )


In [None]:
def plot_WCT_phase_angle_over_time(freq, coi, sig, aWCT):
    """
    This function calculates and plots the phase angle based on the given parameters.

    Parameters:
    freq (numpy.ndarray): Frequency array.
    coi (numpy.ndarray): Cone of influence array.
    sig (numpy.ndarray): Significance array.
    aWCT (numpy.ndarray): Cross wavelet transform array.

    Returns:
    None

    Sources:
    - https://sites.google.com/a/glaciology.net/grinsted/wavelet-coherence/faq
    - Application of the cross wavelet transform and wavelet coherence to geophysical time series
    """

    # Calculate incoi
    incoi = np.outer(1/freq, 1./coi) > 1

    # Calculate sig_phase and sig_phase_in_coi
    sig_phase = np.where(sig>=1,aWCT,np.nan)
    sig_phase_in_coi = np.where(~incoi,sig_phase,np.nan)

    # Create DataFrame phase_df and drop columns with all NaN values
    phase_df = pd.DataFrame(sig_phase_in_coi)
    phase_df.dropna(how='all',axis=1,inplace=True)

    # Calculate X, Y, R, circular_mean, and circular_std
    X = np.cos(phase_df).sum(axis=0)
    Y = np.sin(phase_df).sum(axis=0)
    R = np.sqrt(X**2+Y**2)
    circular_mean = np.arctan2(Y,X)*180/np.pi
    count = phase_df.count(axis=0)
    circular_std = np.sqrt(-2 * np.log(R/count))*180/np.pi

    # Adjust circular_mean values
    circular_mean[circular_mean < 0] += 360

    # Get x values and calculate y1 and y2 for shaded region
    x = circular_mean.index
    y1 = circular_mean - circular_std
    y2 = circular_mean + circular_std

    # Plot mean with shaded region for standard deviation
    circular_mean.plot(label='Mean')
    plt.fill_between(x, y1, y2, alpha=0.3, label='Std Dev')

    # Add anti-phase line at 180 on y-axis
    plt.axhline(180, color='red', linestyle='--', label='Anti-phase')

    # Add in-phase lines at 0 and 360 on y-axis
    plt.axhline(0, color='green', linestyle='--', label='In-phase')
    plt.axhline(360, color='green', linestyle='--')

    # Add title and legend
    plt.title('Average Wavelet Coherence Phase Angle over Season \n PRI_s and SIF_B ')
    plt.legend(loc='best')

    # Add label to y-axis
    plt.ylabel('Degrees')

    plt.ylim((-5,365))



###Animation

In [None]:
def animate_wavelength_range(df, min_nm, max_nm, frames, title=None, xlabel='nanometers', ylabel=None, linestyle='-', save=False, filename='animation.mp4'):
    # filter data to specific wavelength range
    area_of_interest_df = df.T.loc[
        df.T.filter(like=str(min_nm)+'.', axis=0).index[0]:  # starting wl
        df.T.filter(like=str(max_nm)+'.', axis=0).index[-1]  # ending wl
    ].T

    # create the figure and axes objects
    fig, ax = plt.subplots()
    scale = int(len(df) / frames)

    # plot the initial data
    line, = ax.plot(area_of_interest_df.columns, area_of_interest_df.iloc[0], linestyle=linestyle,)

    # set the y-axis range limits
    ax.set_ylim(area_of_interest_df.min().mean(), FULL_reflectance_w_PAR_df.quantile(.75).mean())
    ax.set_xlim(min_nm, max_nm)
    # set the x and y labels
    if ylabel:
        ax.set_ylabel(ylabel)
    ax.set_xlabel(xlabel)

    # function to update the plot for each frame of the animation
    def update(i):
        # update the data of the existing plot
        line.set_ydata(area_of_interest_df.iloc[scale * i])

        # set the title to show the current minute
        if title:
            ax.set_title(title)
        else:
            ax.set_title(df.index.name + " " + str(area_of_interest_df.index[scale * i]))
        return line
    # create the animation object
    ani = animation.FuncAnimation(fig, update, frames=frames, interval=100,blit=False)

    if save:
        ani.save(filename)
    else:
        # show the animation
        return HTML(ani.to_html5_video())
