In [None]:
import numpy as np
import os
import glob
from scipy import signal
import matplotlib.pyplot as plt
import matplotlib as mpl
import bat_functions as bf

In [None]:
fps = 25

In [None]:
plots_save_folder = '.../bats-data/plots/wing-beat-analysis'
os.makedirs(plots_save_folder, exist_ok=True)

In [None]:
process_raw_tracks = False

if process_raw_tracks:
    folders = glob.glob(
        '.../kasanka-bats/processed/deep-learning/*Nov'
    )
    day_folders = sorted(folders)
    min_thresh = 100

    observations = {}
    for day_folder in day_folders[:]:
        print(day_folder)

        date = os.path.basename(day_folder)
        track_files = sorted(
            glob.glob(os.path.join(day_folder, '*/raw_tracks.npy'))
        )
        for t_ind, track_file in enumerate(track_files):
            new_file = os.path.join(os.path.dirname(track_file), 
                                f'long_tracks_min_{min_thresh}.npy')
            if os.path.exists(new_file):
                continue
            tracks_raw = np.load(track_file, allow_pickle=True)
            tracks = bf.threshold_short_tracks(tracks_raw, 
                                               min_length_threshold=min_thresh)

            np.save(new_file, tracks)

In [None]:
def save_fig(save_folder, plot_title, fig=None):
    plot_name = plot_title.replace(' ', '-')
    file = os.path.join(save_folder, plot_name+'.png')
    if fig:
        fig.savefig(file, bbox_inches='tight', dpi=600)
        return
    
    plt.savefig(file, bbox_inches='tight', dpi=600)

In [None]:
def get_track_wingbeat_freqs(track, fps=25, min_freq=.75):
    """ Calculate peak wing freqs and assosiated power.
    
    track: track dict
    fps: frames per second track temporal resolution
    min_freq: minimum frequency for calculating peak_freq.
        Messily segmented tracks often have have high power
        close to 0 Hz because actual signal is not clear.
    """
    
    assert 'max_edge' in track.keys(), "Track must have max_edge already computed"
    
    if len(track['max_edge']) < 255:
        nperseg = len(track['max_edge'])
    else:
        nperseg = 255

    f, p = signal.welch(track['max_edge'], fps, nperseg=nperseg)
    peaks = signal.find_peaks(p, threshold=0, height=1)[0]

    track['freqs'] = f[peaks]
    track['freqs_power'] = p[peaks]

    peak_freq, freq_power  = bf.get_peak_freq(track['freqs'],
                                       track['freqs_power'],
                                       min_freq
                                      )
    track['peak_freq'] = peak_freq
    track['peak_freq_power'] = freq_power
    
def add_wingbeat_info_to_tracks(tracks, fps=25, min_freq=.75, 
                                remove_contours=False):
    """ Add main wingbeat freq info for all tracks in tracks after calculating
    all nessissary extra info. Can remove contours after getting bounding rects 
    to save memory.
    
    tracks: list of track dicts
    fps: frames per second - temporal resolution of tracks
    min_freq: minimum frequency for calculating peak_freq.
        Messily segmented tracks often have have high power
        close to 0 Hz because actual signal is not clear.
    remove_contours: if True remove raw contour info from track dicts.
        Useful if need to save memory
    """
    for track in tracks:
        if 'rects' not in track.keys():
            track['rects'] = bf.get_rects(track)
        if remove_contours:
            try:
                del track['contour']
            except KeyError:
                pass
                
        if 'max_edge' not in track.keys():
            track['max_edge'] = np.nanmax(track['rects'], 1)
        if 'mean_wing' not in track.keys():
            track['mean_wing'] = bf.get_wingspan(track)
        
        get_track_wingbeat_freqs(track, fps=fps, min_freq=min_freq)

In [None]:
process_long_tracks = True
remove_contours = True
overwrite = False

