# Process Sorted Data and Generate Raster Plots

This notebook processes sorted spike sorting data from the `sorted` folder and generates raster plots for each neuron.

## Features:
- Load cluster information and spike data from phy_folder_for_kilosort
- Generate cluster_inf and spike_inf DataFrames
- Create raster plots for each neuron (20s per PDF page)
- Export results as CSV files and PDF plots


In [1]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import glob
from typing import List, Tuple, Dict
import warnings
warnings.filterwarnings('ignore')

# Set matplotlib parameters for better plots
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 10
plt.rcParams['axes.linewidth'] = 1.2


## Data Loading Functions


In [2]:
def load_cluster_info(phy_dir: str) -> pd.DataFrame:
    """
    Load cluster information from cluster_info.tsv file.
    
    Args:
        phy_dir: Path to phy_folder_for_kilosort directory
        
    Returns:
        DataFrame with cluster information
    """
    cluster_info_path = os.path.join(phy_dir, 'cluster_info.tsv')
    if not os.path.exists(cluster_info_path):
        raise FileNotFoundError(f"Cluster info file not found: {cluster_info_path}")
    
    df = pd.read_csv(cluster_info_path, sep='\t')
    
    # Ensure cluster_id column exists
    if 'cluster_id' not in df.columns:
        raise ValueError(f"cluster_id column not found in {cluster_info_path}")
    
    return df


def load_spike_data(phy_dir: str) -> pd.DataFrame:
    """
    Load spike-level data from spike_clusters.npy and spike_times.npy files.
    
    Args:
        phy_dir: Path to phy_folder_for_kilosort directory
        
    Returns:
        DataFrame with spike data (cluster_id, time)
    """
    spike_clusters_path = os.path.join(phy_dir, 'spike_clusters.npy')
    spike_times_path = os.path.join(phy_dir, 'spike_times.npy')
    
    if not os.path.exists(spike_clusters_path):
        raise FileNotFoundError(f"Spike clusters file not found: {spike_clusters_path}")
    if not os.path.exists(spike_times_path):
        raise FileNotFoundError(f"Spike times file not found: {spike_times_path}")
    
    # Load numpy arrays
    spike_clusters = np.load(spike_clusters_path)
    spike_times = np.load(spike_times_path)
    
    # Flatten arrays
    spike_clusters = spike_clusters.flatten()
    spike_times = spike_times.flatten()
    
    if len(spike_clusters) != len(spike_times):
        raise ValueError(f"Mismatch in spike_clusters and spike_times lengths: {len(spike_clusters)} vs {len(spike_times)}")
    
    # Create DataFrame
    df = pd.DataFrame({
        'cluster_id': spike_clusters.astype(int),
        'time': spike_times.astype(int)
    })
    
    return df


def get_sample_rate(phy_dir: str) -> float:
    """
    Get sample rate from params.py file.
    
    Args:
        phy_dir: Path to phy_folder_for_kilosort directory
        
    Returns:
        Sample rate in Hz
    """
    params_path = os.path.join(phy_dir, 'params.py')
    if not os.path.exists(params_path):
        print(f"Warning: params.py not found in {phy_dir}, using default sample rate 20000 Hz")
        return 20000.0
    
    with open(params_path, 'r') as f:
        content = f.read()
    
    # Extract sample_rate from params.py
    import re
    match = re.search(r'sample_rate\s*=\s*([0-9.]+)', content)
    if match:
        return float(match.group(1))
    else:
        print(f"Warning: Could not extract sample_rate from {params_path}, using default 20000 Hz")
        return 20000.0


## Data Processing Functions


