<a href="https://colab.research.google.com/github/manii5228/Aquaculture/blob/main/aquaculture.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')


CNN

In [None]:
import xarray as xr
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.impute import SimpleImputer
from sklearn.metrics import (mean_squared_error, r2_score, mean_absolute_error,
                             accuracy_score, precision_score, recall_score) # Added classification metrics
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout, BatchNormalization
from tensorflow.keras.models import Sequential
from tensorflow.keras.regularizers import l2
import matplotlib.pyplot as plt
import os # Import os to check for file existence

# === Step 1: Load and Preprocess Data ===
def load_and_preprocess(filepath, downsample_factor=4):
    """
    Loads the salinity dataset, preprocesses it by downsampling, flipping,
    imputing missing values, and standardizing.
    Prepares X (input images) and y (target mean salinity for the next step).

    Args:
        filepath (str): Path to the NetCDF file containing salinity data.
        downsample_factor (int): Factor by which to downsample the spatial dimensions.

    Returns:
        tuple: (X, y, scaler) where X is the preprocessed input,
               y is the target output, and scaler is the StandardScaler object.
    """
    # Check if the file exists before attempting to open
    if not os.path.exists(filepath):
        print(f"Error: Dataset file not found at '{filepath}'")
        print("Please ensure the 'salinity.nc' file is in the specified path.")
        # Exit or raise an error as the program cannot proceed without the file
        raise FileNotFoundError(f"Dataset file not found at {filepath}")

    print(f"Loading dataset from: {filepath}")
    ds = xr.open_dataset(filepath, decode_times=False)
    # Select 'SALT' variable, take the first depth layer, and convert to numpy array
    salt = ds['SALT'].values[:, 0, :, :]  # Remove depth dimension

    print(f"Original data shape: {salt.shape}")

    # Downsample and flip vertically to correct orientation
    # Downsampling reduces spatial dimensions by the downsample_factor
    salt = salt[:, ::downsample_factor, ::downsample_factor]
    # Flipping along axis=1 (latitude) to correct map orientation
    salt = np.flip(salt, axis=1)

    print(f"Downsampled and flipped data shape: {salt.shape}")

    # Replace NaNs with interpolation (mean imputation per time slice)
    # Iterate through each time slice and apply SimpleImputer
    print("Imputing missing values...")
    for i in range(len(salt)):
        # Reshape to 2D for SimpleImputer, then reshape back
        imputed_slice = SimpleImputer(strategy='mean').fit_transform(salt[i].reshape(-1, 1)).reshape(salt[i].shape)
        salt[i] = imputed_slice
    print("Missing values imputed.")

    # Standardize globally (not per-slice) for consistent scaling across the entire dataset
    print("Standardizing data...")
    scaler = StandardScaler()
    # Reshape the 3D array (time, lat, lon) into 2D (total_elements, 1) for fitting the scaler
    salt_reshaped = salt.reshape(-1, salt.shape[-1] * salt.shape[-2])
    # Fit and transform the data, then reshape back to original 3D shape
    salt_scaled = scaler.fit_transform(salt_reshaped).reshape(salt.shape)
    print("Data standardization complete.")

    # Prepare X (input features) and y (target labels)
    # X will be the salinity map at time t
    X = salt_scaled[:-1]
    # y will be the mean salinity of the map at time t+1
    y = np.mean(salt_scaled[1:], axis=(1, 2))

    # Add channel dimension for CNN: (batch_size, height, width, channels)
    # For grayscale images, channels = 1
    X = X[..., np.newaxis]
    print(f"Final X shape for CNN: {X.shape}")
    print(f"Final y shape: {y.shape}")
    return X, y, scaler

# Define the path to your dataset.
# IMPORTANT: Ensure 'salinity.nc' is present at this path in your environment.
# If running in Colab, this means it needs to be in your mounted Google Drive.
DATASET_FILEPATH = "/content/drive/MyDrive/salinity.nc"

# === Step 2: Downsample and Split Data ===
try:
    X, y, scaler = load_and_preprocess(
        DATASET_FILEPATH,
        downsample_factor=4  # Original 410x720 â†’ Downsampled to 102x180
    )
except FileNotFoundError as e:
    print(e)
    print("Exiting program. Please provide the correct dataset path.")
    exit() # Exit if the file is not found

# Split data into Train/Validation/Test sets
# Training: 60%, Validation: 20%, Test: 20%
print("\nSplitting data into training, validation, and test sets (60/20/20)...")
X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.4, random_state=42)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42)

print(f"X_train shape: {X_train.shape}, y_train shape: {y_train.shape}")
print(f"X_val shape: {X_val.shape}, y_val shape: {y_val.shape}")
print(f"X_test shape: {X_test.shape}, y_test shape: {y_test.shape}")


# === Step 3: Build CNN Model ===
def build_model(input_shape):
    """
    Builds a Convolutional Neural Network (CNN) model for regression.

    Args:
        input_shape (tuple): Shape of the input data (height, width, channels).

    Returns:
        tf.keras.Model: Compiled CNN model.
    """
    print("\nBuilding CNN model...")
    model = Sequential([
        # First Convolutional Block
        Conv2D(32, (3, 3), activation='relu', input_shape=input_shape, kernel_regularizer=l2(0.01), padding='same'),
        BatchNormalization(), # Normalizes activations of previous layer
        MaxPooling2D((2, 2)), # Reduces spatial dimensions, helps in extracting dominant features

        # Second Convolutional Block
        Conv2D(64, (3, 3), activation='relu', kernel_regularizer=l2(0.01), padding='same'),
        BatchNormalization(),
        MaxPooling2D((2, 2)),

        Flatten(), # Flattens the 2D feature maps into a 1D vector for Dense layers

        # Dense (Fully Connected) Layers
        Dense(32, activation='relu'),
        Dropout(0.2), # Randomly sets a fraction of input units to 0 at each update during training, prevents overfitting
        Dense(1) # Output layer for regression, single unit with linear activation
    ])

    # Compile the model
    # Optimizer: Adam is a popular choice for its efficiency
    # Loss: Mean Squared Error (MSE) is common for regression tasks
    # Metrics: Mean Absolute Error (MAE) provides another view on error magnitude
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), loss='mse', metrics=['mae'])
    print("CNN model built and compiled.")
    model.summary() # Print model summary to see layer details
    return model

# Build the model using the input shape derived from the training data
model = build_model(X_train.shape[1:])

# === Step 4: Train with Early Stopping ===
print("\nStarting model training...")
# EarlyStopping callback monitors a validation metric and stops training
# when the metric has stopped improving for a specified number of epochs (patience).
# restore_best_weights ensures that the model weights are set to the epoch with the best monitored value.
early_stopping = tf.keras.callbacks.EarlyStopping(patience=5, restore_best_weights=True, monitor='val_loss')

history = model.fit(
    X_train, y_train,
    epochs=50, # Maximum number of epochs
    batch_size=64, # Number of samples per gradient update
    validation_data=(X_val, y_val), # Data on which to evaluate the loss and any model metrics at the end of each epoch
    callbacks=[early_stopping], # List of callbacks to apply during training
    verbose=1 # Show training progress bar
)
print("Model training finished.")


# === Step 5: Evaluate on Test Set ===
print("\nEvaluating model on test set...")
# Predict on the test set
y_pred_raw = model.predict(X_test)
y_pred = y_pred_raw.flatten() # Flatten predictions to match y_test shape

# --- Regression Metrics ---
mse = mean_squared_error(y_test, y_pred)
rmse = np.sqrt(mse)
mae = mean_absolute_error(y_test, y_pred) # Calculate MAE explicitly
r2 = r2_score(y_test, y_pred)

print(f"\nðŸ“Š Final Regression Test Metrics:")
print(f"âœ… MSE (Mean Squared Error): {mse:.6f}")
print(f"âœ… RMSE (Root Mean Squared Error): {rmse:.6f}")
print(f"âœ… MAE (Mean Absolute Error): {mae:.6f}")
print(f"âœ… RÂ² (R-squared): {r2:.4f}")

# --- Classification Metrics (for Demonstration Purposes Only) ---
print("\nðŸ”„ Classification Metrics (Derived for Demonstration - See Note Below):")

# Define a threshold to convert regression output to binary classification.
# Using the median of the true test values as a simple example threshold.
# Values > threshold are class 1 (e.g., "high salinity"), else class 0 ("not high salinity").
classification_threshold = np.median(y_test)
print(f"  (Using classification threshold: {classification_threshold:.4f} based on median of true test values)")

# Convert true and predicted continuous values to binary classes based on the threshold
y_test_binary = (y_test > classification_threshold).astype(int)
y_pred_binary = (y_pred > classification_threshold).astype(int)

# Calculate classification metrics
accuracy = accuracy_score(y_test_binary, y_pred_binary)
precision = precision_score(y_test_binary, y_pred_binary, zero_division=0) # zero_division=0 handles cases with no positive predictions
recall = recall_score(y_test_binary, y_pred_binary, zero_division=0)

print(f"âœ… Accuracy:  {accuracy:.4f}")
print(f"âœ… Precision: {precision:.4f}")
print(f"âœ… Recall:    {recall:.4f}")

print("\n--- Important Note on Classification Metrics ---")
print("The 'Accuracy', 'Precision', and 'Recall' metrics above are calculated by converting the continuous regression "
      "outputs into binary classes using an arbitrary threshold (the median of true test values in this case).")
print("This transformation is done *only for demonstration* of these metrics. Your model's primary task is regression, "
      "and its performance should be primarily judged by MSE, RMSE, MAE, and R-squared.")
print("Choosing a different threshold would likely change these classification metric values.")

# === Step 6: Visualize Results ===
print("\nGenerating plots for visualization...")
plt.figure(figsize=(14, 6)) # Increased figure size for better readability

# Plot True vs Predicted values
plt.subplot(1, 2, 1)
plt.plot(y_test, label='True Mean SALT', marker='o', linestyle='-', color='blue', markersize=4)
plt.plot(y_pred, label='Predicted Mean SALT', marker='x', linestyle='--', color='red', markersize=4)
plt.title("True vs. Predicted Mean Salinity (Standardized)")
plt.xlabel("Time Step Index")
plt.ylabel("Mean Salinity (Standardized)")
plt.legend()
plt.grid(True)
plt.tight_layout() # Adjust subplot parameters for a tight layout

# Plot a sample input image
plt.subplot(1, 2, 2)
# Display the first image from the test set, removing the channel dimension for imshow
# 'origin='lower'' ensures that the image is displayed with (0,0) at the bottom-left,
# which is common for scientific data plots.
plt.imshow(X_test[0, :, :, 0], cmap='viridis', origin='lower')
plt.title("Sample Input Salinity Map (Corrected Orientation)")
plt.colorbar(label="Standardized Salinity Value")
plt.xlabel("Longitude Index")
plt.ylabel("Latitude Index")
plt.tight_layout() # Adjust subplot parameters for a tight layout

plt.show()

print("\nScript execution complete. Please check the plots for visualization.")

CNN-LSTM

In [None]:
import xarray as xr
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.impute import SimpleImputer
from sklearn.metrics import (mean_squared_error, r2_score, mean_absolute_error,
                             accuracy_score, precision_score, recall_score) # Added classification metrics
import tensorflow as tf
from tensorflow.keras.layers import (Conv2D, MaxPooling2D, Flatten,
                                     Dense, Dropout, BatchNormalization,
                                     LSTM, TimeDistributed)
from tensorflow.keras.models import Sequential
from tensorflow.keras.regularizers import l2
import matplotlib.pyplot as plt
import os

