In [None]:
# Modified raster plot: one page per time bin, one row per cluster
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import warnings
warnings.filterwarnings('ignore')

# Set matplotlib parameters
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 10
plt.rcParams['axes.linewidth'] = 1.2

# Define time bins (assuming trigger_time contains trial boundaries)
# For example, if you want 50-second bins:
time_bin_duration = 50.0  # seconds

# Get all unique clusters
clusters = sorted(spike_inf['cluster_id'].unique())

# Calculate time bins based on the data range
min_time = spike_inf['time'].min()
max_time = spike_inf['time'].max()
total_duration = (max_time - min_time) / 20000.0  # Convert to seconds
num_bins = int(np.ceil(total_duration / time_bin_duration))

print(f"Creating raster plots: {len(clusters)} clusters, {num_bins} time bins")
print(f"Time range: {min_time/20000:.1f}s - {max_time/20000:.1f}s")
print(f"Bin duration: {time_bin_duration}s")
print(f"Clusters: {clusters}")

with PdfPages("/media/ubuntu/sda/mouse_test/processed_results/raster_overall_20250909_Janus_1_250909_144332.pdf") as pdf:
    for bin_idx in range(num_bins):
        # Calculate time window for this bin
        bin_start_time = min_time + bin_idx * time_bin_duration * 20000  # Convert to samples
        bin_end_time = min(bin_start_time + time_bin_duration * 20000, max_time)
        
        # Convert to seconds for display
        bin_start_sec = bin_start_time / 20000.0
        bin_end_sec = bin_end_time / 20000.0
        
        # Create figure with white background
        fig, ax = plt.subplots(figsize=(12, 7.5))
        ax.set_facecolor('white')
        fig.patch.set_facecolor('white')
        
        # Set up the plot - FIXED: proper y-axis range
        ax.set_xlim(0, time_bin_duration)
        ax.set_ylim(-0.5, len(clusters) - 0.5)  # This should work correctly
        
        # Plot spikes for each cluster
        for i, cluster_id in enumerate(clusters):
            # Filter spikes for this cluster and time bin
            cluster_spikes = spike_inf[
                (spike_inf['cluster_id'] == cluster_id) & 
                (spike_inf['time'] >= bin_start_time) & 
                (spike_inf['time'] < bin_end_time)
            ]
            
            if not cluster_spikes.empty:
                # Convert spike times to relative time within the bin
                relative_times = (cluster_spikes['time'] - bin_start_time) / 20000.0
                
                # Plot spikes as vertical lines
                for spike_time in relative_times:
                    ax.axvline(x=spike_time, ymin=i-0.4, ymax=i+0.4, 
                             color='black', linewidth=0.3, alpha=0.8)
            
            # Add cluster label on the left
            spike_count = len(cluster_spikes)
            label_text = f"Cluster {cluster_id}\n({spike_count} spikes)"
            
            ax.text(-time_bin_duration * 0.05, i, label_text, 
                    verticalalignment='center', horizontalalignment='right',
                    fontsize=8, bbox=dict(boxstyle='round,pad=0.3', 
                    facecolor='lightgray', alpha=0.7))
        
        # Set labels and title
        ax.set_xlabel('Time (seconds)', fontsize=12)
        ax.set_ylabel('Clusters', fontsize=12)
        ax.set_title(f'Time Bin {bin_idx+1}/{num_bins} | {bin_start_sec:.1f}-{bin_end_sec:.1f}s', 
                    fontsize=14, fontweight='bold')
        
        # Set y-axis ticks - FIXED: ensure all clusters are visible
        ax.set_yticks(range(len(clusters)))
        ax.set_yticklabels([f'C{cid}' for cid in clusters])
        
        # Add grid
        ax.grid(True, alpha=0.3, axis='x')
        
        # Add bin info
        total_spikes_in_bin = len(spike_inf[
            (spike_inf['time'] >= bin_start_time) & 
            (spike_inf['time'] < bin_end_time)
        ])
        info_text = f"Bin {bin_idx+1}/{num_bins} | Time: {bin_start_sec:.1f}-{bin_end_sec:.1f}s | Total Spikes: {total_spikes_in_bin}"
        ax.text(0.5, 1.02, info_text, transform=ax.transAxes, 
                horizontalalignment='center', fontsize=10,
                bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.8))
        
        plt.tight_layout()
        
        # Save page to PDF
        pdf.savefig(fig, bbox_inches='tight', facecolor='white')
        plt.close(fig)
        
        print(f"  ✓ Added bin {bin_idx+1}/{num_bins} ({total_spikes_in_bin} spikes)")

print("Raster plot saved successfully!")


