<a href="https://colab.research.google.com/github/dtabuena/Resources/blob/main/Tools/Scroll_Plot.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
def scroll_plot(trace,time,sample_rate,wrap_t,t_range=None,row_dim=[5,0.25],ylims=None,color='k',fig_ax=None,yticks=[0],scale=True,title=None,zmax=None):
    if t_range is not None:
        in_range = np.logical_and( time>=np.min(t_range), time<np.max(t_range) )
        trace = trace[in_range]
        time = time[in_range]
    if ylims is None:
        ylims = [np.min(trace),np.max(trace)]
    if zmax is not None:
        max_zv = (zmax * np.std(trace)) + np.mean(trace)


    num_samples = len(time)
    len_wrap = int(wrap_t*sample_rate)
    num_rows = int(np.ceil(num_samples/len_wrap))
    wrapped_trace = np.full((num_rows, len_wrap), np.nan)
    wrapped_time = np.full((num_rows, len_wrap), np.nan)
    wrapped_trace.flat[:len(trace)] = trace
    wrapped_time.flat[:len(time)] = time
    wrapped_time.flat[len(time)+1:] = time[-1]+np.arange(np.sum(np.isnan(wrapped_time)))*sample_rate

    if fig_ax is None:
        fig_dim = np.array(row_dim)*np.array((1,num_rows))
        fig,ax=plt.subplots(num_rows,1,figsize=fig_dim)
    else:
        [fig,ax] = fig_ax
    for r in range(num_rows):
        if np.max(np.abs(trace)) > max_zv:
              ax[r].set_ylim([-max_zv,max_zv] )
        ax[r].plot(wrapped_time[r,:],wrapped_trace[r,:],color=color)
        ax[r].set_position([0,1-r/num_rows,1,1/num_rows])
        ax[r].set_ylim(ylims)
        ax[r].axis('off')
        ax[r].set_xlim([wrapped_time[r,0],wrapped_time[r,0]+wrap_t])
        for y in yticks:
            wid = plt.rcParams['lines.linewidth']/2
            ax[r].axhline(y,color='k',linewidth=wid,linestyle='dotted')


    if title is not None:
        fig.suptitle(title,color='k')
        fig.subplots_adjust(top=0.9)

    if scale:
        y_scale = log_round(np.diff(ylims)[0]/4)
        x_scale = wrap_t*.05
        x_0 = time[0]

        x0 = time[0] +x_scale
        x1 = time[0]+x_scale +x_scale
        y0 = -y_scale - y_scale*.5
        y1 = 0 - y_scale*.5

        scale_color = 'r'
        ax[0].plot([x0,x1],[y0,y0],color=scale_color)
        ax[0].plot([x0,x0],[y0,y1],color=scale_color)
        ax[0].text( (x0+x1)/2 , y0, str(x_scale), color=scale_color,ha='center',va='top')
        ax[0].text(x0 , (y0+y1)/2 , str(y_scale), color=scale_color,ha='right',va='center')

    plt.show()
    return fig, ax

def log_round(x):
    lg = np.log10(x)
    lg_flr = np.floor(lg)
    x_rnd = np.round(x,int(-lg_flr))
    if lg>1:
        x_rnd=int(x_rnd)
    return x_rnd

"""
_ = scroll_plot(raw_lfp[:,0],time,sample_rate,wrap_t=1,t_range=[0,2],title='Title',zmax=3)
"""