# === Step 1: Enhanced Data Loading and Preprocessing ===
def load_and_preprocess(filepath, downsample_factor=4, seq_length=3):
    """
    Loads the salinity dataset, preprocesses it, creates temporal sequences,
    and standardizes the data.

    Args:
        filepath (str): Path to the NetCDF file containing salinity data.
        downsample_factor (int): Factor by which to downsample the spatial dimensions.
        seq_length (int): Number of time steps to use as input sequence for prediction.

    Returns:
        tuple: (X, y, scaler) where X is the preprocessed input sequences,
               y is the target mean salinity, and scaler is the StandardScaler object.
    """
    if not os.path.exists(filepath):
        print(f"Error: Dataset file not found at '{filepath}'")
        print("Please ensure the 'salinity.nc' file is in the specified path.")
        raise FileNotFoundError(f"Dataset file not found at {filepath}")

    print(f"Loading dataset from: {filepath}")
    ds = xr.open_dataset(filepath, decode_times=False)
    salt = ds['SALT'].values[:, 0, :, :]

    print(f"Original data shape: {salt.shape}")

    salt = salt[:, ::downsample_factor, ::downsample_factor]
    salt = np.flip(salt, axis=1)

    print(f"Downsampled and flipped data shape: {salt.shape}")

    print("Imputing missing values using median strategy...")
    salt = np.array([SimpleImputer(strategy='median').fit_transform(s.reshape(-1, 1)).reshape(s.shape)
                     for s in salt])
    print("Missing values imputed.")

    print("Standardizing data globally...")
    scaler = StandardScaler()
    salt_flat = salt.reshape(-1, salt.shape[1] * salt.shape[2])
    salt_scaled = scaler.fit_transform(salt_flat).reshape(salt.shape)
    print("Data standardization complete.")

    print(f"Creating temporal sequences with length: {seq_length}...")
    X_seq, y_target = [], []
    for i in range(len(salt_scaled) - seq_length):
        X_seq.append(salt_scaled[i:i+seq_length])
        y_target.append(np.mean(salt_scaled[i+seq_length]))

    X_seq = np.array(X_seq)
    y_target = np.array(y_target)

    X_seq = X_seq[..., np.newaxis]

    print(f"Final X (input sequences) shape: {X_seq.shape}")
    print(f"Final y (target mean salinity) shape: {y_target.shape}")

    if np.var(y_target) < 0.01:
        print("Warning: Target variance is very low. This might indicate that predicting the mean salinity "
              "is too simple or there's not enough variability to learn. Consider predicting "
              "spatial patterns or specific points instead.")

    return X_seq, y_target, scaler

# Define the path to your dataset.
DATASET_FILEPATH = "/content/drive/MyDrive/salinity.nc"
SEQUENCE_LENGTH = 3

# === Step 2: Data Preparation ===
try:
    X, y, scaler = load_and_preprocess(
        DATASET_FILEPATH,
        downsample_factor=4,
        seq_length=SEQUENCE_LENGTH
    )
except FileNotFoundError as e:
    print(e)
    print("Exiting program. Please provide the correct dataset path.")
    exit()

print("\nSplitting data into training, validation, and test sets with temporal ordering (60/20/20)...")
train_size = int(0.6 * len(X))
val_size = int(0.2 * len(X))

X_train, y_train = X[:train_size], y[:train_size]
X_val, y_val = X[train_size:train_size+val_size], y[train_size:train_size+val_size]
X_test, y_test = X[train_size+val_size:], y[train_size+val_size:]

print(f"X_train shape: {X_train.shape}, y_train shape: {y_train.shape}")
print(f"X_val shape: {X_val.shape}, y_val shape: {y_val.shape}")
print(f"X_test shape: {X_test.shape}, y_test shape: {y_test.shape}")


# === Step 3: Hybrid CNN-LSTM Model ===
def build_model(input_shape):
    """
    Builds a Hybrid CNN-LSTM model for sequence prediction.
    """
    print("\nBuilding Hybrid CNN-LSTM model...")
    model = Sequential([
        TimeDistributed(Conv2D(32, (3,3), activation='relu', kernel_regularizer=l2(0.001), padding='same'), input_shape=input_shape),
        TimeDistributed(BatchNormalization()),
        TimeDistributed(MaxPooling2D((2,2))),

        TimeDistributed(Conv2D(64, (3,3), activation='relu', kernel_regularizer=l2(0.001), padding='same')),
        TimeDistributed(BatchNormalization()),
        TimeDistributed(MaxPooling2D((2,2))),

        # Added another Conv2D layer for more feature extraction capacity
        TimeDistributed(Conv2D(128, (3,3), activation='relu', kernel_regularizer=l2(0.001), padding='same')),
        TimeDistributed(BatchNormalization()),
        TimeDistributed(MaxPooling2D((2,2))),

        TimeDistributed(Flatten()),

        # Increased LSTM units for more capacity
        LSTM(256, return_sequences=False, kernel_regularizer=l2(0.001)), # Changed from 128 to 256

        # Increased Dense units for more capacity
        Dense(128, activation='relu', kernel_regularizer=l2(0.01)), # Changed from 64 to 128
        # Adjusted Dropout rate for stronger regularization
        Dropout(0.5), # Changed from 0.4
        Dense(1)
    ])

    model.compile(
        # Adjusted Adam learning rate
        optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), # Changed from 0.0005
        loss='mse',
        metrics=['mae']
    )
    print("Hybrid CNN-LSTM model built and compiled.")
    model.summary()
    return model

model = build_model(X_train.shape[1:])

# === Step 4: Training with Callbacks ===
print("\nStarting model training...")
callbacks = [
    # Increased patience for EarlyStopping
    tf.keras.callbacks.EarlyStopping(patience=15, restore_best_weights=True, monitor='val_loss'), # Changed from 10
    tf.keras.callbacks.ReduceLROnPlateau(factor=0.5, patience=5, monitor='val_loss', min_lr=0.00001)
]

history = model.fit(
    X_train, y_train,
    epochs=100,
    # Adjusted batch size
    batch_size=32,
    validation_data=(X_val, y_val),
    callbacks=callbacks,
    verbose=1
)
print("Model training finished.")

# === Step 5: Evaluation ===
print("\nEvaluating model on test set...")
y_pred_raw = model.predict(X_test)
y_pred = y_pred_raw.flatten()

# --- Regression Metrics ---
mse = mean_squared_error(y_test, y_pred)
rmse = np.sqrt(mse)
mae = mean_absolute_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)

print(f"\nðŸ“Š Regression Test Metrics:")
print(f"âœ… MSE (Mean Squared Error): {mse:.6f}")
print(f"âœ… RMSE (Root Mean Squared Error): {rmse:.6f}")
print(f"âœ… MAE (Mean Absolute Error): {mae:.6f}")
print(f"âœ… RÂ² (R-squared): {r2:.4f}")

# --- Classification Metrics (for Demonstration Purposes Only) ---
print("\nðŸ”„ Classification Metrics (Derived for Demonstration - See Note Below):")

classification_threshold = np.median(y_test)
print(f"  (Using classification threshold: {classification_threshold:.4f} based on median of true test values)")

y_test_binary = (y_test > classification_threshold).astype(int)
y_pred_binary = (y_pred > classification_threshold).astype(int)

accuracy = accuracy_score(y_test_binary, y_pred_binary)
precision = precision_score(y_test_binary, y_pred_binary, zero_division=0)
recall = recall_score(y_test_binary, y_pred_binary, zero_division=0)

print(f"âœ… Accuracy:  {accuracy:.4f}")
print(f"âœ… Precision: {precision:.4f}")
print(f"âœ… Recall:    {recall:.4f}")

print("\n--- Important Note on Classification Metrics ---")
print("The 'Accuracy', 'Precision', and 'Recall' metrics above are calculated by converting the continuous regression "
      "outputs into binary classes using an arbitrary threshold (the median of true test values in this case).")
print("This transformation is done *only for demonstration* of these metrics. Your model's primary task is regression, "
      "and its performance should be primarily judged by MSE, RMSE, MAE, and R-squared.")
print("Choosing a different threshold would likely change these classification metric values.")


# === Step 6: Enhanced Visualization ===
print("\nGenerating plots for visualization...")
plt.figure(figsize=(15, 6))

# Subplot 1: True vs. Predicted Values Scatter Plot
plt.subplot(1, 2, 1)
plt.scatter(y_test, y_pred, alpha=0.6, color='blue', edgecolors='w', s=50)
plt.plot([min(y_test.min(), y_pred.min()), max(y_test.max(), y_pred.max())],
         [min(y_test.min(), y_pred.min()), max(y_test.max(), y_pred.max())],
         'r--', label='Perfect Prediction')
plt.xlabel("True Values (Standardized Mean Salinity)", fontsize=12)
plt.ylabel("Predictions (Standardized Mean Salinity)", fontsize=12)
plt.title("Prediction Accuracy: True vs. Predicted Values", fontsize=14)
plt.legend()
plt.grid(True, linestyle='--', alpha=0.7)
plt.axis('equal')

# --- Plot True vs. Predicted Mean Salinity (Standardized) ---
plt.figure(figsize=(10, 6)) # Adjust figure size as needed for a single plot

plt.plot(y_test, label='True Mean SALT', marker='o', linestyle='-', color='blue', markersize=4)
plt.plot(y_pred, label='Predicted Mean SALT', marker='x', linestyle='--', color='red', markersize=4)
plt.title("True vs. Predicted Mean Salinity (Standardized)")
plt.xlabel("Time Step Index")
plt.ylabel("Mean Salinity (Standardized)")
plt.legend()
plt.grid(True)
plt.tight_layout() # Adjust subplot parameters for a tight layout

plt.show()
print("\nScript execution complete. Please check the console output for metrics and the plots for visualization.")


CNN-UNET

In [None]:
import xarray as xr
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import TimeSeriesSplit
from sklearn.impute import SimpleImputer
from sklearn.metrics import (mean_squared_error, r2_score, mean_absolute_error,
                             accuracy_score, precision_score, recall_score) # Added classification metrics
import tensorflow as tf
from tensorflow.keras.layers import (Conv2D, MaxPooling2D, UpSampling2D,
                                     Concatenate, Input, BatchNormalization, Cropping2D) # Import Cropping2D
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
import matplotlib.pyplot as plt
import os