In [None]:
# Alternative version: Using trigger_time for time bins
# This version uses the original trigger_time array to define time bins

def create_raster_by_time_bins(spike_inf, trigger_time, output_path, time_bin_duration=50.0):
    """
    Create raster plot with one page per time bin, one row per cluster.
    
    Args:
        spike_inf: DataFrame with columns ['cluster_id', 'time']
        trigger_time: Array of trigger times (in samples)
        output_path: Path for output PDF file
        time_bin_duration: Duration of each time bin in seconds
    """
    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt
    from matplotlib.backends.backend_pdf import PdfPages
    
    # Get all unique clusters
    clusters = sorted(spike_inf['cluster_id'].unique())
    
    # Calculate time bins based on trigger_time or data range
    if trigger_time is not None and len(trigger_time) > 1:
        # Use trigger_time to define bins
        min_time = trigger_time[0]
        max_time = trigger_time[-1]
        print(f"Using trigger_time range: {min_time/20000:.1f}s - {max_time/20000:.1f}s")
    else:
        # Use data range
        min_time = spike_inf['time'].min()
        max_time = spike_inf['time'].max()
        print(f"Using data range: {min_time/20000:.1f}s - {max_time/20000:.1f}s")
    
    total_duration = (max_time - min_time) / 20000.0  # Convert to seconds
    num_bins = int(np.ceil(total_duration / time_bin_duration))
    
    print(f"Creating raster plots: {len(clusters)} clusters, {num_bins} time bins")
    print(f"Bin duration: {time_bin_duration}s")
    
    with PdfPages(output_path) as pdf:
        for bin_idx in range(num_bins):
            # Calculate time window for this bin
            bin_start_time = min_time + bin_idx * time_bin_duration * 20000  # Convert to samples
            bin_end_time = min(bin_start_time + time_bin_duration * 20000, max_time)
            
            # Convert to seconds for display
            bin_start_sec = bin_start_time / 20000.0
            bin_end_sec = bin_end_time / 20000.0
            
            # Create figure with white background
            fig, ax = plt.subplots(figsize=(12, 7.5))
            ax.set_facecolor('white')
            fig.patch.set_facecolor('white')
            
            # Set up the plot
            ax.set_xlim(0, time_bin_duration)
            ax.set_ylim(-0.5, len(clusters) - 0.5)
            
            # Plot spikes for each cluster
            for i, cluster_id in enumerate(clusters):
                # Filter spikes for this cluster and time bin
                cluster_spikes = spike_inf[
                    (spike_inf['cluster_id'] == cluster_id) & 
                    (spike_inf['time'] >= bin_start_time) & 
                    (spike_inf['time'] < bin_end_time)
                ]
                
                if not cluster_spikes.empty:
                    # Convert spike times to relative time within the bin
                    relative_times = (cluster_spikes['time'] - bin_start_time) / 20000.0
                    
                    # Plot spikes as vertical lines
                    for spike_time in relative_times:
                        ax.axvline(x=spike_time, ymin=i-0.4, ymax=i+0.4, 
                                 color='black', linewidth=0.3, alpha=0.8)
                
                # Add cluster label on the left
                spike_count = len(cluster_spikes)
                label_text = f"Cluster {cluster_id}\n({spike_count} spikes)"
                
                ax.text(-time_bin_duration * 0.05, i, label_text, 
                        verticalalignment='center', horizontalalignment='right',
                        fontsize=8, bbox=dict(boxstyle='round,pad=0.3', 
                        facecolor='lightgray', alpha=0.7))
            
            # Set labels and title
            ax.set_xlabel('Time (seconds)', fontsize=12)
            ax.set_ylabel('Clusters', fontsize=12)
            ax.set_title(f'Time Bin {bin_idx+1}/{num_bins} | {bin_start_sec:.1f}-{bin_end_sec:.1f}s', 
                        fontsize=14, fontweight='bold')
            
            # Set y-axis ticks
            ax.set_yticks(range(len(clusters)))
            ax.set_yticklabels([f'C{cid}' for cid in clusters])
            
            # Add grid
            ax.grid(True, alpha=0.3, axis='x')
            
            # Add bin info
            total_spikes_in_bin = len(spike_inf[
                (spike_inf['time'] >= bin_start_time) & 
                (spike_inf['time'] < bin_end_time)
            ])
            info_text = f"Bin {bin_idx+1}/{num_bins} | Time: {bin_start_sec:.1f}-{bin_end_sec:.1f}s | Total Spikes: {total_spikes_in_bin}"
            ax.text(0.5, 1.02, info_text, transform=ax.transAxes, 
                    horizontalalignment='center', fontsize=10,
                    bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.8))
            
            plt.tight_layout()
            
            # Save page to PDF
            pdf.savefig(fig, bbox_inches='tight', facecolor='white')
            plt.close(fig)
            
            print(f"  ✓ Added bin {bin_idx+1}/{num_bins} ({total_spikes_in_bin} spikes)")
    
    print(f"Raster plot saved to: {output_path}")