In [None]:
def scroll_plot(trace, time, sample_rate, wrap_t, t_range=None, row_dim=[5, 0.25], ylims=None, color='k', fig_ax=None, yticks=[0], scale=True, title=None, zmax=None):
    """
    Plots a scrolling trace with a shared Y-axis across multiple rows, with each row displaying a segment of the trace.

    Parameters:
    - trace: 1D array-like of signal values
    - time: 1D array-like of time values matching trace
    - sample_rate: Sampling rate of the signal
    - wrap_t: Duration (in seconds) of each row's time span
    - t_range: Optional [start, end] time range to restrict the plotted data
    - row_dim: [width, height] of each row in inches
    - ylims: Optional [ymin, ymax] to set Y-axis limits for all rows
    - color: Color of the trace line
    - fig_ax: Optional tuple (fig, ax) to plot into specific axes
    - yticks: List of Y-axis tick positions to draw dotted lines
    - scale: Whether to add a scale bar to the plot
    - title: Optional title string for the plot
    - zmax: Optional threshold for trace amplitude in standard deviations

    Returns:
    - fig, ax: Matplotlib figure and axes objects
    """

    # Restrict data to the specified time range, if provided
    if t_range is not None:
        in_range = np.logical_and(time >= np.min(t_range), time < np.max(t_range))
        trace = trace[in_range]
        time = time[in_range]

    # Set default y-limits if not provided
    if ylims is None:
        ylims = [np.min(trace), np.max(trace)]

    # Set a maximum threshold for plotting based on z-score, if zmax is provided
    if zmax is not None:
        max_zv = (zmax * np.std(trace)) + np.mean(trace)

    # Calculate the number of samples and rows needed to wrap the trace
    num_samples = len(time)
    len_wrap = int(wrap_t * sample_rate)
    num_rows = int(np.ceil(num_samples / len_wrap))

    # Initialize empty arrays to hold wrapped trace and time data
    wrapped_trace = np.full((num_rows, len_wrap), np.nan)
    wrapped_time = np.full((num_rows, len_wrap), np.nan)

    # Fill the wrapped arrays with data and pad remaining values with NaNs
    wrapped_trace.flat[:len(trace)] = trace
    wrapped_time.flat[:len(time)] = time
    wrapped_time.flat[len(time) + 1:] = time[-1] + np.arange(np.sum(np.isnan(wrapped_time))) * sample_rate

    # Set up the figure and axes if not provided
    if fig_ax is None:
        fig_dim = np.array(row_dim) * np.array((1, num_rows))
        fig, ax = plt.subplots(1, 1, figsize=fig_dim)
    else:
        [fig, ax] = fig_ax

    y_step = np.max(trace)-np.min(trace)
    if zmax is not None:
        y_step = 2 * zmax * np.std(trace) + np.mean(trace)

    wrapped_time = wrapped_time-np.expand_dims(wrapped_time[:,0],-1)
    y_shift_mat = 1-np.cumsum(np.ones_like(wrapped_time),axis=0)
    y_shift_mat = y_shift_mat*y_step
    wrapped_trace_shift = wrapped_trace+y_shift_mat

    ax.plot(wrapped_time.T,wrapped_trace_shift.T,color='k')
    ax.set_position([0, 0, 1, 1])
    ax.axis('off')

    # Add a title if specified
    if title is not None:
        fig.suptitle(title, color='k')
        fig.subplots_adjust(top=0.9)


    # Draw a scale bar if scale=True
    if scale:
        y_scale = log_round(y_step / 4)
        x_scale = wrap_t * 0.05
        x0 = time[0] + x_scale
        x1 = time[0] + x_scale + x_scale
        y0 = -y_scale - y_scale * 0.5
        y1 = 0 - y_scale * 0.5

        # Plot scale bar in red on the first row
        scale_color = 'r'
        ax.plot([x0, x1], [y0, y0], color=scale_color)
        ax.plot([x0, x0], [y0, y1], color=scale_color)
        ax.text((x0 + x1) / 2, y0, str(x_scale), color=scale_color, ha='center', va='top')
        ax.text(x0, (y0 + y1) / 2, str(y_scale), color=scale_color, ha='right', va='center')


    # Add dotted horizontal lines at specified y-tick positions
    for y in yticks:
        wid = plt.rcParams['lines.linewidth'] / 2
        ax.plot(wrapped_time.T, y+y_shift_mat.T, color='k', linewidth=wid, linestyle='dotted')


    return fig, ax

def log_round(x):
    """
    Rounds the input to the nearest power of ten, for use in scale bar calculations.
    """
    lg = np.log10(x)
    lg_flr = np.floor(lg)
    x_rnd = np.round(x, int(-lg_flr))
    if lg > 1:
        x_rnd = int(x_rnd)
    return x_rnd

# Example usage:
_ = scroll_plot(raw_lfp[:,0], time, sample_rate, wrap_t=10, t_range=[0,60], title='Title',zmax=3)