# === 1. Data Loading with Shape Control ===
def load_data(filepath, target_shape=(100, 180)):
    """
    Loads the salinity dataset, preprocesses it by downsampling, correcting orientation,
    imputing NaNs, and standardizing.

    Args:
        filepath (str): Path to the NetCDF file containing salinity data.
        target_shape (tuple): Desired (height, width) for the spatial dimensions.

    Returns:
        tuple: (salt_data, scaler) where salt_data is the preprocessed and shaped
               salinity data, and scaler is the StandardScaler object.
    """
    if not os.path.exists(filepath):
        print(f"Error: Dataset file not found at '{filepath}'")
        print("Please ensure the 'salinity.nc' file is in the specified path.")
        return None, None # Return None to indicate failure

    print(f"Loading dataset from: {filepath}")
    ds = xr.open_dataset(filepath, decode_times=False)
    salt = ds['SALT'].values[:, 0, :, :]  # Select 'SALT', remove depth dim

    print(f"Original data shape: {salt.shape}")

    # Calculate current dimensions
    h_orig, w_orig = salt.shape[1], salt.shape[2]

    # Calculate downsample factors to reach target_shape
    # This approach handles cases where original dimensions are not direct multiples
    downsample_factor_h = h_orig // target_shape[0]
    downsample_factor_w = w_orig // target_shape[1]

    # Ensure factors are at least 1 to avoid upsampling or division by zero
    downsample_factor_h = max(1, downsample_factor_h)
    downsample_factor_w = max(1, downsample_factor_w)

    # Downsample using calculated factors
    salt = salt[:, ::downsample_factor_h, ::downsample_factor_w]

    # Crop to exactly target_shape if dimensions are slightly larger after downsampling
    # This might happen if original dimensions are not perfect multiples
    if salt.shape[1] > target_shape[0]:
        salt = salt[:, :target_shape[0], :]
    if salt.shape[2] > target_shape[1]:
        salt = salt[:, :, :target_shape[1]]


    salt = np.flip(salt, axis=1)  # Correct orientation (e.g., flip latitude)

    print(f"Downsampled and flipped data shape: {salt.shape}")

    print("Imputing missing values using median strategy...")
    # Apply median imputation slice by slice
    salt_imputed = []
    for s_slice in salt:
        imputer = SimpleImputer(strategy='median')
        # Reshape for imputer (2D array, column-wise imputation) and then reshape back
        imputed_slice = imputer.fit_transform(s_slice.reshape(-1, 1)).reshape(s_slice.shape)
        salt_imputed.append(imputed_slice)
    salt = np.array(salt_imputed)
    print("Missing values imputed.")

    print("Standardizing data globally...")
    scaler = StandardScaler()
    # Reshape the 3D array (time, height, width) to 2D (total_elements, 1) for scaler
    salt_reshaped_for_scaler = salt.reshape(-1, target_shape[0] * target_shape[1])
    salt_scaled = scaler.fit_transform(salt_reshaped_for_scaler)
    # Reshape back to original 3D structure for spatial data
    salt_data = salt_scaled.reshape(-1, *target_shape)

    # Add channel dimension for CNN input (batch, height, width, channels)
    salt_data = salt_data[..., np.newaxis]
    print(f"Final preprocessed data shape: {salt_data.shape}")
    print("Data standardization complete.")

    return salt_data, scaler

# === 2. U-Net with Shape Matching ===
def build_unet(input_shape):
    """
    Builds a U-Net model for image-to-image regression.
    The U-Net architecture is designed to capture features at multiple scales
    and produce a segmentation-like output (a predicted map in this case).

    Args:
        input_shape (tuple): Shape of the input images (height, width, channels).

    Returns:
        tf.keras.Model: Compiled U-Net model.
    """
    inputs = Input(input_shape)

    # Encoder (Downsampling Path)
    # Block 1
    c1 = Conv2D(32, (3,3), activation='relu', padding='same')(inputs)
    c1 = BatchNormalization()(c1)
    c1 = Conv2D(32, (3,3), activation='relu', padding='same')(c1)
    c1 = BatchNormalization()(c1)
    p1 = MaxPooling2D((2,2))(c1)
    print(f"Shape after Block 1: {p1.shape}")

    # Block 2
    c2 = Conv2D(64, (3,3), activation='relu', padding='same')(p1)
    c2 = BatchNormalization()(c2)
    c2 = Conv2D(64, (3,3), activation='relu', padding='same')(c2)
    c2 = BatchNormalization()(c2)
    p2 = MaxPooling2D((2,2))(c2)
    print(f"Shape after Block 2: {p2.shape}")


    # Block 3
    c3 = Conv2D(128, (3,3), activation='relu', padding='same')(p2)
    c3 = BatchNormalization()(c3)
    c3 = Conv2D(128, (3,3), activation='relu', padding='same')(c3)
    c3 = BatchNormalization()(c3)
    p3 = MaxPooling2D((2,2))(c3)
    print(f"Shape after Block 3: {p3.shape}")


    # Bridge (Bottleneck)
    c4 = Conv2D(256, (3,3), activation='relu', padding='same')(p3)
    c4 = BatchNormalization()(c4)
    c4 = Conv2D(256, (3,3), activation='relu', padding='same')(c4)
    c4 = BatchNormalization()(c4)
    print(f"Shape after Bridge: {c4.shape}")

    # Decoder (Upsampling Path)
    # Block 5
    u5 = UpSampling2D((2,2))(c4)
    print(f"Shape after UpSampling 1: {u5.shape}")

    # Crop c3 to match the spatial dimensions of u5 before concatenation
    shape_diff_h_c3_u5 = c3.shape[1] - u5.shape[1]
    shape_diff_w_c3_u5 = c3.shape[2] - u5.shape[2]
    crop_top_c3_u5 = shape_diff_h_c3_u5 // 2
    crop_bottom_c3_u5 = shape_diff_h_c3_u5 - crop_top_c3_u5
    crop_left_c3_u5 = shape_diff_w_c3_u5 // 2
    crop_right_c3_u5 = shape_diff_w_c3_u5 - crop_left_c3_u5
    c3_cropped = Cropping2D(cropping=((crop_top_c3_u5, crop_bottom_c3_u5), (crop_left_c3_u5, crop_right_c3_u5)))(c3)
    print(f"Shape of c3_cropped for concat 1: {c3_cropped.shape}")
    u5 = Concatenate()([u5, c3_cropped])
    print(f"Shape after Concat 1: {u5.shape}")
    c5 = Conv2D(128, (3,3), activation='relu', padding='same')(u5)
    c5 = BatchNormalization()(c5)
    c5 = Conv2D(128, (3,3), activation='relu', padding='same')(c5)
    c5 = BatchNormalization()(c5)
    print(f"Shape after Conv Block 5: {c5.shape}")


    # Block 6
    u6 = UpSampling2D((2,2))(c5)
    print(f"Shape after UpSampling 2: {u6.shape}")

    # Crop c2 to match the spatial dimensions of u6 before concatenation
    shape_diff_h_c2_u6 = c2.shape[1] - u6.shape[1]
    shape_diff_w_c2_u6 = c2.shape[2] - u6.shape[2]
    crop_top_c2_u6 = shape_diff_h_c2_u6 // 2
    crop_bottom_c2_u6 = shape_diff_h_c2_u6 - crop_top_c2_u6
    crop_left_c2_u6 = shape_diff_w_c2_u6 // 2
    crop_right_c2_u6 = shape_diff_w_c2_u6 - crop_left_c2_u6
    c2_cropped = Cropping2D(cropping=((crop_top_c2_u6, crop_bottom_c2_u6), (crop_left_c2_u6, crop_right_c2_u6)))(c2)
    print(f"Shape of c2_cropped for concat 2: {c2_cropped.shape}")
    u6 = Concatenate()([u6, c2_cropped])
    print(f"Shape after Concat 2: {u6.shape}")
    c6 = Conv2D(64, (3,3), activation='relu', padding='same')(u6)
    c6 = BatchNormalization()(c6)
    c6 = Conv2D(64, (3,3), activation='relu', padding='same')(c6)
    c6 = BatchNormalization()(c6)
    print(f"Shape after Conv Block 6: {c6.shape}")


    # Block 7
    u7 = UpSampling2D((2,2))(c6)
    print(f"Shape after UpSampling 3: {u7.shape}")

    # Crop c1 to match the spatial dimensions of u7 before concatenation
    shape_diff_h_c1_u7 = c1.shape[1] - u7.shape[1]
    shape_diff_w_c1_u7 = c1.shape[2] - u7.shape[2]
    crop_top_c1_u7 = shape_diff_h_c1_u7 // 2
    crop_bottom_c1_u7 = shape_diff_h_c1_u7 - crop_top_c1_u7
    crop_left_c1_u7 = shape_diff_w_c1_u7 // 2
    crop_right_c1_u7 = shape_diff_w_c1_u7 - crop_left_c1_u7
    c1_cropped = Cropping2D(cropping=((crop_top_c1_u7, crop_bottom_c1_u7), (crop_left_c1_u7, crop_right_c1_u7)))(c1)
    print(f"Shape of c1_cropped for concat 3: {c1_cropped.shape}")
    u7 = Concatenate()([u7, c1_cropped])
    print(f"Shape after Concat 3: {u7.shape}")
    c7 = Conv2D(32, (3,3), activation='relu', padding='same')(u7)
    c7 = BatchNormalization()(c7)
    c7 = Conv2D(32, (3,3), activation='relu', padding='same')(c7)
    c7 = BatchNormalization()(c7)
    print(f"Shape after Conv Block 7: {c7.shape}")


    # Output layer: 1x1 convolution to produce a single channel output (for salinity map)
    outputs = Conv2D(1, (1,1), activation='linear')(c7)
    print(f"Shape after Output Layer: {outputs.shape}")


    model = Model(inputs=[inputs], outputs=[outputs])
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0005), loss='mse', metrics=['mae'])
    print("U-Net model built and compiled.")
    model.summary()
    return model

# Define the path to your dataset.
DATASET_FILEPATH = "/content/drive/MyDrive/salinity.nc"
TARGET_IMAGE_SHAPE = (96,176) # Desired height x width for images
SEQUENCE_LENGTH = 3 # Number of past frames to consider for predicting the next frame

# === 3. Data Preparation ===
salt_data, scaler = load_data(DATASET_FILEPATH, target_shape=TARGET_IMAGE_SHAPE)
if salt_data is None:
    raise ValueError("Failed to load data. Please check file path and contents.")

# Create sequences: X is the last frame of a sequence, y is the frame to predict
X, y = [], []
for i in range(len(salt_data) - SEQUENCE_LENGTH):
    # Input X: The last frame of the sequence (t, t+1, t+2 -> use t+2 to predict t+3)
    X.append(salt_data[i + SEQUENCE_LENGTH - 1])
    # Target y: The frame immediately following the sequence (t+3)
    y.append(salt_data[i + SEQUENCE_LENGTH])

X = np.array(X)
y = np.array(y)

print(f"X (input frames) shape: {X.shape}") # Should be (n_samples, H, W, C)
print(f"y (target frames) shape: {y.shape}") # Should be (n_samples, H, W, C)

# Time-series cross-validation
# We use TimeSeriesSplit to ensure no data leakage from future time steps
tscv = TimeSeriesSplit(n_splits=2) # Using 2 splits for demonstration; adjust for more robust evaluation

# Initialize lists to store metrics from each fold
all_mse, all_rmse, all_mae, all_r2 = [], [], [], []
all_accuracy, all_precision, all_recall = [], [], []