# Example usage:
# create_raster_by_time_bins(spike_inf, trigger_time, "output.pdf", time_bin_duration=50.0)


In [None]:
# Example usage of the modified raster plot function
# Assuming you have spike_inf DataFrame and trigger_time array

# Example 1: Using the function with trigger_time
# create_raster_by_time_bins(spike_inf, trigger_time, 
#                           "/media/ubuntu/sda/mouse_test/processed_results/raster_by_bins.pdf", 
#                           time_bin_duration=50.0)

# Example 2: Using the function without trigger_time (uses data range)
# create_raster_by_time_bins(spike_inf, None, 
#                           "/media/ubuntu/sda/mouse_test/processed_results/raster_by_bins_no_trigger.pdf", 
#                           time_bin_duration=50.0)

# Example 3: Different time bin durations
# create_raster_by_time_bins(spike_inf, trigger_time, 
#                           "/media/ubuntu/sda/mouse_test/processed_results/raster_30s_bins.pdf", 
#                           time_bin_duration=30.0)

print("Modified raster plot functions ready!")
print("Key changes:")
print("1. One page per time bin (instead of one page per neuron)")
print("2. One row per cluster (instead of one row per trial)")
print("3. Time bins can be defined by trigger_time or data range")
print("4. Configurable time bin duration")
print("5. White background with black spike lines")


In [None]:
# Debug version: Simple test to see all clusters
import matplotlib.pyplot as plt
import numpy as np

# Test with a simple plot to see if all clusters are visible
fig, ax = plt.subplots(figsize=(12, 8))

# Get clusters
clusters = sorted(spike_inf['cluster_id'].unique())
print(f"Total clusters: {len(clusters)}")
print(f"Cluster IDs: {clusters}")

# Set up plot
ax.set_xlim(0, 50)  # 50 seconds
ax.set_ylim(-0.5, len(clusters) - 0.5)

# Plot a simple test for each cluster
for i, cluster_id in enumerate(clusters):
    # Just plot a simple line for each cluster
    ax.axhline(y=i, xmin=0, xmax=1, color='red', linewidth=2)
    
    # Add cluster label
    ax.text(-2, i, f"Cluster {cluster_id}", 
            verticalalignment='center', horizontalalignment='right',
            fontsize=10, bbox=dict(boxstyle='round', facecolor='yellow', alpha=0.7))

# Set y-axis ticks
ax.set_yticks(range(len(clusters)))
ax.set_yticklabels([f'C{cid}' for cid in clusters])

ax.set_xlabel('Time (seconds)')
ax.set_ylabel('Clusters')
ax.set_title(f'Test Plot - All {len(clusters)} Clusters Should Be Visible')

plt.tight_layout()
plt.show()

print("If you can see all clusters in this test plot, the issue is elsewhere.")
print("If you only see top and bottom clusters, there's a y-axis scaling issue.")


In [None]:
# Fixed version: Ensure all clusters are visible
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import warnings
warnings.filterwarnings('ignore')