In [3]:
def process_single_session(session_dir: str) -> Tuple[pd.DataFrame, pd.DataFrame, Dict]:
    """
    Process a single session directory and return cluster_inf, spike_inf, and metadata.
    
    Args:
        session_dir: Path to session directory (e.g., 20250909_Janus_1_250909_144332)
        
    Returns:
        Tuple of (cluster_inf, spike_inf, metadata)
    """
    phy_dir = os.path.join(session_dir, 'phy_folder_for_kilosort')
    
    if not os.path.exists(phy_dir):
        raise FileNotFoundError(f"phy_folder_for_kilosort not found in {session_dir}")
    
    # Load data
    cluster_inf = load_cluster_info(phy_dir)
    spike_inf = load_spike_data(phy_dir)
    sample_rate = get_sample_rate(phy_dir)
    
    # Extract session metadata from directory name
    session_name = os.path.basename(session_dir)
    
    # Add metadata to cluster_inf
    cluster_inf['session'] = session_name
    cluster_inf['sample_rate'] = sample_rate
    cluster_inf = cluster_inf[cluster_inf['group'] == 'good']
    
    # Add metadata to spike_inf
    spike_inf['session'] = session_name
    spike_inf['sample_rate'] = sample_rate
    
    # Convert spike times to seconds
    spike_inf['time_seconds'] = spike_inf['time'] / sample_rate
    
    metadata = {
        'session': session_name,
        'sample_rate': sample_rate,
        'total_spikes': len(spike_inf),
        'total_clusters': len(cluster_inf),
        'recording_duration': spike_inf['time_seconds'].max() if len(spike_inf) > 0 else 0
    }
    
    return cluster_inf, spike_inf, metadata


def process_all_sessions(sorted_dir: str) -> Tuple[pd.DataFrame, pd.DataFrame, List[Dict]]:
    """
    Process all sessions in the sorted directory.
    
    Args:
        sorted_dir: Path to sorted directory
        
    Returns:
        Tuple of (all_cluster_inf, all_spike_inf, all_metadata)
    """
    all_cluster_inf = []
    all_spike_inf = []
    all_metadata = []
    
    # Get all session directories
    session_dirs = [d for d in os.listdir(sorted_dir) 
                   if os.path.isdir(os.path.join(sorted_dir, d))]
    session_dirs.sort()
    
    print(f"Found {len(session_dirs)} sessions to process:")
    for session_dir in session_dirs:
        print(f"  - {session_dir}")
    
    for session_dir in session_dirs:
        session_path = os.path.join(sorted_dir, session_dir)
        
        try:
            print(f"\nProcessing {session_dir}...")
            cluster_inf, spike_inf, metadata = process_single_session(session_path)
            
            all_cluster_inf.append(cluster_inf)
            all_spike_inf.append(spike_inf)
            all_metadata.append(metadata)
            
            print(f"  ✓ Loaded {metadata['total_clusters']} clusters, {metadata['total_spikes']} spikes")
            print(f"  ✓ Recording duration: {metadata['recording_duration']:.2f} seconds")
            
        except Exception as e:
            print(f"  ✗ Error processing {session_dir}: {e}")
            continue
    
    if len(all_cluster_inf) == 0:
        raise RuntimeError("No sessions were successfully processed")
    
    # Concatenate all data
    final_cluster_inf = pd.concat(all_cluster_inf, ignore_index=True)
    final_spike_inf = pd.concat(all_spike_inf, ignore_index=True)
    
    print(f"\n=== Summary ===")
    print(f"Total sessions processed: {len(all_metadata)}")
    print(f"Total clusters: {len(final_cluster_inf)}")
    print(f"Total spikes: {len(final_spike_inf)}")
    
    return final_cluster_inf, final_spike_inf, all_metadata


## Raster Plot Functions