for fold, (train_index, test_index) in enumerate(tscv.split(X)):
    print(f"\n=== Fold {fold+1} ===")
    X_train, X_test = X[train_index], X[test_index]
    y_train, y_test = y[train_index], y[test_index]

    print(f"Training shapes - X: {X_train.shape}, y: {y_train.shape}")
    print(f"Testing shapes - X: {X_test.shape}, y: {y_test.shape}")

    # === 4. Model Training ===
    # Input shape to U-Net is (height, width, channels)
    model = build_unet(X_train.shape[1:])

    history = model.fit(
        X_train, y_train,
        epochs=100, # Increased epochs
        batch_size=8, # Adjusted batch size for U-Net
        validation_split=0.1, # Use a small validation split from training data
        callbacks=[
            EarlyStopping(patience=20, restore_best_weights=True, monitor='val_loss'), # Increased patience
            ReduceLROnPlateau(factor=0.5, patience=10, monitor='val_loss', min_lr=0.000001) # Increased patience
        ],
        verbose=1
    )

    # === 5. Enhanced Evaluation ===
    print("\nEvaluating model on test set...")
    y_pred = model.predict(X_test)

    # Flatten the image data for metric calculation (needed for sklearn metrics)
    y_test_flat = y_test.reshape(-1)
    y_pred_flat = y_pred.reshape(-1)

    # --- Regression Metrics ---
    mse = mean_squared_error(y_test_flat, y_pred_flat)
    rmse = np.sqrt(mse)
    mae = mean_absolute_error(y_test_flat, y_pred_flat)
    r2 = r2_score(y_test_flat, y_pred_flat)

    all_mse.append(mse)
    all_rmse.append(rmse)
    all_mae.append(mae)
    all_r2.append(r2)

    print(f"\nðŸ“Š Fold {fold+1} Regression Test Metrics:")
    print(f"âœ… MSE (Mean Squared Error): {mse:.6f}")
    print(f"âœ… RMSE (Root Mean Squared Error): {rmse:.6f}")
    print(f"âœ… MAE (Mean Absolute Error): {mae:.6f}")
    print(f"âœ… RÂ² (R-squared): {r2:.4f}")

    # --- Classification Metrics (for Demonstration Purposes Only) ---
    print("\nðŸ”„ Fold {fold+1} Classification Metrics (Derived for Demonstration - See Note Below):")

    # Define a threshold to convert regression output to binary classification.
    # Using the median of the true test values (flattened) as a simple example threshold.
    classification_threshold = np.median(y_test_flat)
    print(f"  (Using classification threshold: {classification_threshold:.4f} based on median of true test values)")

    # Convert true and predicted continuous values to binary classes based on the threshold
    y_test_binary = (y_test_flat > classification_threshold).astype(int)
    y_pred_binary = (y_pred_flat > classification_threshold).astype(int)

    # Calculate classification metrics
    accuracy = accuracy_score(y_test_binary, y_pred_binary)
    precision = precision_score(y_test_binary, y_pred_binary, zero_division=0) # zero_division=0 handles cases with no positive predictions
    recall = recall_score(y_test_binary, y_pred_binary, zero_division=0)

    all_accuracy.append(accuracy)
    all_precision.append(precision)
    all_recall.append(recall)

    print(f"âœ… Accuracy:  {accuracy:.4f}")
    print(f"âœ… Precision: {precision:.4f}")
    print(f"âœ… Recall:    {recall:.4f}")

    print("\n--- Important Note on Classification Metrics ---")
    print("The 'Accuracy', 'Precision', and 'Recall' metrics above are calculated by converting the continuous regression "
          "outputs into binary classes using an arbitrary threshold (the median of true test values in this case).")
    print("This transformation is done *only for demonstration* of these metrics. Your model's primary task is regression, "
          "and its performance should be primarily judged by MSE, RMSE, MAE, and R-squared.")
    print("Choosing a different threshold would likely change these classification metric values.")

    # === 6. Enhanced Visualization ===
    print("\nGenerating plots for visualization...")
    plt.figure(figsize=(18, 12)) # Larger figure for multiple plots

    # Plot 1: True vs Predicted Plot (flattened)
    plt.subplot(2, 3, 1)
    plt.scatter(y_test_flat, y_pred_flat, alpha=0.1, s=5) # Reduced alpha and size for dense scatter
    plt.plot([y_test_flat.min(), y_test_flat.max()], [y_test_flat.min(), y_test_flat.max()], 'r--', label='Perfect Prediction')
    plt.title(f"Fold {fold+1}: True vs Predicted Values\n(RÂ²: {r2:.4f})", fontsize=14)
    plt.xlabel("True Values", fontsize=12)
    plt.ylabel("Predictions", fontsize=12)
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.legend()
    # plt.axis('equal') # Ensure aspect ratio is equal - Removed as it distorts view with varied data ranges

    # Plot 2: Error Distribution
    plt.subplot(2, 3, 2)
    errors = y_test_flat - y_pred_flat
    plt.hist(errors, bins=50, color='skyblue', edgecolor='black')
    plt.title(f"Fold {fold+1}: Prediction Error Distribution\n(MAE: {mae:.4f})", fontsize=14)
    plt.xlabel("Prediction Error (True - Predicted)", fontsize=12)
    plt.ylabel("Frequency", fontsize=12)
    plt.grid(True, linestyle='--', alpha=0.7)

    # Plot 3: Sample True Salinity Map
    plt.subplot(2, 3, 3)
    if y_test.shape[0] > 0: # Ensure there's at least one sample
        plt.imshow(y_test[0,:,:,0], cmap='viridis', origin='lower')
        plt.title(f"Fold {fold+1}: Sample True Map", fontsize=14)
        plt.colorbar(label="Standardized Salinity")
        plt.xlabel("Longitude Index")
        plt.ylabel("Latitude Index")
    else:
        plt.text(0.5, 0.5, "No test data for sample map", horizontalalignment='center', verticalalignment='center', transform=plt.gca().transAxes)


    # Plot 4: Sample Predicted Salinity Map
    plt.subplot(2, 3, 4)
    if y_pred.shape[0] > 0: # Ensure there's at least one sample
        plt.imshow(y_pred[0,:,:,0], cmap='viridis', origin='lower')
        plt.title(f"Fold {fold+1}: Sample Predicted Map", fontsize=14)
        plt.colorbar(label="Standardized Salinity")
        plt.xlabel("Longitude Index")
        plt.ylabel("Latitude Index")
    else:
        plt.text(0.5, 0.5, "No test data for sample map", horizontalalignment='center', verticalalignment='center', transform=plt.gca().transAxes)


    # Plot 5: Sample Error Map (True - Predicted)
    plt.subplot(2, 3, 5)
    if y_test.shape[0] > 0 and y_pred.shape[0] > 0: # Ensure samples exist
        error_map = y_test[0,:,:,0] - y_pred[0,:,:,0]
        plt.imshow(error_map, cmap='coolwarm', vmin=-1.0, vmax=1.0, origin='lower') # Set vmin/vmax for consistent color scaling
        plt.title(f"Fold {fold+1}: Sample Error Map (True - Pred)", fontsize=14)
        plt.colorbar(label="Standardized Error")
        plt.xlabel("Longitude Index")
        plt.ylabel("Latitude Index")
    else:
        plt.text(0.5, 0.5, "No test data for sample error map", horizontalalignment='center', verticalalignment='center', transform=plt.gca().transAxes)

    plt.tight_layout(pad=3.0)
    plt.show()

# After all folds, print average metrics
print("\n=== Average Metrics Across All Folds ===")
print(f"Average MSE: {np.mean(all_mse):.6f}")
print(f"Average RMSE: {np.mean(all_rmse):.6f}")
print(f"Average MAE: {np.mean(all_mae):.6f}")
print(f"Average RÂ²: {np.mean(all_r2):.4f}")
print(f"Average Accuracy: {np.mean(all_accuracy):.4f}")
print(f"Average Precision: {np.mean(all_precision):.4f}")
print(f"Average Recall: {np.mean(all_recall):.4f}")

print("\nScript execution complete. Please check the console output for metrics and the plots for visualization.")

GRU

In [None]:
import numpy as np
import xarray as xr
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import (mean_squared_error, r2_score, mean_absolute_error,
                             accuracy_score, precision_score, recall_score, confusion_matrix)
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import GRU, Dense, Dropout # Changed LSTM to GRU
from tensorflow.keras.callbacks import EarlyStopping
import matplotlib.pyplot as plt
import os
import seaborn as sns

# === 1. Data Loading and Preprocessing ===
def load_time_series_data(filepath, downsample_factor=4):
    """
    Load and preprocess time series data from NetCDF file.
    Extracts the 'SALT' variable, performs downsampling, flipping,
    imputes missing values spatially, and then computes the mean salinity
    to form a 1D time series for LSTM input. Finally, standardizes this series.

    Args:
        filepath (str): Path to the NetCDF file.
        downsample_factor (int): Factor by which to downsample the spatial dimensions.

    Returns:
        tuple: (time_series, scaler) where time_series is the preprocessed 1D array,
               and scaler is the StandardScaler object used for inverse transformation.
    """
    if not os.path.exists(filepath):
        print(f"Error: Dataset file not found at '{filepath}'")
        print("Please ensure the 'salinity.nc' file is in the specified path.")
        raise FileNotFoundError(f"Dataset file not found at {filepath}")

    print(f"Loading dataset from: {filepath}")
    ds = xr.open_dataset(filepath, decode_times=False)

    # Extract 'SALT' variable, take the first depth layer, and convert to numpy array
    salt_data = ds['SALT'].values[:, 0, :, :] # Remove depth dimension

    print(f"Original SALT data shape: {salt_data.shape}")

    # Downsample and flip vertically to correct orientation
    salt_data = salt_data[:, ::downsample_factor, ::downsample_factor]
    salt_data = np.flip(salt_data, axis=1)

    print(f"Downsampled and flipped SALT data shape: {salt_data.shape}")

    # --- IMPORTANT FIX: Impute missing values on the 3D spatial data first ---
    # This ensures that each spatial slice has no NaNs before taking the mean.
    print("Imputing missing values in 3D SALT data using median strategy per slice...")
    imputer = SimpleImputer(strategy='median')
    for i in range(salt_data.shape[0]):
        # Reshape each 2D slice for imputation, then reshape back
        # Ensure there's at least one non-NaN value in the slice for median strategy to work.
        # If an entire slice is NaN, the median imputation might still fail or yield NaN.
        # However, for oceanographic data, usually not all values in a large area are NaN.
        # Handle cases where an entire slice might be NaN (rare but possible).
        if np.all(np.isnan(salt_data[i])):
            print(f"Warning: Entire slice {i} is NaN. Median imputation might not be effective.")
            # Option: Fill with global median/mean if slice is all NaN
            # For now, SimpleImputer will handle it by skipping the feature if no valid values.
            # To be safer, you might consider filling with a very small constant or zero if this happens frequently.
            pass # Let SimpleImputer try, it will warn if it can't

        imputed_slice = imputer.fit_transform(salt_data[i].reshape(-1, 1)).reshape(salt_data[i].shape)
        salt_data[i] = imputed_slice
    print("Missing values imputed in 3D data.")

    # Compute the mean salinity for each time step to get a 1D time series
    time_series = np.nanmean(salt_data, axis=(1, 2)) # Use nanmean to handle potential remaining NaNs

    # Reshape to 2D (time steps, features) for StandardScaler
    if len(time_series.shape) == 1:
        time_series = time_series.reshape(-1, 1)

    print(f"Time series shape after mean aggregation: {time_series.shape}")
    print(f"Time series shape BEFORE standardization: {time_series.shape}") # Debug print

    # Standardize the data
    print("Standardizing time series data...")
    scaler = StandardScaler()
    # Check if time_series is empty or contains only NaNs before scaling
    if time_series.shape[1] == 0 or np.all(np.isnan(time_series)):
        raise ValueError("Time series data is empty or contains only NaNs after preprocessing, cannot standardize.")
    time_series = scaler.fit_transform(time_series)
    print("Data standardization complete.")

    return time_series, scaler