def create_fixed_raster_plot(spike_inf, output_path, time_bin_duration=50.0):
    """
    Create raster plot with proper y-axis scaling to show all clusters.
    """
    # Get all unique clusters
    clusters = sorted(spike_inf['cluster_id'].unique())
    
    # Calculate time bins based on the data range
    min_time = spike_inf['time'].min()
    max_time = spike_inf['time'].max()
    total_duration = (max_time - min_time) / 20000.0  # Convert to seconds
    num_bins = int(np.ceil(total_duration / time_bin_duration))
    
    print(f"Creating raster plots: {len(clusters)} clusters, {num_bins} time bins")
    print(f"Time range: {min_time/20000:.1f}s - {max_time/20000:.1f}s")
    print(f"Bin duration: {time_bin_duration}s")
    print(f"Clusters: {clusters}")
    
    with PdfPages(output_path) as pdf:
        for bin_idx in range(num_bins):
            # Calculate time window for this bin
            bin_start_time = min_time + bin_idx * time_bin_duration * 20000  # Convert to samples
            bin_end_time = min(bin_start_time + time_bin_duration * 20000, max_time)
            
            # Convert to seconds for display
            bin_start_sec = bin_start_time / 20000.0
            bin_end_sec = bin_end_time / 20000.0
            
            # Create figure with white background
            fig, ax = plt.subplots(figsize=(12, 8))  # Increased height
            ax.set_facecolor('white')
            fig.patch.set_facecolor('white')
            
            # FIXED: Set up the plot with proper margins
            ax.set_xlim(0, time_bin_duration)
            # Ensure all clusters are visible with proper spacing
            ax.set_ylim(-0.5, len(clusters) - 0.5)
            
            # Plot spikes for each cluster
            for i, cluster_id in enumerate(clusters):
                # Filter spikes for this cluster and time bin
                cluster_spikes = spike_inf[
                    (spike_inf['cluster_id'] == cluster_id) & 
                    (spike_inf['time'] >= bin_start_time) & 
                    (spike_inf['time'] < bin_end_time)
                ]
                
                if not cluster_spikes.empty:
                    # Convert spike times to relative time within the bin
                    relative_times = (cluster_spikes['time'] - bin_start_time) / 20000.0
                    
                    # Plot spikes as vertical lines
                    for spike_time in relative_times:
                        ax.axvline(x=spike_time, ymin=i-0.4, ymax=i+0.4, 
                                 color='black', linewidth=0.3, alpha=0.8)
                
                # Add cluster label on the left
                spike_count = len(cluster_spikes)
                label_text = f"Cluster {cluster_id}\n({spike_count} spikes)"
                
                ax.text(-time_bin_duration * 0.08, i, label_text, 
                        verticalalignment='center', horizontalalignment='right',
                        fontsize=8, bbox=dict(boxstyle='round,pad=0.3', 
                        facecolor='lightgray', alpha=0.7))
            
            # Set labels and title
            ax.set_xlabel('Time (seconds)', fontsize=12)
            ax.set_ylabel('Clusters', fontsize=12)
            ax.set_title(f'Time Bin {bin_idx+1}/{num_bins} | {bin_start_sec:.1f}-{bin_end_sec:.1f}s', 
                        fontsize=14, fontweight='bold')
            
            # FIXED: Set y-axis ticks to ensure all clusters are visible
            ax.set_yticks(range(len(clusters)))
            ax.set_yticklabels([f'C{cid}' for cid in clusters])
            
            # Add grid
            ax.grid(True, alpha=0.3, axis='x')
            
            # Add bin info
            total_spikes_in_bin = len(spike_inf[
                (spike_inf['time'] >= bin_start_time) & 
                (spike_inf['time'] < bin_end_time)
            ])
            info_text = f"Bin {bin_idx+1}/{num_bins} | Time: {bin_start_sec:.1f}-{bin_end_sec:.1f}s | Total Spikes: {total_spikes_in_bin}"
            ax.text(0.5, 1.02, info_text, transform=ax.transAxes, 
                    horizontalalignment='center', fontsize=10,
                    bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.8))
            
            # FIXED: Adjust layout to ensure proper spacing
            plt.subplots_adjust(left=0.15, right=0.95, top=0.9, bottom=0.1)
            
            # Save page to PDF
            pdf.savefig(fig, bbox_inches='tight', facecolor='white')
            plt.close(fig)
            
            print(f"  ✓ Added bin {bin_idx+1}/{num_bins} ({total_spikes_in_bin} spikes)")
    
    print(f"Raster plot saved to: {output_path}")

# Test the fixed function
# create_fixed_raster_plot(spike_inf, "/media/ubuntu/sda/mouse_test/processed_results/raster_fixed.pdf", time_bin_duration=50.0)


In [None]:
# Minimal test version - just to verify all clusters are visible
import matplotlib.pyplot as plt

# Get clusters
clusters = sorted(spike_inf['cluster_id'].unique())
print(f"Number of clusters: {len(clusters)}")

# Create a simple test plot
fig, ax = plt.subplots(figsize=(10, 6))

# Set y-axis range
ax.set_ylim(-0.5, len(clusters) - 0.5)
ax.set_xlim(0, 50)

# Draw a horizontal line for each cluster
for i, cluster_id in enumerate(clusters):
    ax.axhline(y=i, xmin=0, xmax=1, color='blue', linewidth=1)
    ax.text(-5, i, f"C{cluster_id}", ha='right', va='center')

# Set y-axis ticks
ax.set_yticks(range(len(clusters)))
ax.set_yticklabels([f'C{cid}' for cid in clusters])

ax.set_xlabel('Time (s)')
ax.set_ylabel('Cluster ID')
ax.set_title(f'Test: All {len(clusters)} clusters should be visible')

plt.tight_layout()
plt.show()

print("This test plot should show ALL clusters from top to bottom.")
print("If you only see the top and bottom clusters, there's a matplotlib issue.")
