In [None]:
from scipy.io import loadmat
import numpy as np
import pandas as pd

from pathlib import Path
import matplotlib.pylab as plt

from scipy.ndimage import uniform_filter1d
from skvideo.io import vread, vwrite
from PIL import ImageDraw, Image
from matplotlib import colors

cmap = plt.cm.tab20.colors

In [None]:
# thresholds for stop analysis
thresh_stop, thresh_walk = 5, 5
thresh_vel = 0.25

# define paths to analyse
ps = Path(r'Y:\Neha sapkal\flat chamber- sucrose- 2-17-23-29.34\new 4-20-23\re tracked-2').glob(r'**/*-track.mat')
ps = [Path('./starved-bb/starved-bb-track.mat')]
print(ps)

In [None]:
# run
for p_mat in ps:

    # files
    p_dir = p_mat.parent.parent
    p_txt = p_dir / 'coordinates.txt'
    p_png = p_dir / 'frame.png'
    p_avi = p_dir / (p_mat.name.replace('-track.mat', '') + '.avi')

    p_avi_annot = p_avi.parent / '{}_annot.avi'.format(p_avi.with_suffix('').name)
    p_png2 = p_avi.parent / '{}_vel_stop.png'.format(p_avi.with_suffix('').name)

    # load tracking data
    m = loadmat(p_mat, squeeze_me=True, struct_as_record=False)
    data = vars(m['trk'])['data'][:, :, [0, 1]]
    
    # load border corrdinates
    pnt = np.loadtxt(p_txt).astype(int)

    # load first frame
    img = Image.open(p_png)

    # load video
    vid = vread(str(p_avi))

    
    df = pd.DataFrame()
    for i, d in enumerate(data):

        fig, ax = plt.subplots()

        # plot photo
        ax.imshow(img)
        # plot line defining points
        ax.scatter(pnt[:,0], pnt[:,1], zorder=99, color='k')

        # define line separating left and right
        line = dict()
        for p1, p2 in zip(pnt, pnt[1:]):
            dx = p2[0] - p1[0]
            dy = p2[1] - p1[1]
            pxl = np.max([dx, dy])

            x = np.linspace(p1[0], p2[0], pxl + 1).astype(int)
            y = np.linspace(p1[1], p2[1], pxl + 1).astype(int)

            # plot line definition
            ax.scatter(x, y, marker=',', s=1, zorder=98, color='k')

            line = {**line, **{ j: i for i, j in zip(x, y)}}


        # drop nan frames
        fnan = np.isnan(d).any(axis=1)
        d = d[~fnan]

        x = d[:, 0]
        y = d[:, 1]

        # masks for left and right of line
        bl = np.array([line[int(j)] >= i for i, j in zip(x, y)])
        br = np.array([line[int(j)] < i for i, j in zip(x, y)])

        # plot trajectory separated by left and right
        ax.scatter(x[br], y[br], marker=',', s=1, color=cmap[2*i + 1])
        ax.scatter(x[bl], y[bl], marker=',', s=1, color=cmap[2*i])

        fig.savefig(p_dir / 'right_left_fly_{}.png'.format(i+1))
        plt.close(fig)

        # store in dataframe
        df.loc[i, 'fly'] = i + 1
        df.loc[i, 'frames_left'] = bl.sum()
        df.loc[i, 'frames_right'] = br.sum()
        df.loc[i, 'frames_left/right'] = bl.sum() / br.sum()
        df.loc[i, 'dropped_frames'] = fnan.sum()

        # count stops
        fig, ax = plt.subplots(figsize=(100, 4))
        number = True

        rgb = tuple([int(f*255) for f in colors.to_rgb('C{}'.format(i))])

        fnan = np.isnan(d).any(axis=1)

        d = d[~fnan]
        vel = np.linalg.norm(np.diff(d, axis=0), axis=1)
        ax.plot(vel, c='C0', label='raw trace')

        vel = np.linalg.norm(np.diff(uniform_filter1d(d, 15, axis=0), axis=0), axis=1)
        ax.plot(vel, c='C1', label='smoothed')
        
        stop = pd.Series(vel < thresh_vel)

        # cycle through stop and walk periods
        split = np.split(stop, np.flatnonzero(np.diff(stop)) + 1)
        for s in split:
            # if walk periods are shorter than thresh_walk, set them to stop
            if not s.sum() and (len(s) < thresh_walk):
                stop.loc[s.index] = True

        # redefine periods and cylce again
        split = np.split(stop, np.flatnonzero(np.diff(stop)) + 1)
        for s in split:
            # if stop intervals are shorter than thresh_stop, set them to walk
            if s.sum() and (len(s) < thresh_stop):
                stop.loc[s.index] = False

        # count stops and write to video
        split = np.split(stop, np.flatnonzero(np.diff(stop)) + 1)
        n_r, n_l = 0, 0
        f_r, f_l = 0, 0
        for s in split:
            if s.sum():
                f_i = s.index[0]
                if br[f_i]:
                    n_r += 1
                    f_r += s.sum()
                else:
                    n_l += 1
                    f_l += s.sum()

                for f in s.index:
                    x, y = d[f].astype(int)
                    
                    if number:
                        img = Image.fromarray(vid[f])
                        draw = ImageDraw.Draw(img)
                        draw.text((x, y), str(n_r + n_l), rgb)
                        vid[f] = np.array(img)
                    else: 
                        w = 3
                        vid[f, y-w:y+w, x-w:x+w, :] = rgb

        df.loc[i, 'stop_frames_left'] = f_l
        df.loc[i, 'stop_frames_right'] = f_r

        df.loc[i, 'n_stop_left'] = n_l
        df.loc[i, 'n_stop_right'] = n_r

        print(n_r, n_l)

        x = stop.index[stop]
        y = np.zeros_like(x) - 0.25
        ax.scatter(x, y, marker='.', color='k')

        ax.axhline(thresh_vel, c='gray', ls='--')

        ax.axhline(0, c='k', lw=0.5)
        ax.set_xlabel('frame')
        ax.set_ylabel('velocity')
        ax.set_ylim(-1, 7)
        # ax.set_xlim(0, 500)
        ax.set_title('fly {} | stops L: {} R: {}'.format(i, n_l, n_r))
        ax.legend()

        fig.savefig(p_dir / 'stops_fly_{}.png'.format(i+1))
        plt.close(fig)

    df.loc[:, 'fraction_stop_frames_left'] = df.loc[:, 'stop_frames_left'] / df.loc[:, 'frames_left']
    df.loc[:, 'fraction_stop_frames_right'] = df.loc[:, 'stop_frames_right'] / df.loc[:, 'frames_right']

    df.loc[:, 'stops_per_frame_left'] = df.loc[:, 'n_stop_left'] / df.loc[:, 'frames_left']
    df.loc[:, 'stops_per_frame_right'] = df.loc[:, 'n_stop_right'] / df.loc[:, 'frames_right']
    cols = ['fly', 'frames_left', 'frames_right', 'dropped_frames', 'n_stop_left', 'n_stop_right']
    df.loc[:, cols] = df.loc[:, cols].astype(int)
    df.to_csv(p_dir / 'right_left.csv')