In [None]:
import numpy as np
import os
import glob
from scipy import signal
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
import bat_functions as bf

In [None]:
def save_fig(save_folder, plot_title, fig=None):
    plot_name = plot_title.replace(' ', '-')
    file = os.path.join(save_folder, plot_name+'.png')
    if fig:
        fig.savefig(file, bbox_inches='tight', dpi=600)
        return
    
    plt.savefig(file, bbox_inches='tight', dpi=600)

In [None]:
plots_save_folder = '.../bats-data/plots/wing-beat-analysis'
os.makedirs(plots_save_folder, exist_ok=True)

In [None]:
day_folders = sorted(
    glob.glob(
        '.../kasanka-bats/processed/deep-learning/*Nov'
    )
)

min_thresh = 100

all_tracks = {}
for day_folder in day_folders[:]:
    print(day_folder)
    date = os.path.basename(day_folder)
    track_files = sorted(
        glob.glob(
            os.path.join(day_folder, 
                         f'*/long_tracks_min_{min_thresh}_wingbeat.npy'
                        )
        )
    )
    all_tracks[date] = {}
    for t_ind, track_file in enumerate(track_files):
        camera = track_file.split('/')[-2]
        tracks = np.load(track_file, allow_pickle=True)
        all_tracks[date][camera] = tracks

In [None]:
shift = 48
HCONST = 1454.9 # pixels
FRAME_WIDTH = 2704 - (2 * shift)
WINGSPAN = .8 # meters, max extent while flying 

for date, day_tracks in all_tracks.items():
    for camera, tracks in day_tracks.items():
        for track in tracks:
            height = bf.calculate_height(track['mean_wing'], HCONST, WINGSPAN)
            track['height'] = height


In [None]:
constant = WINGSPAN * HCONST
print(constant)
wing_pixels = np.arange(20, 300)
plt.plot(constant / wing_pixels)

In [None]:
total = 0
for date, day_tracks in all_tracks.items():
    for camera, tracks in day_tracks.items():
        total += len(tracks)
print(total)

In [None]:
power = all_tracks['16Nov']['BBC'][0]['freqs_power']
freq = all_tracks['16Nov']['BBC'][0]['freqs']

plt.figure()
plt.stem(freq, abs(power), use_line_collection=True)

In [None]:
# plt.plot(camera_freqs)

# sorted_freqs = sorted(np.array(camera_freqs))
# print(sorted_freqs[0])
# plt.plot(sorted_freqs)
# print(np.array(camera_freqs).shape)

subset = np.array(camera_freqs[:])
subset = subset[~np.isnan(subset)]
plt.plot(subset)
# print(subset)
sorted_subset = sorted(subset)
# print(sorted_subset)
plt.plot(sorted_subset)

In [None]:
camera_names = {'16Nov':['NotChyniangale', 'Chyniangale',
                         'BBC', 'FibweParking', 'FibwePublic',
                         'MusoleTower', 'MusolePath', 'MusoleParking',
                         'Sunset', 'Puku'],
                '17Nov': ['NotChyniangale', 'Chyniangale',
                         'BBC', 'FibweParking2', 'FibwePublic',
                         'MusoleTower', 'MusolePath2', 'MusoleParking',
                         'Sunset', 'Puku'],
                '18Nov': ['NotChyniangale', 'Chyniangale',
                         'BBC', 'FibweParking', 'FibwePublic',
                         'MusoleTower', 'MusoleParking',
                         'Sunset', 'Puku'],
                '19Nov': ['NotChyniangale', 'Chyniangale',
                         'BBC', 'FibweParking', 'FibwePublic',
                         'MusoleTower', 'MusolePath', 'MusoleParking',
                         'Sunset', 'Puku'],
                '20Nov': ['NotChyniangale', 'Chyniangale',
                         'BBC', 'FibweParking', 'FibwePublic',
                         'MusoleTower', 'MusoleParking',
                         'Sunset', 'Puku'],
               }

percent = .999
                
for date, day_camera_names in camera_names.items():

    peak_freqs = []
    xs = []
    wingspans = []
    heights = []
    for t_ind, camera in enumerate(day_camera_names):
        tracks = all_tracks[date][camera]
        first_frames = []
        camera_freqs = []
        for track in tracks:
            wingspans.append(track['mean_wing'])
            camera_freqs.append(track['peak_freq'])
            first_frames.append(track['first_frame'])
            heights.append(track['height'])
        camera_freqs = np.array(camera_freqs)
        camera_freqs = camera_freqs[~np.isnan(camera_freqs)]
        sorted_freqs = sorted(camera_freqs)
        core_freqs = sorted_freqs[int(len(sorted_freqs)*(1-percent)):int(len(sorted_freqs)*percent)]
