# Loading the data as a time series

In [None]:
import numpy as np
from numpy.array_api import linspace
# import mne
from scipy.io import loadmat
from scipy.signal import decimate, butter, filtfilt
import pandas as pd
import matplotlib.pyplot as plt
import importlib
from eeg_utils import *
%matplotlib inline

In [None]:
mat = mat_to_dataframe("sub-0cGdk9_HoldL_MedOff_run1_LFP_Hilbert/sub_i4oK0F_HoldL_MedOff_run1_LFP_Hilbert.mat")

In [None]:
df, left_lfp, right_lfp = mat

In [None]:
right_lfp = df['LFP-right-23']
left_lfp = df['LFP-left-78']
print(right_lfp)
print(left_lfp)
if 'df' in locals():
    print("\nPlotting first 5000 samples...")
    plot_slice = 5000
    
    df.iloc[:plot_slice].plot(
        subplots=True,   # Plot each channel separately
        layout=(2, 1),   # Arrange in 2 rows, 1 column
        grid=True,
        title="LFP Time Series (First 5000 Samples)",
        figsize=(15, 6)  # Width, Height in inches
    )
    plt.xlabel(df.index.name)
    plt.tight_layout()
    plt.show()

In [None]:
# Applying a Band-Pass filter for 4-48 Hz
# Note that this filtering method turns pd.Series into np.array
fs = 2000
lowcut=4
highcut=48
left_filtered = butter_bandpass_filter(data=left_lfp,
                              lowcut=lowcut,
                              highcut=highcut,
                              fs=fs,
                              order=5)

right_filtered = butter_bandpass_filter(data=right_lfp,
                              lowcut=lowcut,
                              highcut=highcut,
                              fs=fs,
                              order=5)

In [None]:
# Plot the band-pass filtered data
time = np.arange(left_lfp.size) / fs

fig, axes = plt.subplots(nrows=2, ncols= 1, figsize=(15, 6), sharex=True)

axes[0].set_title(f"Band-pass filtered signals  ({lowcut}-{highcut}Hz)")
axes[0].plot(time[:left_lfp.size], right_filtered[:left_lfp.size], label="LFP-rigt-34")
axes[0].set_ylabel("Amplitude")
axes[0].legend()

axes[1].plot(time[:left_lfp.size], left_filtered[:left_lfp.size], label="LFP-left-56", color="orange")

plt.xlabel("Time (seconds)")
plt.ylabel("Amplitude")
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()

In [None]:
# Downsampling the signal

original_fs = 2000
target_fs = 100 # because of nyquist thm and divisor of fs

if original_fs % target_fs != 0:
    raise ValueError("Original fs must be an integer multiple of Target fs.")

q = original_fs // target_fs

## Hangi filtrelerin daha iyi olduğuna bak. iir iyi değil

decimate nasıl bir filtreleme yapıyor? Belki kendim filtreleme yapmama gerek kalmayabilir. belki sadece high-pass yapmak yetebilir. Önce decimate sonra high-pass

In [None]:
left = decimate(x=left_filtered, q=q, ftype="fir", zero_phase=True)

# Verify the result
print(f"Original number of samples: {len(left_filtered)}")
print(f"Downsampled number of samples: {len(left)}")
print(f"New sampling rate: {target_fs} Hz")

In [None]:
right = decimate(x=right_filtered, q=q, ftype="fir", zero_phase=True)

# Verify the result
print(f"Original number of samples: {len(right_filtered)}")
print(f"Downsampled number of samples: {len(right)}")
print(f"New sampling rate: {target_fs} Hz")

In [None]:
plot_duration = 1000 # seconds

time_dwn = np.linspace(0, plot_duration, int(plot_duration * target_fs))

fig, axes = plt.subplots(nrows=2, ncols= 1, figsize=(15, 6), sharex=True)

