In [54]:
# Import third-party libraries
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
from sklearn.feature_selection import mutual_info_regression
from matplotlib.gridspec import GridSpec
from matplotlib.gridspec import GridSpecFromSubplotSpec

# Import local modules
from recovar import RepresentationLearningSingleAutoencoder, RepresentationLearningDenoisingSingleAutoencoder
from directory import get_checkpoint_path
from config import BATCH_SIZE, N_CHANNELS
from kfold_environment import KFoldEnvironment

In [55]:
# -------------------------------
# Configuration and Setup
# -------------------------------

# Experiment name identifier
EXP_NAME = "exp_test"

# Choose the representation learning model class
REPRESENTATION_LEARNING_MODEL_CLASS = RepresentationLearningSingleAutoencoder

# Specify training and testing datasets ('stead' or 'instance')
TRAIN_DATASET = "stead"
TEST_DATASET = "stead"

# Number of training epochs
EPOCH = 6

# Data split identifier
SPLIT = 0

In [71]:
# -------------------------------
# Define Utility Functions
# -------------------------------

def gaussian_1d(x, mean, std_dev):
    """
    Compute the 1D Gaussian (Normal) probability density function.

    Parameters:
    - x (float or np.ndarray): The input value(s).
    - mean (float): Mean of the distribution.
    - std_dev (float): Standard deviation of the distribution.

    Returns:
    - float or np.ndarray: Probability density at x.
    """
    coefficient = 1.0 / (std_dev * np.sqrt(2 * np.pi))
    exponent = -0.5 * ((x - mean) / std_dev) ** 2
    return coefficient * np.exp(exponent)

def multivariate_gaussian(x, mean, cov):
    """
    Compute the Multivariate Gaussian probability density function.

    Parameters:
    - x (np.ndarray): Input vector of shape (k,).
    - mean (np.ndarray): Mean vector of shape (k,).
    - cov (np.ndarray): Covariance matrix of shape (k, k).

    Returns:
    - float: Probability density at point x.
    """
    k = mean.shape[0]
    x_m = x - mean
    try:
        inv_cov = np.linalg.inv(cov)
        det_cov = np.linalg.det(cov)
    except np.linalg.LinAlgError:
        raise ValueError("Covariance matrix must be invertible")

    norm_const = 1.0 / (np.power((2 * np.pi), k / 2) * np.sqrt(det_cov))
    exponent = -0.5 * np.einsum("bnc, cnm, bmc->bc", x_m, inv_cov, x_m)
    return norm_const[None, :] * np.exp(exponent)

def select_elements_along_last_axis(tensor, num_elements=9, random_seed=0):
    """
    Selects elements along the last axis of a 3D tensor.

    Parameters:
    - tensor (np.ndarray): Input tensor of shape (A, B, C).
    - num_elements (int): Number of elements to select along the last axis. Default is 9.
    - indices (list or np.ndarray, optional): Specific indices to select along the last axis. 
                                             If None, selects randomly without replacement.
    - random_seed (int, optional): Seed for random selection to ensure reproducibility.

    Returns:
    - selected_tensor (np.ndarray): Tensor with selected elements along the last axis. Shape (A, B, num_elements).
    """
    A, B, C = tensor.shape
    
    if random_seed is not None:
        np.random.seed(random_seed)  # For reproducibility
    selected_indices = np.random.choice(C, size=num_elements, replace=False)
    selected_indices = np.array(selected_indices)
    selected_indices = selected_indices.astype(np.int32)
    selected_tensor = tensor[:, :, selected_indices]
    
    return selected_tensor

def estimate_mutual_information(x, y):
    """
    Estimate mutual information between two variables.

    Parameters:
    - x (np.ndarray): Feature array.
    - y (np.ndarray): Target array.

    Returns:
    - np.ndarray: Mutual information scores.
    """
    x = np.expand_dims(x, axis=-1)
    mi = mutual_info_regression(x, y)
    mi = np.squeeze(mi)
    return mi

def estimate_pairwise_mutual_informations(y, channels, timesteps):
    """
    Estimate pairwise mutual informations for all channel and timestep combinations.

    Parameters:
    - y (np.ndarray): Input data array.
    - channels (int): Number of channels.
    - timesteps (int): Number of timesteps.

    Returns:
    - np.ndarray: Tensor containing mutual information values.
    """
    mi = np.zeros(shape=[channels, timesteps, timesteps])

    for c in range(channels):
        for i in range(timesteps):
            for j in range(timesteps):
                if i < j:
                    mi[c, i, j] = estimate_mutual_information(y[:, i, c], y[:, j, c])

    # Make the mutual information matrix symmetric
    mi = mi + np.transpose(mi, axes=[0, 2, 1])
    return mi