In [4]:
def create_raster_plot(spike_inf: pd.DataFrame, cluster_inf: pd.DataFrame, 
                      session_name: str, output_path: str, 
                      time_window: float = 50.0, figsize: Tuple[float, float] = (15, 10)):
    """
    Create a multi-page raster plot for all clusters in a session.
    Each page shows a time window (default 50s), each cluster occupies one row.
    
    Args:
        spike_inf: DataFrame with spike data
        cluster_inf: DataFrame with cluster information
        session_name: Name of the session
        output_path: Path for output PDF file
        time_window: Time window in seconds for each page
        figsize: Figure size
    """
    # Filter data for this session
    session_spikes = spike_inf[spike_inf['session'] == session_name]
    session_clusters = cluster_inf[cluster_inf['session'] == session_name]
    
    if len(session_spikes) == 0:
        print(f"No spike data found for session {session_name}")
        return
    
    # Get unique cluster IDs and sort them
    cluster_ids = sorted(session_clusters['cluster_id'].unique())
    
    # Get time range for the entire session
    min_time = session_spikes['time_seconds'].min()
    max_time = session_spikes['time_seconds'].max()
    total_duration = max_time - min_time
    
    # Calculate number of pages needed
    num_pages = int(np.ceil(total_duration / time_window))
    
    print(f"Creating raster plot for {len(cluster_ids)} clusters in session {session_name}...")
    print(f"Total duration: {total_duration:.1f}s, {num_pages} pages (50s per page)")
    
    with PdfPages(output_path) as pdf:
        for page in range(num_pages):
            # Calculate time window for this page
            page_start = min_time + page * time_window
            page_end = min(page_start + time_window, max_time)
            
            # Filter spikes for this time window
            page_spikes = session_spikes[
                (session_spikes['time_seconds'] >= page_start) & 
                (session_spikes['time_seconds'] < page_end)
            ]
            
            # Create figure with white background
            fig, ax = plt.subplots(figsize=figsize)
            ax.set_facecolor('white')
            fig.patch.set_facecolor('white')
            
            # Set up the plot
            ax.set_xlim(page_start, page_end)
            ax.set_ylim(-0.5, len(cluster_ids) - 0.5)
            
            # Plot spikes for each cluster
            for i, cluster_id in enumerate(cluster_ids):
                cluster_spikes = page_spikes[page_spikes['cluster_id'] == cluster_id]
                
                if len(cluster_spikes) == 0:
                    continue
                
                # Plot spikes as vertical lines
                spike_times = cluster_spikes['time_seconds'].values
                for spike_time in spike_times:
                    ax.axvline(x=spike_time, ymin=i-0.4, ymax=i+0.4, color='black', linewidth=0.3, alpha=0.8)
                
                # Add cluster label on the left
                firing_rate = len(spike_times) / time_window if time_window > 0 else 0
                label_text = f"Cluster {cluster_id}\n({len(spike_times)} spikes)"
                
                ax.text(page_start - (page_end - page_start) * 0.05, i, label_text, 
                        verticalalignment='center', horizontalalignment='right',
                        fontsize=8, bbox=dict(boxstyle='round,pad=0.3', facecolor='lightgray', alpha=0.7))
            
            # Set labels and title
            ax.set_xlabel('Time (seconds)', fontsize=12)
            ax.set_ylabel('Clusters', fontsize=12)
            ax.set_title(f'Session: {session_name} | Page {page+1}/{num_pages} | {page_start:.1f}-{page_end:.1f}s', 
                        fontsize=14, fontweight='bold')
            
            # Set y-axis ticks
            ax.set_yticks(range(len(cluster_ids)))
            ax.set_yticklabels([f'C{cid}' for cid in cluster_ids])
            
            # Add grid
            ax.grid(True, alpha=0.3, axis='x')
            
            # Add page info
            page_spikes_count = len(page_spikes)
            info_text = f"Page {page+1}/{num_pages} | Time: {page_start:.1f}-{page_end:.1f}s | Spikes: {page_spikes_count}"
            ax.text(0.5, 1.02, info_text, transform=ax.transAxes, 
                    horizontalalignment='center', fontsize=10,
                    bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.8))
            
            plt.tight_layout()
            
            # Save page to PDF
            pdf.savefig(fig, bbox_inches='tight', facecolor='white')
            plt.close(fig)
            
            print(f"  ✓ Added page {page+1}/{num_pages} ({page_spikes_count} spikes)")
    
    print(f"Raster plot saved to: {output_path}")