# === 2. GRU Model Architecture === # Renamed function
def create_gru_model(input_shape, units=[100, 80, 60], dropout_rate=0.2): # Renamed function
    """
    Creates a Sequential GRU model for time series prediction.

    Args:
        input_shape (tuple): Shape of the input data (seq_length, n_features).
        units (list): List of integers specifying the number of GRU units for each layer.
        dropout_rate (float): Dropout rate for regularization.

    Returns:
        tensorflow.keras.models.Sequential: Compiled GRU model.
    """
    print("\nBuilding GRU model...") # Updated print statement
    model = Sequential()

    # First GRU layer # Changed LSTM to GRU
    model.add(GRU(units[0], activation='tanh', # Reverted to tanh for GRU layers
                   input_shape=input_shape,
                   return_sequences=True))
    model.add(Dropout(dropout_rate))

    # Second GRU layer # Changed LSTM to GRU
    model.add(GRU(units[1], activation='tanh', # Reverted to tanh for GRU layers
                   return_sequences=True))
    model.add(Dropout(dropout_rate))

    # Third GRU layer # Changed LSTM to GRU
    model.add(GRU(units[2], activation='tanh')) # Reverted to tanh for GRU layers
    model.add(Dropout(dropout_rate))

    # Output layer for regression (single continuous value prediction)
    model.add(Dense(1))

    # Compile the model
    model.compile(optimizer='adam', loss='mse')
    print("GRU model built and compiled.") # Updated print statement
    model.summary()
    return model

# === 3. Sequence Creation ===
def create_sequences(data, seq_length):
    """
    Creates input sequences (X) and corresponding target values (y)
    for time series forecasting.

    Args:
        data (np.ndarray): The 1D or 2D time series data.
        seq_length (int): The number of past time steps to use as input
                          to predict the next time step.

    Returns:
        tuple: (X, y) where X are the input sequences and y are the targets.
    """
    X, y = [], []
    for i in range(len(data) - seq_length):
        X.append(data[i:i+seq_length])
        y.append(data[i+seq_length])
    return np.array(X), np.array(y)

# === 4. Main Execution ===
if __name__ == "__main__":
    # Configuration
    FILE_PATH = "/content/drive/MyDrive/salinity.nc"
    SEQ_LENGTH = 30
    BATCH_SIZE = 64
    EPOCHS = 200

    # Load data
    try:
        time_series_data, scaler = load_time_series_data(FILE_PATH)
    except FileNotFoundError as e:
        print(e)
        print("Exiting program. Please provide the correct dataset path.")
        exit()
    except ValueError as e:
        print(f"Data preprocessing failed: {e}")
        print("Exiting program.")
        exit()


    if time_series_data is not None:
        # Create sequences
        X, y = create_sequences(time_series_data, SEQ_LENGTH)

        # Train-test split (maintaining temporal order)
        split_idx = int(0.8 * len(X))
        X_train, X_test = X[:split_idx], X[split_idx:]
        y_train, y_test = y[:split_idx], y[split_idx:]

        print(f"\nTraining data shape: X_train={X_train.shape}, y_train={y_train.shape}")
        print(f"Test data shape: X_test={X_test.shape}, y_test={y_test.shape}")

        # Create GRU model # Changed function call
        model = create_gru_model(input_shape=(SEQ_LENGTH, 1))

        # Callbacks
        early_stop = EarlyStopping(monitor='val_loss',
                                   patience=15,
                                   restore_best_weights=True)

        # Train model
        print("\nStarting model training...")
        history = model.fit(
            X_train, y_train,
            epochs=EPOCHS,
            batch_size=BATCH_SIZE,
            validation_data=(X_test, y_test),
            callbacks=[early_stop],
            verbose=1
        )
        print("Model training finished.")

        # Evaluation
        print("\nEvaluating model on test set...")
        y_pred = model.predict(X_test)

        # Inverse transform to get original scale values for meaningful metrics
        y_test_orig = scaler.inverse_transform(y_test)
        y_pred_orig = scaler.inverse_transform(y_pred)

        # --- Regression Metrics ---
        mse = mean_squared_error(y_test_orig, y_pred_orig)
        rmse = np.sqrt(mse)
        mae = mean_absolute_error(y_test_orig, y_pred_orig)
        r2 = r2_score(y_test_orig, y_pred_orig)

        print("\nðŸ“Š Final Regression Test Metrics:")
        print(f"âœ… MSE (Mean Squared Error): {mse:.6f}")
        print(f"âœ… RMSE (Root Mean Squared Error): {rmse:.6f}")
        print(f"âœ… MAE (Mean Absolute Error): {mae:.6f}")
        print(f"âœ… RÂ² (R-squared): {r2:.4f}")

        # --- Classification Metrics (for Demonstration Purposes Only) ---
        print("\nðŸ”„ Classification Metrics (Derived for Demonstration - See Note Below):")

        classification_threshold = np.median(y_test_orig)
        print(f" Â (Using classification threshold: {classification_threshold:.4f} based on median of true test values)")

        y_test_binary = (y_test_orig > classification_threshold).astype(int)
        y_pred_binary = (y_pred_orig > classification_threshold).astype(int)

        accuracy = accuracy_score(y_test_binary, y_pred_binary)
        precision = precision_score(y_test_binary, y_pred_binary, zero_division=0)
        recall = recall_score(y_test_binary, y_pred_binary, zero_division=0)
        conf_matrix = confusion_matrix(y_test_binary, y_pred_binary)

        print(f"âœ… Accuracy: Â {accuracy:.4f}")
        print(f"âœ… Precision: {precision:.4f}")
        print(f"âœ… Recall: Â  Â {recall:.4f}")
        print(f"âœ… Confusion Matrix:\n{conf_matrix}")


        print("\n--- Important Note on Classification Metrics ---")
        print("The 'Accuracy', 'Precision', and 'Recall' metrics above are calculated by converting the continuous regression "
              "outputs into binary classes using an arbitrary threshold (the median of true test values in this case).")
        print("This transformation is done *only for demonstration* of these metrics. Your model's primary task is regression, "
              "and its performance should be primarily judged by MSE, RMSE, MAE, and R-squared.")
        print("Choosing a different threshold would likely change these classification metric values.")

        # === 5. Visualization ===
        print("\nGenerating plots for visualization...")
        plt.figure(figsize=(15, 6))

        # Subplot 1: True vs. Predicted Values Plot (Time Series)
        plt.subplot(1, 2, 1)
        plt.plot(y_test_orig, label='True Mean Salinity', color='blue', linestyle='-')
        plt.plot(y_pred_orig, label='Predicted Mean Salinity', color='red', linestyle='--')
        plt.title("True vs. Predicted Mean Salinity (Original Scale)")
        plt.xlabel("Time Step Index")
        plt.ylabel("Mean Salinity")
        plt.legend()
        plt.grid(True)

        # Subplot 2: Loss History
        plt.subplot(1, 2, 2)
        plt.plot(history.history['loss'], label='Training Loss')
        plt.plot(history.history['val_loss'], label='Validation Loss')
        plt.title('Model Loss History (MSE)')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        plt.grid(True)

        plt.tight_layout()
        plt.show()

        # Separate plot for Confusion Matrix
        plt.figure(figsize=(7, 6))
        sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues',
                    xticklabels=['Predicted 0', 'Predicted 1'],
                    yticklabels=['Actual 0', 'Actual 1'])
        plt.title('Confusion Matrix for Binary Classification')
        plt.xlabel('Predicted Label')
        plt.ylabel('True Label')
        plt.tight_layout()
        plt.show()

        print("\nScript execution complete. Check console for metrics and plots for visualization.")

    else:
        print("Data loading failed. Cannot proceed with model training.")


LSTM

In [None]:
import numpy as np
import xarray as xr
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import (mean_squared_error, r2_score, mean_absolute_error,
                             accuracy_score, precision_score, recall_score, confusion_matrix)
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense, Dropout
from tensorflow.keras.callbacks import EarlyStopping
import matplotlib.pyplot as plt
import os
import seaborn as sns

# === 1. Data Loading and Preprocessing ===
def load_time_series_data(filepath, downsample_factor=4): # Removed time_dim, var_name as they were not consistently used for SALT
    """
    Load and preprocess time series data from NetCDF file.
    Extracts the 'SALT' variable, performs downsampling, flipping,
    imputes missing values spatially, and then computes the mean salinity
    to form a 1D time series for LSTM input. Finally, standardizes this series.

    Args:
        filepath (str): Path to the NetCDF file.
        downsample_factor (int): Factor by which to downsample the spatial dimensions.

    Returns:
        tuple: (time_series, scaler) where time_series is the preprocessed 1D array,
               and scaler is the StandardScaler object used for inverse transformation.
    """
    if not os.path.exists(filepath):
        print(f"Error: Dataset file not found at '{filepath}'")
        print("Please ensure the 'salinity.nc' file is in the specified path.")
        raise FileNotFoundError(f"Dataset file not found at {filepath}")

    print(f"Loading dataset from: {filepath}")
    ds = xr.open_dataset(filepath, decode_times=False)

    # Extract 'SALT' variable, take the first depth layer, and convert to numpy array
    salt_data = ds['SALT'].values[:, 0, :, :] # Remove depth dimension

    print(f"Original SALT data shape: {salt_data.shape}")

    # Downsample and flip vertically to correct orientation
    salt_data = salt_data[:, ::downsample_factor, ::downsample_factor]
    salt_data = np.flip(salt_data, axis=1)

    print(f"Downsampled and flipped SALT data shape: {salt_data.shape}")

    # --- IMPORTANT FIX: Impute missing values on the 3D spatial data first ---
    # This ensures that each spatial slice has no NaNs before taking the mean.
    print("Imputing missing values in 3D SALT data using median strategy per slice...")
    imputer = SimpleImputer(strategy='median')
    for i in range(salt_data.shape[0]):
        # Reshape each 2D slice for imputation, then reshape back
        # Ensure there's at least one non-NaN value in the slice for median strategy to work.
        # If an entire slice is NaN, the median imputation might still fail or yield NaN.
        # However, for oceanographic data, usually not all values in a large area are NaN.
        imputed_slice = imputer.fit_transform(salt_data[i].reshape(-1, 1)).reshape(salt_data[i].shape)
        salt_data[i] = imputed_slice
    print("Missing values imputed in 3D data.")

    # Compute the mean salinity for each time step to get a 1D time series
    time_series = np.mean(salt_data, axis=(1, 2))

    # Reshape to 2D (time steps, features) for StandardScaler
    if len(time_series.shape) == 1:
        time_series = time_series.reshape(-1, 1)

    print(f"Time series shape after mean aggregation: {time_series.shape}")

    # Standardize the data
    print("Standardizing time series data...")
    scaler = StandardScaler()
    # Check if time_series is empty or contains only NaNs before scaling
    if time_series.shape[1] == 0 or np.all(np.isnan(time_series)):
        raise ValueError("Time series data is empty or contains only NaNs after preprocessing, cannot standardize.")
    time_series = scaler.fit_transform(time_series)
    print("Data standardization complete.")

    return time_series, scaler

# === 2. LSTM Model Architecture ===
def create_lstm_model(input_shape, units=[100, 80, 60], dropout_rate=0.2):
    """
    Creates a Sequential LSTM model for time series prediction.

    Args:
        input_shape (tuple): Shape of the input data (seq_length, n_features).
        units (list): List of integers specifying the number of LSTM units for each layer.
        dropout_rate (float): Dropout rate for regularization.

    Returns:
        tensorflow.keras.models.Sequential: Compiled LSTM model.
    """
    print("\nBuilding LSTM model...")
    model = Sequential()

    # First LSTM layer
    model.add(LSTM(units[0], activation='relu',
                   input_shape=input_shape,
                   return_sequences=True))
    model.add(Dropout(dropout_rate))

    # Second LSTM layer
    model.add(LSTM(units[1], activation='relu',
                   return_sequences=True))
    model.add(Dropout(dropout_rate))

    # Third LSTM layer
    model.add(LSTM(units[2], activation='relu'))
    model.add(Dropout(dropout_rate))

    # Output layer for regression (single continuous value prediction)
    model.add(Dense(1))

    # Compile the model
    model.compile(optimizer='adam', loss='mse')
    print("LSTM model built and compiled.")
    model.summary()
    return model