axes[0].set_title(f"Band-pass filtered signals ({lowcut}-{highcut}Hz) at 100Hz")
axes[0].plot(time_dwn, right[:1000], label="LFP-rigt-34")
axes[0].set_ylabel("Amplitude")
axes[0].grid(True)
axes[0].legend()

axes[1].plot(time_dwn, left[:1000], label="LFP-left-56", color="orange")

plt.xlabel("Time (seconds)")
plt.ylabel("Amplitude")
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()

In [None]:
print(len(right[:5000]))
print(len(left[:5000]))

In [None]:
# Define the sampling rate of your downsampled signal
fs_downsampled = 100 # Hz


right_hold = []
for sec in [314, 416, 522, 617, 707]:
    # Define the start and end times (in seconds) for your desired slice
    start_time_sec = sec
    end_time_sec = start_time_sec + 60
    # Calculate the corresponding start and end indices in the NumPy array
    start_index = int(start_time_sec * fs_downsampled)
    end_index = int(end_time_sec * fs_downsampled)

    # Perform the slicing
    # Ensure the indices are within the bounds of the array
    if start_index < 0:
        start_index = 0
    if end_index > len(left):
        end_index = len(left)

    # Assign the slice to a new variable
    right_hold.append(right[start_index:end_index])

In [None]:
# Variables created
# right_hold
# left_hold
# right_resting
# left_resting

In [None]:
import pickle

# Define file paths
left_diagrams_path = "./i4oK0F/medOff_left_hold.pkl"
right_diagrams_path = "./i4oK0F/medOff_right_hold.pkl"

# Save the diagrams
with open(left_diagrams_path, "wb") as f:
    pickle.dump(left_hold, f)
print(f"Saved left_diagrams to {left_diagrams_path}")

with open(right_diagrams_path, "wb") as f:
    pickle.dump(right_hold, f)
print(f"Saved right_diagrams to {right_diagrams_path}")


# Doing the TDA magic

In [None]:
from gtda.time_series import SingleTakensEmbedding
from gtda.time_series import TakensEmbedding
import itertools

# --- Plotting Libraries ---
from plotly.subplots import make_subplots
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import plotly.io as pio
from gtda.plotting import plot_point_cloud, plot_diagram, plot_heatmap
from gtda.homology import VietorisRipsPersistence
from gtda.diagrams import PersistenceEntropy, PairwiseDistance, PersistenceLandscape, BettiCurve, HeatKernel, PersistenceImage, Silhouette, Scaler
import seaborn as sns


# --- Set Plotting Themes to Light Mode ---
pio.templates.default = "plotly_white" # For plotly and giotto-tda plots
plt.style.use('default') # For matplotlib plots


In [None]:
import eeg_utils
importlib.reload(eeg_utils)
from eeg_utils import *

## The current situation of the data
- Band-pass filtered to 4-48 Hz
- Down sampled to 100 Hz (for anti-aliasing, FIR filter is used)
- Resting state data is splitted into 5 non-overlapping windows of 60 second, stored in `left_resting` and `right_resting`
- Hold state data is already in different windows in the series. They are stored in `left_hold` and `right_hold`

In [None]:
max_embedding_dim = 10
max_time_delay = 10
stride = 1

embedder = SingleTakensEmbedding(
    parameters_type="search",
    time_delay=max_time_delay,
    dimension=max_embedding_dim,
    stride=stride,
    n_jobs=-1
)

In [None]:
left_resting_embeddings = []
right_resting_embeddings = []
left_hold_embeddings = []
right_hold_embeddings = []

for slice in range(len(left_resting)):
    print(f"Slice {slice + 1}:")
    left_resting_embeddings.append(fit_embedder(embedder, left_resting[slice]))

for slice in range(len(left_hold)):
    print(f"Slice {slice + 1}:")
    left_hold_embeddings.append(fit_embedder(embedder, left_resting[slice]))

for slice in range(len(right_resting)):
    print(f"Slice {slice + 1}:")
    right_resting_embeddings.append(fit_embedder(embedder, right_resting[slice]))