if process_long_tracks:

    folders = glob.glob(
        '.../kasanka-bats/processed/deep-learning/*Nov'
    )

    save_files = True
    day_folders = sorted(folders)
    min_thresh = 100

    all_tracks = {}
    for day_folder in day_folders[:1]:
        print(day_folder)

        date = os.path.basename(day_folder)
        track_files = sorted(
            glob.glob(
                os.path.join(day_folder, f'Chyniangale/long_tracks_min_{min_thresh}.npy'))
        )
        all_tracks[date] = {}
        for t_ind, track_file in enumerate(track_files):
            camera = track_file.split('/')[-2]
            print(camera)
            tracks = np.load(track_file, allow_pickle=True)
            add_wingbeat_info_to_tracks(tracks, 
                                        fps=fps, min_freq=.75, 
                                        remove_contours=remove_contours)
            if save_files:
                new_file = os.path.join(os.path.dirname(track_file), 
                                f'long_tracks_min_{min_thresh}_wingbeat.npy')
                if not os.path.exists(new_file) or overwrite:
                    np.save(new_file, tracks)
                    
            break
        break

In [None]:
tracks[0].keys()
    

In [None]:
peak_freqs = [t['peak_freq'] for t in tracks if ((t['peak_freq'] >= 3) & (t['peak_freq'] < 4))]

In [None]:
# plt.figure(figsize=(10,10))
peak_freqs = np.around(np.array(peak_freqs), 5)
unique_freqs = np.unique(peak_freqs)
print(unique_freqs.shape)
hist_info = plt.hist(peak_freqs, bins=200, density=True, range=(3,4))
# plt.figure()
_ = plt.hist(all_freqs, bins=200, density=True, alpha=.7, range=(3, 4))

length = 255
freq_resolution = 25 / length
samples = np.ones(length) * freq_resolution
measured_freqs = np.cumsum(samples)
length_freqs = measured_freqs[(measured_freqs>=3.0)&(measured_freqs<4)]
for f in length_freqs:
    print(f)
    plt.axvline(f, ls='--')
    
plt.xlabel('Frequency')
plt.ylabel('Track density')

title = 'sampling derived frequency peak origins'

save_fig(plots_save_folder, title)

In [None]:
counts = hist_info[0]
bins = hist_info[1]

In [None]:
bin_ind =np.argmax(counts) 
min_bin_val, max_bin_val = bins[bin_ind:bin_ind+2]

In [None]:
focal_tracks = []
for t in tracks:
    if (t['peak_freq'] >= min_bin_val) and (t['peak_freq'] < max_bin_val):
        if len(t['max_edge']) > 0:
            focal_tracks.append(t)

print(len(focal_tracks))

In [None]:
print(t['peak_freq'])

In [None]:
25 / 256, 25/100

In [None]:
# focal_peak = [t['first_frame'] for t in focal_tracks]
# focal_peak
# plt.scatter(focal_peak, np.arange(len(focal_peak)))

In [None]:
possible_frequencies = []
for t_num, t in enumerate(focal_tracks[::]):
    if len(t['max_edge']) < 255:
        nperseg = len(t['max_edge'])
    else:
        nperseg = 255
    
    f, p = signal.welch(t['max_edge'], fps, nperseg=nperseg)
    possible_frequencies.extend(f)
#     plt.figure()
#     plt.stem(f, p, use_line_collection=True)
#     plt.title(f"{len(f)}  {len(t['max_edge'])}")
#     peaks = signal.find_peaks(p, threshold=0, height=1)[0]
#     plt.figure()
#     plt.stem(t['freqs'], abs(t['freqs_power']), use_line_collection=True)
#     plt.title(t_num)
#     plt.figure()
#     plt.plot(t['max_edge'])
#     plt.title(t_num)


In [None]:
unique = np.unique(possible_frequencies)
threes = unique[(unique>=3)&(unique<4)]
threes.shape

In [None]:
threes

In [None]:
lengths = np.arange(100, 256)

In [None]:
freq_resolution = 25 / lengths

all_freqs = []

for length in lengths[::]:
    freq_resolution = 25 / length
    samples = np.ones(length) * freq_resolution
    measured_freqs = np.cumsum(samples)
    all_freqs.extend(measured_freqs[(measured_freqs < 4) & (measured_freqs >=3)])

In [None]:

_ = plt.hist(all_freqs, bins=200)

In [None]:

unique = np.unique(np.around(np.array(all_freqs), 5))
print(unique.shape)
plt.hist(unique, bins=100)

In [None]:
freqs  = np.array(all_freqs)

In [None]:
np.unique(freqs).shape

In [None]:
np.unique(np.ones(10)*1.1)