# === 3. Sequence Creation ===
def create_sequences(data, seq_length):
    """
    Creates input sequences (X) and corresponding target values (y)
    for time series forecasting.

    Args:
        data (np.ndarray): The 1D or 2D time series data.
        seq_length (int): The number of past time steps to use as input
                          to predict the next time step.

    Returns:
        tuple: (X, y) where X are the input sequences and y are the targets.
    """
    X, y = [], []
    for i in range(len(data) - seq_length):
        X.append(data[i:i+seq_length])
        y.append(data[i+seq_length])
    return np.array(X), np.array(y)

# === 4. Main Execution ===
if __name__ == "__main__":
    # Configuration
    FILE_PATH = "/content/drive/MyDrive/salinity.nc"
    SEQ_LENGTH = 30
    BATCH_SIZE = 64
    EPOCHS = 200

    # Load data
    try:
        time_series_data, scaler = load_time_series_data(FILE_PATH)
    except FileNotFoundError as e:
        print(e)
        print("Exiting program. Please provide the correct dataset path.")
        exit()
    except ValueError as e:
        print(f"Data preprocessing failed: {e}")
        print("Exiting program.")
        exit()


    if time_series_data is not None:
        # Create sequences
        X, y = create_sequences(time_series_data, SEQ_LENGTH)

        # Train-test split (maintaining temporal order)
        split_idx = int(0.8 * len(X))
        X_train, X_test = X[:split_idx], X[split_idx:]
        y_train, y_test = y[:split_idx], y[split_idx:]

        print(f"\nTraining data shape: X_train={X_train.shape}, y_train={y_train.shape}")
        print(f"Test data shape: X_test={X_test.shape}, y_test={y_test.shape}")

        # Create LSTM model
        model = create_lstm_model(input_shape=(SEQ_LENGTH, 1))

        # Callbacks
        early_stop = EarlyStopping(monitor='val_loss',
                                   patience=15,
                                   restore_best_weights=True)

        # Train model
        print("\nStarting model training...")
        history = model.fit(
            X_train, y_train,
            epochs=EPOCHS,
            batch_size=BATCH_SIZE,
            validation_data=(X_test, y_test),
            callbacks=[early_stop],
            verbose=1
        )
        print("Model training finished.")

        # Evaluation
        print("\nEvaluating model on test set...")
        y_pred = model.predict(X_test)

        # Inverse transform to get original scale values for meaningful metrics
        y_test_orig = scaler.inverse_transform(y_test)
        y_pred_orig = scaler.inverse_transform(y_pred)

        # --- Regression Metrics ---
        mse = mean_squared_error(y_test_orig, y_pred_orig)
        rmse = np.sqrt(mse)
        mae = mean_absolute_error(y_test_orig, y_pred_orig)
        r2 = r2_score(y_test_orig, y_pred_orig)

        print("\nðŸ“Š Final Regression Test Metrics:")
        print(f"âœ… MSE (Mean Squared Error): {mse:.6f}")
        print(f"âœ… RMSE (Root Mean Squared Error): {rmse:.6f}")
        print(f"âœ… MAE (Mean Absolute Error): {mae:.6f}")
        print(f"âœ… RÂ² (R-squared): {r2:.4f}")

        # --- Classification Metrics (for Demonstration Purposes Only) ---
        print("\nðŸ”„ Classification Metrics (Derived for Demonstration - See Note Below):")

        classification_threshold = np.median(y_test_orig)
        print(f" Â (Using classification threshold: {classification_threshold:.4f} based on median of true test values)")

        y_test_binary = (y_test_orig > classification_threshold).astype(int)
        y_pred_binary = (y_pred_orig > classification_threshold).astype(int)

        accuracy = accuracy_score(y_test_binary, y_pred_binary)
        precision = precision_score(y_test_binary, y_pred_binary, zero_division=0)
        recall = recall_score(y_test_binary, y_pred_binary, zero_division=0)
        conf_matrix = confusion_matrix(y_test_binary, y_pred_binary)

        print(f"âœ… Accuracy: Â {accuracy:.4f}")
        print(f"âœ… Precision: {precision:.4f}")
        print(f"âœ… Recall: Â  Â {recall:.4f}")
        print(f"âœ… Confusion Matrix:\n{conf_matrix}")


        print("\n--- Important Note on Classification Metrics ---")
        print("The 'Accuracy', 'Precision', and 'Recall' metrics above are calculated by converting the continuous regression "
              "outputs into binary classes using an arbitrary threshold (the median of true test values in this case).")
        print("This transformation is done *only for demonstration* of these metrics. Your model's primary task is regression, "
              "and its performance should be primarily judged by MSE, RMSE, MAE, and R-squared.")
        print("Choosing a different threshold would likely change these classification metric values.")

        # === 5. Visualization ===
        print("\nGenerating plots for visualization...")
        plt.figure(figsize=(15, 6))

        # Subplot 1: True vs. Predicted Values Plot (Time Series)
        plt.subplot(1, 2, 1)
        plt.plot(y_test_orig, label='True Mean Salinity', color='blue', linestyle='-')
        plt.plot(y_pred_orig, label='Predicted Mean Salinity', color='red', linestyle='--')
        plt.title("True vs. Predicted Mean Salinity (Original Scale)")
        plt.xlabel("Time Step Index")
        plt.ylabel("Mean Salinity")
        plt.legend()
        plt.grid(True)

        # Subplot 2: Loss History
        plt.subplot(1, 2, 2)
        plt.plot(history.history['loss'], label='Training Loss')
        plt.plot(history.history['val_loss'], label='Validation Loss')
        plt.title('Model Loss History (MSE)')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        plt.grid(True)

        plt.tight_layout()
        plt.show()

        # Separate plot for Confusion Matrix
        plt.figure(figsize=(7, 6))
        sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues',
                    xticklabels=['Predicted 0', 'Predicted 1'],
                    yticklabels=['Actual 0', 'Actual 1'])
        plt.title('Confusion Matrix for Binary Classification')
        plt.xlabel('Predicted Label')
        plt.ylabel('True Label')
        plt.tight_layout()
        plt.show()

        print("\nScript execution complete. Check console for metrics and plots for visualization.")

    else:
        print("Data loading failed. Cannot proceed with model training.")


ANN

In [None]:
import numpy as np
import xarray as xr
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import (mean_squared_error, r2_score, mean_absolute_error, # Added mean_absolute_error
                             accuracy_score, precision_score, recall_score, confusion_matrix) # Added classification metrics and confusion_matrix
from tensorflow.keras.models import Sequential # Using tensorflow.keras
from tensorflow.keras.layers import Dense, Dropout
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.regularizers import l2
import matplotlib.pyplot as plt
import os
import seaborn as sns # Added for confusion matrix visualization

# === 1. Data Loading and Preprocessing ===
def load_time_series_data(filepath, downsample_factor=4):
    """
    Load and preprocess time series data from NetCDF file.
    Extracts the 'SALT' variable, performs downsampling, flipping,
    imputes missing values spatially, and then computes the mean salinity
    to form a 1D time series for ANN input. Finally, standardizes this series.

    Args:
        filepath (str): Path to the NetCDF file.
        downsample_factor (int): Factor by which to downsample the spatial dimensions.

    Returns:
        tuple: (time_series, scaler) where time_series is the preprocessed 1D array,
               and scaler is the StandardScaler object used for inverse transformation.
    """
    if not os.path.exists(filepath):
        print(f"Error: Dataset file not found at '{filepath}'")
        print("Please ensure the 'salinity.nc' file is in the specified path.")
        raise FileNotFoundError(f"Dataset file not found at {filepath}")

    print(f"Loading dataset from: {filepath}")
    ds = xr.open_dataset(filepath, decode_times=False)

    # Extract 'SALT' variable, take the first depth layer, and convert to numpy array
    salt_data = ds['SALT'].values[:, 0, :, :] # Remove depth dimension

    print(f"Original SALT data shape: {salt_data.shape}")

    # Downsample and flip vertically to correct orientation
    salt_data = salt_data[:, ::downsample_factor, ::downsample_factor]
    salt_data = np.flip(salt_data, axis=1)

    print(f"Downsampled and flipped SALT data shape: {salt_data.shape}")

    # Impute missing values on the 3D spatial data first
    print("Imputing missing values in 3D SALT data using median strategy per slice...")
    imputer = SimpleImputer(strategy='median')
    for i in range(salt_data.shape[0]):
        if np.all(np.isnan(salt_data[i])):
            print(f"Warning: Entire slice {i} is NaN. Median imputation might not be effective.")
        imputed_slice = imputer.fit_transform(salt_data[i].reshape(-1, 1)).reshape(salt_data[i].shape)
        salt_data[i] = imputed_slice
    print("Missing values imputed in 3D data.")

    # Compute the mean salinity for each time step to get a 1D time series
    time_series = np.nanmean(salt_data, axis=(1, 2)) # Use nanmean to handle potential remaining NaNs

    # Reshape to 2D (time steps, features) for StandardScaler
    if len(time_series.shape) == 1:
        time_series = time_series.reshape(-1, 1)

    print(f"Time series shape after mean aggregation: {time_series.shape}")
    print(f"Time series shape BEFORE standardization: {time_series.shape}") # Debug print

    # Standardize the data
    print("Standardizing time series data...")
    scaler = StandardScaler()
    if time_series.shape[1] == 0 or np.all(np.isnan(time_series)):
        raise ValueError("Time series data is empty or contains only NaNs after preprocessing, cannot standardize.")
    time_series = scaler.fit_transform(time_series)
    print("Data standardization complete.")

    return time_series, scaler

# === 2. ANN Model Architecture ===
def create_ann_model(input_shape, units=[256, 128, 64], dropout_rate=0.3, l2_reg=0.01):
    """
    Creates a Sequential ANN model for time series prediction.

    Args:
        input_shape (int): Number of features in the input (window size).
        units (list): List of integers specifying the number of units for each Dense layer.
        dropout_rate (float): Dropout rate for regularization.
        l2_reg (float): L2 regularization factor.

    Returns:
        tensorflow.keras.models.Sequential: Compiled ANN model.
    """
    print("\nBuilding ANN model...")
    model = Sequential()

    # Input layer
    model.add(Dense(units[0], activation='relu',
                    input_dim=input_shape, # For ANN, input_dim is the flattened sequence length
                    kernel_regularizer=l2(l2_reg)))
    model.add(Dropout(dropout_rate))

    # Hidden layers
    model.add(Dense(units[1], activation='relu',
                    kernel_regularizer=l2(l2_reg)))
    model.add(Dropout(dropout_rate))

    model.add(Dense(units[2], activation='relu',
                    kernel_regularizer=l2(l2_reg)))
    model.add(Dropout(dropout_rate))

    # Output layer (linear activation for regression)
    model.add(Dense(1, activation='linear'))

    model.compile(optimizer='adam', loss='mse', metrics=['mae']) # Added MAE metric
    print("ANN model built and compiled.")
    model.summary()
    return model

# === 3. Modified Sequence Creation for ANN ===
def create_ann_sequences(data, window_size):
    """
    Create sliding windows for ANN (flattened sequences).

    Args:
        data (np.ndarray): The 1D or 2D time series data.
        window_size (int): The number of past time steps to use as input
                           to predict the next time step.

    Returns:
        tuple: (X, y) where X are the flattened input windows and y are the targets.
    """
    X, y = [], []
    for i in range(len(data) - window_size):
        X.append(data[i:i+window_size].flatten())  # Flatten the window
        y.append(data[i+window_size])
    return np.array(X), np.array(y)

