## Fly song notebook

The goal is to identify two distinct motifs in recordings of Drosophila song: *pulse* song and *sine* song.

Although there exist __[complex strategies](https://bmcbiol.biomedcentral.com/articles/10.1186/1741-7007-11-11)__ for performing this analysis, we will use a very simple approach here.

For detection of *pulse song*, we will threshold the audio signal, under the assumption that the magnitude of bouts of pulse song is consistently higher than that of sine song or no song.

For detection of *sine song*, we will compute a short-time Fourier transform (STFT) and evaluate the peak spectral power (And associated frequency) at each time bin.

For both types of song, we then filter bouts to be of a minimum duration.

In [1]:
# Import packages

import neo
import pywt
import numpy as np
import matplotlib.pyplot as pl
from scipy.ndimage import label
from scipy.signal import stft
%matplotlib notebook

In [2]:
## -- User parameters -- ##

# data parameters
data_file = 'flysong.abf' # path to data file of interest

# analysis parameters
# sine detection parameters
stft_window = .025 # length of STFT window in seconds
sine_power_thresh_fxn = np.median # function used to compute the threshold on the power spectrum
min_sine_dur = .100 # minimum continuous time window for a sine song bout in seconds
within_sine_gap = .050 # maximum duration of time in seconds allowed below threshold inside a single sine bout
# pulse detection parameters
min_pulse_dur = .100 # minimum continuous time window for a pulse song bout in seconds
within_pulse_gap = .050 # maximum duration of time in seconds allowed below threshold inside a single pulse bout
within_singlepulse_gap = .001 # maximum duration of time in seconds allowed below threshold inside a single *pulse*
pulse_magnitude_thresh = 2. # number of absolute standard deviations from mean to consider a pulse signal
pulse_rolling_mean_win = 0.10 # length of window in seconds to compute local mean for thresholding

In [3]:
# define a function to plot a line coloured by another variable
import matplotlib.colors as mcolors
from matplotlib.colors import ListedColormap, BoundaryNorm
from matplotlib.collections import LineCollection
from matplotlib.lines import Line2D

def label_plot(x, y, c, c_colors=None, labels=None, ax=None, **kwargs):
    """Plot y versus x, colored c_color when c is nonzero
    """
    x = np.asarray(x).astype(float)
    y = np.asarray(y).astype(float)
    c = np.asarray(c).astype(int)
    ucol = np.unique(c)
    ucol = ucol[ucol!=0]
    
    if ax is None:
        ax = pl.gca()
    cmap = kwargs.pop('cmap', pl.cm.viridis)
    if c_colors is None:
        c_colors = cmap(np.linspace(0,1,len(ucol)))
    
    color = kwargs.pop('color', 'k')
    color = mcolors.to_rgba(color)
    
    lcmap = ListedColormap(np.concatenate([[color], c_colors]))
    norm = BoundaryNorm(np.arange(lcmap.N+1), lcmap.N)
    
    points = np.array([x, y]).T[:,None,:]
    segments = np.concatenate([points[:-1], points[1:]], axis=1)
    lc = LineCollection(segments, cmap=lcmap, norm=norm)
    lc.set_array((c[1:] | c[:-1]))
    
    coll = ax.add_collection(lc, **kwargs)
    ax.axis('auto')
    
    # legend
    if labels is not None:
        legend_lines = [Line2D([0], [0], color=lcmap.colors[i], lw=4) for i in range(lcmap.N)]
        ax.legend(legend_lines, labels, loc='best')
    
    return ax

# define a rolling average function
def rolling_mean(x, win=1, pad=False):
    x = np.asarray(x)
    cum = np.cumsum(x)
    cum[win:] = cum[win:] - cum[:-win]
    cum = cum[win-1:] / win
    if pad:
        cum = np.pad(cum, (int(np.floor(win/2)), int(np.ceil(win/2))-1), mode='edge')
    return cum

In [4]:
# load the data file into Python
r = neo.AxonIO(filename=data_file)
block = r.read_block()
segment = block.segments[0]

data = segment.analogsignals[0]
Ts = float(data.sampling_period)
fs = 1/Ts
data = data.as_array().squeeze()

cwt,cwt_freq = pywt.cwt(data, np.arange(1,129), 'gaus1', Ts)

In [5]:
# process the parameters (trailing underscore in variable name means units of samples as opposed to time)
win_ = int(round(stft_window/Ts)) # length of STFT window in samples
min_pulse_dur_ = int(round(min_pulse_dur/Ts)) # minimum continuous time window for a pulse song bout in samples
within_pulse_gap_ = int(round(within_pulse_gap/Ts)) # maximum duration of time in samples allowed below threshold inside a single pulse bout
within_singlepulse_gap_ = int(round(within_singlepulse_gap/Ts)) # maximum duration of time in samples allowed below threshold inside a single *pulse*
within_sine_gap_ = int(round(within_sine_gap/Ts)) # maximum duration of time in samples allowed below threshold inside a single sine bout
pulse_rolling_mean_win_ = int(round(0.300/Ts)) # length of window in samples to compute local mean for thresholding

In [6]:
## -- Pulse song -- ##

# rolling mean
rm = rolling_mean(np.abs(data), pulse_rolling_mean_win_, pad=True)

# threshold
putative_pulse = np.abs(data)-rm > pulse_magnitude_thresh*data.std()

# fill gaps within tolerance
neg_labs,_ = label(~putative_pulse)
gaps_to_fill = np.array([l for l in np.unique(neg_labs) if np.sum(neg_labs==l)<=within_pulse_gap_])
fill = np.array([nl in gaps_to_fill for nl in neg_labs])

is_pulse = putative_pulse | fill

In [7]:
# individual pulse analysis

# threshold
putative_singlepulse = np.abs(data)-rm > pulse_magnitude_thresh*data.std()

# fill gaps within tolerance
neg_labs,_ = label(~putative_singlepulse)
gaps_to_fill = np.array([l for l in np.unique(neg_labs) if np.sum(neg_labs==l)<=within_singlepulse_gap_])
fill = np.array([nl in gaps_to_fill for nl in neg_labs])

is_singlepulse = putative_singlepulse | fill

# wavelet analysis for frequency
pulse_labs,npulses = label(is_singlepulse)
pulse_carrier_freqs = np.zeros(data.size)
for pulse_id in np.arange(1,npulses+1):
    sub_cwt = cwt[:,pulse_labs==pulse_id]
    maxs = np.max(np.abs(sub_cwt), axis=1)
    peak_freq = cwt_freq[np.argmax(maxs)]
    pulse_carrier_freqs[pulse_labs==pulse_id] = peak_freq
pulse_carrier_freqs[~is_singlepulse] = np.nan

In [8]:
## -- Sine song -- ##

# perform STFT
freq,t,z = stft(data, fs=1/Ts, nperseg=win_, window='hann')
stft_Ts = np.mean(np.diff(t))
min_sine_dur_ = int(round(min_sine_dur/stft_Ts)) # minimum continuous time window for a sine song bout in samples
power = np.abs(z)**2 # perhaps just magnitude and not power will be better, not clear yet

# display STFT
#pl.pcolor(t, freq, power)

# analyze STFT

# power magnitude criteria
thresh = sine_power_thresh_fxn(power)
peak_power = np.max(power, axis=0)
is_high = peak_power > thresh

# freqency criteria
peak_freq = freq[np.argmax(power, axis=0)]
is_band = (peak_freq>80) & (peak_freq<220)

# combine magnitude and frequency critera
putative_sine = is_high & is_band

# threshold continuous segments that pass criteria
labs,nlab = label(putative_sine)
keep_labs = np.array([l for l in np.unique(labs) if np.sum(labs==l)>=min_sine_dur_ and l!=0])
labs = np.array([True if l in keep_labs else False for l in labs])
putative_sine = np.repeat(labs, int(np.ceil(data.size/labs.size)))[:len(data)]

# exclude pulse song from sine song definition
putative_sine = putative_sine & (~is_pulse) # sine song is only possible if a timepoint is not already considered pulse song

# fill gaps within tolerance
neg_labs,_ = label(~putative_sine)
gaps_to_fill = np.array([l for l in np.unique(neg_labs) if np.sum(neg_labs==l)<=within_sine_gap_])
fill = np.array([nl in gaps_to_fill for nl in neg_labs])
putative_sine = putative_sine | fill

is_sine = putative_sine

# using FFT determine carrier frequency within song bouts
sine_carrier_freqs = np.zeros(data.size)
song_labs,n_songs = label(is_sine)
for lab in range(1,n_songs+1):
    song = data[song_labs==lab]
    
    # fft
    fft = np.fft.fft(song)
    fft_freq = np.fft.fftfreq(song.size, Ts)
    #print('Resolution for song {}: {:0.2f} Hz'.format(lab,np.mean(np.abs(np.diff(fft_freq)))))
    power = np.abs(fft)**2
    peak_freq = np.abs(fft_freq[np.argmax(power)])
    
    sine_carrier_freqs[song_labs==lab] = peak_freq
    
# or use cwt in windows
winsize = 100
nwin = int(np.ceil(data.size/winsize))
for i in range(nwin):
    sub_cwt = cwt[:,i*winsize:i*winsize+winsize]
    maxs = np.max(np.abs(sub_cwt), axis=1)
    peak_freq = cwt_freq[np.argmax(maxs)]
    sine_carrier_freqs[i*winsize:i*winsize+winsize] = peak_freq
    
sine_carrier_freqs[~is_sine] = np.nan # blank out frequencies for those samples that are not within a sine song

In [9]:
time = np.arange(len(data))*Ts

song = np.zeros_like(data)
song[is_singlepulse] = 1
song[is_sine] = 2

ax = label_plot(time, data, song, labels=['No song','Pulse song','Sine song'], cmap=pl.cm.winter)
ax.set_xlabel('Time (sec)')
ax.set_ylabel('Audio amplitude (a.u.)')

ax_carrier = ax.twinx()
ax_carrier.plot(time, sine_carrier_freqs, marker='.', color='springgreen', linewidth=0)
ax_carrier.plot(time, pulse_carrier_freqs, marker='.', color='blue', linewidth=0)
ax_carrier.set_ylabel('Carrier frequency (Hz)')

<matplotlib.text.Text at 0x111a2ab00>