In [None]:
import pyabf
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.cluster import KMeans
import pyarrow

In [None]:

# ---------------------------
# Load the ABF file
# ---------------------------

abf = pyabf.ABF("PATH/TO/FILES/filename.abf")

signal = abf.sweepY   # Ionic current (e.g., in pA)
time_signal = abf.sweepX   # Time vector (e.g., in seconds)


In [None]:
print("Sampling Rate:", abf.dataRate, "Hz")

print("Number of Sweeps:", abf.sweepCount)

print("Protocol Name:", abf.protocol)

print("Recording Start Time:", abf.abfDateTime)

print("Channel Units:", abf.adcUnits)

print("Channel Scaling Factors:", abf.adcUnits)

In [None]:
from pyabf.tools.abfHeaderDisplay import abfInfoPage

info_page = abfInfoPage(abf)
header_text = info_page.getText()  # This might throw the same error, but worth trying.
print(header_text)

## Single Sweep

In [None]:
# ---------------------------
# Plot the ionic current trace for single sweep
# ---------------------------
plt.figure(figsize=(10, 4))
plt.plot(time_signal, signal, label="Ionic current")
plt.xlabel("Time (s)")
plt.ylabel("Current (pA)")
plt.title("Ionic Current Trace")
plt.legend()
plt.show()

## All Sweep Load

In [None]:
all_sweeps = []
all_times = []
for sweep in range(1):
    abf.setSweep(sweep)
    all_sweeps.append(abf.sweepY)
    all_times.append(abf.sweepX)

data_all = np.concatenate(all_sweeps)
time_all = np.concatenate(all_times)

#sweeps = data
sample_rate = abf.dataRate

### Maasking

In [None]:
start_time = 35
end_time = time_all[-1] - 50
mask = (time_all >= start_time) & (time_all <= end_time)

data = data_all[mask]
time = time_all[mask]

### Cluster

In [None]:
data_c = data.reshape(-1, 1)
print("Length of data", len(data))
n_clusters = 2

# Apply k-means clustering
kmeans = KMeans(n_clusters=n_clusters, random_state=0)
kmeans.fit(data_c)

# Get cluster labels
labels = kmeans.labels_

# Get the cluster centroids
centroids = kmeans.cluster_centers_
print(centroids)

# Get points in each cluster
cluster_1 = data_c[labels == 0]
cluster_2 = data_c[labels == 1]
# plot_cluster_histogram(cluster_1, cluster_2)
cluster1_indices = np.where(kmeans.labels_ == 0)[0]
cluster2_indices = np.where(kmeans.labels_ == 1)[0]

In [None]:
cluster_1_flat = np.array([val[0] for val in cluster_1])
cluster_2_flat = np.array([val[0] for val in cluster_2])
plt.figure()
plt.hist(cluster_1_flat, bins=100, color='blue', alpha=0.7, label='Cluster 1')
plt.hist(cluster_2_flat, bins=100, color='red', alpha=0.7, label='Cluster 2')
plt.xlabel('Values')
plt.ylabel('Frequency')
plt.title('Distribution of Cluster 1 and Cluster 2')
plt.legend()
plt.show()

### Baseline Correction

In [None]:
# Adaptive baseline correction using rolling median
window_duration_sec = 10  # Window size (seconds)
window_size = int(window_duration_sec * sample_rate)  # Window size (samples)

data_series = pd.Series(data)
adaptive_baseline = data_series.rolling(window=window_size, center=True, min_periods=1).median()

corrected_data = data - adaptive_baseline.values


In [None]:
# -------------------------------
# Define threshold and event detection parameters
# -------------------------------

# Set threshold (after baseline correction)
threshold = np.median(corrected_data) - 1.5 * np.std(corrected_data)  # Adaptive threshold

# Minimum event duration (e.g., 0.1 ms = 0.0001 s)
min_event_duration = 0.0001
min_event_points = int(min_event_duration * sample_rate)


In [None]:
plt.figure(figsize=(12, 5))
plt.plot(time, data, label="Original Current", color='blue', alpha=0.5)
plt.plot(time, adaptive_baseline, label="Adaptive Baseline", color='orange', linewidth=2)
adaptive_threshold = adaptive_baseline + threshold
plt.plot(time, adaptive_threshold, color='red', linestyle='--', label="Adaptive Threshold")
plt.xlabel("Time (s)")
plt.ylabel("Current (pA)")
plt.title("Ionic Current Trace with Adaptive Baseline and Adaptive Threshold")
plt.legend()
plt.tight_layout()
plt.show()

In [None]:
# Detect indices where current is below the adaptive threshold
below_threshold = np.where(data < adaptive_threshold)[0]

if len(below_threshold) == 0:
    print("No events detected with the current threshold.")
    exit()