def calculate_pairwise_cross_correlations(y):
    """
    Calculate pairwise cross-correlations for the input data.

    Parameters:
    - y (np.ndarray): Input data array.

    Returns:
    - np.ndarray: Tensor containing cross-correlation values.
    """
    y = y - np.mean(y, axis=0, keepdims=True)
    y = y / np.std(y, axis=0)
    cov = np.einsum("bni, bmi->inm", y, y) / np.shape(y)[0]
    return cov

def calculate_pairwise_cross_covariances(y):
    """
    Calculate pairwise cross-covariances for the input data.

    Parameters:
    - y (np.ndarray): Input data array.

    Returns:
    - np.ndarray: Tensor containing cross-covariance values.
    """
    y = y - np.mean(y, axis=0, keepdims=True)
    y = y.astype(np.float64)
    cov = np.einsum("bni, bmi->inm", y, y) / np.shape(y)[0]
    return cov

def plot_heatmap_grid(fig, tensor, num_rows, num_columns, subgridspec, range_min, range_max, cmap):
    """
    Plot a grid of heatmaps within a given subgridspec.

    Parameters:
    - fig (matplotlib.figure.Figure): The figure object.
    - tensor (np.ndarray): Tensor containing heatmap data. Shape: (num_heatmaps, H, W)
    - num_rows (int): Number of rows in the heatmap grid.
    - num_columns (int): Number of columns in the heatmap grid.
    - subgridspec (GridSpec): SubGridSpec object defining the heatmap grid layout.
    - range_min (float): Minimum value for the color scale.
    - range_max (float): Maximum value for the color scale.
    - cmap (str): Colormap to use for heatmaps.

    Returns:
    - None
    """
    num_heatmaps = num_rows * num_columns

    for i in range(num_heatmaps):
        row = i // num_columns
        col = i % num_columns

        if i < tensor.shape[0]:
            ax = fig.add_subplot(subgridspec[row, col])
            sns.heatmap(
                tensor[i],
                ax=ax,
                cmap=cmap,
                vmin=range_min,
                vmax=range_max,
                cbar=False,          # Disable individual colorbars
                square=True,
                linewidths=0,        # Remove grid lines
                linecolor='gray'     # Optional: can be omitted when linewidths=0
            )
            ax.set_title(f'Ch {i+1}', fontsize=12)
            ax.set_xlabel('')
            ax.set_ylabel('')
        else:
            # Hide unused subplots
            ax = fig.add_subplot(subgridspec[row, col])
            ax.axis('off')

