# CTD on benthic lander - data processing notebook

Disclaimer: this code is still in development and can still contain multiple errors.

The following notebook was developed by the Suspended Material and Seabed Monitoring and Modelling group (SUMO) from the Institute of Natural Sciences, Belgium. It aims at providing a standard protocol to process data collected in-situ with a SBE19 plus V2 CTD instrument placed on a benthic lander. To use this notebook, simply import the raw data, run the cells containing code and inspect the results obtained. The graphs are then automatically saved as jpg and the processed data as csv.

In this notebook, the CTD data processing has been divided into two stages: 
1. *Data import & flagging*: This notebook can then be used to carry out the conversion from raw to csv, which includes the import and pre-processing of the data (i.e. date conversion, data flagging). At the end of this stage, the user can save the full dataset with a flagging system and dates as datetime. 
2. *Statistics computation & visualization*: The third stage performs the calculation of different statistical parameters on a clean dataset (after filtrating data above a certain flag value), such as mean particle diameter, D10, D50, D90 and distribution characteristics, and displays outputs as graphs. The clean dataset as well as the graphs can be saved. 

**Important!** Before starting the processing, make sure that all the necessary packages and libraries are installed on your computer and that you run the cell below to import everything that is required. Before running this cell, make sure you've installed all the necessary packages on your computer. To do that, you can run the following line in the prompt:
- pip install tk pandas numpy matplotlib seaborn scipy yaml

In [1]:
                                                    ## Cell 01 ##
import os
import ipywidgets as widgets
from IPython.display import display
import shutil
import tkinter as tk
from tkinter import filedialog, messagebox
from tkinter.filedialog import asksaveasfilename
import pandas as pd
from datetime import datetime, timedelta
import numpy as np
from scipy.signal import find_peaks, savgol_filter
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import seaborn as sns
import yaml 
import sys

# Now you can import the module
import scripts.data_processing as data_processing
                                                        ###

## Stage 1: Data pre-processing

#### Stage 1 - Step 1 : Enter metadata

In this cell, the user will be asked to enter the campaign identification code, year at which the data were collected, the name of the sampling location and coordinates in decimal degrees (XX.XXXXX). Press enter after each input.
**Important!** Avoid using "/" in the campaign identification code, as it might create issues when saving the data.

In [6]:
                                                    ## Cell 02 ##
deployment_code = input("Enter the deployment identification code: ")
location = input("Enter the name of the sampling location: ")
latitude = input("Enter the latitude of the data: ")
longitude = input("Enter the longitude of the data: ")
                                                        ###

Enter the deployment identification code: T004
Enter the name of the sampling location: COD-T-E
Enter the latitude of the data: x
Enter the longitude of the data: x


#### Stage 2 - Step 2 : Data import & formatting

Upon running this cell, a new window should automatically open and invite the user to select the csv file to be processed. Be aware that this window might open in the background and could be hidden under the current window. Once imported, the data will be attributed column names, campaign identification number and sampling location (as inserted by the user in the cell above). In addition to that, year, day, hour, minute and second columns will be converted to jday and datetime values. A new column containing a flag value (set by default to zero at this step) will be added at the end of the dataset. 

The dataset will be displayed upon running the cell, enabling the user to check that columns and values were assigned correctly and that the calculated date and time of the data correspond to the actual sampling time.

In [7]:
                                                    ## Cell 03 ##
# Open a window to select input file
root = tk.Tk()
root.withdraw()

# Allow the user to select a folder
base_path = filedialog.askdirectory(title="Select a directory")

if base_path:  # Check if a directory was selected

    if deployment_code == '':
        print("\033[1;31m*** You have not provided a Deployment Code, please enter Deployment Code ***\033[0m")
    else:
        print(f"Deployment Code: {deployment_code}\n")

        # Oxygen sensor alignment is fixed at 2
        oxygen_alignment = 2
        print(f"Oxygen sensor alignment default value: {oxygen_alignment}")

        # Set the cruise working directory
        pyear = '20' + deployment_code[2:4]
        pvessel = deployment_code[0:2]

        filepath = os.path.normpath(os.path.join(base_path))
        print(f"Working directory is: {filepath}")

        backup_server_path = os.path.normpath(
            os.path.join('Z:/2.1 Oceanographic', deployment_code, f"{deployment_code}_processed_CTD_data_BACKUP")
        )
        print(f"Backup directory is: {backup_server_path}")

        # Define directory structure
        directories = {'raw': 'raw_files', 'logs': 'logsheets', 'sbe35_raw': 'SBE35', 'cals': 'cal_samples', 'psa': 'psa', 'out': 'output'}
        directories_out = {'screen_2Hz': 'screen_2Hz', 'all_2Hz': 'all_2Hz', 'plots': 'plots'}

        def dir_path(path, name):
            d = os.path.join(path, name)
            if not os.path.exists(d):
                os.mkdir(d)
            return d

        # Create processing directories
        print("Processing directory folder locations:")
        for k, i in directories.items():
            directories[k] = dir_path(filepath, i)
        for k, i in directories_out.items():
            directories_out[k] = dir_path(directories.get("out"), i)
            print(directories_out[k])

        # Assign specific directory paths
        raw = directories.get("raw", "")
        logs = directories.get("logs", "")
        psa = directories.get("psa", "")
        sbe35_raw = directories.get("sbe35_raw", "")
        bottle = directories_out.get("bottle", "")
        screen_2Hz = directories_out.get("screen_2Hz", "")
        all_2Hz = directories_out.get("all_2Hz", "")
        plots = directories_out.get("plots", "")

        # Check if files exist in the raw_files directory
        if len(os.listdir(raw)) > 0:
            count, countn = 0, 0
            for filename in os.listdir(raw):
                if deployment_code in filename.upper():
                    count += 1
                else:
                    countn += 1

            if countn != 0:
                print(
                    '\n\033[1;31m*** Check filenames. Sea-Bird files present in the raw_files directory do not follow the '
                    'filename convention <DEPLOYMENT_CODE>_CTD<NUMBER> ***\n\033[0m'
                )
            else:
                print(f'\nFiles present in raw_files directory: {count}')
        else:
            print(
                '\n\033[1;31m*** No Sea-Bird files present in the raw_files directory. If running the notebook for the first time '
                'for this deployment, before proceeding copy across the raw SBE files to the "raw_files" folder in the working directory. ***\n\033[0m'
            )

        # Get HEX filenames
        filelist = os.listdir(raw)
        hexfilelist = [item.split('.')[0].upper() for item in filelist if item.endswith(".hex")]
        print(f'\tNumber of HEX files available in deployment folder: {len(hexfilelist)}')

        # Extract metadata from HDR files
        print("\nExtracting cast metadata from the header information for each cast for reference.")
        df_NMEA = data_processing.get_NMEA_from_header(raw, 'hdr')

        # Check all fields populated
        print(f'\tNumber of HDR files available in deployment folder: {len(df_NMEA)}')
        df_missingNMEA = df_NMEA.isnull().sum()

        for item in ['Lat', 'Long', 'Upload Time', 'UTC Time']:
            if df_missingNMEA[item] != 0:
                counts = df_missingNMEA[item]
                print(
                    f"\033[1;31mACTION *** {item} missing in {counts} HDR files *** Ensure {item} entered into logsheet from paper logs for:\033[0m"
                )
                print(df_NMEA[df_NMEA[item].isnull()]['CTD number'].tolist())
            else:
                print(f"\t{item} present in all HDR files")
                