def create_all_raster_pdfs(cluster_inf: pd.DataFrame, spike_inf: pd.DataFrame, 
                          output_dir: str, time_window: float = 50.0):
    """
    Create raster plot PDFs for all sessions.
    Each PDF contains multiple pages, each page shows a time window.
    
    Args:
        cluster_inf: DataFrame with cluster information
        spike_inf: DataFrame with spike data
        output_dir: Directory to save PDF files
        time_window: Time window in seconds for each page
    """
    os.makedirs(output_dir, exist_ok=True)
    
    # Get unique sessions
    sessions = sorted(cluster_inf['session'].unique())
    
    print(f"Creating raster plots for {len(sessions)} sessions...")
    
    for session in sessions:
        output_path = os.path.join(output_dir, f"raster_plot_{session}.pdf")
        create_raster_plot(spike_inf, cluster_inf, session, output_path, time_window)
    
    print(f"\nAll raster plots saved to: {output_dir}")


## Main Processing Pipeline


In [5]:
# Set paths
sorted_dir = "/media/ubuntu/sda/mouse_test/sorted"
output_dir = "/media/ubuntu/sda/mouse_test/processed_results"

# Create output directory
os.makedirs(output_dir, exist_ok=True)

print("=== Processing Sorted Data ===")
print(f"Sorted directory: {sorted_dir}")
print(f"Output directory: {output_dir}")
print()


=== Processing Sorted Data ===
Sorted directory: /media/ubuntu/sda/mouse_test/sorted
Output directory: /media/ubuntu/sda/mouse_test/processed_results



In [6]:
# Process all sessions
cluster_inf, spike_inf, metadata = process_all_sessions(sorted_dir)


Found 5 sessions to process:
  - 20250909_Janus_1_250909_144332
  - 20250909_Janus_2_250909_145527
  - 20250909_Janus_3_250909_151329
  - 20250910_Janus2_1_250910_110238
  - 20250910_Janus2_2_250910_112659

Processing 20250909_Janus_1_250909_144332...
  ✓ Loaded 15 clusters, 274041 spikes
  ✓ Recording duration: 599.38 seconds

Processing 20250909_Janus_2_250909_145527...
  ✓ Loaded 17 clusters, 499201 spikes
  ✓ Recording duration: 944.69 seconds

Processing 20250909_Janus_3_250909_151329...
  ✓ Loaded 5 clusters, 112999 spikes
  ✓ Recording duration: 298.93 seconds

Processing 20250910_Janus2_1_250910_110238...
  ✓ Loaded 13 clusters, 713816 spikes
  ✓ Recording duration: 1370.29 seconds

Processing 20250910_Janus2_2_250910_112659...
  ✓ Loaded 5 clusters, 143147 spikes
  ✓ Recording duration: 824.95 seconds

=== Summary ===
Total sessions processed: 5
Total clusters: 55
Total spikes: 1743204


In [7]:
# Display summary information
print("\n=== Session Summary ===")
for meta in metadata:
    print(f"Session: {meta['session']}")
    print(f"  Clusters: {meta['total_clusters']}")
    print(f"  Spikes: {meta['total_spikes']:,}")
    print(f"  Duration: {meta['recording_duration']:.2f} seconds")
    print(f"  Sample Rate: {meta['sample_rate']} Hz")
    print()

# Display cluster information summary
print("\n=== Cluster Information Summary ===")
print(f"Total clusters across all sessions: {len(cluster_inf)}")
print(f"Clusters by group:")
if 'group' in cluster_inf.columns:
    group_counts = cluster_inf['group'].value_counts()
    for group, count in group_counts.items():
        print(f"  {group}: {count}")

print(f"\nClusters by session:")
session_counts = cluster_inf['session'].value_counts()
for session, count in session_counts.items():
    print(f"  {session}: {count}")



=== Session Summary ===
Session: 20250909_Janus_1_250909_144332
  Clusters: 15
  Spikes: 274,041
  Duration: 599.38 seconds
  Sample Rate: 20000.0 Hz