#         plt.figure()
#         plt.hist(core_freqs, bins=500)
        peak_freqs.extend(core_freqs)
        xs.extend([t_ind for _ in core_freqs])
        
    fig, (ax1) = plt.subplots(1, 1)
    xs = np.array(xs)
    peak_freqs=np.array(peak_freqs)
    sns.violinplot(x=xs, y=peak_freqs, ax=ax1)

    # sns.violinplot(xs[~np.isnan(peak_freqs)], peak_freqs[~np.isnan(peak_freqs)], ax=ax1)
#     ax1.set_title(f"{date}, peak frequencies")
    ax1.set_xticklabels(day_camera_names, rotation = 90)
    ax1.set_ylabel('peak wingbeat frequency')
#     break
                                      
    # #     plt.figure()
    # wingspans = np.array(wingspans)
    # xs = np.array(xs)
    # sns.violinplot(xs[wingspans<100], wingspans[wingspans<100], ax=ax2)
    # ax2.set_title(f"{date}, wingspans")
    # ax2.set_xticklabels(camera_names, rotation = 45)
    # ax2.set_ylim(0, 80)

    # sns.violinplot(xs, heights, ax=ax3)
    # ax3.set_title(f"{date}, height")
    # ax3.set_xticklabels(camera_names, rotation = 45)
    # ax3.set_ylim(0, 150)

    title = f"{date} wing beat info for all cameras {percent} percent"
    save_fig(plots_save_folder, title, fig)

In [None]:
for date, day_tracks in all_tracks.items():
    peak_freqs = []
    xs = []
    wingspans = []
    camera_names = []
    heights = []
    for t_ind, (camera, tracks) in enumerate(day_tracks.items()):

        camera_names.append(camera)
        first_frames = []
        for track in tracks:
            wingspans.append(track['mean_wing'])
            peak_freqs.append(track['peak_freq'])
            first_frames.append(track['first_frame'])
            heights.append(track['height'])
            xs.append(t_ind)
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(30,10))
    sns.violinplot(xs, peak_freqs, ax=ax1)
    ax1.set_title(f"{date}, peak frequencies")
    ax1.set_xticklabels(camera_names, rotation = 45)
#     plt.figure()
    wingspans = np.array(wingspans)
    xs = np.array(xs)
    sns.violinplot(xs[wingspans<100], wingspans[wingspans<100], ax=ax2)
    ax2.set_title(f"{date}, wingspans")
    ax2.set_xticklabels(camera_names, rotation = 45)
    ax2.set_ylim(0, 80)
    
    sns.violinplot(xs, heights, ax=ax3)
    ax3.set_title(f"{date}, height")
    ax3.set_xticklabels(camera_names, rotation = 45)
    ax3.set_ylim(0, 150)

In [None]:
def get_power(raw_freqs, raw_powers, min_freq):
    """ Calculate max power frequency above min_freq.
    
    raw_freqs: list of frequencies
    raw_powers: list of powers assosiated with each raw freq value
    min_freq: minimum acceptable frequency value
    """
    
    freqs = raw_freqs[raw_freqs>min_freq]
    powers = raw_powers[raw_freqs>min_freq]
    
    if np.any(np.isnan(freqs)) or len(freqs)==0:
        return np.nan
    
    return powers[np.argmax(powers)] / np.sum(powers)

In [None]:
min_freq = .75
for date, day_tracks in all_tracks.items():
    for t_ind, (camera, tracks) in enumerate(day_tracks.items()):
#         print(track['freqs'], track['freqs_power'])
        track['peak_freq_power'] = get_power(track['freqs'],
                                             track['freqs_power'],
                                             min_freq)
#         print(track['peak_freq_power'])
        if track['freqs']:
            print(track['freqs'])
            break

In [None]:
for camera, tracks in all_tracks['17Nov'].items():
    if camera != "MusoleTower":
        continue
    wingspans = []
    peak_freqs = []
    first_frames = []
    heights = []
    power = []
    for track in tracks:
        wingspans.append(track['mean_wing'])
        peak_freqs.append(track['peak_freq'])
        first_frames.append(track['first_frame'])
        heights.append(track['height'])
        power.append(track['peak_freq_power'])

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))
    ax1.scatter(heights, peak_freqs, alpha=power)
    ax1.set_title(f'{camera} height')
    ax2.scatter(first_frames, peak_freqs, alpha=power)
    ax2.set_title(f'{camera} frame')
    

In [None]:
track.keys()