else:
    print('\n\033[1;31m*** No directory selected. ***\n\033[0m')
                                                        ###

Deployment Code: T004

Oxygen sensor alignment default value: 2
Working directory is: D:\Tripod_COD_004_SBE19Plus
Backup directory is: Z:\2.1 Oceanographic\T004\T004_processed_CTD_data_BACKUP
Processing directory folder locations:
D:\Tripod_COD_004_SBE19Plus\output\bottle
D:\Tripod_COD_004_SBE19Plus\output\screen_2Hz
D:\Tripod_COD_004_SBE19Plus\output\all_2Hz
D:\Tripod_COD_004_SBE19Plus\output\plots

[1;31m*** Check filenames. Sea-Bird files present in the raw_files directory do not follow the filename convention <DEPLOYMENT_CODE>_CTD<NUMBER> ***
[0m
	Number of HEX files available in deployment folder: 1

Extracting cast metadata from the header information for each cast for reference.


KeyError: 1

In [None]:
                                                    ## Cell 03 ##
# Open a window to select input file
root = tk.Tk()
root.withdraw()
file_path = filedialog.askopenfilename(title="Select a CSV file", filetypes=[("CSV files", "*.csv")])
directory_path = os.path.dirname(file_path)

output_directory = f"{directory_path}/CTD-SBE19plusv2-{deployment_code}-{location}-processed"
if os.path.exists(output_directory):
    shutil.rmtree(output_directory)
    print(f'Existing output folder removed at: {output_directory}')
os.makedirs(output_directory)
print(f'Output folder created at: {output_directory}')

if file_path:  
    try:
        data = pd.read_csv(file_path, header=None, sep=" ")
        print("File successfully loaded!") 
        # Add column names
        data.columns = ["1.21","1.6","1.89","2.23","2.63","3.11","3.67","4.33","5.11","6.03","7.11","8.39","9.90","11.7","13.8","16.3","19.2","22.7","26.7","31.6","37.2","43.9","51.9","61.2","72.2","85.2","101","119","140","165","195","230","273","324","386","459","laser_transmission_sensor_mW","supply_voltage_V","external_input_1_V","laser_reference_sensor_mW","depth_in_m","temperature_C","year","month","day","hour","minute","second","external_input_2_V","mean_diameter_um","total_volume_concentration_ppm","relative_humidity_%","accelerometer_x","accelerometer_y","accelerometer_z","raw_pressure_most_significant","raw_pressure_least_significant","ambient_light_counts","external_analog_input_3_V","computed_optical_transmission","beam_attenuation_m"]
        # Add deployment identification number, sampling location name, latitude and longitude
        data['deployment'] = deployment_code
        data['location'] = location
        data['latitude'] = latitude
        data['longitude'] = longitude
        for col in data.columns.difference(['deployment', 'location']):
            if pd.api.types.is_numeric_dtype(data[col]) or data[col].dtype == object:
                data[col] = data[col].astype(str).str.replace(r"[^\d.-]", "", regex=True)  # Retain numbers, decimal points, and negatives
                data[col] = pd.to_numeric(data[col], errors='coerce')  # Convert to numeric, invalid entries to NaN
        # Calculate julian day & datetime
        data['year'] = data['year'].astype(str).str.zfill(2)
        data['month'] = data['month'].astype(str).str.zfill(2)
        data['day'] = data['day'].astype(str).str.zfill(2)
        data['hour'] = data['hour'].astype(str).str.zfill(2)
        data['minute'] = data['minute'].astype(str).str.zfill(2)
        data['second'] = data['second'].astype(str).str.zfill(2)
        data['datetime'] = data['year'].astype(str) +"-"+ data['month'].astype(str) +"-"+ data['day'].astype(str)+" "+data['hour'].astype(str) +":"+ data['minute'].astype(str) +":"+ data['second'].astype(str)
        data['datetime'] = pd.to_datetime(data['datetime'], format='%Y-%m-%d %H:%M:%S')
        data['jday'] = data['datetime'].dt.dayofyear + \
               (data['datetime'].dt.hour / 24) + \
               (data['datetime'].dt.minute / 1440) + \
               (data['datetime'].dt.second / 86400)
        data['flag'] = 0
        # Compute sampling start and stop datetime
        sampling_start = data['datetime'].min()
        sampling_stop = data['datetime'].max()
        # Display output
        print(f'Data were collected at {location} between {sampling_start} and {sampling_stop} during deployment {deployment_code}.')
        display(data)
    except Exception as e:
        print(f"Error reading the file: {e}")
else:
    print("No file selected")
                                                        ###

## Complete CTD processing steps

### The following cell performs the following tasks:
- load data from archived SDB processed to 2 Hz CSV file into a dataframe,
- if oxygen sensors are deployed, then advance oxygen voltage channel(s) by number of seconds determined from the interactive plots above (or the set defualt value),
- at the end of the up-cast, flag the 'cast' where pressure is less than 2 dbar (~height of the rig) to indicate rig breaking surface (flag = 'S'),
- load the down-cast start time (as seconds elapsed since data acquisition started) and flag the 'cast' to indicate the surface soak (flag = 'SS'),
- drop channels labeled 'NotInUse','pumps','Start',
- set temperatures outside range -5 to 40 degrees C as NaN,
- recalculate practical salinities, potential temperature, sigma-theta and sound velocity (EOS-80 toolbox),
- recalculate the oxygen concentration in umol/L and saturation using algorithms defined in SBE Application notes 64 and 64-3,
- save data to CSV file.

#### Stage 2 - Step 3: Flagging of the data

In this cell, flagging operations are performed as follows:
1. A first check is done on the **reference laser**, if values are below 0.2 mW, this indicates that the laser is likely not working properly. Quality flags will be assigned to 1 if data are above this value and to 3 if not (see flags meaning below). If more than 10% of the data present such a low value, a warning message will be displayed.  
2. Data out of the water are flagged with the value of 4 (rows where the beam attenuation is equal to zero). The **beam attenuation** corresponds to the loss of light intensity when a laser beam passes through water, due to both absorption and scattering. Higher attenuation values typically indicate more turbid water. 
3. The **laser transmission** values are checked. This indicates how much of the laser light has passed through water without being absorbed or scattered. A value of 100% indicates that there has been no light loss through the sample (either water without particles or air) while values below indicate scattering and absorption (which is expected in seawater). Values above 100%, however, indicate a sensor malfunction or miscalibration. If most data (>50%) are above 100%, then a warning message is generated. Data above 100 are flagged with the value of 3. 
4. The **optical transmission** is checked: values should be ranging between 0 and 1. Values outside that range should be discarded (flag 4). However, values above 0.98-0.995 reflect extremely clear water conditions, meaning a low signal-to-noise ratio (flag 3, to be taken with caution. Above that range, values should be discarded), whereas values below 0.10 reflect very turbid data and should be discarded. 
5. Finally, a flag of 3 is attributed to **outliers** detected based on the total volume concentration, depth, optical beam transmission and temperature and comparing the value of each point to the calculated mean over a moving window of 25 points. Values are considered outliers when they are higher or lower than the moving average plus or minus three times the standard deviation.

Once the flagging has been performed, a graph allows the user to visualize the quality of the data. The complete flagged dataset is automatically saved in the output directory selected by the user at the beginning of this notebook. In stage 3, a cell then allows to filter out all the data with a quality flag equal or higher than 4.

Quality flags are defined following the quality flags standards defined by the NERC Environmental Data Service of the British Oceanographic Data Centre (https://vocab.nerc.ac.uk/collection/L20/current/):
0: No quality control
1: Good value
2: Probably good value
3: Probably bad value
4: Bad value
5: Changed value
6: Value below detection
7: Value in excess
8: Interpolated value
9: Missing value

In [None]:
                                                    ## Cell 04 ##
# Step 1: Flagging based on reference_laser
data['flag'] = data.apply(lambda row: max(row['flag'], 1) if row['laser_reference_sensor_mW'] > 0.02 else max(row['flag'], 3), axis=1)
reference_below_thsld = (data['laser_reference_sensor_mW'] <= 0.02).sum() / len(data) * 100
if reference_below_thsld > 20:  # If more than 20% of the data are below 0.02 mW
    messagebox.showwarning("Laser should be checked", 
                           f"Warning: {reference_below_thsld:.0f}% of the data have a laser reference value below 0.02 mW. "
                           "Please check laser.")

# Step 2: Beam attenuation flagging
data['flag'] = data.apply(lambda row: max(row['flag'], 4) if row['beam_attenuation_m'] <= 0 else row['flag'], axis=1)

# Step 3: Laser transmission flagging
data['flag'] = data.apply(lambda row: 3 if row['laser_transmission_sensor_mW'] > 100 and row['flag'] < 3 else row['flag'], axis=1)
underwater_data = data[data['beam_attenuation_m'] > 0]
above_100_percent = (underwater_data['laser_transmission_sensor_mW'] > 100).sum() / len(underwater_data) * 100
above_100_percent = (data['laser_transmission_sensor_mW'] > 100).sum() / len(data) * 100
if above_100_percent > 50:  # If more than 50% of the data are above 100%
    messagebox.showwarning("Recalibration Needed", 
                           f"Warning: {above_100_percent:.0f}% of the underwater data are above 100% transmission. "
                           "Please recalibrate the sensor.")
    
# Step 4: Optical transmission flagging
data['flag'] = data.apply(lambda row: 3 if row['computed_optical_transmission'] > 0.98 and row['computed_optical_transmission'] < 0.995 and row['flag'] < 3 else row['flag'], axis=1)
data['flag'] = data.apply(lambda row: 4 if row['computed_optical_transmission'] >= 0.995 and row['flag'] < 4 else row['flag'], axis=1)
data['flag'] = data.apply(lambda row: 4 if row['computed_optical_transmission'] <= 0.10 and row['flag'] < 4 else row['flag'], axis=1)
    
# Step 5: Outlier detection using rolling mean and standard deviation
columns_to_check = ['total_volume_concentration_ppm', 'temperature_C', 'depth_in_m', 'computed_optical_transmission']
window_size = 25  # Define the window size for rolling calculations
threshold = 3  # Define the threshold for outlier detection

for col in columns_to_check:
    rolling_mean = data[col].rolling(window=window_size, center=True).mean()
    rolling_std = data[col].rolling(window=window_size, center=True).std()
    outlier_column = f'is_outlier_{col}'  # Create a column to flag outliers for each variable
    data[outlier_column] = abs(data[col] - rolling_mean) > (threshold * rolling_std)

    # Update the 'flag' column if the data is an outlier
    data['flag'] = data.apply(
        lambda row: 3 if row[outlier_column] else row['flag'], axis=1
    )

# Combine all outlier flags into one column for visualization
data['is_outlier'] = data[[f'is_outlier_{col}' for col in columns_to_check]].any(axis=1)

# Count total outliers
outlier_count = data['is_outlier'].sum()
print(f'Number of outliers detected: {outlier_count}')

# Warning message if quality flags exceed 75%
percentage_high_flags = (data['flag'] >= 3).sum() / len(data) * 100
if percentage_high_flags > 75:  # If more than 75% of the data have a quality flag of 3 or higher
    messagebox.showwarning("Unsatisfactory data quality", 
                           f"Warning: {percentage_high_flags:.0f}% of the data have a quality flag above 3. "
                           "Data should be used with caution.")

# Output
print('Flagging has been successfully performed on the complete dataset:')
display(data)

# Visualizing outliers for each column
for col in columns_to_check:
    plt.figure(figsize=(8, 4))
    plt.scatter(data.index, data[col], label='', color='blue', s=10, alpha=0.5)
    plt.scatter(
        data[data[f'is_outlier_{col}']].index, 
        data[data[f'is_outlier_{col}']][col], 
        label='Outliers', color='red', marker='x', s=15
    )
    plt.grid(axis='both', which='both', linewidth=0.3)
    plt.title(f'Outlier detection for {col}')
    plt.xlabel('Index')
    plt.ylabel(col.replace('_', ' ').title())
    plt.legend()
    plt.show()

# Visualizing flags
flag_counts = data['flag'].value_counts().sort_index()
colors = {1: 'green', 2: 'yellow', 3: 'orange', 4: 'red'}
fig, ax = plt.subplots()
flag_counts.plot(kind='bar', color=[colors.get(flag, 'blue') for flag in flag_counts.index], ax=ax)
ax.grid(axis='both', which='both', linewidth=0.3)
ax.set_title("Count per flag category")
ax.set_xlabel("Flag")
ax.set_ylabel("Count")
plt.show()

# Save as csv
data_path = f"{output_directory}/{deployment_code}-{location}-LISST200x-full-data.csv"
if data_path:
    data.to_csv(data_path, index=False)
    print(f"File saved to {file_path}")
else:
    print("Save operation cancelled.")
                                                        ###

## Stage 3: Statistics computation & visualization

#### Stage 3 - Step 1: Statistics computation

By running this cell, a series of statistics are calculated on the full dataset. This includes:
1. **Total volume concentration**: This sums up the volume concentration of each grain size class per row.
2. **Relative volume concentration**: This calculates the percentage that each class represents compared to the total volume concentration.
3. **Mean diameter**: This corresponds to the mean diameter of particles in each row weighted by their volume concentration and normalized by the total volume concentration.
4. **D10, D50, D90 values**: These represent the diameter below which 10%, 50% or 90% of the volume of the data are found. The D10 helps characterizing the finer fraction of the sample, the D50 corresponds to the median diameter and the D90 characterizes the coarser sediment fraction. These values are calculated based on the cumulative distribution.
5. **Span**: It is a measure of the sorting of particle sizes, as a normalized measure of the distribution spread around the median particle size, showing the relative range of the middle 80% of the particle size distribution. Small values indicate good sorting (close to 1) while a larger span indicates poor sorting of the particles.
6. **Standard deviation**: It measures the average dispersion of particle sizes from the mean, a greater standard deviation shows greater variability around the mean with possible outliers or tails. If the span and standard deviation are low, it shows a uniform distribution; if the span is moderate but the standard deviation is high, it shows a moderate spread with outliers; and if the span and standard deviation are high, it shows a wide distribution.
7. **Mode**: The mode is the particle size that has the largest volume of the distribution (peak in the distribution).
8. **Peaks in the distribution**: It identifies peaks in the particle size distribution.

The particle size distribution combined with the cumulative volume concentration is displayed upon running of this cell as well as a table with the mean value of all the calculated parameters. The graph and updated complete clean dataframe are automatically saved in the directory selected by the user.

In [None]:
                                                    ## Cell 05 ##  
# Calculate the total volume concentration 
volume_concentration_columns = ["1.21","1.6","1.89","2.23","2.63","3.11","3.67","4.33","5.11","6.03","7.11","8.39","9.90","11.7","13.8","16.3","19.2","22.7","26.7","31.6","37.2","43.9","51.9","61.2","72.2","85.2","101","119","140","165","195","230","273","324","386","459"]
average_volume_fractions = data[volume_concentration_columns].mean()
volume_concentrations = data[volume_concentration_columns].values
grain_sizes = np.array([float(col) for col in volume_concentration_columns])

# Calculate the percentage of each class compared to the total volume concentration
for col in volume_concentration_columns:
    data[f'{col}_%'] = ((data[col] / data['total_volume_concentration_ppm']) * 100).astype(float)

# Cumulative volume concentration
cumulative_volumes = np.cumsum(data[volume_concentration_columns].values, axis=1)

# Percentiles calculation function
def calculate_d_percentile(cumulative_volumes, grain_sizes, total_volume, percentile):
    target = total_volume * (percentile / 100.0)
    d_percentile = []
    for i in range(len(total_volume)):
        if total_volume.iloc[i] == 0: 
            d_percentile.append(np.nan)
        else:
            greater_equal_idx = np.argmax(cumulative_volumes[i, :] >= target.iloc[i])
            d_percentile.append(grain_sizes[greater_equal_idx])
    return np.array(d_percentile)

# Calculate D10, D50, D90
data['D10_um'] = calculate_d_percentile(cumulative_volumes, grain_sizes, data['total_volume_concentration_ppm'], 10)
data['D50_um'] = calculate_d_percentile(cumulative_volumes, grain_sizes, data['total_volume_concentration_ppm'], 50)
data['D90_um'] = calculate_d_percentile(cumulative_volumes, grain_sizes, data['total_volume_concentration_ppm'], 90)

# Calculate the span
data['span'] = (data['D90_um'] - data['D10_um']) / data['D50_um']

# Calculate the standard deviation
def calculate_std(grain_sizes, volume_concentrations, mean_diameter):
    variance = np.sum(volume_concentrations * (grain_sizes - mean_diameter[:, None])**2, axis=1) / np.sum(volume_concentrations, axis=1)
    return np.sqrt(variance)
data['std_dev_um'] = calculate_std(grain_sizes, volume_concentrations, data['mean_diameter_um'].values)

# Calculate the mode
def calculate_mode(grain_sizes, volume_concentrations):
    mode_values = []
    for vc in volume_concentrations:
        if np.all(vc == 0):
            mode_values.append(np.nan)
        else:
            mode_index = np.argmax(vc)
            mode_values.append(grain_sizes[mode_index])
    return np.array(mode_values)
data['mode_um'] = calculate_mode(grain_sizes, volume_concentrations)

# Identify peaks in the distribution 
def find_all_peaks(grain_sizes, volume_concentrations):
    all_peaks = []
    for vc in volume_concentrations:
        if np.all(vc == 0):
            all_peaks.append([])
        else:
            peaks, _ = find_peaks(vc)
            peak_sizes = grain_sizes[peaks]
            all_peaks.append(peak_sizes.tolist())
    return all_peaks
data['peaks'] = find_all_peaks(grain_sizes, volume_concentrations)

# Display updated dataframe
print('Statistics have been successfully computed')
display(data)

# Display the histogram of the particle size distribution and the cumulative volume distribution (in red)
total_volume_concentration_per_class = data[volume_concentration_columns].sum(axis=0)
fig, ax1 = plt.subplots(figsize=(10, 6), dpi=300)
ax1.plot(grain_sizes, total_volume_concentration_per_class, color='blue', marker='o', 
         label='Particle size distribution', linestyle='-')
ax1.set_xscale('log')
ax1.set_xlabel('Grain size (µm)')
ax1.set_ylabel('Total volume concentration (µl/l)')
ax1.set_title(f'LISST-200x {location} {deployment_code}')
ax1.grid(axis='both', which='both', linewidth=0.3)
cumulative_volumes = np.cumsum(total_volume_concentration_per_class)
ax2 = ax1.twinx() 
ax2.plot(grain_sizes, cumulative_volumes / cumulative_volumes[-1] * 100, color='red', 
         label='Cumulative Distribution', marker='o')
ax2.set_ylabel('Cumulative volume (%)', color='red')
ax2.tick_params(axis='y', labelcolor='red')
ax1.axvline(data['D10_um'].mean(), color='green', linestyle='--', label='D10')
ax1.axvline(data['D50_um'].mean(), color='orange', linestyle='--', label='D50')
ax1.axvline(data['D90_um'].mean(), color='purple', linestyle='--', label='D90')
ax1.legend(loc='upper left')
ax2.legend(loc='upper center')
plt.tight_layout()
plt.show()
graph_file_path = f"{output_directory}/{deployment_code}-{location}-LISST200x-full-PSD.png"
fig.savefig(graph_file_path) 

# Display table with the mean of each parameter
mean_values = {
    'Total volume concentration (ppm)': data['total_volume_concentration_ppm'].mean(),
    'Mean diameter (µm)': data['mean_diameter_um'].mean(),
    'D10 (µm)': data['D10_um'].mean(),
    'D50 (µm)': data['D50_um'].mean(),
    'D90 (µm)': data['D90_um'].mean(),
    'Span': data['span'].mean(),
    'Standard deviation (µm)': data['std_dev_um'].mean(),
    'Mode (µm)': data['mode_um'].mean()
}
mean_values_df = pd.DataFrame(mean_values, index=[0])
print('Mean values of calculated parameters for the complete clean dataset:')
display(mean_values_df)

fig, axs = plt.subplots(5, 1, figsize=(10, 12), dpi=300, sharex=True)
axs[0].set_title(f'LISST-200x {location} {deployment_code}')
axs[0].plot(data['datetime'], data['depth_in_m'], color='steelblue', linestyle='-', linewidth=0.4)
axs[0].set_ylabel('Depth (m)')
axs[0].grid(axis='both', which='both', linewidth=0.3)
axs[1].plot(data['datetime'], data['computed_optical_transmission'], color='steelblue', linestyle='-', linewidth=0.4)
axs[1].set_ylabel('Optical transmission (%)')
axs[1].grid(axis='both', which='both', linewidth=0.3)
axs[2].plot(data['datetime'], data['temperature_C'], color='steelblue', linestyle='-', linewidth=0.4)
axs[2].set_ylabel('Temperature (°C)')
axs[2].grid(axis='both', which='both', linewidth=0.3)
axs[3].plot(data['datetime'], data['total_volume_concentration_ppm'], color='steelblue', linestyle='-', linewidth=0.4)
axs[3].set_ylabel('Total volume concentration (ppm)')
axs[3].grid(axis='both', which='both', linewidth=0.3)
axs[4].plot(data['datetime'], data['mean_diameter_um'], color='steelblue', linestyle='-', linewidth=0.4)
axs[4].set_ylabel('Mean diameter (µm)')
axs[4].grid(axis='both', which='both', linewidth=0.3)
axs[4].set_xlabel('')
#axs[4].xaxis.set_major_formatter(mdates.DateFormatter('%D'))
axs[4].xaxis.set_major_locator(mdates.DayLocator(interval=7))
fig.align_ylabels()
plt.tight_layout()
plt.show()
graph_file_path = f"{output_directory}/{deployment_code}-{location}-LISST200x-g-full-graph2.png"
fig.savefig(graph_file_path) 

# Save output
print(f'Graphs are saved to: {output_directory}')
                                                        ###

#### Stage 3 - Step 2: Manual adjustment of acceptable optical transmission, total volume concentration and/or mean diameter values (optional, upon verification on the graph above)

The following cell is optional and does not have to be ran if no additional filtering of the data is needed. This decision should be done following a visual inspection of the graph generated in cell 5 above. The user can define minimum and/or maximum thresholds for five possible fields (optical transmission, total volume concentration, mean diameter, pressure and/or temperature) below and/or above which the data are flagged to 4. The full dataset is saved again.
If one threshold wants to be ignored, it can be set to -99.99.

In [None]:
                                                    ## Cell 06 ##
# Initialize thresholds
minimum_threshold = -99.99
maximum_threshold = -99.99

# Dropdown for field selection
field_threshold_dropdown = widgets.Dropdown(
    options=['optical transmission', 'total volume concentration', 'mean diameter', 'depth', 'temperature'],
    value=None,
    description='Select field on which to apply a threshold:',
    style={'description_width': 'initial'}
)

# Widgets for threshold inputs (initially hidden)
min_threshold_widget = widgets.FloatText(
    value=-99.99,
    description='Minimum threshold:',
    style={'description_width': 'initial'}
)
max_threshold_widget = widgets.FloatText(
    value=-99.99,
    description='Maximum threshold:',
    style={'description_width': 'initial'}
)

# Button to apply thresholds (initially hidden)
apply_button = widgets.Button(
    description='Apply Thresholds',
    button_style='success'
)

# Container to display widgets dynamically
threshold_widgets_container = widgets.VBox([])

# Function to handle dropdown changes
def on_field_change(change):
    selected_field = change['new']
    if selected_field:
        print(f"Selected field: {selected_field}")
        # Display threshold widgets and apply button
        threshold_widgets_container.children = [min_threshold_widget, max_threshold_widget, apply_button]
    else:
        # Hide threshold widgets and apply button
        threshold_widgets_container.children = []

# Observe dropdown changes
field_threshold_dropdown.observe(on_field_change, names='value')

# Function to apply thresholds
def apply_thresholds(_):
    global minimum_threshold, maximum_threshold
    minimum_threshold = min_threshold_widget.value
    maximum_threshold = max_threshold_widget.value
    print(f"Applying thresholds: Min = {minimum_threshold}, Max = {maximum_threshold}, Field = {field_threshold_dropdown.value}")
    
    # Apply flags based on selected field
    if field_threshold_dropdown.value == 'optical transmission':
        if minimum_threshold != -99.99:
            data['flag'] = data.apply(
                lambda row: 4 if row['computed_optical_transmission'] < minimum_threshold and row['flag'] < 4 else row['flag'], axis=1
            )
        if maximum_threshold != -99.99:
            data['flag'] = data.apply(
                lambda row: 4 if row['computed_optical_transmission'] > maximum_threshold and row['flag'] < 4 else row['flag'], axis=1
            )
    elif field_threshold_dropdown.value == 'total volume concentration':
        if minimum_threshold != -99.99:
            data['flag'] = data.apply(
                lambda row: 4 if row['total_volume_concentration_ppm'] < minimum_threshold and row['flag'] < 4 else row['flag'], axis=1
            )
        if maximum_threshold != -99.99:
            data['flag'] = data.apply(
                lambda row: 4 if row['total_volume_concentration_ppm'] > maximum_threshold and row['flag'] < 4 else row['flag'], axis=1
            )
    elif field_threshold_dropdown.value == 'mean diameter':
        if minimum_threshold != -99.99:
            data['flag'] = data.apply(
                lambda row: 4 if row['mean_diameter_um'] < minimum_threshold and row['flag'] < 4 else row['flag'], axis=1
            )
        if maximum_threshold != -99.99:
            data['flag'] = data.apply(
                lambda row: 4 if row['mean_diameter_um'] > maximum_threshold and row['flag'] < 4 else row['flag'], axis=1
            )
    elif field_threshold_dropdown.value == 'depth':
        if minimum_threshold != -99.99:
            data['flag'] = data.apply(
                lambda row: 4 if row['depth_in_m'] < minimum_threshold and row['flag'] < 4 else row['flag'], axis=1
            )
        if maximum_threshold != -99.99:
            data['flag'] = data.apply(
                lambda row: 4 if row['depth_in_m'] > maximum_threshold and row['flag'] < 4 else row['flag'], axis=1
            )
    elif field_threshold_dropdown.value == 'temperature':
        if minimum_threshold != -99.99:
            data['flag'] = data.apply(
                lambda row: 4 if row['temperature_C'] < minimum_threshold and row['flag'] < 4 else row['flag'], axis=1
            )
        if maximum_threshold != -99.99:
            data['flag'] = data.apply(
                lambda row: 4 if row['temperature_C'] > maximum_threshold and row['flag'] < 4 else row['flag'], axis=1
            )

    # Save updated data
    data_path = f"{output_directory}/{deployment_code}-{location}-LISST200x-full-data.csv"
    data.to_csv(data_path, index=False)
    print(f"File saved to {data_path}")
    print(f"Flagging has been successfully performed and the updated dataset has been saved to: {output_directory}")

# Connect button to function
apply_button.on_click(apply_thresholds)

# Display dropdown and container
display(field_threshold_dropdown, threshold_widgets_container)
                                                        ###

#### Stage 3 - Step 3: Removal of suspicious data & creation of a 'clean' dataset

This step allows the user to choose a certain flag threshold at and above which data are discarded for further analysis. The default value is set to 4, meaning all data flagged 4 and above will be removed from the filtered dataframe. In this cell, values considered outliers can also be removed.

In [None]:
                                                    ## Cell 07 ##
try:
    threshold = int(input("Enter the threshold for flag filtration (default is 4): ") or 4)
except ValueError:
    print("Invalid input, defaulting to threshold = 4.")
    threshold = 4

outliers = input("Remove outliers (yes or no): ")

# Apply initial filters
filtered_data = data[
    (data['total_volume_concentration_ppm'] != 0) & 
    (data['mean_diameter_um'].notna()) & 
    (data['flag'] < threshold)
]

# Additional filtering based on outliers
if outliers.lower() == 'yes':
    filtered_data = filtered_data[filtered_data['is_outlier'] != True]
elif outliers.lower() != 'no':
    print("Invalid input for outliers; no outlier filtering applied.")

# Display the filtered data
print(f"After filtering, the dataset is:")
display(filtered_data)

# Save filtered data as csv
filtered_data_path = f"{output_directory}/{deployment_code}-{location}-LISST200x-clean-data.csv"
filtered_data.to_csv(filtered_data_path, index=False)
print(f"Filtered data saved to {filtered_data_path}")

# Apply Savitzky-Golay filter
filtered_data['mean_diameter_filtered_6h'] = savgol_filter(filtered_data['mean_diameter_um'], 359, 3)
filtered_data['mean_diameter_filtered_12h'] = savgol_filter(filtered_data['mean_diameter_um'], 719, 3)
filtered_data['mean_diameter_filtered_24h'] = savgol_filter(filtered_data['mean_diameter_um'], 1439, 3)
filtered_data['total_volume_filtered_6h'] = savgol_filter(filtered_data['total_volume_concentration_ppm'], 359, 3)
filtered_data['total_volume_filtered_12h'] = savgol_filter(filtered_data['total_volume_concentration_ppm'], 719, 3)
filtered_data['total_volume_filtered_24h'] = savgol_filter(filtered_data['total_volume_concentration_ppm'], 1439, 3)

# New graph with clean and filtered data
fig, axs = plt.subplots(5, 1, figsize=(10, 12), dpi=300, sharex=True)
axs[0].set_title(f'LISST-200x {location} {deployment_code}')
axs[0].plot(filtered_data['datetime'], filtered_data['depth_in_m'], color='steelblue', linestyle='-', linewidth=0.4)
axs[0].set_ylabel('Depth (m)')
axs[0].grid(axis='both', which='both', linewidth=0.3)
axs[1].plot(filtered_data['datetime'], filtered_data['computed_optical_transmission'], color='steelblue', linestyle='-', linewidth=0.4)
axs[1].set_ylabel('Optical transmission (%)')
axs[1].grid(axis='both', which='both', linewidth=0.3)
axs[2].plot(filtered_data['datetime'], filtered_data['temperature_C'], color='steelblue', linestyle='-', linewidth=0.4)
axs[2].set_ylabel('Temperature (°C)')
axs[2].grid(axis='both', which='both', linewidth=0.3)
axs[3].plot(filtered_data['datetime'], filtered_data['total_volume_concentration_ppm'], color='steelblue', linestyle='-', linewidth=0.4, label='Original')
axs[3].plot(filtered_data['datetime'], filtered_data['total_volume_filtered_6h'], color='lightblue', linestyle='-', linewidth=0.6, label='6h')
axs[3].plot(filtered_data['datetime'], filtered_data['total_volume_filtered_12h'], color='lime', linestyle='-', linewidth=0.6, label='12h')
axs[3].plot(filtered_data['datetime'], filtered_data['total_volume_filtered_24h'], color='red', linestyle='-', linewidth=0.6, label='24h')
axs[3].set_ylabel('Total volume concentration (µl/l)')
axs[3].grid(axis='both', which='both', linewidth=0.3)
axs[3].legend()
axs[4].plot(filtered_data['datetime'], filtered_data['mean_diameter_um'], color='steelblue', linestyle='-', linewidth=0.4, label='Original')
axs[4].plot(filtered_data['datetime'], filtered_data['mean_diameter_filtered_6h'], color='lightblue', linestyle='-', linewidth=0.6, label='6h')
axs[4].plot(filtered_data['datetime'], filtered_data['mean_diameter_filtered_12h'], color='lime', linestyle='-', linewidth=0.6, label='12h')
axs[4].plot(filtered_data['datetime'], filtered_data['mean_diameter_filtered_24h'], color='red', linestyle='-', linewidth=0.6, label='24h')
axs[4].set_ylabel('Mean diameter (µm)')
axs[4].grid(axis='both', which='both', linewidth=0.3)
axs[4].set_xlabel('')
axs[4].xaxis.set_major_locator(mdates.DayLocator(interval=7))
fig.align_ylabels()
plt.tight_layout()
plt.show()

graph_file_path = f"{output_directory}/{deployment_code}-{location}-LISST200x-g-filtered-graph2.png"
fig.savefig(graph_file_path)
                                                        ###

#### Stage 3 - Step 4: Computation of the full PSD and evolution of the PSD with tides

In [None]:
                                                    ## Cell 08 ##
# Total volume concentration per class (for full PSD)
volume_concentration_columns = ["1.21","1.6","1.89","2.23","2.63","3.11","3.67","4.33","5.11","6.03","7.11","8.39","9.90","11.7","13.8","16.3","19.2","22.7","26.7","31.6","37.2","43.9","51.9","61.2","72.2","85.2","101","119","140","165","195","230","273","324","386","459"]
total_volume_concentration_per_class = filtered_data[volume_concentration_columns].sum(axis=0)
average_volume_fractions = filtered_data[volume_concentration_columns].mean()
volume_concentrations = filtered_data[volume_concentration_columns].values
grain_sizes = np.array([float(col) for col in volume_concentration_columns])

# Create figure and axis for full PSD plot
fig, ax1 = plt.subplots(figsize=(10, 6), dpi=300)

# Plot the full Particle Size Distribution (PSD) on a logarithmic scale
ax1.plot(grain_sizes, total_volume_concentration_per_class, color='blue', marker='o', 
         label='Particle size distribution', linestyle='-', markersize=5)
ax1.set_xscale('log')
ax1.set_xlabel('Grain size (µm)', fontsize=12)
ax1.set_ylabel('Total volume concentration (µl/l)', fontsize=12)
ax1.set_title(f'LISST-200x {location} {deployment_code}', fontsize=14)
ax1.grid(axis='both', which='both', linewidth=0.3)

# Calculate cumulative volume distribution
cumulative_volumes = np.cumsum(total_volume_concentration_per_class)

# Create secondary axis for the Cumulative Volume Distribution
ax2 = ax1.twinx()
ax2.plot(grain_sizes, cumulative_volumes / cumulative_volumes[-1] * 100, color='red', 
         label='Cumulative Distribution', marker='o', linestyle='-', markersize=5)
ax2.set_ylabel('Cumulative volume (%)', color='red', fontsize=12)
ax2.tick_params(axis='y', labelcolor='red')

# Add lines for D10, D50, D90 values
ax1.axvline(filtered_data['D10_um'].mean(), color='green', linestyle='--', label='D10', linewidth=1.2)
ax1.axvline(filtered_data['D50_um'].mean(), color='orange', linestyle='--', label='D50', linewidth=1.2)
ax1.axvline(filtered_data['D90_um'].mean(), color='purple', linestyle='--', label='D90', linewidth=1.2)

# Add legends for the two y-axes
ax1.legend(loc='upper left', fontsize=10)
ax2.legend(loc='upper center', fontsize=10)

# Tight layout for better presentation
plt.tight_layout()

# Show the plot
plt.show()

# Save the graph as PNG
graph_file_path = f"{output_directory}/{deployment_code}-{location}-LISST200x-full-PSD.png"
fig.savefig(graph_file_path)
print(f"Graph saved to {graph_file_path}")

### Plot PSD for High and Low Tide Conditions

def classify_tides_with_12h_window(data):
    # Ensure data is sorted by datetime
    data = data.sort_values(by='datetime').reset_index(drop=True)
    
    # Smooth the depth data with Savitzky-Golay filter
    sampling_interval = (data['datetime'].iloc[1] - data['datetime'].iloc[0]).total_seconds() / 3600  # In hours
    time_interval = timedelta(hours=3)
    window_size = int(time_interval.total_seconds() / (sampling_interval * 3600))  # Approx. 3-hour window
    if window_size % 2 == 0:  # Ensure window size is odd
        window_size += 1
    
    data['smoothed_depth'] = savgol_filter(data['depth_in_m'], window_length=window_size, polyorder=3)
    
    # Apply a 12-hour rolling window to identify high tides (maxima) and low tides (minima)
    rolling_window = int(timedelta(hours=12).total_seconds() / (sampling_interval * 3600))
    if rolling_window % 2 == 0:  # Ensure rolling window size is odd
        rolling_window += 1

    data['rolling_max'] = data['smoothed_depth'].rolling(window=rolling_window, center=True).max()
    data['rolling_min'] = data['smoothed_depth'].rolling(window=rolling_window, center=True).min()
    
    # Identify high and low tides
    data['is_high_tide'] = data['smoothed_depth'] == data['rolling_max']
    data['is_low_tide'] = data['smoothed_depth'] == data['rolling_min']
    
    # Extract high tide and low tide information
    high_tides = data[data['is_high_tide']][['datetime', 'smoothed_depth']].rename(columns={'smoothed_depth': 'high_tide_depth'})
    low_tides = data[data['is_low_tide']][['datetime', 'smoothed_depth']].rename(columns={'smoothed_depth': 'low_tide_depth'})
    
    # Reset index for convenience
    high_tides = high_tides.reset_index(drop=True)
    low_tides = low_tides.reset_index(drop=True)
    
    # Combine high and low tides into one DataFrame
    tides = pd.concat([high_tides, low_tides], keys=['high', 'low']).sort_values(by='datetime')
    
    # Define slack tide offset and buffer
    slack_offset = timedelta(hours=2, minutes=30)
    buffer = timedelta(minutes=10)
    
    slack_tides = []
    ebb_tides = []
    flow_tides = []
    
    # Generate slack tide times and classify ebb and flow
    for i, row in tides.iterrows():
        if 'high_tide_depth' in row:  # High tide
            high_tide = row['datetime']
            slack_tide = high_tide - slack_offset
            slack_tides.append(slack_tide)
            ebb_tides.append((high_tide, slack_tide))
        elif 'low_tide_depth' in row:  # Low tide
            low_tide = row['datetime']
            slack_tide = low_tide - slack_offset
            slack_tides.append(slack_tide)
            flow_tides.append((low_tide, slack_tide))
    
    # Classify tides for each timestamp in the dataset
    def classify_timestamp(timestamp):
        for high in high_tides['datetime']:
            if high - buffer <= timestamp <= high + buffer:
                return "High Tide"
        for low in low_tides['datetime']:
            if low - buffer <= timestamp <= low + buffer:
                return "Low Tide"
        for slack in slack_tides:
            if slack - buffer <= timestamp <= slack + buffer:
                return "Slack Tide"
        return "Other"
    
    # Apply classification to dataset
    data['expected_tide'] = data['datetime'].apply(classify_timestamp)
    
    return data

# Assuming `filtered_data` contains datetime and depth_in_m columns
filtered_data = classify_tides_with_12h_window(filtered_data)

def summarize_tides_data(filtered_data, output_directory, campaign_code, location):
    # Create a summary table grouped by 'expected_tide' column
    summary_table = filtered_data.groupby('expected_tide').agg(
        mean_total_volume_concentration=('total_volume_concentration_ppm', 'mean'),
        mean_diameter=('mean_diameter_um', 'mean'),
        mean_D10=('D10_um', 'mean'),
        mean_D50=('D50_um', 'mean'),
        mean_D90=('D90_um', 'mean'),
        mean_span=('span', 'mean'),
        mean_std_dev=('std_dev_um', 'mean'),
        mean_mode=('mode_um', 'mean')  # Handle volume columns
    ).reset_index()

    # Save the summary table to CSV
    summary_table_path = f"{output_directory}/{deployment_code}-{location}-LISST200x-tide-summary.csv"
    summary_table.to_csv(summary_table_path, index=False)
    
    print(f"Summary table saved to: {summary_table_path}")
    return summary_table

# Generate the summary table
summary_table = summarize_tides_data(filtered_data, output_directory, deployment_code, location)
display(summary_table)

# Generate the plot
volume_concentration_columns = ["1.21_%", "1.6_%", "1.89_%", "2.23_%", "2.63_%", "3.11_%", "3.67_%", "4.33_%", "5.11_%", "6.03_%", "7.11_%", "8.39_%", "9.90_%", 
                                "11.7_%", "13.8_%", "16.3_%", "19.2_%", "22.7_%", "26.7_%", "31.6_%", "37.2_%", "43.9_%", "51.9_%", "61.2_%", 
                                "72.2_%", "85.2_%", "101_%", "119_%", "140_%", "165_%", "195_%", "230_%", "273_%", "324_%", "386_%", "459_%"]

# Subset data based on tide types
high_tide_data = filtered_data[filtered_data['expected_tide'] == 'High Tide']
low_tide_data = filtered_data[filtered_data['expected_tide'] == 'Low Tide']
slack_tide_data = filtered_data[filtered_data['expected_tide'] == 'Slack Tide']
other_data = filtered_data[filtered_data['expected_tide'] == 'Other']

# Calculate total volume concentration for each tide type
high_tide_concentration = high_tide_data[volume_concentration_columns].mean(axis=0)
low_tide_concentration = low_tide_data[volume_concentration_columns].mean(axis=0)
slack_tide_concentration = slack_tide_data[volume_concentration_columns].mean(axis=0)
other_concentration = other_data[volume_concentration_columns].mean(axis=0)

colors = {
    'High Tide': '#ef3b2c',   # Red
    'Low Tide': '#08306b',  # Dark Blue
    'Slack Tide': '#74c476', # Green
    'Other': 'lightgrey', # Grey
}

# Define line styles and alpha for prominence
line_styles = {
    'High Tide': {'linestyle': '-', 'alpha': 1.0},
    'Low Tide': {'linestyle': '-', 'alpha': 1.0},
    'Slack Tide': {'linestyle': '-', 'alpha': 1.0},
    'Other': {'linestyle': '-', 'alpha': 0.5}
}

# Create the plot
fig, ax1 = plt.subplots(figsize=(10, 6), dpi=300)

# Plot the total volume concentration for each tide type
for tide, concentration, label in zip(
    ['High Tide', 'Low Tide', 'Slack Tide', 'Other'],
    [high_tide_concentration, low_tide_concentration, slack_tide_concentration, other_concentration],
    ['High Tide', 'Low Tide', 'Slack Tide', 'Other']
):
    ax1.plot(
        grain_sizes, 
        concentration, 
        color=colors[label], 
        label=label, 
        marker='o', 
        markersize=1, 
        linewidth=1.0, 
        **line_styles[label]
    )

# Set log scale for x-axis and labels for axes
ax1.set_xscale('log')
ax1.set_xlabel('Grain size (µm)', fontsize=12)
ax1.set_ylabel('Total volume concentration (µl/l)', fontsize=12)
ax1.set_title(f'LISST-200x {location} {deployment_code}', fontsize=14)
ax1.grid(axis='both', which='both', linewidth=0.3, linestyle='--', color='gray')

# Calculate cumulative volume distribution for each tide type
cumulative_volumes = {
    'High Tide': np.cumsum(high_tide_concentration),
    'Low Tide': np.cumsum(low_tide_concentration),
    'Slack Tide': np.cumsum(slack_tide_concentration),
    'Other': np.cumsum(other_concentration)
}

# Create secondary axis for the Cumulative Volume Distribution
ax2 = ax1.twinx()
for tide, cumulative_volume in cumulative_volumes.items():
    ax2.plot(
        grain_sizes, 
        cumulative_volume / cumulative_volume[-1] * 100, 
        color=colors[tide], 
        linestyle=':', 
        linewidth=0.8, 
        alpha=line_styles[tide]['alpha'], 
        label=f'Cumulative {tide}'
    )

# Set y-axis label for cumulative volume
ax2.set_ylabel('Cumulative volume (%)', color='k', fontsize=12)
ax2.tick_params(axis='y', labelcolor='k')

# Add legends for the two y-axes
ax1.legend(loc='upper left', fontsize=10, frameon=False)
ax2.legend(loc='upper center', fontsize=10, frameon=False)

# Tight layout for better presentation
plt.tight_layout()

# Show the plot
plt.show()

# Save the graph
graph_file_path = f"{output_directory}/{deployment_code}-{location}-LISST200x-PSD-tidal-conditions.png"
fig.savefig(graph_file_path, bbox_inches='tight')
print(f"Graph saved to {graph_file_path}")

import random

def plot_tide_classified_data(filtered_data):
    # Define colors for each tide type
    tide_colors = {
        "High Tide": "red",
        "Low Tide": "blue",
        "Slack Tide": "green",
        "Other": "grey"
    }
    
    # Randomly select a day from the deployment
    random_day = random.choice(pd.to_datetime(filtered_data['datetime']).dt.date.unique())
    zoomed_data = filtered_data[filtered_data['datetime'].dt.date == random_day]

    # Plotting the zoomed data for the random day
    plt.figure(figsize=(14, 7))
    for tide_type, color in tide_colors.items():
        # Filter data for each tide type
        tide_data = zoomed_data[zoomed_data['expected_tide'] == tide_type]
        plt.scatter(tide_data['datetime'], tide_data['mean_diameter_um'], 
                    color=color, label=tide_type, s=10)
    
    # Formatting the zoomed plot
    plt.title(f'Mean diameter vs time (Zoomed in on {random_day})')
    plt.xlabel('')
    plt.ylabel('Mean diameter (µm)')
    plt.legend()
    plt.grid(True)
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show()
    
plot_tide_classified_data(filtered_data)
                                                        ###