Session: 20250909_Janus_2_250909_145527
  Clusters: 17
  Spikes: 499,201
  Duration: 944.69 seconds
  Sample Rate: 20000.0 Hz

Session: 20250909_Janus_3_250909_151329
  Clusters: 5
  Spikes: 112,999
  Duration: 298.93 seconds
  Sample Rate: 20000.0 Hz

Session: 20250910_Janus2_1_250910_110238
  Clusters: 13
  Spikes: 713,816
  Duration: 1370.29 seconds
  Sample Rate: 20000.0 Hz

Session: 20250910_Janus2_2_250910_112659
  Clusters: 5
  Spikes: 143,147
  Duration: 824.95 seconds
  Sample Rate: 20000.0 Hz


=== Cluster Information Summary ===
Total clusters across all sessions: 55
Clusters by group:
  good: 55

Clusters by session:
  20250909_Janus_2_250909_145527: 17
  20250909_Janus_1_250909_144332: 15
  20250910_Janus2_1_250910_110238: 13
  20250909_Janus_3_250909_151329: 5
  20250910_Janus2_2_250910_112659: 5


In [8]:
# Save cluster_inf and spike_inf as CSV files
cluster_csv_path = os.path.join(output_dir, "cluster_inf_all_sessions.csv")
spike_csv_path = os.path.join(output_dir, "spike_inf_all_sessions.csv")

cluster_inf.to_csv(cluster_csv_path, index=False)
spike_inf.to_csv(spike_csv_path, index=False)

print(f"\n=== Saved Data Files ===")
print(f"Cluster information: {cluster_csv_path}")
print(f"Spike information: {spike_csv_path}")
print(f"\nCluster info shape: {cluster_inf.shape}")
print(f"Spike info shape: {spike_inf.shape}")



=== Saved Data Files ===
Cluster information: /media/ubuntu/sda/mouse_test/processed_results/cluster_inf_all_sessions.csv
Spike information: /media/ubuntu/sda/mouse_test/processed_results/spike_inf_all_sessions.csv

Cluster info shape: (55, 31)
Spike info shape: (1743204, 5)


In [22]:
spike_inf_temp = spike_inf[spike_inf['session'] == '20250909_Janus_1_250909_144332']
trigger_time = np.linspace(0, spike_inf_temp['time_seconds'].max(), 15)

In [32]:
cluster_inf_temp = cluster_inf[cluster_inf['session'] == '20250909_Janus_1_250909_144332']


In [60]:
spike_inf_temp

Unnamed: 0,cluster_id,time,session,sample_rate,time_seconds
0,49,35,20250909_Janus_1_250909_144332,20000.0,0.00175
2,49,202,20250909_Janus_1_250909_144332,20000.0,0.01010
4,17,270,20250909_Janus_1_250909_144332,20000.0,0.01350
5,22,290,20250909_Janus_1_250909_144332,20000.0,0.01450
7,49,303,20250909_Janus_1_250909_144332,20000.0,0.01515
...,...,...,...,...,...
274036,17,11987149,20250909_Janus_1_250909_144332,20000.0,599.35745
274037,8,11987196,20250909_Janus_1_250909_144332,20000.0,599.35980
274038,7,11987301,20250909_Janus_1_250909_144332,20000.0,599.36505
274039,49,11987431,20250909_Janus_1_250909_144332,20000.0,599.37155