# === 4. Main Execution ===
if __name__ == "__main__":
    # Configuration
    FILE_PATH = "/content/drive/MyDrive/salinity.nc"
    WINDOW_SIZE = 20
    BATCH_SIZE = 32
    EPOCHS = 200

    # Load data
    try:
        time_series_data, scaler = load_time_series_data(FILE_PATH)
    except FileNotFoundError as e:
        print(e)
        print("Exiting program. Please provide the correct dataset path.")
        exit()
    except ValueError as e:
        print(f"Data preprocessing failed: {e}")
        print("Exiting program.")
        exit()

    if time_series_data is not None:
        # Create sequences (flattened for ANN)
        X, y = create_ann_sequences(time_series_data, WINDOW_SIZE)

        # Train-test split (maintaining temporal order for time series data)
        split_idx = int(0.8 * len(X))
        X_train, X_test = X[:split_idx], X[split_idx:]
        y_train, y_test = y[:split_idx], y[split_idx:]

        print(f"\nTraining data shape: X_train={X_train.shape}, y_train={y_train.shape}")
        print(f"Test data shape: X_test={X_test.shape}, y_test={y_test.shape}")

        # Create ANN model
        # Input shape for ANN is the flattened sequence length (X_train.shape[1])
        model = create_ann_model(input_shape=X_train.shape[1])

        # Callbacks
        early_stop = EarlyStopping(monitor='val_loss',
                                   patience=20, # Increased patience as per original snippet
                                   restore_best_weights=True)

        # Train model
        print("\nStarting model training...")
        history = model.fit(
            X_train, y_train,
            epochs=EPOCHS,
            batch_size=BATCH_SIZE,
            validation_data=(X_test, y_test),
            callbacks=[early_stop],
            verbose=1
        )
        print("Model training finished.")

        # Evaluation
        print("\nEvaluating model on test set...")
        y_pred = model.predict(X_test)

        # Inverse transform to get original scale values for meaningful metrics
        # Ensure y_test is correctly reshaped for inverse_transform if it's 1D
        y_test_orig = scaler.inverse_transform(y_test.reshape(-1, 1))
        y_pred_orig = scaler.inverse_transform(y_pred) # y_pred is already 2D from model.predict

        # --- Regression Metrics ---
        mse = mean_squared_error(y_test_orig, y_pred_orig)
        rmse = np.sqrt(mse)
        mae = mean_absolute_error(y_test_orig, y_pred_orig) # Calculate MAE
        r2 = r2_score(y_test_orig, y_pred_orig)

        print("\nðŸ“Š Final Regression Test Metrics:")
        print(f"âœ… MSE (Mean Squared Error): {mse:.6f}")
        print(f"âœ… RMSE (Root Mean Squared Error): {rmse:.6f}")
        print(f"âœ… MAE (Mean Absolute Error): {mae:.6f}")
        print(f"âœ… RÂ² (R-squared): {r2:.4f}")

        # --- Classification Metrics (for Demonstration Purposes Only) ---
        print("\nðŸ”„ Classification Metrics (Derived for Demonstration - See Note Below):")

        # Define a threshold to convert regression output to binary classification.
        # Using the median of the true test values as a simple example threshold.
        classification_threshold = np.median(y_test_orig)
        print(f" Â (Using classification threshold: {classification_threshold:.4f} based on median of true test values)")

        # Convert true and predicted continuous values to binary classes based on the threshold
        y_test_binary = (y_test_orig > classification_threshold).astype(int)
        y_pred_binary = (y_pred_orig > classification_threshold).astype(int)

        # Calculate classification metrics
        accuracy = accuracy_score(y_test_binary, y_pred_binary)
        precision = precision_score(y_test_binary, y_pred_binary, zero_division=0)
        recall = recall_score(y_test_binary, y_pred_binary, zero_division=0)
        conf_matrix = confusion_matrix(y_test_binary, y_pred_binary)

        print(f"âœ… Accuracy: Â {accuracy:.4f}")
        print(f"âœ… Precision: {precision:.4f}")
        print(f"âœ… Recall: Â  Â {recall:.4f}")
        print(f"âœ… Confusion Matrix:\n{conf_matrix}")


        print("\n--- Important Note on Classification Metrics ---")
        print("The 'Accuracy', 'Precision', and 'Recall' metrics above are calculated by converting the continuous regression "
              "outputs into binary classes using an arbitrary threshold (the median of true test values in this case).")
        print("This transformation is done *only for demonstration* of these metrics. Your model's primary task is regression, "
              "and its performance should be primarily judged by MSE, RMSE, MAE, and R-squared.")
        print("Choosing a different threshold would likely change these classification metric values.")

        # === 5. Visualization ===
        print("\nGenerating plots for visualization...")
        plt.figure(figsize=(15, 6))

        # Subplot 1: True vs. Predicted Values Plot (Time Series)
        plt.subplot(1, 2, 1)
        # Ensure y_test_orig and y_pred_orig are 1D arrays for plotting
        plt.plot(y_test_orig.flatten(), label='True Mean Salinity', color='blue', linestyle='-')
        plt.plot(y_pred_orig.flatten(), label='Predicted Mean Salinity', color='red', linestyle='--')
        plt.title("True vs. Predicted Mean Salinity (Original Scale)")
        plt.xlabel("Time Step Index")
        plt.ylabel("Mean Salinity")
        plt.legend()
        plt.grid(True)

        # Subplot 2: Loss History
        plt.subplot(1, 2, 2)
        plt.plot(history.history['loss'], label='Training Loss')
        plt.plot(history.history['val_loss'], label='Validation Loss')
        plt.title('Model Loss History (MSE)')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        plt.grid(True)

        plt.tight_layout()
        plt.show()

        # Separate figure for Scatter Plot and Confusion Matrix
        plt.figure(figsize=(15, 6))

        # Subplot 1 of second figure: Scatter plot of actual vs predicted values (as in your original snippet)
        plt.subplot(1, 2, 1)
        plt.scatter(y_test_orig, y_pred_orig, alpha=0.6, color='royalblue', edgecolor='white', label='Predictions')
        max_val = max(y_test_orig.max(), y_pred_orig.max())
        min_val = min(y_test_orig.min(), y_pred_orig.min())
        plt.plot([min_val, max_val], [min_val, max_val], 'r--', lw=2, label='Ideal Prediction')
        plt.annotate(f'$R^2$ = {r2:.4f}',
                     xy=(0.05, 0.85),
                     xycoords='axes fraction',
                     fontsize=14,
                     bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
        plt.title('ANN Model Performance: Actual vs Predicted Salinity', fontsize=14, pad=20)
        plt.xlabel('Actual Salinity Values', fontsize=12)
        plt.ylabel('Predicted Salinity Values', fontsize=12)
        plt.legend(fontsize=12, loc='upper left')
        plt.grid(True, linestyle='--', alpha=0.3)


        # Subplot 2 of second figure: Confusion Matrix
        plt.subplot(1, 2, 2)
        sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues',
                    xticklabels=['Predicted 0', 'Predicted 1'],
                    yticklabels=['Actual 0', 'Actual 1'])
        plt.title('Confusion Matrix for Binary Classification')
        plt.xlabel('Predicted Label')
        plt.ylabel('True Label')
        plt.tight_layout()
        plt.show()

        print("\nScript execution complete. Check console for metrics and plots for visualization.")

    else:
        print("Data loading failed. Cannot proceed with model training.")


GAM

In [None]:
import numpy as np
import pandas as pd
import xarray as xr
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import (mean_squared_error, r2_score, mean_absolute_error,
                             accuracy_score, precision_score, recall_score, confusion_matrix)
from pygam import LinearGAM, s, f
from pygam.terms import TermList
import matplotlib.pyplot as plt
import os
import seaborn as sns
from sklearn.model_selection import train_test_split

# === 1. Data Loading and Preprocessing for GAM ===
def load_and_prepare_data(filepath, n_lags=5):
    """
    Load and preprocess salinity data for GAM.
    Extracts 'SALT' variable, computes its mean time series,
    imputes missing values, standardizes, creates lag features from salinity,
    and adds time-based features (day of year, month).

    Args:
        filepath (str): Path to the NetCDF file.
        n_lags (int): Number of lag features to create.

    Returns:
        tuple: (X, y, scaler) where X is the feature DataFrame,
               y is the target array, and scaler is the StandardScaler for salinity.
    """
    if not os.path.exists(filepath):
        print(f"Error: Dataset file not found at '{filepath}'")
        print("Please ensure the 'salinity.nc' file is in the specified path.")
        raise FileNotFoundError(f"Dataset file not found at {filepath}")

    print(f"Loading dataset from: {filepath}")
    ds = xr.open_dataset(filepath, decode_times=False)

    # Extract 'SALT' variable, take the first depth layer
    # Assuming 'SALT' has dimensions (time, depth, lat, lon)
    if 'SALT' not in ds:
        raise ValueError("'SALT' variable not found in the dataset.")
    if len(ds['SALT'].shape) < 4:
        raise ValueError("'SALT' variable expected to have at least 4 dimensions (time, depth, lat, lon).")

    salt_data_3d = ds['SALT'].values[:, 0, :, :]

    # Compute the mean salinity for each time step to get a 1D time series
    # Use np.nanmean to safely handle NaNs during aggregation
    mean_salinity_ts = np.nanmean(salt_data_3d, axis=(1, 2))

    # Create a DataFrame for processing
    df = pd.DataFrame({'salinity_value': mean_salinity_ts})

    # Impute missing values in the 1D mean salinity time series
    print("Imputing missing values in time series data using median strategy...")
    imputer = SimpleImputer(strategy='median')
    df['salinity_value'] = imputer.fit_transform(df[['salinity_value']])
    print("Missing values imputed.")

    # Standardize the salinity time series
    scaler = StandardScaler()
    df['salinity_scaled'] = scaler.fit_transform(df[['salinity_value']])
    print("Salinity time series standardized.")

    # Create lag features from the *standardized salinity*
    # The target `y` will be the current `salinity_scaled`
    # The features `X` will be lagged `salinity_scaled` plus time features
    print(f"Creating {n_lags} lag features...")
    for i in range(1, n_lags + 1):
        df[f'lag_salinity_{i}'] = df['salinity_scaled'].shift(i)
    print("Lag features created.")

    # Add time features
    # Check if 'TIME' coordinate exists and has datetime accessor
    time_feature_added = False
    if 'TIME' in ds:
        if hasattr(ds['TIME'], 'dt'):
            # Ensure that the time coordinate length matches the salinity data
            if len(ds['TIME'].values) == len(df):
                df['day_of_year'] = ds['TIME'].dt.dayofyear.values
                df['month'] = ds['TIME'].dt.month.values
                time_feature_added = True
            else:
                print("Warning: Length of 'TIME' coordinate does not match salinity data. Skipping datetime features.")
        else:
            # If 'TIME' exists but is not datetime, use its numeric values
            if len(ds['TIME'].values) == len(df):
                df['time_numeric'] = ds['TIME'].values
                time_feature_added = True
            else:
                print("Warning: Length of 'TIME' coordinate does not match salinity data. Skipping numeric time feature.")

    if not time_feature_added:
        print("Warning: 'TIME' coordinate not found or not suitable for datetime features. Using numeric index as a fallback time feature.")
        df['time_numeric'] = np.arange(len(df)) # Fallback to numeric index

    # Drop rows with NaNs created by shifting
    initial_rows = len(df)
    df = df.dropna()
    print(f"Dropped {initial_rows - len(df)} rows due to NaNs created by lag features or missing time data.")


    # Define features (X) and target (y)
    # y is the 'salinity_scaled' column after shifting
    # X includes the lag features and time features
    y = df['salinity_scaled'].values
    X_columns = [col for col in df.columns if col.startswith('lag_salinity_') or col == 'day_of_year' or col == 'month' or col == 'time_numeric']
    X = df[X_columns]

    print(f"Features (X) shape: {X.shape}")
    print(f"Target (y) shape: {y.shape}")
    return X, y, scaler