# Group contiguous indices into individual events
def group_contiguous(indices, max_gap=100):
    """
    Group contiguous indices into clusters if consecutive indices
    are separated by no more than max_gap samples.
    """
    if len(indices) == 0:
        return []
    
    groups = []
    current_group = [indices[0]]
    
    for i in range(1, len(indices)):
        if indices[i] - indices[i - 1] <= max_gap:
            current_group.append(indices[i])
        else:
            groups.append(current_group)
            current_group = [indices[i]]
    
    groups.append(current_group)
    return groups

# Group indices below the threshold into events
event_groups = group_contiguous(below_threshold, max_gap=10)

# Filter out events that are too short to be considered valid
event_groups = [group for group in event_groups if len(group) >= min_event_points]

# Print the number of detected events
print(f"Detected {len(event_groups)} events.")

In [None]:
# Extract event properties
events = []
for group in event_groups:
    start_idx = group[0]
    end_idx = group[-1]
    event_time = time[start_idx:end_idx + 1]
    event_data = data[start_idx:end_idx + 1]
    
    duration = event_time[-1] - event_time[0]
    amplitude = np.min(event_data)  # Minimum current during the event
    
    events.append({
        "start_time": event_time[0],
        "end_time": event_time[-1],
        "duration": duration,
        "amplitude": amplitude
    })

# Print detected events
print("Detected events:")
for i, event in enumerate(events, start=1):
    print(f"Event {i}: Start = {event['start_time']:.3f} s, "
          f"End = {event['end_time']:.3f} s, Duration = {event['duration']:.3f} s, "
          f"Amplitude = {event['amplitude']:.2f} pA")

In [None]:
import pandas as pd

# Create a list to store event information as dictionaries
event_list = []
for i, group in enumerate(event_groups, start=1):
    # Get the indices for this event
    extra_points = 300
    start_idx = group[0] - extra_points
    end_idx = group[-1] + extra_points
    
    # Extract the time and current for this event and convert to lists
    event_time = time[start_idx:end_idx+1].tolist()
    event_current = data[start_idx:end_idx+1].tolist()
    
    # Calculate duration and amplitude (minimum current)
    duration = event_time[-1] - event_time[0]
    amplitude = min(event_current)
    
    # Append the event details to the list
    event_list.append({
        "event_index": i,
        "start_time": event_time[0],
        "end_time": event_time[-1],
        "duration": duration,
        "amplitude": amplitude,
        "time_series": event_time,
        "current_series": event_current
    })

# Convert the list of dictionaries into a pandas DataFrame
df_events = pd.DataFrame(event_list)

In [None]:
#df_events = pd.read_csv("detected_events.csv")

In [None]:
def compute_blockade_percentage(event):
    current_series = event["current_series"]
    if len(current_series) < 10:
        return 0
    baseline = np.median(current_series[:50]) if len(current_series) > 50 else np.median(current_series)
    min_current = min(current_series)
    blockade_depth = baseline - min_current
    if baseline == 0:
        return 0
    return (blockade_depth / baseline) * 100

# blockade_percentage to each event
for event in event_list:
    event["blockade_percentage"] = compute_blockade_percentage(event)

df_events = pd.DataFrame(event_list)

# Filtering events with blockade_percentage >= 5%
filtered_events_df = df_events[df_events["blockade_percentage"] >= 5].reset_index(drop=True)

# filtered DataFrame
filtered_events_df

In [None]:
import matplotlib.pyplot as plt
import matplotlib.widgets as widgets
import ast
import numpy as np

def plot_events_interactive_jupyter(filtered_events_df):

    try:
        from ipywidgets import interact, IntSlider, Dropdown
        import matplotlib.pyplot as plt
        
        if filtered_events_df.empty:
            print("No events to plot.")
            return
        
        # Get only the event indices that exist in filtered_events_df
        available_event_indices = sorted(filtered_events_df["event_index"].tolist())
        
        def plot_single_event(slider_position):
            """Plot a single event based on the slider position."""
            if slider_position < 0 or slider_position >= len(available_event_indices):
                print(f"Invalid slider position: {slider_position}")
                return
                
            event_idx = available_event_indices[slider_position]
            
            event_row = filtered_events_df[filtered_events_df["event_index"] == event_idx]
            if event_row.empty:
                print(f"Event index {event_idx} not found in filtered_events_df.")
                return
            
            event_row = event_row.iloc[0]
            time_series = event_row["time_series"]
            current_series = event_row["current_series"]
            start_time = event_row["start_time"]
            end_time = event_row["end_time"]
            blockade_percentage = event_row.get("blockade_percentage", 0)
            
            if isinstance(time_series, str):
                time_series = ast.literal_eval(time_series)
            if isinstance(current_series, str):
                current_series = ast.literal_eval(current_series)
            
            plt.figure(figsize=(12, 6))
            plt.plot(time_series, current_series, label=f"Event {event_idx}", color="blue", linewidth=1.5)
            
            plt.xlabel("Time (s)")
            plt.ylabel("Current (A)")
            plt.title(f"Event {event_idx} ({slider_position+1}/{len(available_event_indices)}) - Blockade: {blockade_percentage:.1f}% (Time {start_time:.3f} to {end_time:.3f} s)")
            plt.legend()
            plt.grid(True, alpha=0.3)
            plt.tight_layout()
            plt.show()
        
        interact(plot_single_event, 
                slider_position=IntSlider(
                    value=0,
                    min=0,
                    max=len(available_event_indices)-1,
                    step=1,
                    description='Event:',
                    continuous_update=False,  # Only update when slider is released
                    layout={'width': '400px'}
                ))
        
        print(f"Use the slider to navigate between {len(available_event_indices)} available events")
        print(f"Event indices range: {available_event_indices[0]} to {available_event_indices[-1]}")
        print(f"Available event indices: {available_event_indices[:10]}{'...' if len(available_event_indices) > 10 else ''}")
        
    except ImportError:

def plot_single_event_manual(filtered_events_df, event_idx):

    import matplotlib.pyplot as plt
    import ast
    
    event_row = filtered_events_df[filtered_events_df["event_index"] == event_idx]
    if event_row.empty:
        print(f"Event index {event_idx} not found in filtered_events_df.")
        available_indices = sorted(filtered_events_df["event_index"].tolist())
        print(f"Available event indices: {available_indices}")
        return
    
    event_row = event_row.iloc[0]
    time_series = event_row["time_series"]
    current_series = event_row["current_series"]
    start_time = event_row["start_time"]
    end_time = event_row["end_time"]
    blockade_percentage = event_row.get("blockade_percentage", 0)
    
    if isinstance(time_series, str):
        time_series = ast.literal_eval(time_series)
    if isinstance(current_series, str):
        current_series = ast.literal_eval(current_series)
    
    plt.figure(figsize=(12, 6))
    plt.plot(time_series, current_series, label=f"Event {event_idx}", color="blue", linewidth=1.5)
    
    plt.xlabel("Time (s)")
    plt.ylabel("Current (A)")
    plt.title(f"Event {event_idx} - Blockade: {blockade_percentage:.1f}% (Time {start_time:.3f} to {end_time:.3f} s)")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

In [None]:
plot_events_interactive_jupyter(filtered_events_df)

In [None]:
filtered_events_df

### Features

In [None]:
import numpy as np
import pandas as pd
from sklearn.mixture import GaussianMixture

def estimate_levels_gmm(signal, max_states=6):
    X = signal.reshape(-1, 1)
    best_bic = np.inf
    best_k = 1
    best_gmm = None
    
    for k in range(1, max_states + 1):
        gmm = GaussianMixture(n_components=k, covariance_type='full', random_state=0)
        gmm.fit(X)
        bic = gmm.bic(X)
        if bic < best_bic:
            best_bic = bic
            best_k = k
            best_gmm = gmm
    
    level_values = np.sort(best_gmm.means_.flatten())
    return best_k, level_values

for event in event_list:
    event_time = np.array(event["time_series"])
    event_current = np.array(event["current_series"])
    
    if len(event_current) > 0:
        baseline = event_current[0]
        mean_current = event_current.mean()
        std_current = event_current.std()
        area = np.trapz(event_current, event_time)
        peak_current = event_current.max()
        
        num_level, level_values = estimate_levels_gmm(event_current, max_states=8)
        volatility = std_current / np.abs(mean_current) if mean_current != 0 else np.nan
    else:
        baseline = mean_current = std_current = area = peak_current = np.nan
        num_level = np.nan
        level_values = []
        volatility = np.nan

    if len(event_current) > 1:
        slopes = np.gradient(event_current, event_time)
        max_slope = slopes.max()
        min_slope = slopes.min()
    else:
        max_slope = min_slope = np.nan
    
    n_points = len(event_current)
    
    event.update({
        "baseline": baseline,
        "mean_current": mean_current,
        "std_current": std_current,
        "area": area,
        "peak_current": peak_current,
        "max_slope": max_slope,
        "min_slope": min_slope,
        "n_points": n_points,
        "num_level": num_level,
        "level_values": level_values.tolist(),  # Convert numpy array to list for storage
        "Noise": volatility
    })

# Build DataFrame
filtered_events_df = pd.DataFrame(event_list)


In [None]:
filtered_events_df

In [None]:
filtered_events_df.to_csv("detected_events_features.csv")

In [None]:
filtered_events_df.to_parquet('Lamdaa_with_features.parquet', compression='snappy')