def plot_four_quadrant_heatmaps(mi_tensor_eq, cc_tensor_eq, mi_tensor_noise, cc_tensor_noise,
                                num_rows=3, num_columns=3, path='four_quadrant_heatmaps.png',
                                mi_range_min=0, mi_range_max=10,
                                cc_range_min=0, cc_range_max=10):
    """
    Plot a 2x2 grid where each cell contains a 3x3 grid of heatmaps.

    Layout:
    [0,0]: Mutual Information heatmaps for earthquakes
    [1,0]: Cross-Covariance heatmaps for earthquakes
    [0,1]: Mutual Information heatmaps for noise
    [1,1]: Cross-Covariance heatmaps for noise

    Each cell contains num_rows x num_columns heatmaps.

    Parameters:
    - mi_tensor_eq (np.ndarray): Tensor containing mutual information data for earthquakes. Shape: (9, H, W)
    - cc_tensor_eq (np.ndarray): Tensor containing cross-covariance data for earthquakes. Shape: (9, H, W)
    - mi_tensor_noise (np.ndarray): Tensor containing mutual information data for noise. Shape: (9, H, W)
    - cc_tensor_noise (np.ndarray): Tensor containing cross-covariance data for noise. Shape: (9, H, W)
    - num_rows (int): Number of rows in each heatmap grid.
    - num_columns (int): Number of columns in each heatmap grid.
    - path (str): File path to save the resulting plot.
    - mi_range_min (float): Minimum value for the MI color scale.
    - mi_range_max (float): Maximum value for the MI color scale.
    - cc_range_min (float): Minimum value for the CC color scale.
    - cc_range_max (float): Maximum value for the CC color scale.

    Returns:
    - None
    """
    # Create the main figure
    fig = plt.figure(figsize=(num_columns * 4 * 2, num_rows * 4 * 2))  # Adjust size as needed
    fig.suptitle('Mutual Information and Cross-Covariance Heatmaps: Earthquake vs Noise', fontsize=24, y=0.95)

    # Create a 2x2 GridSpec for the main layout with reduced gaps
    main_gs = GridSpec(2, 2, figure=fig, wspace=0.2, hspace=0.2)  # Decreased gaps between quadrants

    # Define the four quadrants with corresponding tensors
    quadrants = {
        'MI_Eq': mi_tensor_eq,
        'CC_Eq': cc_tensor_eq,
        'MI_Noise': mi_tensor_noise,
        'CC_Noise': cc_tensor_noise
    }

    # Titles for each quadrant
    titles = {
        'MI_Eq': 'Mutual Information - Earthquake',
        'CC_Eq': 'Cross-Covariance - Earthquake',
        'MI_Noise': 'Mutual Information - Noise',
        'CC_Noise': 'Cross-Covariance - Noise'
    }

    # Color maps for MI and CC
    color_maps = {
        'MI_Eq': 'magma',
        'MI_Noise': 'magma',
        'CC_Eq': 'viridis',
        'CC_Noise': 'viridis'
    }

    # Assign color ranges
    color_ranges = {
        'MI_Eq': (mi_range_min, mi_range_max),
        'MI_Noise': (mi_range_min, mi_range_max),
        'CC_Eq': (cc_range_min, cc_range_max),
        'CC_Noise': (cc_range_min, cc_range_max)
    }

    # Plot each quadrant
    for key, tensor in quadrants.items():
        # Determine position in main_gs
        if key == 'MI_Eq':
            main_row, main_col = 0, 0
        elif key == 'CC_Eq':
            main_row, main_col = 1, 0
        elif key == 'MI_Noise':
            main_row, main_col = 0, 1
        elif key == 'CC_Noise':
            main_row, main_col = 1, 1

        # Create subgridspec for the quadrant with increased gaps between heatmaps
        sub_gs = main_gs[main_row, main_col].subgridspec(num_rows, num_columns, wspace=0.25, hspace=0.25)  # Increased gaps between heatmaps

        # Plot the heatmap grid within the subgridspec
        plot_heatmap_grid(fig, tensor, num_rows, num_columns, sub_gs,
                          color_ranges[key][0], color_ranges[key][1],
                          cmap=color_maps[key])

        # Add quadrant title by adding a subplot that covers the entire subgridspec
        ax_main = fig.add_subplot(main_gs[main_row, main_col])
        ax_main.axis('off')  # Hide the main grid cell
        ax_main.set_title(titles[key], fontsize=16, pad=20)

    # Create colorbars for MI and CC
    # Define position for MI colorbar (left, bottom, width, height)
    mi_cbar_ax = fig.add_axes([0.92, 0.6, 0.02, 0.25])  # [left, bottom, width, height]
    mi_norm = mpl.colors.Normalize(vmin=mi_range_min, vmax=mi_range_max)
    mi_cmap = mpl.cm.get_cmap(color_maps['MI_Eq'])
    sm_mi = mpl.cm.ScalarMappable(cmap=mi_cmap, norm=mi_norm)
    sm_mi.set_array([])
    cbar_mi = fig.colorbar(sm_mi, cax=mi_cbar_ax)
    cbar_mi.set_label('Mutual Information', fontsize=12)

    # Define position for CC colorbar
    cc_cbar_ax = fig.add_axes([0.92, 0.175, 0.02, 0.25])  # [left, bottom, width, height]
    cc_norm = mpl.colors.Normalize(vmin=cc_range_min, vmax=cc_range_max)
    cc_cmap = mpl.cm.get_cmap(color_maps['CC_Eq'])
    sm_cc = mpl.cm.ScalarMappable(cmap=cc_cmap, norm=cc_norm)
    sm_cc.set_array([])
    cbar_cc = fig.colorbar(sm_cc, cax=cc_cbar_ax)
    cbar_cc.set_label('Cross-Covariance', fontsize=12)

    # Adjust layout to make room for colorbars
    plt.tight_layout(rect=[0, 0.03, 0.9, 0.95])  # Leave space on the right for colorbar

    # Save the figure and close to free memory
    plt.savefig(path, dpi=300, bbox_inches='tight')
    plt.close(fig)

In [72]:
# -------------------------------
# Initialize K-Fold Environment
# -------------------------------

# Create a K-Fold environment for the specified test dataset
kenv = KFoldEnvironment(TEST_DATASET)