for slice in range(len(right_hold)):
    print(f"Slice {slice + 1}:")
    right_hold_embeddings.append(fit_embedder(embedder, right_hold[slice]))


In [None]:
from sklearn.decomposition import PCA
pca = PCA(n_components=3)

In [None]:
left_embedded_pca = pca.fit_transform(left_embedded)
plot_point_cloud(left_embedded_pca)

In [None]:
right_embedded_pca = pca.fit_transform(right_embedded)
plot_point_cloud(right_embedded_pca)

In [None]:
print(len(left_resting_embeddings[0]))
print(len(left_hold_embeddings[0]))
print(len(right_resting_embeddings[0]))
print(len(right_hold_embeddings[0]))

In [None]:
 # Assuming you have your VietorisRipsPersistence object initialized
homology_dims = [0, 1, 2]
persistence = VietorisRipsPersistence(homology_dimensions=homology_dims, n_jobs=-1)

# Create lists to store the diagrams for each slice
left_resting_diagrams = []
right_resting_diagrams = []
left_hold_diagrams = []
right_hold_diagrams = []


# --- Process Left Embeddings ---
print("--- Processing Left Channel Slices ---")
for i, embedding in enumerate(left_resting_embeddings[:2][:3000]):
    print(f"Calculating persistence for left resting slice {i+1}...")

    # Reshape the 2D embedding to 3D (1, n_points, n_dimensions)
    embedding_3d = embedding[None, :, :]

    # Calculate the persistence diagram
    diagram = persistence.fit_transform(embedding_3d)

    # Add the resulting diagram to our list
    left_resting_diagrams.append(diagram)

for i, embedding in enumerate(left_hold_embeddings[:2][:3000]):
    print(f"Calculating persistence for left hold slice {i+1}...")

    # Reshape the 2D embedding to 3D (1, n_points, n_dimensions)
    embedding_3d = embedding[None, :, :]

    # Calculate the persistence diagram
    diagram = persistence.fit_transform(embedding_3d)

    # Add the resulting diagram to our list
    left_hold_diagrams.append(diagram)

# --- Process Right Embeddings ---
print("\n--- Processing Right Channel Slices ---")
for i, embedding in enumerate(right_resting_embeddings[:2][:3000]):
    print(f"Calculating persistence for right resting slice {i+1}...")

    # Reshape and transform
    embedding_3d = embedding[None, :, :]
    diagram = persistence.fit_transform(embedding_3d)

    # Add to the list
    right_resting_diagrams.append(diagram)

for i, embedding in enumerate(right_hold_embeddings[:2][:3000]):
    print(f"Calculating persistence for right hold slice {i+1}...")

    # Reshape and transform
    embedding_3d = embedding[None, :, :]
    diagram = persistence.fit_transform(embedding_3d)

    # Add to the list
    right_hold_diagrams.append(diagram)

In [None]:
# Now, left_diagrams and right_diagrams are lists of persistence diagrams.
# You can inspect the first one:
print("Example of a calculated diagram's shape:")
print(left_diagrams[0].shape)

# You can then plot one of them just like before:
print("Plotting the diagram for the first slice of the left channel:")

plot_diagram(left_diagrams[3][0])

In [None]:
# --- Create Subplot Titles ---
left_titles = [f"Left Slice {i+1}" for i in range(len(left_diagrams))]
right_titles = [f"Right Slice {i+1}" for i in range(len(right_diagrams))]

# --- Create a 4x5 Figure with Subplots ---
fig = make_subplots(
    rows=4,
    cols=5,
    subplot_titles=(left_titles + right_titles)
)

# Define colors for the different homology dimensions
dim_colors = {
    0: 'blue',
    1: 'red',
    2: 'green',
    3: 'purple'
}