# === 2. GAM Model Creation ===
def create_gam_model(X_df):
    """Create GAM model with proper term construction based on DataFrame columns."""
    print("\nBuilding GAM model with smoothing and categorical terms...")
    terms = []
    # Create terms for each feature based on its position in the DataFrame
    for i, col in enumerate(X_df.columns):
        if col == 'month':
            terms.append(f(i, n_splines=12))  # Categorical feature, 12 categories for months
        elif col.startswith('lag_salinity_') or col == 'time_numeric' or col == 'day_of_year':
            terms.append(s(i))  # Numerical feature with smoothing spline
        else:
            # Fallback for any other unexpected columns, treat as smoothing
            print(f"Warning: Unexpected column '{col}'. Treating as smoothing term.")
            terms.append(s(i))

    # Initialize LinearGAM with the constructed terms
    gam = LinearGAM(terms=TermList(*terms))
    print("GAM model terms defined.")
    return gam

# === 3. Main Execution ===
if __name__ == "__main__":
    # Configuration
    FILE_PATH = "/content/drive/MyDrive/salinity.nc"
    N_LAGS = 5 # Number of lag features
    TEST_SIZE = 0.2 # 20% for testing

    # Load and prepare data
    try:
        X, y, scaler = load_and_prepare_data(FILE_PATH, n_lags=N_LAGS)
    except FileNotFoundError as e:
        print(e)
        print("Exiting program. Please provide the correct dataset path.")
        exit()
    except ValueError as e:
        print(f"Data preprocessing failed: {e}")
        print("Exiting program.")
        exit()
    except Exception as e:
        print(f"An unexpected error occurred during data loading: {e}")
        print("Exiting program.")
        exit()


    if X is not None:
        # Train-test split (maintaining temporal order)
        print(f"\nSplitting data into training and test sets (test_size={TEST_SIZE})...")
        X_train, X_test, y_train, y_test = train_test_split(
            X, y, test_size=TEST_SIZE, shuffle=False
        )
        print(f"X_train shape: {X_train.shape}, y_train shape: {y_train.shape}")
        print(f"X_test shape: {X_test.shape}, y_test shape: {y_test.shape}")


        # Create GAM model
        gam = create_gam_model(X_train) # Pass DataFrame to get column names

        # Fit model using gridsearch for automatic parameter tuning
        print("Fitting GAM model using gridsearch...")
        # pygam's gridsearch expects numpy arrays for X and y
        gam.gridsearch(X_train.values, y_train)
        print("GAM model fitting complete.")
        # FIX: Access the numeric value from the OrderedDict before formatting
        if 'pseudo_r2' in gam.statistics_:
             pseudo_r2_value = gam.statistics_['pseudo_r2']
             # Check if it's a dict/OrderedDict and extract the value if needed
             if isinstance(pseudo_r2_value, dict):
                 # Assuming the first value is the score if it's an OrderedDict
                 pseudo_r2_value = list(pseudo_r2_value.values())[0]
             print(f"Best cross-validation score (Pseudo R-squared): {pseudo_r2_value:.4f}")
        else:
            print("Pseudo R-squared not found in gam.statistics_ after gridsearch.")


        # Evaluation
        print("\nEvaluating GAM model on test set...")
        y_pred = gam.predict(X_test.values) # pygam's predict expects numpy array

        # Inverse transform to get original scale values for meaningful metrics and plotting
        # y_test and y_pred are currently scaled
        y_test_orig = scaler.inverse_transform(y_test.reshape(-1, 1)).flatten()
        y_pred_orig = scaler.inverse_transform(y_pred.reshape(-1, 1)).flatten()


        # --- Regression Metrics ---
        mse = mean_squared_error(y_test_orig, y_pred_orig)
        rmse = np.sqrt(mse)
        mae = mean_absolute_error(y_test_orig, y_pred_orig) # Calculate MAE
        r2 = r2_score(y_test_orig, y_pred_orig)

        print("\nðŸ“Š Final Regression Test Metrics:")
        print(f"âœ… MSE (Mean Squared Error): {mse:.6f}")
        print(f"âœ… RMSE (Root Mean Squared Error): {rmse:.6f}")
        print(f"âœ… MAE (Mean Absolute Error): {mae:.6f}")
        print(f"âœ… RÂ² (R-squared): {r2:.4f}")

        # --- Classification Metrics (for Demonstration Purposes Only) ---
        print("\nðŸ”„ Classification Metrics (Derived for Demonstration - See Note Below):")

        # Define a threshold to convert regression output to binary classification.
        # Using the median of the true test values as a simple example threshold.
        classification_threshold = np.median(y_test_orig)
        print(f" Â (Using classification threshold: {classification_threshold:.4f} based on median of true test values)")

        # Convert true and predicted continuous values to binary classes based on the threshold
        y_test_binary = (y_test_orig > classification_threshold).astype(int)
        y_pred_binary = (y_pred_orig > classification_threshold).astype(int)

        # Calculate classification metrics
        accuracy = accuracy_score(y_test_binary, y_pred_binary)
        precision = precision_score(y_test_binary, y_pred_binary, zero_division=0)
        recall = recall_score(y_test_binary, y_pred_binary, zero_division=0)
        conf_matrix = confusion_matrix(y_test_binary, y_pred_binary)

        print(f"âœ… Accuracy: Â {accuracy:.4f}")
        print(f"âœ… Precision: {precision:.4f}")
        print(f"âœ… Recall: Â  Â {recall:.4f}")
        print(f"âœ… Confusion Matrix:\n{conf_matrix}")


        print("\n--- Important Note on Classification Metrics ---")
        print("The 'Accuracy', 'Precision', and 'Recall' metrics above are calculated by converting the continuous regression "
              "outputs into binary classes using an arbitrary threshold (the median of true test values in this case).")
        print("This transformation is done *only for demonstration* of these metrics. Your model's primary task is regression, "
              "and its performance should be primarily judged by MSE, RMSE, MAE, and R-squared.")
        print("Choosing a different threshold would likely change these classification metric values.")

        # === 5. Visualization ===
        print("\nGenerating plots for visualization...")

        # Plot 1: True vs. Predicted Values Plot (Time Series - Original Scale)
        plt.figure(figsize=(12, 6))
        plt.plot(y_test_orig, label='True Mean Salinity', color='blue', linestyle='-')
        plt.plot(y_pred_orig, label='Predicted Mean Salinity', color='red', linestyle='--')
        plt.title("True vs. Predicted Mean Salinity (Original Scale)")
        plt.xlabel("Time Step Index")
        plt.ylabel("Mean Salinity")
        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        plt.show()

        # Plot 2: Scatter plot of actual vs predicted values with R^2
        plt.figure(figsize=(8, 8))
        plt.scatter(y_test_orig, y_pred_orig, alpha=0.6, color='teal', edgecolor='white', label='Predictions')
        max_val = max(y_test_orig.max(), y_pred_orig.max())
        min_val = min(y_test_orig.min(), y_pred_orig.min())
        plt.plot([min_val, max_val], [min_val, max_val], 'r--', lw=2, label='Ideal Prediction')
        plt.annotate(f'$R^2$ = {r2:.4f}',
                     xy=(0.05, 0.85),
                     xycoords='axes fraction',
                     fontsize=14,
                     bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
        plt.title('GAM Model Performance: Actual vs Predicted Values', fontsize=14, pad=20)
        plt.xlabel('Actual Values', fontsize=12)
        plt.ylabel('Predicted Values', fontsize=12)
        plt.legend(fontsize=12, loc='upper left')
        plt.grid(True, linestyle='--', alpha=0.3)
        plt.tight_layout()
        plt.show()

        # Plot 3: Confusion Matrix
        plt.figure(figsize=(7, 6))
        sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues',
                    xticklabels=['Predicted 0', 'Predicted 1'],
                    yticklabels=['Actual 0', 'Actual 1'])
        plt.title('Confusion Matrix for Binary Classification')
        plt.xlabel('Predicted Label')
        plt.ylabel('True Label')
        plt.tight_layout()
        plt.show()

        # Plot 4: Partial dependence plots for GAM
        print("\nGenerating Partial Dependence Plots for GAM terms...")
        try:
            titles = X_train.columns # Use training data columns for titles
            n_features = len(titles)
            rows = int(np.ceil(n_features / 3)) # Max 3 plots per row

            plt.figure(figsize=(15, rows * 5))
            # Corrected iteration through gam.terms_
            for i, term in enumerate(gam.terms_):
                if term.isintercept:
                    continue

                # Access the feature index correctly from the term
                feature_index = term.feature
                if feature_index < len(titles):
                    feature_name = titles[feature_index]
                else:
                    feature_name = f"Feature Index {feature_index} (Unknown)"
                    print(f"Warning: Feature index {feature_index} out of bounds for titles.")


                # Handle different term types for plotting
                # Check against the actual type instances
                if isinstance(term, s): # Smoothing term
                    XX = gam.generate_X_grid(term=i, exclude_na=True)
                    pdep, confi = gam.partial_dependence(term=i, X=XX, width=0.95, return_std=True)

                    plt.subplot(rows, 3, i + 1)
                    plt.plot(XX[:, feature_index], pdep, label='Partial Dependence')
                    plt.fill_between(XX[:, feature_index], pdep - confi, pdep + confi, alpha=0.2, label='95% Confidence')
                    plt.title(f'Partial Dependence of {feature_name}')
                    plt.xlabel(feature_name)
                    plt.ylabel('Partial Dependence')
                    plt.grid(True, linestyle='--', alpha=0.6)
                    plt.legend()
                elif isinstance(term, f): # Factor (categorical) term
                    XX = gam.generate_X_grid(term=i, exclude_na=True)
                    pdep = gam.partial_dependence(term=i, X=XX)

                    plt.subplot(rows, 3, i + 1)
                    # Ensure XX[:, feature_index] is used for x-axis labels if needed,
                    # but often for factor plots, simply using range(len(pdep)) and setting xticks is better
                    plt.bar(range(len(pdep)), pdep)
                    plt.title(f'Partial Dependence of {feature_name}')
                    plt.xlabel(feature_name)
                    plt.ylabel('Partial Dependence')
                    # Attempt to set xticks if feature_name is 'month'
                    if feature_name == 'month':
                         plt.xticks(range(len(pdep)), XX[:, feature_index].astype(int))
                    else:
                         plt.xticks(range(len(pdep)), [str(int(x)) for x in XX[:, feature_index]]) # Generic labels for other factors

                    plt.grid(axis='y', linestyle='--', alpha=0.6)

            plt.tight_layout()
            plt.show()

        except ImportError:
            print("Matplotlib not available - skipping partial dependence plots")
        except Exception as e:
            print(f"Error generating partial dependence plots: {e}")


        print("\nScript execution complete. Check console for metrics and plots for visualization.")

    else:
        print("Data loading failed. Cannot proceed with model training.")

In [None]:
!pip install pygam