In [63]:
for session in cluster_inf['session'].unique():
    time_bin_duration = 50.0  # seconds
    cluster_inf_temp = cluster_inf[cluster_inf['session'] == session]

    spike_inf_temp = spike_inf[spike_inf['session'] == session]
    spike_inf_temp = spike_inf_temp[spike_inf_temp['cluster_id'].isin(cluster_inf_temp['cluster_id'].unique())]
    # Get all unique clusters
    clusters = sorted(spike_inf_temp['cluster_id'].unique())

    with PdfPages(f"/media/ubuntu/sda/mouse_test/processed_results/raster_overall_{session}.pdf") as pdf:
        for bin_idx in range(int(spike_inf_temp['time_seconds'].max() // time_bin_duration) + 1):
            # Calculate time window for this bin
            bin_start_time = bin_idx * time_bin_duration * 20000  # Convert to samples
            bin_end_time = min(bin_start_time + time_bin_duration * 20000, spike_inf_temp['time'].max())
            
            bin_start_sec = bin_start_time / 20000.0
            bin_end_sec = bin_end_time / 20000.0
            
            fig, ax = plt.subplots(figsize=(12, 7.5))
            ax.set_facecolor('white')
            fig.patch.set_facecolor('white')
            
            ax.set_xlim(0, time_bin_duration)
            
            # 设置y轴范围，确保每个cluster有足够的空间
            ax.set_ylim(-1, len(clusters))
            
            for i, cluster_id in enumerate(clusters):
                cluster_spikes = spike_inf_temp[
                    (spike_inf_temp['cluster_id'] == cluster_id) & 
                    (spike_inf_temp['time'] >= bin_start_time) & 
                    (spike_inf_temp['time'] < bin_end_time)
                ]
                
                if not cluster_spikes.empty:
                    relative_times = (cluster_spikes['time'] - bin_start_time) / 20000.0
                    
                    # 为每个cluster的spike绘制在不同高度的水平线上
                    y_position = i  # 每个cluster占据一个唯一的y位置
                    
                    # 使用eventplot而不是多个axvline，提高效率
                    ax.eventplot(relative_times, 
                            lineoffsets=y_position, 
                            linelengths=0.8,  # 控制线条高度
                            linewidths=0.5, 
                            colors='black', 
                            alpha=0.8)
            
            # 设置y轴标签
            ax.set_yticks(range(len(clusters)))
            ax.set_yticklabels([f'C{cid}' for cid in clusters])
            
            plt.tight_layout()
            pdf.savefig(fig, bbox_inches='tight', facecolor='white')
            plt.close(fig)
            
    print("Raster plot saved successfully!")


Raster plot saved successfully!
Raster plot saved successfully!
Raster plot saved successfully!
Raster plot saved successfully!
Raster plot saved successfully!


In [None]:
image_mean_spike_rate_data = {}
gk = GaussianKernel(25 * ms)

total_duration = 20000  

for image in range(1, 118):
    image_dict = {}
    
    for date in date_order:
        neuron_data = []
        
        for neuron in spike_inf['Neuron'].unique():
            neuron_df = spike_inf[spike_inf['Neuron'] == neuron]
            trigger_time_temp = trigger_time[(trigger_time['image'] == image) 
                                            & (trigger_time['date'] == int(date))]
            
            trial_rates = [] 
            
            for _, row in trigger_time_temp.iterrows():
                start = row['start'] - 5000
                end = row['end'] + 10000
                
                filtered_spikes = neuron_df[(neuron_df['date'] == date) 
                                          & (neuron_df['time'] >= start)
                                          & (neuron_df['time'] <= end)]
                
                relative_spikes = filtered_spikes['time'] - start
                relative_spikes = relative_spikes.values / 10
                temp_spiketrain = neo.SpikeTrain(relative_spikes.astype(int) * ms, t_stop=2000, t_start=0)
                inst_rate = instantaneous_rate(temp_spiketrain, kernel=gk, sampling_period=10*ms).magnitude
                trial_rates.append(inst_rate)
            
            if trial_rates:
                mean_rate = np.mean(trial_rates, axis=0)
            
            neuron_data.append(mean_rate)
        
        neuron_data = np.stack(neuron_data)
        image_dict[date] = neuron_data
    
    image_mean_spike_rate_data[image] = image_dict

In [70]:
relative_times_sec.values

array([1.565000e+01, 3.650000e+01, 1.240500e+02, ..., 4.988320e+04,
       4.998650e+04, 4.999165e+04], shape=(1294,))

In [75]:
import neo
from quantities import ms
from elephant.statistics import instantaneous_rate
from elephant.kernels import GaussianKernel
import numpy as np
from quantities import sec  # 添加sec单位

for session in cluster_inf['session'].unique():
    gk = GaussianKernel(2000 * ms)

    time_bin_duration = 50.0  # seconds
    cluster_inf_temp = cluster_inf[cluster_inf['session'] == session]

    spike_inf_temp = spike_inf[spike_inf['session'] == session]
    spike_inf_temp = spike_inf_temp[spike_inf_temp['cluster_id'].isin(cluster_inf_temp['cluster_id'].unique())]
    # Get all unique clusters
    clusters = sorted(spike_inf_temp['cluster_id'].unique())

    with PdfPages(f"/media/ubuntu/sda/mouse_test/processed_results/firing_rate_{session}.pdf") as pdf:
        for bin_idx in range(int(spike_inf_temp['time_seconds'].max() // time_bin_duration) + 1):
            # Calculate time window for this bin
            bin_start_time = bin_idx * time_bin_duration * 20000  # Convert to samples
            bin_end_time = min(bin_start_time + time_bin_duration * 20000, spike_inf_temp['time'].max())
            
            bin_start_sec = bin_start_time / 20000.0
            bin_end_sec = bin_end_time / 20000.0
            
            fig, ax = plt.subplots(figsize=(12, 7.5))
            ax.set_facecolor('white')
            fig.patch.set_facecolor('white')
            
            ax.set_xlim(0, time_bin_duration)
            
            
            for i, cluster_id in enumerate(clusters):
                cluster_spikes = spike_inf_temp[
                    (spike_inf_temp['cluster_id'] == cluster_id) & 
                    (spike_inf_temp['time'] >= bin_start_time) & 
                    (spike_inf_temp['time'] < bin_end_time)
                ]
                
                if not cluster_spikes.empty:
                    relative_times_sec = (cluster_spikes['time'] - bin_start_time) / 20000
                    temp_spiketrain = neo.SpikeTrain(relative_times_sec.values * sec, t_stop=time_bin_duration * sec, t_start=0 * sec)
                    inst_rate = instantaneous_rate(temp_spiketrain, kernel=gk, sampling_period=1000*ms)
                    
                    # 获取时间点和firing rate值
                    times = np.linspace(0, time_bin_duration, len(inst_rate))
                    rates = inst_rate.magnitude.flatten()
                
                    ax.plot(times, rates, color='black', linewidth=0.8, alpha=0.8)
                    
            
            # 设置y轴标签
            ax.set_ylabel('Firing Rate (Hz)')
            ax.set_xlabel('Time (s)')
            
            # 隐藏y轴刻度，因为实际值已被偏移
            ax.set_yticks([])
            
            plt.tight_layout()
            pdf.savefig(fig, bbox_inches='tight', facecolor='white')
            plt.close(fig)
            
print("Firing rate plots saved successfully!")


Firing rate plots saved successfully!


In [74]:
adjusted_rates

array([63.45622917, 63.63365882, 63.57373194, 63.50425547, 63.54025903,
       63.64183593, 63.75008083, 63.90017355, 64.19781196, 64.71384416,
       65.39737031, 66.04954479, 66.39189852, 66.24291922, 65.66487686,
       64.90140023, 64.1891685 , 63.69458293, 63.56510784, 63.85604459,
       64.35942027, 64.6639831 , 64.49267384, 63.90533768, 63.14042252,
       62.38867951, 61.75685142, 61.30741107, 61.04893821, 60.93253249,
       60.89204079, 60.88120215, 60.87896794, 60.87861366, 60.87857016,
       60.8785657 , 60.8785657 , 60.8785657 , 60.8785657 , 60.8785657 ,
       60.8785657 , 60.8785657 , 60.8785657 , 60.8785657 , 60.8785657 ,
       60.8785657 , 60.8785657 , 60.8785657 , 60.8785657 , 60.8785657 ])