# Retrieve metadata for training, validation, and testing splits
train_metadata, validation_metadata, test_metadata = kenv.get_split_metadata(SPLIT)

# Retrieve data generators for training, validation, and testing
train_gen, validation_gen, test_gen, __ = kenv.get_generators(SPLIT)

In [58]:
# -------------------------------
# Model Initialization and Loading
# -------------------------------

# Initialize the representation learning model
model = REPRESENTATION_LEARNING_MODEL_CLASS()
model.compile()

# Perform a forward pass with random input to initialize model weights
model(np.random.normal(size=[BATCH_SIZE, 3000, N_CHANNELS]))

# Construct the checkpoint path for the model weights
cp_path = get_checkpoint_path(
    EXP_NAME,
    REPRESENTATION_LEARNING_MODEL_CLASS().name,
    TRAIN_DATASET,
    SPLIT,
    EPOCH
)

# Load the pre-trained weights into the model
model.load_weights(cp_path)

In [59]:
# -------------------------------
# Data Preparation
# -------------------------------

# Initialize lists to hold batches of data
X = []
Y = []

# Iterate over the first 24 batches from the test generator
for i in range(24):
    x_batch, y_batch = test_gen.__getitem__(i)
    X.append(x_batch)
    Y.append(y_batch)

# Concatenate all batches into single numpy arrays
X = np.concatenate(X, axis=0)
Y = np.concatenate(Y, axis=0)

In [60]:
# -------------------------------
# Data Filtering
# -------------------------------

# Create boolean masks based on the target values
mask_eq = Y > 0.5    # Mask for earthquake waveforms
mask_no = Y <= 0.5   # Mask for noise waveforms

# Separate the datasets into earthquake and noise
X_eq = X[mask_eq]
X_no = X[mask_no]

# Limit the datasets to the first 512 samples each
X_eq = X_eq[0:512]
X_no = X_no[0:512]

In [66]:
# -------------------------------
# Feature Extraction
# -------------------------------

# Pass the earthquake and noise data through the model to obtain features
F_eq, __ = model(X_eq)
F_no, __ = model(X_no)

F_eq = F_eq.numpy()
F_no = F_no.numpy()

F_eq = select_elements_along_last_axis(F_eq, num_elements=9)
F_no = select_elements_along_last_axis(F_no, num_elements=9)

In [67]:
# -------------------------------
# Mutual Information Calculation
# -------------------------------

# Estimate pairwise mutual informations for earthquake and noise features
mi_eq = estimate_pairwise_mutual_informations(F_eq, 9, 94)
mi_no = estimate_pairwise_mutual_informations(F_no, 9, 94)

In [68]:
# -------------------------------
# Covariance and Correlation Calculation
# -------------------------------

# Calculate pairwise cross-covariances for earthquake and noise features
cov_eq = calculate_pairwise_cross_covariances(F_eq)
cov_no = calculate_pairwise_cross_covariances(F_no)

In [69]:
# -------------------------------
# Determine Plotting Ranges
# -------------------------------

# Calculate the range for mutual information plots based on combined data
mi_range = 5.0 * np.std(np.concatenate([mi_eq, mi_no], axis=0))

# Calculate the range for covariance plots based on combined data
cov_range = 5.0 * np.std(np.concatenate([cov_eq, cov_no], axis=0))

In [73]:
# -------------------------------
# Plotting Heatmaps
# -------------------------------
# Plot mutual information heatmaps for earthquake data
plot_four_quadrant_heatmaps(mi_eq, cov_eq, mi_no, cov_no,
                            num_rows=3, num_columns=3, path='four_quadrant_heatmaps.png',
                            mi_range_min=0, mi_range_max=mi_range,
                            cc_range_min=-cov_range, cc_range_max=cov_range)

# Plot cross-covariance heatmaps for earthquake data
#plot_cross_covariance_tensor(
#    cov_eq,
#    num_rows=3,
#    num_columns=3,
#    path="eq_covariances.png",
#    range_min=-cov_range,
#    range_max=cov_range
#)

# Plot cross-covariance heatmaps for noise data
#plot_cross_covariance_tensor(
#    cov_no,
#    num_rows=3,
#    num_columns=3,
#    path="no_covariances.png",
#    range_min=-cov_range,
#    range_max=cov_range
#)

  mi_cmap = mpl.cm.get_cmap(color_maps['MI_Eq'])
  cc_cmap = mpl.cm.get_cmap(color_maps['CC_Eq'])
  plt.tight_layout(rect=[0, 0.03, 0.9, 0.95])  # Leave space on the right for colorbar