# --- Add Traces for Left Channel Diagrams (Rows 1-2) ---
for i, diagram_3d in enumerate(left_diagrams):
    # Correctly calculate row and column for a 5-column grid
    row = (i // 5) + 1
    col = (i % 5) + 1

    diagram_2d = diagram_3d[0]

    for dim in sorted(dim_colors.keys()):
        dim_mask = diagram_2d[:, 2] == dim
        birth = diagram_2d[dim_mask, 0]
        death = diagram_2d[dim_mask, 1]
        finite_mask = np.isfinite(death)

        fig.add_trace(
            go.Scatter(
                x=birth[finite_mask],
                y=death[finite_mask],
                mode='markers',
                marker_color=dim_colors.get(dim),
                name=f'H{dim}',
                legendgroup=f'H{dim}',
                showlegend=(i == 0)  # Show legend only on the very first plot
            ),
            row=row,
            col=col
        )

    max_val = np.max(diagram_2d[np.isfinite(diagram_2d[:, 1])]) if np.any(np.isfinite(diagram_2d[:, 1])) else 1
    fig.add_shape(
        type="line", x0=0, y0=0, x1=max_val, y1=max_val,
        line=dict(color="black", width=1, dash="dash"),
        row=row, col=col
    )

# --- Add Traces for Right Channel Diagrams (Rows 3-4) ---
for i, diagram_3d in enumerate(right_diagrams):
    # Calculate row and column, starting from row 3
    row = (i // 5) + 3
    col = (i % 5) + 1

    diagram_2d = diagram_3d[0]

    for dim in sorted(dim_colors.keys()):
        dim_mask = diagram_2d[:, 2] == dim
        birth = diagram_2d[dim_mask, 0]
        death = diagram_2d[dim_mask, 1]
        finite_mask = np.isfinite(death)

        fig.add_trace(
            go.Scatter(
                x=birth[finite_mask],
                y=death[finite_mask],
                mode='markers',
                marker_color=dim_colors.get(dim),
                name=f'H{dim}',
                legendgroup=f'H{dim}',
                showlegend=False  # Hide legends for all other plots
            ),
            row=row,
            col=col
        )

    max_val = np.max(diagram_2d[np.isfinite(diagram_2d[:, 1])]) if np.any(np.isfinite(diagram_2d[:, 1])) else 1
    fig.add_shape(
        type="line", x0=0, y0=0, x1=max_val, y1=max_val,
        line=dict(color="black", width=1, dash="dash"),
        row=row, col=col
    )

# --- Update Layout and Axis Titles ---
fig.update_layout(
    height=1200,  # Increased height for 4 rows
    width=1600,
    title_text="Persistence Diagrams for All Slices",
    title_x=0.5,
    legend_title_text='Homology Dimension',
    plot_bgcolor='white'
)

fig.update_xaxes(title_text="Birth", showgrid=True, gridwidth=1, gridcolor='lightgray')
fig.update_yaxes(title_text="Death", showgrid=True, gridwidth=1, gridcolor='lightgray')

fig.show()

# Save and Load Persistence Diagrams

In [None]:
import pickle

# Define file paths
left_diagrams_path = "./Saved_Data/medOff_left_diagrams.pkl"
right_diagrams_path = "./Saved_Data/medOff_right_diagrams.pkl"

# Save the diagrams
with open(left_diagrams_path, "wb") as f:
    pickle.dump(left_diagrams, f)
print(f"Saved left_diagrams to {left_diagrams_path}")

with open(right_diagrams_path, "wb") as f:
    pickle.dump(right_diagrams, f)
print(f"Saved right_diagrams to {right_diagrams_path}")


# Load Persistence Diagrams (run this cell if you want to load saved data)

In [None]:
import pickle

# Define file paths
left_diagrams_path = "./Saved_Data/left_diagrams.pkl"
right_diagrams_path = "./Saved_Data/right_diagrams.pkl"

# Load the diagrams
try:
    with open(left_diagrams_path, "rb") as f:
        left_diagrams = pickle.load(f)
    print(f"Loaded left_diagrams from {left_diagrams_path}")

    with open(right_diagrams_path, "rb") as f:
        right_diagrams = pickle.load(f)
    print(f"Loaded right_diagrams from {right_diagrams_path}")
except FileNotFoundError:
    print("One or both diagram files not found. Please ensure they have been saved.")


# Feature Extraction

## Persistence Entropy

In [None]:
PE = PersistenceEntropy()
left_pe_features = [PE.fit_transform(slice) for slice in left_diagrams]
for feat in range(len(left_pe_features)):
    print(f"Slice {feat + 1}: {left_pe_features[feat]}")

In [None]:
PE = PersistenceEntropy()
right_pe_features = [PE.fit_transform(slice) for slice in right_diagrams]
for feat in range(len(right_pe_features)):
    print(f"Slice {feat + 1}: {right_pe_features[feat]}")

## Other Small Features

In [None]:
left_sm_features = [extract_features(slice, homology_dimensions=homology_dims, verbose=True) for slice in left_diagrams]

In [None]:
right_sm_features = [extract_features(slice, homology_dimensions=homology_dims, verbose=True) for slice in right_diagrams]

## Persistence Landscapes

In [None]:
PL = PersistenceLandscape()
left_pl = PL.fit_transform(left_diagrams[0])


In [None]:
# Left Persistence Landscape Features
PL = PersistenceLandscape()

left_pl_features = [PL.fit_transform(slice) for slice in left_diagrams]
print(f"Shape of the PL of each slice: {left_pl_features[0].shape}")

for feat in range(len(left_pl_features)):
    print(f"Slice {feat + 1}: {left_pl_features[feat]}")


In [None]:
PL.plot(left_pl_features[0])

In [None]:
# Left Persistence Landscape Features
PL = PersistenceLandscape()

right_pl_features = [PL.fit_transform(slice) for slice in right_diagrams]
print(f"Shape of the PL of each slice: {right_pl_features[0].shape}")

for feat in range(len(right_pl_features)):
    print(f"Slice {feat + 1}: {right_pl_features[feat]}")


In [None]:
PL.plot(right_pl_features[0])

## Betti Curves

In [None]:
BC = BettiCurve()

left_bcs = [BC.fit_transform(slice) for slice in left_diagrams]

BC.plot(left_bcs[2])


In [None]:
BC = BettiCurve()

right_bcs = [BC.fit_transform(slice) for slice in right_diagrams]

BC.plot(right_bcs[2])

## Heat Kernel

In [None]:
HK = HeatKernel()

left_hk = HK.fit_transform(left_diagrams[0])
left_hk.shape

In [None]:
HK.plot(left_hk)

In [None]:
HK.plot(left_hk, homology_dimension_idx=2)

## Calculating Pairwise Distance
Here's an example of how to use `PairwiseDistance` to calculate the distance between two of your persistence diagrams.

## First pad the diagrams into suitable dimensions

In [None]:
# Extract all 10 diagrams (squeeze out the singleton dimension)
left_all_diagrams = [left_diagrams[i][0] for i in range(len(left_diagrams))]

# Pad them correctly
left_diagrams_padded = pad_diagrams(left_all_diagrams)

print(f"Number of diagrams: {len(left_all_diagrams)}")
print(f"Padded array shape: {left_diagrams_padded.shape}")

In [None]:
# Extract all 10 diagrams (squeeze out the singleton dimension)
right_all_diagrams = [right_diagrams[i][0] for i in range(len(right_diagrams))]

# Pad them correctly
right_diagrams_padded = pad_diagrams(right_all_diagrams)

print(f"Number of diagrams: {len(right_all_diagrams)}")
print(f"Padded array shape: {right_diagrams_padded.shape}")

## Now scale the diagrams with respect to a metric

In [None]:
scaler = Scaler(metric='wasserstein')  # or try 'landscape', 'silhouette'
left_diagrams_scaled = scaler.fit_transform(left_diagrams_padded)

In [None]:
scaler = Scaler(metric='wasserstein')  # or try 'landscape', 'silhouette'
right_diagrams_scaled = scaler.fit_transform(right_diagrams_padded)

## Wasserstein Distance

In [None]:
# Compute all pairwise distances
pwise_dist = PairwiseDistance(metric="wasserstein")
left_dist_mx = pwise_dist.fit_transform(left_diagrams_scaled)

print(f"Distance matrix shape: {left_dist_mx.shape}")
print(f"Distance matrix:\n{left_dist_mx}")

In [None]:
plt.figure(figsize=(8, 6))
sns.heatmap(left_dist_mx, annot=True, fmt='.2f', cmap='viridis')
plt.title('Pairwise Wasserstein Distances of left diagrams with Wasserstein Scaling')
plt.xlabel('Diagram Index')
plt.ylabel('Diagram Index')
plt.show()

In [None]:
# Compute all pairwise distances
pwise_dist = PairwiseDistance(metric="wasserstein")
right_dist_mx = pwise_dist.fit_transform(right_diagrams_scaled)

print(f"Distance matrix shape: {right_dist_mx.shape}")
print(f"Distance matrix:\n{right_dist_mx}")

In [None]:
plt.figure(figsize=(8, 6))
sns.heatmap(right_dist_mx, annot=True, fmt='.2f', cmap='viridis')
plt.title('Pairwise Wasserstein Distances with Wasserstein Scaling')
plt.xlabel('Diagram Index')
plt.ylabel('Diagram Index')
plt.show()

In [None]:
# Define file paths
left_diagrams_path = "./Saved_Data/medOff_left_wasserstein.pkl"
right_diagrams_path = "./Saved_Data/medOff_right_wasserstein.pkl"

# Save the diagrams
with open(left_diagrams_path, "wb") as f:
    pickle.dump(left_dist_mx, f)
print(f"Saved left_diagrams to {left_diagrams_path}")

with open(right_diagrams_path, "wb") as f:
    pickle.dump(right_dist_mx, f)
print(f"Saved right_diagrams to {right_diagrams_path}")

## Bottleneck Distance

In [None]:
# Compute all pairwise distances
pwise_dist = PairwiseDistance(metric="bottleneck")
left_dist_mx = pwise_dist.fit_transform(left_diagrams_scaled)

print(f"Distance matrix shape: {left_dist_mx.shape}")
print(f"Distance matrix:\n{left_dist_mx}")

In [None]:
plt.figure(figsize=(8, 6))
sns.heatmap(left_dist_mx, annot=True, fmt='.2f', cmap='viridis')
plt.title('Pairwise Bottleneck Distances of left diagrams with Wasserstein Scaling')
plt.xlabel('Diagram Index')
plt.ylabel('Diagram Index')
plt.show()

In [None]:
# Compute all pairwise distances
pwise_dist = PairwiseDistance(metric="bottleneck")
right_dist_mx = pwise_dist.fit_transform(right_diagrams_scaled)

print(f"Distance matrix shape: {right_dist_mx.shape}")
print(f"Distance matrix:\n{right_dist_mx}")

In [None]:
plt.figure(figsize=(8, 6))
sns.heatmap(right_dist_mx, annot=True, fmt='.2f', cmap='viridis')
plt.title('Pairwise Bottleneck Distances of right diagrams with Wasserstein Scaling')
plt.xlabel('Diagram Index')
plt.ylabel('Diagram Index')
plt.show()

In [None]:
# Define file paths
left_diagrams_path = "./Saved_Data/medOff_left_bottleneck.pkl"
right_diagrams_path = "./Saved_Data/medOff_right_bottleneck.pkl"

# Save the diagrams
with open(left_diagrams_path, "wb") as f:
    pickle.dump(left_dist_mx, f)
print(f"Saved left_diagrams to {left_diagrams_path}")

with open(right_diagrams_path, "wb") as f:
    pickle.dump(right_dist_mx, f)
print(f"Saved right_diagrams to {right_diagrams_path}")