In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import cv2
from skg import nsphere_fit
import matplotlib.collections as mcoll
import matplotlib.path as mpath
import scipy.io as sio
import json
from tqdm import tqdm
import os


def colorline(
    x, y, z=None, cmap=plt.get_cmap('copper'), norm=plt.Normalize(0.0, 1.0),
        linewidth=3, alpha=1.0):
    """
    http://nbviewer.ipython.org/github/dpsanders/matplotlib-examples/blob/master/colorline.ipynb
    http://matplotlib.org/examples/pylab_examples/multicolored_line.html
    Plot a colored line with coordinates x and y
    Optionally specify colors in the array z
    Optionally specify a colormap, a norm function and a line width
    """

    # Default colors equally spaced on [0,1]:
    if z is None:
        z = np.linspace(0.0, 1.0, len(x))

    # Special case if a single number:
    if not hasattr(z, "__iter__"):  # to check for numerical input -- this is a hack
        z = np.array([z])

    z = np.asarray(z)

    segments = make_segments(x, y)
    lc = mcoll.LineCollection(segments, array=z, cmap=cmap, norm=norm,
                              linewidth=linewidth, alpha=alpha)

    ax = plt.gca()
    ax.add_collection(lc)

    return lc


def make_segments(x, y):
    """
    Create list of line segments from x and y coordinates, in the correct format
    for LineCollection: an array of the form numlines x (points per line) x 2 (x
    and y) array
    """

    points = np.array([x, y]).T.reshape(-1, 1, 2)
    segments = np.concatenate([points[:-1], points[1:]], axis=1)
    return segments

def generate_arc(center, radius, start_angle, end_angle, num_points):
    angles = np.linspace(start_angle, end_angle, num_points)
    x = center[0] + radius * np.cos(angles)
    y = center[1] + radius * np.sin(angles)
    return x, y

# create figure and video directory
if not os.path.exists('figures'):
    os.makedirs('figures')
if not os.path.exists('videos'):
    os.makedirs('videos')

%matplotlib qt

In [16]:
FPS = 10
N_FLIES = 4
prefix = '20hr-wingless-orco-yy'

In [17]:
# PHASE 1
# load background image from mat file
p1_bg = sio.loadmat(f'{prefix}_phase_1/{prefix}_phase_1-bg.mat')['bg'][0][0].item()[0]
# PHASE 2
# load background image from mat file
p2_bg = sio.loadmat(f'{prefix}_phase_2/{prefix}_phase_2-bg.mat')['bg'][0][0].item()[0]

In [18]:
recalculate = False
if os.path.exists('arena.json') and not recalculate:
    
    with open('arena.json', 'r') as f:
        arena = json.load(f)

    center = np.array(arena['center'])
    radius = arena['radius']
    sf = arena['sf']

else:
    # get arena margins
    plt.figure()
    plt.imshow(p1_bg, cmap='gray')
    plt.title('Select 5 points on the margin of the arena')
    pts = plt.ginput(5)
    pts = np.array(pts)
    plt.close()

    # TRUE RADIUS
    true_radius = 75 # mm

    radius, center = nsphere_fit(pts)
    print('Center:', center)
    print('Radius:', radius)
    sf = radius / true_radius

    # save the properties of the arena
    arena = {
        'center': center.tolist(),
        'radius': radius,
        'sf': sf
    }

    with open('arena.json', 'w') as f:
        json.dump(arena, f)

In [19]:
from scipy.optimize import minimize

def nearest_arc_distance(x, y, center, radius, start_angle, end_angle):
    start_angle = start_angle % (2 * np.pi)
    end_angle = end_angle % (2 * np.pi)
    if start_angle>end_angle:
        end_angle += 2 * np.pi
    def distance(x, y, center, radius, angle):
        x0, y0 = center
        x1 = x0 + radius * np.cos(angle)
        y1 = y0 + radius * np.sin(angle)
        return np.sqrt((x - x1)**2 + (y - y1)**2)
    # minimize the distance to the arc withing the given angle range
    res = minimize(lambda angle: distance(x, y, center, radius, angle), (start_angle + end_angle) / 2, bounds=[(start_angle, end_angle)])
    angle = res.x
    d = distance(x, y, center, radius, angle)
    return d, angle


In [20]:
def get_y(ying_pts):
    ying_1 = np.array(ying_pts[:5])
    ying_2 = np.array(ying_pts[5:10])
    ying_3 = np.array(ying_pts[10:])

    # fit circles on three sets of points
    ying_radius_1, ying_center_1 = nsphere_fit(ying_1)
    ying_radius_2, ying_center_2 = nsphere_fit(ying_2)
    ying_radius_3, ying_center_3 = nsphere_fit(ying_3)

    # find the best estimate of the ying circle
    ying_inner_radius = np.mean([
        ying_radius_2,
        ying_radius_3,
        np.linalg.norm(ying_center_2 - ying_center_3) / 2,
        ying_radius_1/2
    ])
    ying_outer_radius = np.mean([
        ying_radius_2*2,
        ying_radius_3*2,
        np.linalg.norm(ying_center_2 - ying_center_3),
        ying_radius_1
    ])
    ying_center_estimate = np.mean([
        ying_center_2,
        ying_center_3,
        ying_center_1
    ], axis=0)
    ying1_angle = np.arctan2(ying_center_2[1] - ying_center_estimate[1], ying_center_2[0] - ying_center_estimate[0])
    ying2_angle = np.arctan2(ying_center_3[1] - ying_center_estimate[1], ying_center_3[0] - ying_center_estimate[0])
    inner_ying_1_center = np.mean([
        ying_center_estimate + ying_inner_radius * np.array([np.cos(ying1_angle), np.sin(ying1_angle)]),
        ying_center_estimate + ying_inner_radius * np.array([np.cos(ying2_angle+np.pi), np.sin(ying2_angle+np.pi)])
    ], axis=0)
    inner_ying_2_center = np.mean([
        ying_center_estimate + ying_inner_radius * np.array([np.cos(ying1_angle+np.pi), np.sin(ying1_angle+np.pi)]),
        ying_center_estimate + ying_inner_radius * np.array([np.cos(ying2_angle), np.sin(ying2_angle)])
    ], axis=0)
    ying1_angle = np.arctan2(inner_ying_1_center[1] - ying_center_estimate[1], inner_ying_1_center[0] - ying_center_estimate[0])
    ying2_angle = np.arctan2(inner_ying_2_center[1] - ying_center_estimate[1], inner_ying_2_center[0] - ying_center_estimate[0])

    # corner point
    corner = np.mean([
        ying_center_estimate + ying_outer_radius * np.array([np.cos(ying1_angle+np.pi), np.sin(ying1_angle+np.pi)]),
        ying_center_estimate + ying_outer_radius * np.array([np.cos(ying2_angle), np.sin(ying2_angle)]),
        ying_1[0],
    ], axis=0)

    # ying props
    ying_props = {
        'center': ying_center_estimate,
        'inner_radius': ying_inner_radius,
        'outer_radius': ying_outer_radius,
        'inner_circle_1_center': inner_ying_1_center,
        'inner_circle_2_center': inner_ying_2_center,
        'corner': corner
    }
    return ying_props

def draw_y(y_props, show_basics=True, bg=p1_bg, axis=None, color='r'):

    # draw the estimated circles
    if axis is None:
        fig, ax = plt.subplots()
    else:
        ax = axis
    
    ax.imshow(bg, cmap='gray')

    if show_basics:
        ax.scatter(y_props['inner_circle_1_center'][0], y_props['inner_circle_1_center'][1], color='g')
        ax.scatter(y_props['inner_circle_2_center'][0], y_props['inner_circle_2_center'][1], color='g')
        ax.scatter(y_props['center'][0], y_props['center'][1], color='g')
        ax.scatter(y_props['corner'][0], y_props['corner'][1], color='r')
        circ_inner = plt.Circle(y_props['inner_circle_1_center'], y_props['inner_radius'], color='g', fill=False)
        circ_outer = plt.Circle(y_props['inner_circle_2_center'], y_props['inner_radius'], color='g', fill=False)
        circ = plt.Circle(y_props['center'], y_props['outer_radius'], color='g', fill=False)
        ax.add_artist(circ_inner)
        ax.add_artist(circ_outer)
        ax.add_artist(circ)

    # draw the corner arc
    corner1_angle = np.arctan2(y_props['corner'][1] - y_props['center'][1], y_props['corner'][0] - y_props['center'][0])
    corner2_angle = np.arctan2(y_props['corner'][1] - y_props['center'][1], y_props['corner'][0] - y_props['center'][0]) + np.pi
    x, y = generate_arc(y_props['center'], y_props['outer_radius'], corner1_angle, corner2_angle, 100)
    ax.plot(x, y, color='r')

    # draw the inner circles
    x, y = generate_arc(y_props['inner_circle_1_center'], y_props['inner_radius'], corner1_angle+np.pi, corner2_angle+np.pi, 100)
    ax.plot(x, y, color='r')
    x, y = generate_arc(y_props['inner_circle_2_center'], y_props['inner_radius'], corner1_angle, corner2_angle, 100)
    ax.plot(x, y, color='r')

    if axis is None:
        ax.set_aspect('equal')
        plt.title('Estimated circles')
        plt.show()

def y_coordinates_1(pos,y_props,distance_resolution=200):

    corner1_angle = np.arctan2(y_props['corner'][1] - y_props['center'][1], y_props['corner'][0] - y_props['center'][0])
    corner2_angle = np.arctan2(y_props['corner'][1] - y_props['center'][1], y_props['corner'][0] - y_props['center'][0]) + np.pi
    
    # get distance from each arc
    dist_outer = np.linalg.norm(pos - y_props['center']) - y_props['outer_radius']
    pos_outer = (np.arctan2(pos[1] - y_props['center'][1], pos[0] - y_props['center'][0])) % (2*np.pi) - (corner1_angle % (2*np.pi))
    pos_outer = pos_outer % (2*np.pi) # make sure it's between 0 and 2pi
    if pos_outer > np.pi:
        dist_outer = np.inf

    dist_inner_1 = np.linalg.norm(pos - y_props['inner_circle_1_center']) - y_props['inner_radius']
    pos_inner_1 = (np.arctan2(pos[1] - y_props['inner_circle_1_center'][1], pos[0] - y_props['inner_circle_1_center'][0]))% (2*np.pi) - (corner2_angle % (2*np.pi))
    pos_inner_1 = pos_inner_1 % (2*np.pi) # make sure it's between 0 and 2pi
    if pos_inner_1 > np.pi:
        dist_inner_1 = np.inf

    dist_inner_2 = np.linalg.norm(pos - y_props['inner_circle_2_center']) - y_props['inner_radius']
    pos_inner_2 = (corner2_angle)%(2*np.pi) - np.arctan2(pos[1] - y_props['inner_circle_2_center'][1], pos[0] - y_props['inner_circle_2_center'][0])%(2*np.pi)
    pos_inner_2 = pos_inner_2 % (2*np.pi) # make sure it's between 0 and 2pi
    if pos_inner_2 > np.pi:
        dist_inner_2 = np.inf

    argdist = np.argmin([np.abs(dist_outer), np.abs(dist_inner_1), np.abs(dist_inner_2)])
    if argdist == 0:
        dist = dist_outer
        position = pos_outer
    elif argdist == 1:
        dist = dist_inner_1
        position = pos_inner_1/2 + np.pi
    else:
        dist = dist_inner_2
        position = pos_inner_2/2 + 1.5*np.pi
    if dist == np.inf:
        position = np.nan
    
    return position, dist

def y_coordinates_2(pos,y_props,distance_resolution=200):

    corner1_angle = np.arctan2(y_props['corner'][1] - y_props['center'][1], y_props['corner'][0] - y_props['center'][0])
    corner2_angle = np.arctan2(y_props['corner'][1] - y_props['center'][1], y_props['corner'][0] - y_props['center'][0]) + np.pi
    
    # get distance from each arc
    # sample distance_resolution points on the arc
    arc_points = np.linspace(corner1_angle, corner2_angle, distance_resolution)
    arc_points = np.array([y_props['center'][0] + y_props['outer_radius']*np.cos(arc_points), y_props['center'][1] + y_props['outer_radius']*np.sin(arc_points)]).T
    dist_outer = np.min(np.linalg.norm(arc_points - pos, axis=1))

    arc_points = np.linspace(corner2_angle, corner2_angle+np.pi, distance_resolution)
    arc_points = np.array([y_props['inner_circle_1_center'][0] + y_props['inner_radius']*np.cos(arc_points), y_props['inner_circle_1_center'][1] + y_props['inner_radius']*np.sin(arc_points)]).T
    dist_inner_1 = np.min(np.linalg.norm(arc_points - pos, axis=1))

    arc_points = np.linspace(corner1_angle, corner1_angle+np.pi, distance_resolution)
    arc_points = np.array([y_props['inner_circle_2_center'][0] + y_props['inner_radius']*np.cos(arc_points), y_props['inner_circle_2_center'][1] + y_props['inner_radius']*np.sin(arc_points)]).T
    dist_inner_2 = np.min(np.linalg.norm(arc_points - pos, axis=1))

    argdist = np.argmin([np.abs(dist_outer), np.abs(dist_inner_1), np.abs(dist_inner_2)])

    if argdist == 0:
        dist = dist_outer
        position = (np.arctan2(pos[1] - y_props['center'][1], pos[0] - y_props['center'][0])% (2*np.pi) - (corner1_angle % (2*np.pi))) % (2*np.pi)
    elif argdist == 1:
        dist = dist_inner_1
        position = ((np.arctan2(pos[1] - y_props['inner_circle_1_center'][1], pos[0] - y_props['inner_circle_1_center'][0])% (2*np.pi) - (corner2_angle % (2*np.pi))) % (2*np.pi)) / 2 + np.pi
    else:
        dist = dist_inner_2
        position = (((corner1_angle + np.pi) % (2*np.pi) - np.arctan2(pos[1] - y_props['inner_circle_2_center'][1], pos[0] - y_props['inner_circle_2_center'][0])%(2*np.pi)) % (2*np.pi)) / 2 + 1.5*np.pi
    
    # determine the sign of the distance
    if np.linalg.norm(pos - y_props['inner_circle_1_center']) < y_props['inner_radius']:
        # MUST BE INSIDE
        dist = dist
    elif np.linalg.norm(pos - y_props['inner_circle_2_center']) < y_props['inner_radius']:
        # MUST BE OUTSIDE
        dist = -dist
    else:
        if np.linalg.norm(pos - y_props['center']) > y_props['outer_radius']:
            # MUST BE OUTSIDE
            dist = -dist
        else:
            angle = (np.arctan2(pos[1] - y_props['center'][1], pos[0] - y_props['center'][0]) - corner1_angle) % (2*np.pi)
            # if the point is within the corner1 to corner2 arc, then it's inside
            if angle > 0 and angle < np.pi:
                # MUST BE INSIDE
                dist = dist
            else:
                dist = -dist
            
    if dist == np.inf:
        position = np.nan
    
    return position, dist

def y_coordinates(pos, y_props):
    corner1_angle = np.arctan2(y_props['corner'][1] - y_props['center'][1], y_props['corner'][0] - y_props['center'][0])
    corner2_angle = np.arctan2(y_props['corner'][1] - y_props['center'][1], y_props['corner'][0] - y_props['center'][0]) + np.pi
    
    # outer arc
    dist_outer, angle_outer = nearest_arc_distance(pos[0], pos[1], y_props['center'], y_props['outer_radius'], corner1_angle, corner2_angle)
    angle_outer = (angle_outer % (2*np.pi) - (corner1_angle % (2*np.pi))) % (2*np.pi)

    # inner arc 1
    dist_inner_1, angle_inner_1 = nearest_arc_distance(pos[0], pos[1], y_props['inner_circle_1_center'], y_props['inner_radius'], corner2_angle, corner2_angle+np.pi)
    angle_inner_1 = ((angle_inner_1 % (2*np.pi) - (corner2_angle % (2*np.pi))) % (2*np.pi)) / 2 + np.pi

    # inner arc 2
    dist_inner_2, angle_inner_2 = nearest_arc_distance(pos[0], pos[1], y_props['inner_circle_2_center'], y_props['inner_radius'], corner1_angle, corner1_angle+np.pi)
    angle_inner_2 = (((corner1_angle + np.pi) % (2*np.pi) - (angle_inner_2 % (2*np.pi))) % (2*np.pi)) / 2 + 1.5*np.pi

    # determine the smallest distance
    argdist = np.argmin([np.abs(dist_outer), np.abs(dist_inner_1), np.abs(dist_inner_2)])

    if argdist == 0:
        dist = dist_outer
        position = angle_outer
    elif argdist == 1:
        dist = dist_inner_1
        position = angle_inner_1
    else:
        dist = dist_inner_2
        position = angle_inner_2

    # determine the sign of the distance
    if np.linalg.norm(pos - y_props['inner_circle_1_center']) < y_props['inner_radius']:
        # MUST BE INSIDE
        dist = dist
    elif np.linalg.norm(pos - y_props['inner_circle_2_center']) < y_props['inner_radius']:
        # MUST BE OUTSIDE
        dist = -dist
    else:
        if np.linalg.norm(pos - y_props['center']) > y_props['outer_radius']:
            # MUST BE OUTSIDE
            dist = -dist
        else:
            angle = (np.arctan2(pos[1] - y_props['center'][1], pos[0] - y_props['center'][0]) - corner1_angle) % (2*np.pi)
            # if the point is within the corner1 to corner2 arc, then it's inside
            if angle > 0 and angle < np.pi:
                # MUST BE INSIDE
                dist = dist
            else:
                dist = -dist

    if dist == np.inf:
        position = np.nan

    return position.item(), dist.item()

In [21]:
recalculate = False

# check if previous ying props are available
try:
    if recalculate:
        raise FileNotFoundError
    
    with open('ying_props.json', 'r') as f:
        ying_props = json.load(f)
        ying_props = {k: np.array(v) if isinstance(v, list) else v for k, v in ying_props.items()}

except FileNotFoundError:
    # get ying points
    plt.figure()
    plt.imshow(p2_bg, cmap='gray')
    plt.title('Starting from the corner, select 15 points on the ying circle, 5 on each arc (outer, smooth inner, sharp inner)')
    ying_pts = plt.ginput(15)
    plt.close()

    ying_props = get_y(ying_pts)

    # save ying props as json
    with open('ying_props.json', 'w') as f:
        # convert numpy arrays to lists
        ying_props = {k: v.tolist() if isinstance(v, np.ndarray) else v for k, v in ying_props.items()}
        json.dump(ying_props, f)

# check if previous yang props are available
try:
    if recalculate:
        raise FileNotFoundError
    
    with open('yang_props.json', 'r') as f:
        yang_props = json.load(f)
        yang_props = {k: np.array(v) if isinstance(v, list) else v for k, v in yang_props.items()}
        
except FileNotFoundError:
    # get yang points
    plt.figure()
    plt.imshow(p2_bg, cmap='gray')
    plt.title('Starting from the corner, select 15 points on the yang circle, 5 on each arc (outer, smooth inner, sharp inner)')
    yang_pts = plt.ginput(15)
    plt.close()

    yang_props = get_y(yang_pts)

    # save yang props as json
    with open('yang_props.json', 'w') as f:
        # convert numpy arrays to lists
        yang_props = {k: v.tolist() if isinstance(v, np.ndarray) else v for k, v in yang_props.items()}
        json.dump(yang_props, f)

In [22]:
max_speed = 50 # mm/s 
max_displacement = max_speed * sf / FPS

### JUMP TEST CODE

In [23]:
# get x and y 
for i in range(N_FLIES):
    df = pd.read_csv(f'{prefix}_phase_1/{prefix}_phase_1-trackfeat.csv/fly{i+1}.csv')
    
    
    x = df['pos x'].values
    y = df['pos y'].values
    # interpolate missing values
    x = np.interp(np.arange(len(x)), np.arange(len(x))[~np.isnan(x)], x[~np.isnan(x)])
    y = np.interp(np.arange(len(y)), np.arange(len(y))[~np.isnan(y)], y[~np.isnan(y)])

    # get difference between consecutive frames
    dx = np.diff(x)
    dy = np.diff(y)
    d = np.sqrt(dx**2 + dy**2)

    plt.plot(x, y, 'k-', lw=0.2)
    # mark every point where d > max_displacement
    points = np.where(d > max_displacement)[0]
    # draw a line between consecutive points
    for i in points:
        plt.plot(x[i:i+2], y[i:i+2], 'r-', lw=1)
plt.gca().set_aspect('equal')
plt.show()

### TRACKLET CONVERSION

In [24]:
recalculate = False

# convert to tracklets
for phase in [1,2]:
    if os.path.exists(f'phase_{phase}_tracklets.csv') and not recalculate:
        continue
    
    tracklets = []
    for i in range(N_FLIES):
        df = pd.read_csv(f'{prefix}_phase_{phase}/{prefix}_phase_{phase}-trackfeat.csv/fly{i+1}.csv')
        x = df['pos x'].values
        y = df['pos y'].values
        ori = df['ori'].values
        frame = df.index.values

        # interpolate missing values
        x = np.interp(np.arange(len(x)), np.arange(len(x))[~np.isnan(x)], x[~np.isnan(x)])
        y = np.interp(np.arange(len(y)), np.arange(len(y))[~np.isnan(y)], y[~np.isnan(y)])
        ori = np.interp(np.arange(len(ori)), np.arange(len(ori))[~np.isnan(ori)], ori[~np.isnan(ori)])
        
        # get difference between consecutive frames
        d = np.sqrt(np.diff(x)**2 + np.diff(y)**2)
        # get points where d > max_displacement
        points = np.where(d > max_displacement)[0]
        start = 0
        for i in points:
            t = pd.DataFrame({'pos x': x[start:i+1], 'pos y': y[start:i+1], 'ori': ori[start:i+1], 'frame': frame[start:i+1]})
            tracklets.append(t)
            start = i+1
        t = pd.DataFrame({'pos x': x[start:], 'pos y': y[start:], 'ori': ori[start:], 'frame': frame[start:]})
        tracklets.append(t)

    # filter out tracklets with less than 5*FPS frames
    tracklets = [t for t in tracklets if len(t) >= 5*FPS]

    # add ying and yang coordinates to each tracklet
    i = 0
    for t in tqdm(tracklets):
        t['ying pos'], t['ying dist'] = zip(*[y_coordinates([x, y], ying_props) for x, y in zip(t['pos x'], t['pos y'])])
        t['yang pos'], t['yang dist'] = zip(*[y_coordinates([x, y], yang_props) for x, y in zip(t['pos x'], t['pos y'])])
        t['track id'] = i
        i += 1

    # combine and save tracklets
    tracklets = pd.concat(tracklets)
    tracklets.to_csv(f'phase_{phase}_tracklets.csv', index=False)

100%|██████████| 7/7 [06:12<00:00, 53.19s/it]
100%|██████████| 21/21 [08:20<00:00, 23.84s/it]


In [25]:
# plot all tracklets in phases
for phase in [1,2]:
    tracklets = pd.read_csv(f'phase_{phase}_tracklets.csv')
    fig, ax = plt.subplots()
    for i in range(tracklets['track id'].max()):
        t = tracklets[tracklets['track id'] == i]
        plt.plot(t['pos x'], t['pos y'], 'w-', lw=0.2)
    if phase == 1:
        draw_y(ying_props, show_basics=False, axis=ax, bg=p1_bg, color='r')
        draw_y(yang_props, show_basics=False, axis=ax, bg=p1_bg, color='b')
    else:
        draw_y(ying_props, show_basics=False, axis=ax, bg=p2_bg, color='r')
        draw_y(yang_props, show_basics=False, axis=ax, bg=p2_bg, color='b')
    plt.gca().set_aspect('equal')
    plt.title(f'Phase {phase}')
    # save figure
    plt.savefig(f'figures/phase_{phase}_tracklets.png')
    plt.show()

In [26]:
# plot all start and end points of tracklets
for phase in [1,2]:
    tracklets = pd.read_csv(f'phase_{phase}_tracklets.csv')
    fig, ax = plt.subplots(1, 2)
    if phase == 1:
        draw_y(ying_props, show_basics=False, axis=ax[0], bg=p1_bg, color='r')
        draw_y(yang_props, show_basics=False, axis=ax[0], bg=p1_bg, color='b')
        draw_y(ying_props, show_basics=False, axis=ax[1], bg=p1_bg, color='r')
        draw_y(yang_props, show_basics=False, axis=ax[1], bg=p1_bg, color='b')
    else:
        draw_y(ying_props, show_basics=False, axis=ax[0], bg=p2_bg, color='r')
        draw_y(yang_props, show_basics=False, axis=ax[0], bg=p2_bg, color='b')
        draw_y(ying_props, show_basics=False, axis=ax[1], bg=p2_bg, color='r')
        draw_y(yang_props, show_basics=False, axis=ax[1], bg=p2_bg, color='b')
    for i in range(tracklets['track id'].max()):
        t = tracklets[tracklets['track id'] == i]
        ax[0].scatter(t.iloc[0]['pos x'], t.iloc[0]['pos y'], color=plt.cm.rainbow(i/tracklets['track id'].max()), s=1)
        ax[1].scatter(t.iloc[-1]['pos x'], t.iloc[-1]['pos y'], color=plt.cm.rainbow(i/tracklets['track id'].max()), s=1)
    ax[0].set_title('Start points')
    ax[1].set_title('End points')
    ax[0].set_aspect('equal')
    ax[1].set_aspect('equal')
    plt.suptitle(f'Phase {phase}')
    plt.tight_layout()
    # save figure
    plt.savefig(f'figures/phase_{phase}_start_end.png')
    plt.show()


In [27]:
encounter_distance = 5*2 # mm body width
encounter_distance = encounter_distance * sf # scale to pixels
encounter_distance = encounter_distance / 2 # convert distance
print(f'Encounter distance: {encounter_distance} pixels = {encounter_distance/sf:.2f} mm = {encounter_distance/sf*0.0393701:.2f} inches')

ying_trail_width = 0.05 # inches
ying_trail_width = ying_trail_width / 0.0393701 # convert to mm
ying_trail_width = ying_trail_width * sf # scale to pixels
ying_trail_width = ying_trail_width / 2 # convert distance
print(f'Ying trail distance: {ying_trail_width} pixels = {ying_trail_width/sf:.2f} mm = {ying_trail_width/sf*0.0393701:.2f} inches')

yang_trail_width = 0.2 # inches
yang_trail_width = yang_trail_width / 0.0393701 # convert to mm
yang_trail_width = yang_trail_width * sf # scale to pixels
yang_trail_width = yang_trail_width / 2 # convert distance
print(f'Yang trail distance: {yang_trail_width} pixels = {yang_trail_width/sf:.2f} mm = {yang_trail_width/sf*0.0393701:.2f} inches')

Encounter distance: 46.68609528255587 pixels = 5.00 mm = 0.20 inches
Ying trail distance: 5.92913089915391 pixels = 0.63 mm = 0.03 inches
Yang trail distance: 23.71652359661564 pixels = 2.54 mm = 0.10 inches


In [28]:

fig, ax = plt.subplots()
plt.imshow(p2_bg, cmap='gray')

# create a meshgrid
res = 50
x = np.linspace(0, p2_bg.shape[1], res)
y = np.linspace(0, p2_bg.shape[0], res)
X, Y = np.meshgrid(x, y)
X = X.flatten()
Y = Y.flatten()
points = np.vstack([X, Y]).T

max_distance = 100

# check if point is within the ying circle
for point in tqdm(points):
    pos, dist = y_coordinates(point, ying_props)
    if abs(dist) < encounter_distance+ying_trail_width:
        # plt.scatter(point[0], point[1], color=plt.cm.cool(pos/(2*np.pi)), s=1)
        distance = (dist+max_distance)/(2*max_distance)
        plt.scatter(point[0], point[1], color=plt.cm.cool(distance), s=1)


# check if point is within the yang circle
for point in tqdm(points):
    pos, dist = y_coordinates(point, yang_props)
    if abs(dist) < encounter_distance+yang_trail_width:
        # plt.scatter(point[0], point[1], color=plt.cm.cool(pos/(2*np.pi)), s=1)
        distance = (dist+max_distance)/(2*max_distance)
        plt.scatter(point[0], point[1], color=plt.cm.cool(distance), s=1)

plt.title('Points within range of ying and yang')

# add colorbar
# sm = plt.cm.ScalarMappable(cmap=plt.cm.cool, norm=plt.Normalize(0, 2*np.pi))
sm = plt.cm.ScalarMappable(cmap=plt.cm.cool, norm=plt.Normalize(-800, 800))
sm.set_array([])
plt.colorbar(sm, ticks=[0, np.pi, 2*np.pi], label='Angle', orientation='horizontal', ax=ax)
plt.gca().set_aspect('equal')
plt.savefig('figures/ying_yang_targeted_points.png')
plt.show()



  0%|          | 0/2500 [00:00<?, ?it/s]

100%|██████████| 2500/2500 [00:04<00:00, 619.56it/s]
100%|██████████| 2500/2500 [00:04<00:00, 607.81it/s]


In [29]:
# plot the ying distance vs ying position
tracklets = pd.read_csv('phase_2_tracklets.csv')
tracklets = [t.reset_index(drop=True) for _,t in tracklets.groupby('track id')]
for i in range(len(tracklets)):
    t = tracklets[i]
    plt.plot(t['ying dist'], t['ying pos'], '-', lw=0.2)

plt.show()

In [86]:
N_CROSSINGS_FOR_ENCOUNTER = 3
# create a subfolder for N_CROSSINGS_FOR_ENCOUNTER
if not os.path.exists(f'figures/Crossings_{N_CROSSINGS_FOR_ENCOUNTER}'):
    os.makedirs(f'figures/Crossings_{N_CROSSINGS_FOR_ENCOUNTER}')
if not os.path.exists(f'videos/Crossings_{N_CROSSINGS_FOR_ENCOUNTER}'):
    os.makedirs(f'videos/Crossings_{N_CROSSINGS_FOR_ENCOUNTER}')
if not os.path.exists(f'Crossings_{N_CROSSINGS_FOR_ENCOUNTER}'):
    os.makedirs(f'Crossings_{N_CROSSINGS_FOR_ENCOUNTER}')

In [87]:
def get_crossings(encounter, ying_or_yang, min_deviation):
    crossings = np.abs(np.diff(np.sign(encounter[ying_or_yang+' dist'])))
    # distance traveled between crossings must be atleast the trail width
    encounters_between_crossings = []
    crossing_points = np.where(crossings > 0)[0]
    if len(crossing_points) == 0:
        return 0, []
    encounters_between_crossings.append(encounter[:crossing_points[0]])
    for i in range(len(crossing_points)-1):
        encounters_between_crossings.append(encounter[crossing_points[i]:crossing_points[i+1]])
    # get the maximum distance travelled away from the center
    furthest_distance = np.array([np.max(np.abs(encounter[ying_or_yang+' dist'])) for encounter in encounters_between_crossings])
    # get the crossing points where the distance travelled is atleast the trail width
    crossing_points = np.array(crossing_points)[np.where(furthest_distance > min_deviation)[0]]
    n_crossings = len(crossing_points)
    return n_crossings, crossing_points+1

# def get_crossings(encounter, ying_or_yang, min_deviation):
#     crossings_left = np.abs(np.diff(np.sign(encounter[ying_or_yang+' dist']+min_deviation)))
#     crossings_right = np.abs(np.diff(np.sign(encounter[ying_or_yang+' dist']-min_deviation)))
#     # any of the crossings must be true
#     crossings = np.logical_or(crossings_left, crossings_right)
#     # distance traveled between crossings must be atleast the trail width
#     encounters_between_crossings = []
#     crossing_points = np.where(crossings > 0)[0]
#     if len(crossing_points) == 0:
#         return 0, []
#     encounters_between_crossings.append(encounter[:crossing_points[0]])
#     for i in range(len(crossing_points)-1):
#         encounters_between_crossings.append(encounter[crossing_points[i]:crossing_points[i+1]])
#     # get the maximum distance travelled away from the center
#     furthest_distance = np.array([np.max(np.abs(encounter[ying_or_yang+' dist'])) for encounter in encounters_between_crossings])
#     # get the crossing points where the distance travelled is atleast the trail width
#     crossing_points = np.array(crossing_points)[np.where(furthest_distance > min_deviation)[0]]
#     n_crossings = len(crossing_points)
#     return n_crossings, crossing_points+1

In [88]:
# # test the function
# tracklets = pd.read_csv('phase_2_tracklets.csv')
# encounter = tracklets[tracklets['track id'] == 15].reset_index(drop=True)
# n_crossings, crossing_points = get_crossings(encounter, 'yang', ying_trail_width)
# fig, ax = plt.subplots()
# draw_y(ying_props, show_basics=False, axis=ax, bg=p2_bg, color='r')
# draw_y(yang_props, show_basics=False, axis=ax, bg=p2_bg, color='b')
# ax.plot(encounter['pos x'], encounter['pos y'], 'k-', lw=0.2)
# # color by ying distance
# ax.scatter(encounter['pos x'], encounter['pos y'], c=encounter['ying dist'], cmap='cool', s=0.5)
# ax.scatter(encounter['pos x'][crossing_points], encounter['pos y'][crossing_points], color='b', s=20)
# plt.gca().set_aspect('equal')
# plt.show()

In [90]:
for phase in [1,2]:
    # load tracklets and split by track id
    tracklets = pd.read_csv(f'phase_{phase}_tracklets.csv')
    tracklets = [group for _, group in tracklets.groupby('track id')]

    # get all ying encounters
    ying_encounters = []
    for tracklet in tracklets:
        val = np.abs(np.concatenate([[np.inf], tracklet['ying dist'].values])) < encounter_distance + ying_trail_width
        count = np.cumsum(np.concatenate([np.array([np.nan]),np.diff(val.astype(int))])>0)
        mes = val*count
        # get every continuous segment where the fly is within encounter_distance of the ying circle
        encounters = [group for m, group in tracklet.groupby(mes[1:]) if m != 0]
        # if the pos changes more than np.pi/4 in a single frame,split it
        for e in encounters:
            pos = e['ying pos'].values
            # add a nan to the start of the array
            pos = np.concatenate([np.array([np.nan]), pos])
            diff = np.abs(np.diff(pos))>np.pi/2
            split = np.where(diff)[0]
            if len(split) == 0:
                ying_encounters.append(e)
                continue
            start = 0
            for s in split:
                ying_encounters.append(e.iloc[start:s])
                start = s
            ying_encounters.append(e.iloc[start:])
    # filter for crossing
    n_crossings_ying = [get_crossings(encounter, 'ying', ying_trail_width)[0] for encounter in ying_encounters]
    mask = [n_crossings_ying[i] >= N_CROSSINGS_FOR_ENCOUNTER for i in range(len(n_crossings_ying))]
    ying_encounters = [ying_encounters[i] for i in range(len(ying_encounters)) if mask[i]]
    n_crossings_ying = [n_crossings_ying[i] for i in range(len(n_crossings_ying)) if mask[i]]

    # get all yang encounters
    yang_encounters = []
    for tracklet in tracklets:
        val = np.abs(np.concatenate([[np.inf], tracklet['yang dist'].values])) < encounter_distance + yang_trail_width
        count = np.cumsum(np.concatenate([np.array([np.nan]),np.diff(val.astype(int))])>0)
        mes = val*count
        # get every continuous segment where the fly is within encounter_distance of the yang circle
        encounters = [group for m, group in tracklet.groupby(mes[1:]) if m != 0]
        # if the pos changes more than np.pi in a single frame,split it
        for e in encounters:
            pos = e['yang pos'].values
            # add a nan to the start of the array
            pos = np.concatenate([np.array([np.nan]), pos])
            diff = np.abs(np.diff(pos))>np.pi
            split = np.where(diff)[0]
            if len(split) == 0:
                yang_encounters.append(e)
                continue
            start = 0
            for s in split:
                yang_encounters.append(e.iloc[start:s])
                start = s
            yang_encounters.append(e.iloc[start:])
    # filter for crossing
    n_crossings_yang = [get_crossings(encounter, 'yang', yang_trail_width)[0] for encounter in yang_encounters]
    mask = [n_crossings_yang[i] >= N_CROSSINGS_FOR_ENCOUNTER for i in range(len(n_crossings_yang))]
    yang_encounters = [yang_encounters[i] for i in range(len(yang_encounters)) if mask[i]]
    n_crossings_yang = [n_crossings_yang[i] for i in range(len(n_crossings_yang)) if mask[i]]

    # plot all ying and yang encounters
    plt.figure()
    plt.imshow(eval(f"p{phase}_bg"), cmap='gray')
    for encounter in ying_encounters:
        plt.plot(encounter['pos x'], encounter['pos y'], '-', lw=2, color=plt.cm.winter(np.random.rand()))
    for encounter in yang_encounters:
        plt.plot(encounter['pos x'], encounter['pos y'], '-', lw=2, color=plt.cm.autumn(np.random.rand()))
    plt.title(f"Phase {phase} | YING: {len(ying_encounters)} encounters, YANG: {len(yang_encounters)} encounters")
    plt.savefig(f"figures/Crossings_{N_CROSSINGS_FOR_ENCOUNTER}/ying_yang_encounters_phase_{phase}.png")
    plt.show()

    # plot in ying spcae
    fig, ax = plt.subplots(1, 2, figsize=(5, 5), sharey=True)

    max_distances_ying = []
    displacements_ying = []
    for encounter in ying_encounters:
        position = (encounter['ying pos']*ying_props['outer_radius']/sf).values
        # find point of first crossing
        crossings = get_crossings(encounter, 'ying', ying_trail_width)[1]
        crossing = crossings[0]
        # crossing = np.argmax(np.abs(np.diff(np.sign(encounter['ying dist'].values))))
        position = position - position[crossing]
        dists = [np.abs(np.max(position[crossing:])),np.abs(np.min(position[crossing:]))]
        if np.argmax(dists) == 1:
            position = -position
        distance = encounter['ying dist']/sf
        max_distances_ying.append(np.max(dists))
        displacements_ying.append(np.sum(np.abs(np.diff(position[crossing:]))))
        ax[0].plot(distance, position, '-', lw=2, color=plt.cm.winter(np.random.rand()))
    # plot the trail
    ax[0].axvline(-ying_trail_width/sf, color='r', lw=1)
    ax[0].axvline(ying_trail_width/sf, color='r', lw=1)

    # sort by max distance
    ying_encounters = [ying_encounters[i].reset_index() for i in np.argsort(max_distances_ying)]
    n_crossings_ying = [n_crossings_ying[i] for i in np.argsort(max_distances_ying)]
    displacements_ying = [displacements_ying[i] for i in np.argsort(max_distances_ying)]
    max_distances_ying = np.sort(max_distances_ying)

    # # sort by number of crossings
    # ying_encounters = [ying_encounters[i].reset_index() for i in np.argsort(n_crossings_ying)]
    # max_distances_ying = [max_distances_ying[i] for i in np.argsort(n_crossings_ying)]
    # displacements_ying = [displacements_ying[i] for i in np.argsort(n_crossings_ying)]
    # n_crossings_ying = np.sort(n_crossings_ying)

    # sort by displacement
    ying_encounters = [ying_encounters[i].reset_index() for i in np.argsort(displacements_ying)]
    max_distances_ying = [max_distances_ying[i] for i in np.argsort(displacements_ying)]
    n_crossings_ying = [n_crossings_ying[i] for i in np.argsort(displacements_ying)]
    displacements_ying = np.sort(displacements_ying)

    # save ying encounters after adding an id
    for i, encounter in enumerate(ying_encounters):
        encounter['id'] = i
        encounter['max_distance'] = max_distances_ying[i]
        encounter['n_crossings'] = n_crossings_ying[i]
        encounter['displacement'] = displacements_ying[i]
    if len(ying_encounters) > 0:
        ying_encounters_CSV = pd.concat(ying_encounters)
        ying_encounters_CSV.to_csv(f'Crossings_{N_CROSSINGS_FOR_ENCOUNTER}/ying_encounters_phase_{phase}.csv', index=False)
    # ying_encounters_CSV.to_csv(f'ying_encounters_phase_{phase}.csv', index=False)


    max_distances_yang = []
    displacements_yang = []
    for encounter in yang_encounters:
        position = (encounter['yang pos']*yang_props['outer_radius']/sf).values
        # find point of first crossing
        # crossing = np.argmax(np.abs(np.diff(np.sign(encounter['yang dist'].values))))
        crossings = get_crossings(encounter, 'yang', yang_trail_width)[1]
        crossing = crossings[0]
        position = position - position[crossing]
        dists = [np.abs(np.max(position[crossing:])),np.abs(np.min(position[crossing:]))]
        if np.argmax(dists) == 1:
            position = -position
        distance = encounter['yang dist']/sf
        max_distances_yang.append(np.max(dists))
        displacements_yang.append(np.sum(np.abs(np.diff(position[crossing:]))))
        ax[1].plot(distance, position, '-', lw=2, color=plt.cm.autumn(np.random.rand()))
    # plot the trail
    ax[1].axvline(-yang_trail_width/sf, color='b', lw=1)
    ax[1].axvline(yang_trail_width/sf, color='b', lw=1)

    # sort by max distance
    yang_encounters = [yang_encounters[i].reset_index() for i in np.argsort(max_distances_yang)]
    n_crossings_yang = [n_crossings_yang[i] for i in np.argsort(max_distances_yang)]
    displacements_yang = [displacements_yang[i] for i in np.argsort(max_distances_yang)]
    max_distances_yang = np.sort(max_distances_yang)

    # # sort by number of crossings
    # yang_encounters = [yang_encounters[i].reset_index() for i in np.argsort(n_crossings_yang)]
    # max_distances_yang = [max_distances_yang[i] for i in np.argsort(n_crossings_yang)]
    # displacements_yang = [displacements_yang[i] for i in np.argsort(n_crossings_yang)]
    # n_crossing_yang = np.sort(n_crossings_yang)

    # sort by displacement
    yang_encounters = [yang_encounters[i].reset_index() for i in np.argsort(displacements_yang)]
    max_distances_yang = [max_distances_yang[i] for i in np.argsort(displacements_yang)]
    n_crossings_yang = [n_crossings_yang[i] for i in np.argsort(displacements_yang)]
    displacements_yang = np.sort(displacements_yang)

    # save yang encounters after adding an id
    for i, encounter in enumerate(yang_encounters):
        encounter['id'] = i
        encounter['max_distance'] = max_distances_yang[i]
        encounter['n_crossings'] = n_crossings_yang[i]
        encounter['displacement'] = displacements_yang[i]
    if len(yang_encounters) > 0:
        yang_encounters_CSV = pd.concat(yang_encounters)
        yang_encounters_CSV.to_csv(f'Crossings_{N_CROSSINGS_FOR_ENCOUNTER}/yang_encounters_phase_{phase}.csv', index=False)

    ax[0].set_title('Ying encounters')
    ax[1].set_title('Yang encounters')
    ax[0].set_xlabel('Distance from ying circle (mm)')
    ax[1].set_xlabel('Distance from yang circle (mm)')
    ax[0].set_ylabel('Position on trail')
    ax[1].set_ylabel('Position on trail')
    ax[0].set_aspect('equal')
    ax[1].set_aspect('equal')
    plt.suptitle(f"Phase {phase}")
    plt.tight_layout()
    plt.savefig(f'figures/Crossings_{N_CROSSINGS_FOR_ENCOUNTER}/ying_yang_encounters_space_phase_{phase}.png')
    plt.show()

    # histogram of ying and yang encounters
    if len(ying_encounters) == 0 and len(yang_encounters) == 0:
        continue
    
    bins = np.linspace(np.min([np.min(max_distances_ying), np.min(max_distances_yang)]), np.max([np.max(max_distances_ying), np.max(max_distances_yang)]), 20)
    plt.figure()
    plt.hist(max_distances_ying, bins=bins, color='r', linewidth=2, label='Ying', density=True, histtype='step')
    plt.hist(max_distances_yang, bins=bins, color='b', linewidth=2, label='Yang', density=True, histtype='step')
    plt.xlabel('Max distance from trail')
    plt.ylabel('Number of encounters')
    plt.yscale('log')
    plt.legend()
    plt.title(f'Ying and Yang encounters phase {phase} (max_distance)')
    plt.savefig(f'figures/Crossings_{N_CROSSINGS_FOR_ENCOUNTER}/ying_yang_encounters_histogram_phase_{phase}.png')
    plt.show()

    # plot top 10 ying and yang encounters
    # YING ENCOUNTERS
    for i in range(10):
        # plot:
        # (1) the trajectory in space
        # (2) the trajectory in ying/yang space

        # get the encounter
        if i >= len(ying_encounters):
            break
        encounter = ying_encounters[-i]
        # get the id
        id = encounter['id'].values[0]
        # get the max distance
        max_distance = encounter['max_distance'].values[0]
        # get the number of crossings
        n_crossings = encounter['n_crossings'].values[0]
        # get the displacement
        displacement = encounter['displacement'].values[0]
        # get the crossing points
        crossings = get_crossings(encounter, 'ying', ying_trail_width)[1]

        fig, ax = plt.subplots(1, 2, figsize=(5, 5))
        # plot the trajectory in space
        ax[0].imshow(eval(f"p{phase}_bg"), cmap='gray')
        ax[0].plot(encounter['pos x'], encounter['pos y'], 'k-', lw=2)
        ax[0].scatter(encounter['pos x'][crossings], encounter['pos y'][crossings], color='r', s=20)
        plt.suptitle(f'Ying encounter {id} | Max distance: {max_distance:.2f} mm | Crossings: {n_crossings} | Displacement: {displacement:.2f} mm')
        ax[0].set_aspect('equal')
        # plot the trajectory in ying space
        position = (encounter['ying pos']*ying_props['outer_radius']/sf).values
        # find point of first crossing
        crossing = crossings[0]
        position = position - position[crossing]
        dists = [np.abs(np.max(position[crossing:])),np.abs(np.min(position[crossing:]))]
        if np.argmax(dists) == 1:
            position = -position
        distance = encounter['ying dist']/sf
        ax[1].plot(distance, position, '-', lw=2, color='r')
        ax[1].axvline(-ying_trail_width/sf, color='r', lw=1)
        ax[1].axvline(ying_trail_width/sf, color='r', lw=1)
        ax[1].set_title('Ying space')
        ax[1].set_aspect('equal')
        plt.tight_layout()
        plt.savefig(f'figures/Crossings_{N_CROSSINGS_FOR_ENCOUNTER}/ying_encounter_{id}_phase_{phase}.png')
        plt.close()

    # YANG ENCOUNTERS
    for i in range(10):
        # plot:
        # (1) the trajectory in space
        # (2) the trajectory in ying/yang space

        # get the encounter
        if i >= len(yang_encounters):
            break
        encounter = yang_encounters[-i]
        # get the max distance
        max_distance = encounter['max_distance'].values[0]
        # get the number of crossings
        n_crossings = encounter['n_crossings'].values[0]
        # get the displacement
        displacement = encounter['displacement'].values[0]
        # get the crossing points
        crossings = get_crossings(encounter, 'yang', yang_trail_width)[1]

        fig, ax = plt.subplots(1, 2, figsize=(9, 5))
        # plot the trajectory in space
        ax[0].imshow(eval(f"p{phase}_bg"), cmap='gray')
        ax[0].plot(encounter['pos x'], encounter['pos y'], 'k-', lw=2)
        ax[0].scatter(encounter['pos x'][crossings], encounter['pos y'][crossings], color='b', s=20)
        plt.suptitle(f'Yang encounter {id} | Max distance: {max_distance:.2f} mm | Crossings: {n_crossings} | Displacement: {displacement:.2f} mm')
        ax[0].set_aspect('equal')
        # plot the trajectory in ying space
        position = (encounter['yang pos']*yang_props['outer_radius']/sf).values
        # find point of first crossing
        crossing = crossings[0]
        position = position - position[crossing]
        dists = [np.abs(np.max(position[crossing:])),np.abs(np.min(position[crossing:]))]
        if np.argmax(dists) == 1:
            position = -position
        distance = encounter['yang dist']/sf
        ax[1].plot(distance, position, '-', lw=2, color='b')
        ax[1].axvline(-yang_trail_width/sf, color='b', lw=1)
        ax[1].axvline(yang_trail_width/sf, color='b', lw=1)
        ax[1].set_title('Yang space')
        ax[1].set_aspect('equal')
        plt.tight_layout()
        plt.savefig(f'figures/Crossings_{N_CROSSINGS_FOR_ENCOUNTER}/yang_encounter_{i+1}_phase_{phase}.png', dpi=300)
        plt.close()

    


In [95]:
tracklets = pd.read_csv('phase_2_tracklets.csv')
tracklets = [group.reset_index(drop=True) for _, group in tracklets.groupby('track id')]

In [96]:
start_times = np.array([tracklet['frame'].iloc[0] for tracklet in tracklets])
end_times = np.array([tracklet['frame'].iloc[-1] for tracklet in tracklets])
start_posx = np.array([tracklet['pos x'].iloc[0] for tracklet in tracklets])
start_posy = np.array([tracklet['pos y'].iloc[0] for tracklet in tracklets])
end_posx = np.array([tracklet['pos x'].iloc[-1] for tracklet in tracklets])
end_posy = np.array([tracklet['pos y'].iloc[-1] for tracklet in tracklets])
assigned = np.array([False for _ in range(len(tracklets))])

# distance matrices

time_distance = np.zeros((len(tracklets), len(tracklets)))*np.nan
for i in range(len(tracklets)):
    for j in range(len(tracklets)):
        time_distance[i, j] = end_times[j] - start_times[i]

pos_distance = np.zeros((len(tracklets), len(tracklets)))*np.nan
for i in range(len(tracklets)):
    for j in range(len(tracklets)):
        pos_distance[i, j] = np.sqrt((end_posx[j] - start_posx[i])**2 + (end_posy[j] - start_posy[i])**2)

min_frame = np.min(start_times)
assert min_frame == 0, f"Minimum frame is not 0, but {min_frame}"
max_frame = np.max(end_times)
tracks = np.zeros((max_frame-min_frame+1, N_FLIES))*np.nan

latest_tracks = np.ones(N_FLIES)*np.nan

# first runs
for t in tqdm(range(min_frame, max_frame+1)):

    # get all tracklets that are active (start time <= t <= end time and not assigned)
    active = np.where((start_times <= t) & (end_times >= t) & ~assigned)[0]

    # make sure there are less active tracklets than flies
    assert len(active) <= N_FLIES, f"More active tracklets than flies at frame {t}"
    
    # if there are no active tracklets, CONTINUE
    if len(active) == 0:
        continue
    
    # ELSE, if all latest tracks are NaN, assign each active tracklet to a fly
    if np.all(np.isnan(latest_tracks)):
        for n,i in enumerate(active):
            tracks[start_times[i]:end_times[i]+1, n] = i
            assigned[i] = True
            latest_tracks[n] = i
        continue
    
    premature_break = False
    # ELSE, for each active tracklet
    for i in active:

        if premature_break:
            break

        # find all places where the tracklet could be joined (i.e. nan in tracks)
        joinable = np.where(np.isnan(tracks[t]))[0]
        # if there are no joinable places, alert the user and STOP
        assert len(joinable) > 0, f"No joinable places at frame {t}"

        # get the most recent tracklet for each fly
        latest = latest_tracks[joinable]

        # get the non_nan values for each fly
        non_nan_indices = np.where(~np.isnan(latest))[0]
        non_nan_latest = np.int32(latest[non_nan_indices])
        
        # if there are no non_nan values, assign the tracklet to the first nan value
        if len(non_nan_latest) == 0:
            n = joinable[0]
            tracks[start_times[i]:end_times[i]+1, n] = i
            assigned[i] = True
            latest_tracks[n] = i
            continue

        # get how far each non_nan value is from the current tracklet
        distances = pos_distance[i, non_nan_latest]
        time = -time_distance[i, non_nan_latest]
        # get the speed of each non_nan value
        speeds = distances/time/sf # mm/s

        # get values that are within the speed limit
        valid_indices = np.where(speeds < max_speed)[0]
        if len(valid_indices) == 0:
            # check if there are any nan values in the latest tracks
            nan_indices = np.where(np.isnan(latest))[0]
            if len(nan_indices) == 0:
                print(f"No valid indices at frame {t} for tracklet {i}")
                premature_break = True
                break
            # if there are nan values, assign the tracklet to the first nan value
            n = np.where(latest_tracks == latest[nan_indices[0]])[0][0]
            tracks[start_times[i]:end_times[i]+1, n] = i
            assigned[i] = True
            latest_tracks[n] = i
            continue

        # get the closest non_nan value
        speeds[speeds > max_speed] = np.inf
        closest = latest[non_nan_indices[np.argmin(speeds)]]

        # get the time and distance gaps
        time_gap = -time_distance[i, int(closest)]/FPS
        distance_gap = pos_distance[i, int(closest)]/sf
        if time_gap >1 or distance_gap > 10:
            print(f"Time gap: {time_gap}, Distance gap: {distance_gap}")

        # assign the tracklet to the closest non_nan value
        n = np.where(latest_tracks == closest)[0][0]
        tracks[start_times[i]:end_times[i]+1, n] = i
        assigned[i] = True
        latest_tracks[n] = i

# check if there are any unassigned tracklets
if np.any(~assigned):
    print(f"Unassigned tracklets: {np.where(~assigned)[0]}")
else:
    print("All tracklets assigned")

 37%|███▋      | 42503/115201 [00:00<00:00, 213180.76it/s]

Time gap: 5.5, Distance gap: 80.28942375098319
Time gap: 1.9375, Distance gap: 27.28666527830715
Time gap: 3.84375, Distance gap: 76.18952746084855
Time gap: 0.875, Distance gap: 59.08651430336132
Time gap: 5.25, Distance gap: 19.566242395198632
Time gap: 4.65625, Distance gap: 22.865077211614327
Time gap: 0.96875, Distance gap: 102.37779824361336
Time gap: 4.84375, Distance gap: 52.5392457199027
Time gap: 4.46875, Distance gap: 16.43264205813007
Time gap: 8.71875, Distance gap: 17.09280218371134
Time gap: 3.1875, Distance gap: 52.86529618020839
Time gap: 6.59375, Distance gap: 0.07582782224717532
Time gap: 3.03125, Distance gap: 82.9821435422147
Time gap: 1.53125, Distance gap: 89.56049858433423
Time gap: 0.875, Distance gap: 11.049070537357506
Time gap: 3.4375, Distance gap: 81.87934633251638
Time gap: 1.90625, Distance gap: 46.438877637442644
Time gap: 3.21875, Distance gap: 0.07486629541625944
Time gap: 6.3125, Distance gap: 60.00741166186203
Time gap: 8.53125, Distance gap: 5.2175

 74%|███████▍  | 85469/115201 [00:00<00:00, 214417.46it/s]

Time gap: 4.28125, Distance gap: 26.525123366244614
Time gap: 2.59375, Distance gap: 33.33608844794951
Time gap: 3.09375, Distance gap: 7.015893107613179
Time gap: 3.28125, Distance gap: 47.53092112617884
Time gap: 2.53125, Distance gap: 7.480478578079525
Time gap: 6.0625, Distance gap: 60.43973362863081
Time gap: 2.25, Distance gap: 43.959167930333514
Time gap: 1.84375, Distance gap: 15.58610206247375
Time gap: 5.25, Distance gap: 65.46467034198737
Time gap: 12.65625, Distance gap: 18.48614181390524
Time gap: 0.875, Distance gap: 15.37299034255566
Time gap: 1.25, Distance gap: 14.439691081942115
Time gap: 1.3125, Distance gap: 10.403546814817433
Time gap: 8.0625, Distance gap: 54.82858858227791
Time gap: 7.71875, Distance gap: 76.63598587454234
Time gap: 3.9375, Distance gap: 42.64416829071219
Time gap: 6.53125, Distance gap: 130.4657253614307
Time gap: 1.125, Distance gap: 112.9972139762681
Time gap: 4.84375, Distance gap: 50.61216865724134
Time gap: 2.125, Distance gap: 56.495405103

100%|██████████| 115201/115201 [00:00<00:00, 213163.44it/s]

Time gap: 2.625, Distance gap: 31.68110277926416
Time gap: 4.09375, Distance gap: 59.0627077073906
Time gap: 5.96875, Distance gap: 121.60405562624202
Time gap: 0.78125, Distance gap: 17.363571927995817
Time gap: 1.65625, Distance gap: 17.56553490932373
Time gap: 4.1875, Distance gap: 56.39806296223027
Time gap: 0.34375, Distance gap: 17.06729001538418
Time gap: 7.625, Distance gap: 48.64828676095287
Time gap: 1.0625, Distance gap: 38.62172867784684
Time gap: 0.9375, Distance gap: 17.69150086493752
Time gap: 3.1875, Distance gap: 4.4195978499089295
Time gap: 6.3125, Distance gap: 37.189716467452314
Time gap: 3.21875, Distance gap: 47.44300419663685
Time gap: 4.625, Distance gap: 67.96157985921141
Time gap: 3.75, Distance gap: 29.18350753244227
Time gap: 5.84375, Distance gap: 77.90478187682908
Time gap: 3.125, Distance gap: 17.496056416972362
Time gap: 4.6875, Distance gap: 59.21598309562258
Time gap: 0.375, Distance gap: 34.107136976342495
Time gap: 1.28125, Distance gap: 12.979118540




In [97]:
# reconstruct the tracks
for i in range(N_FLIES):
    # get all tracklets that are assigned to fly i
    assigned_tracklet_ids = np.unique(tracks[:, i])
    print(f"Fly {i+1} has {len(assigned_tracklet_ids)} tracklets assigned to it")
    # get all the tracklets
    assigned_tracklets = [tracklets[int(tracklet_id)] for tracklet_id in assigned_tracklet_ids if not np.isnan(tracklet_id)]
    # assign frame as index
    for j in range(len(assigned_tracklets)):
        assigned_tracklets[j] = assigned_tracklets[j].set_index('frame')
    # concatenate all tracklets
    assigned_tracklets = pd.concat(assigned_tracklets)
    # add unrepresented indices
    all_indices = np.arange(min_frame, max_frame+1)
    missing_indices = np.setdiff1d(all_indices, assigned_tracklets.index)
    missing_tracklets = pd.DataFrame(index=missing_indices)
    for col in assigned_tracklets.columns:
        missing_tracklets[col] = np.nan
    assigned_tracklets = pd.concat([assigned_tracklets, missing_tracklets])
    assigned_tracklets = assigned_tracklets.sort_index().reset_index()
    assigned_tracklets = assigned_tracklets.rename(columns={'index': 'frame'})
    # interpolate missing values
    # for col in assigned_tracklets.columns:
        # assigned_tracklets[col] = assigned_tracklets[col].interpolate()

Fly 1 has 144 tracklets assigned to it
Fly 2 has 118 tracklets assigned to it
Fly 3 has 66 tracklets assigned to it
Fly 4 has 79 tracklets assigned to it
Fly 5 has 131 tracklets assigned to it
Fly 6 has 141 tracklets assigned to it
Fly 7 has 120 tracklets assigned to it


In [99]:
plt.plot(assigned_tracklets['pos x'], assigned_tracklets['pos y'], 'k-', lw=0.2)

[<matplotlib.lines.Line2D at 0x3246add20>]

In [33]:
plt.plot(tracks)
plt.show()

In [93]:
start_times = np.array([tracklet['frame'].iloc[0] for tracklet in tracklets])
end_times = np.array([tracklet['frame'].iloc[-1] for tracklet in tracklets])
start_posx = np.array([tracklet['pos x'].iloc[0] for tracklet in tracklets])
start_posy = np.array([tracklet['pos y'].iloc[0] for tracklet in tracklets])
end_posx = np.array([tracklet['pos x'].iloc[-1] for tracklet in tracklets])
end_posy = np.array([tracklet['pos y'].iloc[-1] for tracklet in tracklets])

# distance matrices
time_distance = np.zeros((len(tracklets), len(tracklets)))*np.nan
for i in range(len(tracklets)):
    for j in range(len(tracklets)):
        time_distance[i, j] = end_times[i] - start_times[j]
time_distance[time_distance < 0] = np.nan

pos_distance = np.zeros((len(tracklets), len(tracklets)))*np.nan
for i in range(len(tracklets)):
    for j in range(len(tracklets)):
        pos_distance[i, j] = np.sqrt((end_posx[i] - start_posx[j])**2 + (end_posy[i] - start_posy[j])**2)

In [97]:
len(tracklets)

792

In [99]:
np.sum(time_distance<5*FPS)

612

In [101]:
pos_distance[time_distance<5*FPS]

array([6.05317292e+02, 3.19368618e+02, 1.02722405e+03, 4.73700598e+02,
       7.86927636e+02, 4.60518188e+02, 3.76165850e+02, 6.21164763e+02,
       4.76709892e+02, 1.10258786e+00, 3.39048883e+02, 1.61910342e+02,
       8.10802787e+02, 3.57931316e+02, 5.21477825e+02, 6.27023114e+02,
       2.53812254e+02, 6.23404135e+02, 3.50498641e+02, 3.17542074e+02,
       7.84386628e+02, 2.38041551e+02, 3.38005917e+02, 8.85659105e+01,
       3.08642613e+02, 9.50905684e+02, 6.43738495e+02, 6.13596006e+02,
       5.15295664e+02, 6.64139677e+02, 3.41678211e+00, 6.73941852e+02,
       7.29711615e+02, 3.47845512e+02, 2.18742475e+02, 5.91889137e+02,
       7.92043632e+02, 7.35968318e+02, 2.99670087e+02, 6.47359086e+02,
       2.79934953e+02, 8.54159180e+02, 1.07813725e+03, 3.20127020e+02,
       1.05706387e+03, 9.83000960e+02, 6.30001105e+02, 9.13824306e+02,
       3.72323784e+00, 2.95764670e+02, 1.51080785e+02, 3.24626847e+02,
       3.35887533e+02, 3.98664258e+02, 5.14397891e+02, 3.39900015e+02,
      

In [86]:
n_flies_found

[5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 5,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 7,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,
 6,


In [70]:
np.sum(start_times<=100)

1

In [61]:
n_flies_found

[1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,


In [120]:
# phase 1 data
for i in range(7):
    df = pd.read_csv(f'{prefix}_phase_2/{prefix}_phase_2-trackfeat.csv/fly{i+1}.csv')
    plt.plot(df['pos x'], df['pos y'], '-', lw=0.2)
plt.gca().set_aspect('equal')
plt.title('Phase 1')
plt.show()

In [87]:
df.head()

Unnamed: 0,pos x,pos y,ori,major axis len,minor axis len,body area,fg area,img contrast,min fg dist,wing l x,...,leg 6 ang,vel,ang_vel,min_wing_ang,max_wing_ang,mean_wing_length,axis_ratio,fg_body_ratio,contrast,dist_to_wall
0,63.192,874.3,2.3751,21.945,12.038,198.0,210.0,0.42624,507.41,,...,,1.0637,3.5629,,,,1.8229,1.0606,0.42624,3.9332
1,63.348,874.15,2.2637,20.373,11.71,178.0,187.0,0.50142,511.24,,...,,1.0637,3.1116,,,,1.8185,1.0547,0.4882,3.9332
2,63.246,873.41,2.3187,21.277,10.793,175.0,185.0,0.52373,514.05,,...,,1.6116,2.0619,,,,1.8442,1.0586,0.52383,3.9332
3,63.043,873.18,2.2821,20.581,12.146,187.0,200.0,0.54644,511.95,,...,,1.0986,1.0491,,,,1.8071,1.0617,0.5351,3.9666
4,62.989,873.22,2.2853,20.964,11.223,178.0,187.0,0.52379,508.41,,...,,0.14608,0.34918,,,,1.7449,1.0574,0.50905,3.9666


In [121]:
max_displacement = 30 # mm/s 
max_displacement = max_displacement * sf / FPS

In [122]:
# get x and y 
x = df['pos x'].values
y = df['pos y'].values
# interpolate missing values
x = np.interp(np.arange(len(x)), np.arange(len(x))[~np.isnan(x)], x[~np.isnan(x)])
y = np.interp(np.arange(len(y)), np.arange(len(y))[~np.isnan(y)], y[~np.isnan(y)])

# get difference between consecutive frames
dx = np.diff(x)
dy = np.diff(y)
d = np.sqrt(dx**2 + dy**2)

plt.plot(x, y, 'k-', lw=0.2)
# mark every point where d > max_displacement
points = np.where(d > max_displacement)[0]
# draw a line between consecutive points
for i in points:
    plt.plot(x[i:i+2], y[i:i+2], 'r-', lw=0.2)
plt.gca().set_aspect('equal')

In [223]:
# convert to tracklets
tracklets = []
for i in range(7):
    df = pd.read_csv(f'{prefix}_phase_1/{prefix}_phase_1-trackfeat.csv/fly{i+1}.csv')
    x = df['pos x'].values
    y = df['pos y'].values
    ori = df['ori'].values

    # interpolate missing values
    x = np.interp(np.arange(len(x)), np.arange(len(x))[~np.isnan(x)], x[~np.isnan(x)])
    y = np.interp(np.arange(len(y)), np.arange(len(y))[~np.isnan(y)], y[~np.isnan(y)])
    ori = np.interp(np.arange(len(ori)), np.arange(len(ori))[~np.isnan(ori)], ori[~np.isnan(ori)])
    
    # get difference between consecutive frames
    d = np.sqrt(np.diff(x)**2 + np.diff(y)**2)
    # get points where d > max_displacement
    points = np.where(d > max_displacement)[0]
    start = 0
    for i in points:
        t = pd.DataFrame({'pos x': x[start:i+1], 'pos y': y[start:i+1], 'ori': ori[start:i+1]})
        tracklets.append(t)
        start = i+1
    t = pd.DataFrame({'pos x': x[start:], 'pos y': y[start:], 'ori': ori[start:]})
    tracklets.append(t)

# filter out tracklets with less than 5*FPS frames
tracklets = [t for t in tracklets if len(t) >= 5*FPS]

# add ying and yang coordinates to each tracklet
i = 0
for t in tqdm(tracklets):
    t['ying pos'], t['ying dist'] = zip(*[y_coordinates([x, y], ying_props) for x, y in zip(t['pos x'], t['pos y'])])
    t['yang pos'], t['yang dist'] = zip(*[y_coordinates([x, y], yang_props) for x, y in zip(t['pos x'], t['pos y'])])
    t['track id'] = i
    i += 1

# combine and save tracklets
tracklets = pd.concat(tracklets)
tracklets.to_csv('phase_1_tracklets.csv', index=False)

# convert to tracklets
tracklets = []
for i in range(7):
    df = pd.read_csv(f'{prefix}_phase_2/{prefix}_phase_2-trackfeat.csv/fly{i+1}.csv')
    x = df['pos x'].values
    y = df['pos y'].values
    ori = df['ori'].values

    # interpolate missing values
    x = np.interp(np.arange(len(x)), np.arange(len(x))[~np.isnan(x)], x[~np.isnan(x)])
    y = np.interp(np.arange(len(y)), np.arange(len(y))[~np.isnan(y)], y[~np.isnan(y)])
    ori = np.interp(np.arange(len(ori)), np.arange(len(ori))[~np.isnan(ori)], ori[~np.isnan(ori)])
    
    # get difference between consecutive frames
    d = np.sqrt(np.diff(x)**2 + np.diff(y)**2)
    # get points where d > max_displacement
    points = np.where(d > max_displacement)[0]
    start = 0
    for i in points:
        t = pd.DataFrame({'pos x': x[start:i+1], 'pos y': y[start:i+1], 'ori': ori[start:i+1]})
        tracklets.append(t)
        start = i+1
    t = pd.DataFrame({'pos x': x[start:], 'pos y': y[start:], 'ori': ori[start:]})
    tracklets.append(t)

# filter out tracklets with less than 5*FPS frames
tracklets = [t for t in tracklets if len(t) >= 5*FPS]

# add ying and yang coordinates to each tracklet
i = 0
for t in tqdm(tracklets):
    t['ying pos'], t['ying dist'] = zip(*[y_coordinates([x, y], ying_props) for x, y in zip(t['pos x'], t['pos y'])])
    t['yang pos'], t['yang dist'] = zip(*[y_coordinates([x, y], yang_props) for x, y in zip(t['pos x'], t['pos y'])])
    t['track id'] = i
    i += 1

# combine and save tracklets
tracklets = pd.concat(tracklets)
tracklets.to_csv('phase_2_tracklets.csv', index=False)



100%|██████████| 435/435 [00:22<00:00, 19.34it/s]
100%|██████████| 792/792 [00:53<00:00, 14.78it/s]


In [424]:
for phase in [1,2]:
    # load tracklets and split by track id
    tracklets = pd.read_csv(f'phase_{phase}_tracklets.csv')
    tracklets = [group for _, group in tracklets.groupby('track id')]

    encounter_distance = 0.5 * 2 * 1 * sf # 1/2 of 2 * 1mm body widths converted to pixels

    ying_trail_width = 0.5 * 1.5 * sf # 1/2 of 3 mm trail converted to pixels
    yang_trail_width = 0.5 * 5 * sf # 1/2 of 5 mm trail converted to pixels

    print('Encounter distance:', encounter_distance)
    # get all ying encounters
    ying_encounters = []
    for tracklet in tracklets:
        val = np.abs(np.concatenate([[np.inf], tracklet['ying dist'].values])) < encounter_distance + ying_trail_width
        count = np.cumsum(np.concatenate([np.array([np.nan]),np.diff(val.astype(int))])>0)
        mes = val*count
        # get every continuous segment where the fly is within encounter_distance of the ying circle
        encounters = [group for m, group in tracklet.groupby(mes[1:]) if m != 0]
        # if the pos changes more than np.pi in a single frame,split it
        for e in encounters:
            pos = e['ying pos'].values
            # add a nan to the start of the array
            pos = np.concatenate([np.array([np.nan]), pos])
            diff = np.abs(np.diff(pos))>np.pi
            split = np.where(diff)[0]
            if len(split) == 0:
                ying_encounters.append(e)
                continue
            start = 0
            for s in split:
                ying_encounters.append(e.iloc[start:s])
                start = s
            ying_encounters.append(e.iloc[start:])
    # filter for atleast 1 sign change
    ying_encounters = [encounter for encounter in ying_encounters if np.abs(np.diff(np.sign(encounter['ying dist']))).sum() > 1]

    # get all yang encounters
    yang_encounters = []
    for tracklet in tracklets:
        val = np.abs(np.concatenate([[np.inf], tracklet['yang dist'].values])) < encounter_distance + yang_trail_width
        count = np.cumsum(np.concatenate([np.array([np.nan]),np.diff(val.astype(int))])>0)
        mes = val*count
        # get every continuous segment where the fly is within encounter_distance of the yang circle
        encounters = [group for m, group in tracklet.groupby(mes[1:]) if m != 0]
        # if the pos changes more than np.pi in a single frame,split it
        for e in encounters:
            pos = e['yang pos'].values
            # add a nan to the start of the array
            pos = np.concatenate([np.array([np.nan]), pos])
            diff = np.abs(np.diff(pos))>np.pi
            split = np.where(diff)[0]
            if len(split) == 0:
                yang_encounters.append(e)
                continue
            start = 0
            for s in split:
                yang_encounters.append(e.iloc[start:s])
                start = s
            yang_encounters.append(e.iloc[start:])
    # filter for atleast 1 sign change
    yang_encounters = [encounter for encounter in yang_encounters if np.abs(np.diff(np.sign(encounter['yang dist']))).sum() > 1]

    # plot all ying and yang encounters
    plt.figure()
    plt.imshow(p2_bg, cmap='gray')
    for encounter in ying_encounters:
        plt.plot(encounter['pos x'], encounter['pos y'], 'r-', lw=0.5)
    for encounter in yang_encounters:
        plt.plot(encounter['pos x'], encounter['pos y'], 'b-', lw=0.5)
    plt.title(f"Phase {phase} | YING: {len(ying_encounters)} encounters, YANG: {len(yang_encounters)} encounters")
    plt.savefig("figures/" + f"ying_yang_encounters_phase_{phase}.png")
    plt.show()

    # plot in ying spcae
    fig, ax = plt.subplots(1, 2, figsize=(5, 5), sharey=True)

    max_distances_ying = []
    for encounter in ying_encounters:
        position = (encounter['ying pos']*ying_props['outer_radius']/sf).values
        # find point of first crossing
        crossing = np.argmax(np.abs(np.diff(np.sign(encounter['ying dist'].values))))
        position = position - position[crossing]
        dists = [np.abs(np.max(position[crossing:])),np.abs(np.min(position[crossing:]))]
        if np.argmax(dists) == 1:
            position = -position
        distance = encounter['ying dist']/sf
        max_distances_ying.append(np.max(dists))
        ax[0].plot(distance, position, 'r-', lw=1, alpha=0.5)

    # sort by max distance
    ying_encounters = [ying_encounters[i].reset_index(drop=True) for i in np.argsort(max_distances_ying)]
    max_distances_ying = np.sort(max_distances_ying)

    # save ying encounters after adding an id
    for i, encounter in enumerate(ying_encounters):
        encounter['id'] = i
        encounter['max_distance'] = max_distances_ying[i]
    ying_encounters_CSV = pd.concat(ying_encounters)
    ying_encounters_CSV.to_csv(f'ying_encounters_phase_{phase}.csv', index=False)


    max_distances_yang = []
    for encounter in yang_encounters:
        position = (encounter['yang pos']*yang_props['outer_radius']/sf).values
        # find point of first crossing
        crossing = np.argmax(np.abs(np.diff(np.sign(encounter['yang dist'].values))))
        position = position - position[crossing]
        dists = [np.abs(np.max(position[crossing:])),np.abs(np.min(position[crossing:]))]
        if np.argmax(dists) == 1:
            position = -position
        distance = encounter['yang dist']/sf
        max_distances_yang.append(np.max(dists))
        ax[1].plot(distance, position, 'b-', lw=1, alpha=0.5)

    # sort by max distance
    yang_encounters = [yang_encounters[i].reset_index(drop=True) for i in np.argsort(max_distances_yang)]
    max_distances_yang = np.sort(max_distances_yang)

    # save yang encounters after adding an id
    for i, encounter in enumerate(yang_encounters):
        encounter['id'] = i
        encounter['max_distance'] = max_distances_yang[i]
    yang_encounters_CSV = pd.concat(yang_encounters)
    yang_encounters_CSV.to_csv(f'yang_encounters_phase_{phase}.csv', index=False)

    ax[0].set_title('Ying encounters')
    ax[1].set_title('Yang encounters')
    ax[0].set_xlabel('Distance from ying circle (mm)')
    ax[1].set_xlabel('Distance from yang circle (mm)')
    ax[0].set_ylabel('Position on trail')
    ax[1].set_ylabel('Position on trail')
    ax[0].set_aspect('equal')
    ax[1].set_aspect('equal')
    plt.suptitle(f"Phase {phase}")
    plt.tight_layout()
    plt.savefig("figures/" + f"ying_yang_encounters_space_phase_{phase}.png")
    plt.show()

    # histogram of ying and yang encounters
    bins = np.linspace(np.min([np.min(max_distances_ying), np.min(max_distances_yang)]), np.max([np.max(max_distances_ying), np.max(max_distances_yang)]), 20)
    plt.figure()
    plt.hist(max_distances_ying, bins=bins, color='r', linewidth=2, label='Ying', density=True, histtype='step')
    plt.hist(max_distances_yang, bins=bins, color='b', linewidth=2, label='Yang', density=True, histtype='step')
    plt.xlabel('Max distance from trail')
    plt.ylabel('Number of encounters')
    plt.yscale('log')
    plt.legend()
    plt.title(f'Ying and Yang encounters phase {phase}')
    plt.savefig("figures/" + f"ying_yang_encounters_histogram_phase_{phase}.png")
    plt.show()

    # plot top 30
    plt.figure()
    plt.imshow(p2_bg, cmap='gray')
    for encounter in ying_encounters[-30:]:
        plt.plot(encounter['pos x'], encounter['pos y'], '-', lw=1, color = plt.cm.cool(np.random.rand()))
    for encounter in yang_encounters[-30:]:
        plt.plot(encounter['pos x'], encounter['pos y'], '-', lw=1, color = plt.cm.cool(np.random.rand()))
    plt.title(f"Phase {phase}")
    plt.savefig("figures/" + f"ying_yang_encounters_top30_phase_{phase}.png")
    plt.show()




Encounter distance: 9.324972178273994
Encounter distance: 9.324972178273994


In [479]:
# get phase 1 and 2 encounters
ying_encounters_phase_1 = pd.read_csv('ying_encounters_phase_1.csv')
ying_encounters_phase_2 = pd.read_csv('ying_encounters_phase_2.csv')
yang_encounters_phase_1 = pd.read_csv('yang_encounters_phase_1.csv')
yang_encounters_phase_2 = pd.read_csv('yang_encounters_phase_2.csv')

# split by id
ying_encounters_phase_1 = [group for _, group in ying_encounters_phase_1.groupby('id')]
ying_encounters_phase_2 = [group for _, group in ying_encounters_phase_2.groupby('id')]
yang_encounters_phase_1 = [group for _, group in yang_encounters_phase_1.groupby('id')]
yang_encounters_phase_2 = [group for _, group in yang_encounters_phase_2.groupby('id')]




In [452]:
for i in range(50):
    ori = np.unwrap(ying_encounters_phase_2[len(ying_encounters_phase_2)-i-1]['ori'].values)
    vel = np.diff(ori)
    dist = ying_encounters_phase_2[len(ying_encounters_phase_2)-i-1]['ying dist'].values
    # remove close to zero
    mask = np.abs(vel) > np.pi/8
    dist = dist[1:][mask]
    vel = vel[mask]
    plt.scatter(dist, vel, color='r', s=1)



In [480]:
# create a video of top encounters
import skvideo.io
import os
from tqdm import tqdm

phase = 2
N = 15
# load background image in cv2
bg = cv2.cvtColor((eval(f"p{phase}_bg")*255).astype(np.uint8), cv2.COLOR_GRAY2RGB)

# create a video for each ying encounter
for i in range(N):
    # get x and y
    x = eval(f'ying_encounters_phase_{phase}[len(ying_encounters_phase_{phase})-{i}-1]["pos x"].values')
    y = eval(f'ying_encounters_phase_{phase}[len(ying_encounters_phase_{phase})-{i}-1]["pos y"].values')
    ori = eval(f'ying_encounters_phase_{phase}[len(ying_encounters_phase_{phase})-{i}-1]["ori"].values')
    writer = skvideo.io.FFmpegWriter('videos/ying_phase_{}_encounter_{}.mp4'.format(phase, i))
    for j in tqdm(range(len(x))):

        frame = bg.copy()
        cv2.circle(frame, (int(x[j]), int(y[j])), 5, (255,0,0), -1)
        # draw heading
        x2 = x[j] + 30*np.cos(ori[j])
        y2 = y[j] + 30*np.sin(ori[j])
        cv2.line(frame, (int(x[j]), int(y[j])), (int(x2), int(y2)), (0,255,0), 1)
        for k in range(j):
            cv2.circle(frame, (int(x[k]), int(y[k])), 1, (0,0,255), -1)

        writer.writeFrame(frame)
    writer.close()

for i in range(N):
    # get x and y
    x = eval(f'yang_encounters_phase_{phase}[len(yang_encounters_phase_{phase})-{i}-1]["pos x"].values')
    y = eval(f'yang_encounters_phase_{phase}[len(yang_encounters_phase_{phase})-{i}-1]["pos y"].values')
    ori = eval(f'yang_encounters_phase_{phase}[len(yang_encounters_phase_{phase})-{i}-1]["ori"].values')
    writer = skvideo.io.FFmpegWriter('videos/yang_phase_{}_encounter_{}.mp4'.format(phase, i))
    for j in tqdm(range(len(x))):

        frame = bg.copy()
        cv2.circle(frame, (int(x[j]), int(y[j])), 5, (255,0,0), -1)
        # draw heading
        x2 = x[j] + 30*np.cos(ori[j])
        y2 = y[j] + 30*np.sin(ori[j])
        cv2.line(frame, (int(x[j]), int(y[j])), (int(x2), int(y2)), (0,255,0), 1)
        for k in range(j):
            cv2.circle(frame, (int(x[k]), int(y[k])), 1, (0,0,255), -1)

        writer.writeFrame(frame)
    writer.close()

100%|██████████| 52/52 [00:00<00:00, 115.62it/s]
100%|██████████| 50/50 [00:00<00:00, 173.49it/s]
100%|██████████| 71/71 [00:00<00:00, 73.44it/s] 
100%|██████████| 70/70 [00:00<00:00, 111.55it/s]
100%|██████████| 35/35 [00:00<00:00, 147.66it/s]
100%|██████████| 37/37 [00:00<00:00, 150.03it/s]
100%|██████████| 95/95 [00:00<00:00, 112.46it/s]
100%|██████████| 33/33 [00:00<00:00, 150.31it/s]
100%|██████████| 34/34 [00:00<00:00, 156.49it/s]
100%|██████████| 47/47 [00:00<00:00, 156.10it/s]
100%|██████████| 31/31 [00:00<00:00, 155.15it/s]
100%|██████████| 46/46 [00:00<00:00, 150.49it/s]
100%|██████████| 44/44 [00:00<00:00, 129.10it/s]
100%|██████████| 45/45 [00:00<00:00, 169.14it/s]
100%|██████████| 60/60 [00:00<00:00, 105.68it/s]
100%|██████████| 102/102 [00:00<00:00, 120.27it/s]
100%|██████████| 74/74 [00:00<00:00, 119.51it/s]
100%|██████████| 284/284 [00:02<00:00, 102.78it/s]
100%|██████████| 99/99 [00:00<00:00, 117.67it/s]
100%|██████████| 69/69 [00:00<00:00, 120.83it/s]
100%|██████████|

In [468]:
plt.imshow(bg)

<matplotlib.image.AxesImage at 0x32f05e0e0>

In [212]:
# # draw the tracklet in original space
# plt.figure()
# plt.imshow(p2_bg, cmap='gray')

# for tracklet in tracklets[:100]:
#     # project the tracklet on the ying circle
#     ying_pos = []
#     ying_dist = []
#     for i in range(len(tracklet)):
#         pos = np.array([tracklet['pos x'].iloc[i], tracklet['pos y'].iloc[i]])
#         pos, dist = y_coordinates(pos, ying_props)
#         ying_pos.append(pos)
#         ying_dist.append(abs(dist))
#     # color the tracklet based on the position on the ying circle
#     colorline(tracklet['pos x'], tracklet['pos y'], ying_dist, cmap=plt.cm.cool, norm=plt.Normalize(0, 350), linewidth=0.5)

# plt.gca().set_aspect('equal')
# plt.title('Original tracklet')
# plt.show()


In [None]:
encounter_distance = (2 * 3 + 3/2) * sf # 2 * 3mm body widths + 3mm trail converted to pixels

# get all ying encounters
ying_encounters = []
for tracklet in tracklets:
    # project the tracklet on the ying circle using vectorized function
    ying_pos = []
    ying_dist = []
    for i in range(len(tracklet)):
        pos = np.array([tracklet['pos x'].iloc[i], tracklet['pos y'].iloc[i]])
        pos, dist = y_coordinates(pos, ying_props)
        ying_pos.append(pos)
        ying_dist.append(abs(dist))
    
    # check if the tracklet encounters the ying circle
    if np.any(np.array(ying_dist) < encounter_distance):
        ying_encounters.append(tracklet)


In [91]:
# plot all tracklets
plt.figure()
for t in tracklets:
    plt.plot(t[0], t[1], 'k-', lw=0.2)
plt.gca().set_aspect('equal')
plt.title('Tracklets')
plt.show()

In [93]:
plt.figure()
plt.plot(df['vel'])
x = df['pos x']
y = df['pos y']
x = np.interp(np.arange(len(x)), np.where(~np.isnan(x))[0], x[~np.isnan(x)])
y = np.interp(np.arange(len(y)), np.where(~np.isnan(y))[0], y[~np.isnan(y)])
v = np.sqrt(np.diff(x)**2 + np.diff(y)**2)
plt.plot(v)
plt.show()

In [136]:
# find all nans
nan_idx = np.where(np.isnan(df['pos x']))[0]
# find every start and end of a nan sequence
nan_idx = nan_idx[np.where(np.diff(nan_idx) != 0)[0]]
# plot everything outside the nans
plt.figure()
while len(nan_idx) > 0:
    plt.plot(df['pos x'].iloc[:nan_idx[0]], df['pos y'].iloc[:nan_idx[0]], '-', lw=0.2)
    df = df.iloc[nan_idx[0]+1:]
    nan_idx = nan_idx[1:]
plt.gca().set_aspect('equal')



In [144]:
x = np.array(df['pos x'])
y = np.array(df['pos y'])
nan_idx = np.where(np.isnan(x))[0]
# interpolate the nan values
x = np.interp(np.arange(len(x)), np.where(~np.isnan(x))[0], x[~np.isnan(x)])
y = np.interp(np.arange(len(y)), np.where(~np.isnan(y))[0], y[~np.isnan(y)])
# convert to mm
x = x/sf
y = y/sf
# get the velocity
vx = np.gradient(x, 1/FPS)
vy = np.gradient(y, 1/FPS)
v = np.sqrt(vx**2 + vy**2)
plt.figure()
plt.plot(x)
plt.plot(y)
plt.plot(v)
plt.show()

In [162]:
x = np.array(df['pos x'])
y = np.array(df['pos y'])
nan_idx = np.where(np.isnan(x))[0]
# interpolate the nan values
# x = np.interp(np.arange(len(x)), np.where(~np.isnan(x))[0], x[~np.isnan(x)])
# y = np.interp(np.arange(len(y)), np.where(~np.isnan(y))[0], y[~np.isnan(y)])
# convert to mm
x = x#/sf
y = y#/sf
# get the velocity
vx = np.gradient(x, 1/FPS)
vy = np.gradient(y, 1/FPS)
v = np.sqrt(vx**2 + vy**2)

mask = v < 0.5
plt.plot(x, y, 'r-', lw=1, alpha=0.5)
plt.scatter(df['pos x'], df['pos y'], c=df['vel'], cmap='viridis', s=1,zorder=10)
# plt.plot(x[mask], y[mask], '-', lw=0.2)
# plt.scatter(df['pos x'][mask], df['pos y'][mask], c=df['vel'][mask], cmap='viridis')
# plt.colorbar()
plt.gca().set_aspect('equal')


In [112]:
plt.plot(v)

[<matplotlib.lines.Line2D at 0x38b1220b0>]

In [81]:
plt.plot(df['vel'])
# df['vel'] = df['vel']/sf

[<matplotlib.lines.Line2D at 0x37cd2fbb0>]

In [82]:
mask = df['vel'] <10
plt.plot(df['pos x'][mask], df['pos y'][mask], 'r-', lw=0.2)
# plt.scatter(df['pos x'][mask], df['pos y'][mask], c=df['vel'][mask], cmap='viridis')
# plt.colorbar()

[<matplotlib.lines.Line2D at 0x37cdbaa40>]

In [7]:
# get tracklets
def get_tracklets(data, center, radius):
    n_flies = data.shape[1]//6
    tracklets = []
    for i in range(n_flies):
        # get track
        track = data.iloc[:, i*6:i*6+6].copy().reset_index()
        track.columns = ['Frame','ID','x','y','body_crossinglength','body_width','heading']

        # remove nan values
        track.replace(-1, np.nan, inplace=True)
        track.dropna(inplace=True)

        # skip if the track is too short (less than 10 frames)
        if track.shape[0] < 10:
            continue

        # Calculate other variables
        x_ = np.concatenate(([track['x'].iloc[0]], track['x'].values))
        y_ = np.concatenate(([track['y'].iloc[0]], track['y'].values))

        # kinematics
        track['velocity'] = np.sqrt(np.diff(x_)**2 + np.diff(y_)**2) * FPS
        track['motion_direction'] = np.arctan2(np.diff(y_), np.diff(x_))
        track['acceleration'] = np.diff(np.concatenate(([0], track['velocity'].values))) * FPS
        track['angular_velocity'] = np.diff(np.concatenate(([track['heading'].iloc[0]], track['heading'].values))) * FPS
        track['angular_acceleration'] = np.diff(np.concatenate(([0], track['angular_velocity'].values))) * FPS

        # odor trail related variables
        track['distance'] = np.sqrt((track['x'] - center[0])**2 + (track['y'] - center[1])**2) - radius
        track['angle'] = np.rad2deg(np.unwrap(np.arctan2(track['y'] - center[1], track['x'] - center[0])))
        track['angular_distance'] = np.deg2rad(track['angle']) * radius
        track['trail_heading'] = track['heading'] - np.arctan2(track['y'] - center[1], track['x'] - center[0])
        tracklets.append(track)
    return tracklets

def get_segments(tracklet):
    # segment the tracklet into trail and non-trail based on distance from the trail and atleast one crossing
    trail = (np.abs(tracklet['distance']) < sf*0.5).values # trail is within 0.5 inches of the trail
    # find every time the trail is crossed
    crossing = (tracklet['distance']<trail_width*sf).astype(int).values
    crossing = np.abs(np.diff(np.concatenate(([crossing[0]], crossing))))>0
    # find every continuous segment of trail where there is atleast one crossing
    segments = []
    start = 0
    end = 0
    for i in range(1, len(trail)):
        if trail[i] == trail[i-1]:
            continue
        if not trail[i-1]:
            start = i
        else:
            end = i
            if crossing[start:end].sum() > 0:
                segment = tracklet.iloc[start:end].copy().reset_index(drop=True)
                # find the position of the first crossing
                crossing_idx = np.argmax(crossing[start:end])
                # get the interpolated position of the crossing
                crossing_x = np.interp(0, [segment['distance'].iloc[crossing_idx-1], segment['distance'].iloc[crossing_idx]], [segment['x'].iloc[crossing_idx-1], segment['x'].iloc[crossing_idx]])
                crossing_y = np.interp(0, [segment['distance'].iloc[crossing_idx-1], segment['distance'].iloc[crossing_idx]], [segment['y'].iloc[crossing_idx-1], segment['y'].iloc[crossing_idx]])
                crossing_a = np.interp(0, [segment['distance'].iloc[crossing_idx-1], segment['distance'].iloc[crossing_idx]], [segment['angle'].iloc[crossing_idx-1], segment['angle'].iloc[crossing_idx]])
                crossing_ad = np.interp(0, [segment['distance'].iloc[crossing_idx-1], segment['distance'].iloc[crossing_idx]], [segment['angular_distance'].iloc[crossing_idx-1], segment['angular_distance'].iloc[crossing_idx]])
                segment['x_crossing'] = segment['x'] - crossing_x
                segment['y_crossing'] = segment['y'] - crossing_y
                segment['angle_crossing'] = segment['angle'] - crossing_a
                segment['angular_distance_crossing'] = segment['angular_distance'] - crossing_ad
                segment['time_from_crossing'] = (segment['Frame'] - segment['Frame'].iloc[crossing_idx])/FPS
                segment['number_of_crossings'] = crossing[start:end].sum()
                segment['encounter_distance'] = np.max(np.abs(segment['angular_distance_crossing'].values[segment['time_from_crossing'].values>0]))/sf
                segment['time'] = segment['Frame'].iloc[crossing_idx]/FPS
                segments.append(segment)
    return segments


In [5]:
csv_file = '20hr-wingless-orco/20hr-wingless-orco-phase_1.csv'
data = pd.read_csv(csv_file, header=None)
tracklets = get_tracklets(data, center, radius)

In [6]:
fig, ax = plt.subplots()
# plot the inverted image
ax.imshow(1-frame,cmap='gray_r')
# plot the tracklets
for track in tracklets:
    ax.plot(track['x'], track['y'], '-',alpha=0.5, color='black',linewidth=0.1)
# plot the odor trail
circle = plt.Circle(center, radius, color='r', fill=False, linewidth=2)
ax.add_artist(circle)
plt.show()
# calculate a 2d histogram of the x and y coordinates
n_bins = 100
# combine all x and y coordinates
x = np.concatenate([track['x'].values for track in tracklets])
y = np.concatenate([track['y'].values for track in tracklets])
H, xedges, yedges = np.histogram2d(x, y, bins=n_bins)
H = H.T
# log scale
H = np.log(H+1)
extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
fig, ax = plt.subplots()
im = ax.imshow(H, extent=extent, origin='lower', cmap='hot')
cbar = plt.colorbar(im, ax=ax)
cbar.set_ticks(np.log([1, 10, 100, 1000]))
cbar.set_ticklabels([1, 10, 100, 1000])
# plot the odor trail
circle = plt.Circle(center, radius, color='k', fill=False, linewidth=2,zorder=10)
ax.add_artist(circle)
# plot the tracklets
for track in tracklets:
    ax.plot(track['x'], track['y'], '-',alpha=0.5, color='black',linewidth=0.1)
plt.show()

In [7]:
# project the tracklets into the space of the odor trail
tracklet = tracklets[1]

# plot the tracklet (color by speed)
fig, ax = plt.subplots()
ax.scatter(tracklet['distance'], tracklet['angular_distance'], c=tracklet['distance'], cmap='RdBu', s=1)
ax.plot(tracklet['distance'], tracklet['angular_distance'], '-', color='black', alpha=1, linewidth=0.1)
# plot a line at 0
ax.axvline(0, color='r', linestyle='--')
ax.set_xlabel('Distance from the odor trail (pixels)')
ax.set_ylabel('Angle from the odor trail (degrees)')
# ax.set_aspect('equal')
plt.show()
# also plot in normal space
fig, ax = plt.subplots()
ax.scatter(tracklet['x'], tracklet['y'], c=tracklet['distance'], cmap='RdBu', s=1)
ax.plot(tracklet['x'], tracklet['y'], '-', color='black', alpha=1, linewidth=0.1)
circle = plt.Circle(center, radius, color='r', fill=False)
ax.add_artist(circle)
ax.set_xlabel('x (pixels)')
ax.set_ylabel('y (pixels)')
ax.set_aspect('equal')
plt.show()

In [11]:
# plot the segments
show_heading = False
filter_length = 2 # inches
fig, ax = plt.subplots()
all_segments = []
for tracklet in tracklets:
    segments = get_segments(tracklet)
    for segment in segments:

        if filter_length is not None and segment['encounter_distance'].iloc[0] < filter_length:
            continue
        mod = 1 if np.mean(segment['angular_distance_crossing'].values[segment['time_from_crossing'].values < 0]) < 0 else -1
        ax.scatter(segment['distance']/sf, mod*segment['angular_distance_crossing']/sf, c=segment['time_from_crossing'], cmap='viridis', s=5, zorder=1)
        ax.plot(segment['distance']/sf, mod*segment['angular_distance_crossing']/sf, '-', color='black', alpha=1, linewidth=1, zorder=0)
        # plot the trail heading
        for i in range(len(segment)):
            x = segment['distance'].iloc[i]/sf
            y = mod*segment['angular_distance_crossing'].iloc[i]/sf
            a = mod*segment['trail_heading'].iloc[i]
            if show_heading:
                ax.plot([x, x + 10*np.cos(a)/sf], [y, y + 10*np.sin(a)/sf], '-', color='k', linewidth=0.5)
    all_segments.extend(segments)
ax.axvline(0, color='r', linestyle='--')
ax.set_xlabel('Distance from the odor trail (inches)')
ax.set_ylabel('Arc Distance on the odor trail (inches)')
ax.set_aspect('equal')
plt.show()

In [12]:
def corr(x,y):
    r = np.corrcoef(x,y)[0,1]
    n = len(x)
    r_z = np.arctanh(r)
    se = 1/np.sqrt(n-3)
    z = 1.96
    lo_z, hi_z = r_z-z*se, r_z+z*se
    lo, hi = np.tanh((lo_z, hi_z))
    return r, lo, hi

In [13]:
# plot velocity during encounters
from scipy.stats import sem
from tqdm import tqdm
max_frames_before_encounter = np.max([np.sum(segment['time_from_crossing'].values < 0) for segment in all_segments])
max_frames_after_encounter = np.max([np.sum(segment['time_from_crossing'].values > 0) for segment in all_segments])
fig, ax = plt.subplots()
m,l,h = [],[],[]
for i in tqdm(range(-max_frames_before_encounter, max_frames_after_encounter)):
    vels = []
    # get the velocity of each fly at this time
    for segment in all_segments:
        vels.append(np.log(segment['velocity'].values[np.abs(segment['time_from_crossing'].values - i/FPS) < 1/FPS]/sf+1e-4))
    vels = np.concatenate(vels)
    # get velocity and 95% confidence interval
    m.append(np.mean(vels))
    l.append(np.mean(vels) - 1.96*sem(vels))
    h.append(np.mean(vels) + 1.96*sem(vels))
m,l,h = np.exp(m), np.exp(l), np.exp(h)
ax.plot(np.arange(-max_frames_before_encounter, max_frames_after_encounter)/FPS, m, color='k')
ax.fill_between(np.arange(-max_frames_before_encounter, max_frames_after_encounter)/FPS, l, h, color='gray', alpha=0.5)
ax.axvline(0, color='r', linestyle='--')
ax.set_xlabel('Time from encounter (s)')
ax.set_ylabel('Velocity (inches/s)')
plt.show()

  ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
  ret = ret.dtype.type(ret / rcount)
 25%|██▍       | 267/1075 [00:04<00:12, 65.31it/s]


KeyboardInterrupt: 

In [14]:
# encounter duration
x,y = [],[]
for segment in all_segments:
    encounter_distance = segment['encounter_distance'].iloc[0]
    color = 'lightgray' if encounter_distance < 2 else 'k'
    encounter_duration = np.log(segment['time_from_crossing'].max())
    prior_velocity = np.mean(segment['velocity'].values[segment['time_from_crossing'].values < 0])/sf
    x.append(prior_velocity)
    y.append(encounter_duration)
    plt.scatter(prior_velocity, encounter_duration, c=color, s=5)
# calculate the correlation along with the 95% confidence interval
r, lo, hi = corr(x,y)
plt.axhline(0, color='r', linestyle='--')
plt.title('r = {:.2f} ({:.2f}, {:.2f})'.format(r, lo, hi))
plt.xlabel('Prior velocity (inches/s)')
plt.ylabel('Log(Encounter duration)')
plt.show()
# encounter distance
x,y = [],[]
plt.figure()
for segment in all_segments:
    encounter_distance = segment['encounter_distance'].iloc[0]
    color = 'lightgray' if encounter_distance < 2 else 'k'
    prior_velocity = np.mean(segment['velocity'].values[segment['time_from_crossing'].values < 0])/sf
    x.append(prior_velocity)
    y.append(encounter_distance)
    plt.scatter(prior_velocity, encounter_distance, c=color, s=5)
# calculate the correlation along with the 95% confidence interval
r, lo, hi = corr(x,y)
plt.axhline(2, color='r', linestyle='--')
plt.title('r = {:.2f} ({:.2f}, {:.2f})'.format(r, lo, hi))
plt.xlabel('Prior velocity (inches/s)')
plt.ylabel('Encounter distance (inches)')
plt.show()

In [15]:
# plot a histogram of the encounter distance
plt.figure()
x = [segment['encounter_distance'].iloc[0] for segment in all_segments]
vals, bins = np.histogram(x, bins=50)
vals = vals/np.sum(vals)
plt.bar(bins[:-1], vals, width=np.diff(bins)[0])
plt.xlabel('Encounter distance (inches)')
plt.ylabel('Frequency')
# expected exponential distribution
from scipy.optimize import curve_fit
def exp(x, a, b):
    return a*np.exp(-b*x)
popt, pcov = curve_fit(exp, bins[:-1], vals)
plt.plot(bins, exp(bins, *popt), 'r--')
plt.show()

# plot encounter distance vs average velocity before crossing
plt.figure()
x,y = [],[]
for segment in all_segments:
    encounter_distance = segment['encounter_distance'].iloc[0]
    prior_velocity = np.mean(segment['velocity'].values[segment['time_from_crossing'].values < 0])/sf
    x.append(prior_velocity)
    y.append(encounter_distance)
plt.scatter(x, y, c='k', s=5)
# calculate the correlation along with the 95% confidence interval
r, lo, hi = corr(x,y)
plt.title('r = {:.2f} ({:.2f}, {:.2f})'.format(r, lo, hi))
plt.xlabel('Prior velocity (inches/s)')
plt.ylabel('Encounter distance (inches)')
plt.show()

# plot encounter distance vs average trail heading before crossing
def tortuosity(x, y):
    distance = np.sum(np.sqrt(np.diff(x)**2 + np.diff(y)**2))
    straight_distance = np.sqrt((x[0]-x[-1])**2 + (y[0]-y[-1])**2)
    return distance/straight_distance

plt.figure()
x,y,s = [],[],[]
for segment in all_segments:
    encounter_distance = segment['encounter_distance'].iloc[0]
    prior_heading = np.mean(segment['trail_heading'].values[segment['time_from_crossing'].values < 0])
    prior_heading = prior_heading % (2*np.pi)
    x.append(prior_heading)
    y.append(encounter_distance)
plt.scatter(x, y, c='k', cmap='viridis', s=1)
plt.xlim([0, 2*np.pi])
plt.xticks([0, np.pi/2, np.pi, 3*np.pi/2, 2*np.pi], ['0',r'$\frac{\pi}{2}$', r'$\pi$', r'$\frac{3\pi}{2}$', r'$2\pi$'])
plt.xlabel('Prior trail heading')
plt.ylabel('Encounter distance (inches)')
plt.show()


  plt.scatter(x, y, c='k', cmap='viridis', s=1)


In [16]:
# plot points in x,y where the encounter occurs
plt.figure()
plt.imshow(1-frame,cmap='gray_r')
for segment in all_segments:
    encounter_distance = np.max(np.abs(segment['angular_distance_crossing'].values[segment['time_from_crossing'].values>0]))/sf
    if encounter_distance < 2:
        color = 'gray'
    else:
        color = 'r'
    plt.scatter(segment['x'][segment['time_from_crossing'] == 0], segment['y'][segment['time_from_crossing'] == 0], c=color, s=5)
plt.show()

In [38]:
# plot points in x,y where the encounter occurs
plt.figure()
plt.imshow(1-frame,cmap='gray_r')
dists = []
xs = []
ys = []
for segment in all_segments:
    encounter_distance = np.max(np.abs(segment['angular_distance_crossing'].values[segment['time_from_crossing'].values>0]))/sf
    if encounter_distance < 2:
        continue
    xs.append(segment['x'][segment['time_from_crossing'] >= 0].values)
    ys.append(segment['y'][segment['time_from_crossing'] >= 0].values)
    dists.append(encounter_distance)
    plt.plot(segment['x'][segment['time_from_crossing'] >= 0], segment['y'][segment['time_from_crossing'] >= 0], '-', color=plt.cm.Reds(encounter_distance/8), linewidth=0.5)
    # mark the start of the encounter
    plt.scatter(segment['x'][segment['time_from_crossing'] == 0], segment['y'][segment['time_from_crossing'] == 0], c=plt.cm.Reds(encounter_distance/8), s=15)
plt.show()

  plt.scatter(segment['x'][segment['time_from_crossing'] == 0], segment['y'][segment['time_from_crossing'] == 0], c=plt.cm.Reds(encounter_distance/8), s=15)


In [40]:
# plot only top n encounters
N = 10
for i in np.argsort(dists)[-N:]:
    plt.figure()
    plt.imshow(1-frame,cmap='gray_r')
    x = xs[i]
    y = ys[i]
    plt.plot(x, y, '-', color='r', linewidth=0.5)
    # mark the start of the encounter
    plt.scatter(x[0], y[0], c='r', s=15)
    plt.title('D: {:.2f} T = {:d}:{:d}'.format(dists[i], int(all_segments[i]['time'].iloc[0]//60), int(all_segments[i]['time'].iloc[0]%60)))
    plt.show()

In [61]:
video_file = '20hr-wingless-orco/20hr-wingless-orco_phase_1.mp4'
csv_file = '20hr-wingless-orco/20hr-wingless-orco-phase_1.csv'

# load tracklets
data = pd.read_csv(csv_file, header=None)
tracklets = get_tracklets(data, center, radius)
all_segments = []
for tracklet in tracklets:
    segments = get_segments(tracklet)
    all_segments.extend(segments)
# sort the segments by encounter distance
all_segments = sorted(all_segments, key=lambda x: x['encounter_distance'].iloc[0], reverse=True)

# crossings
crossings = [segment['number_of_crossings'].iloc[0] for segment in all_segments]
plt.hist(crossings, bins=range(1,10))
plt.xlabel('Number of crossings')
plt.ylabel('Frequency')
plt.show()

# create video of encounters
import skvideo.io
import os
from tqdm import tqdm

N = 5
# get background image from the video
cap = cv2.VideoCapture(video_file)
for i in range(N):
    segment = all_segments[i]
    x = segment['x'].values
    y = segment['y'].values
    f = segment['Frame'].values
    # set start frame
    start_frame = f[0]
    cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
    writer = skvideo.io.FFmpegWriter('true_encounter_{:d}.mp4'.format(i))
    for j in tqdm(range(len(x))):
        ret, frame = cap.read()
        if not ret:
            break
        # draw on image using opencv
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        cv2.circle(frame, (int(x[j]), int(y[j])), 5, (255,0,0), -1)
        for k in range(j):
            cv2.circle(frame, (int(x[k]), int(y[k])), 1, (0,0,255), -1)
        # draw the trail
        cv2.circle(frame, (int(center[0]), int(center[1])), int(radius+0.05*sf), (0,0,0), 1)
        cv2.circle(frame, (int(center[0]), int(center[1])), int(radius-0.05*sf), (0,0,0), 1)

        writer.writeFrame(frame)
    writer.close()
cap.release()

100%|██████████| 113/113 [00:01<00:00, 75.88it/s]
100%|██████████| 99/99 [00:01<00:00, 67.91it/s]
100%|██████████| 181/181 [00:03<00:00, 60.17it/s]
100%|██████████| 70/70 [00:01<00:00, 66.71it/s]
100%|██████████| 212/212 [00:03<00:00, 64.44it/s]


In [3]:
video_file = '20hr-wingless-orco/20hr-wingless-orco_phase_2.mp4'
csv_file = '20hr-wingless-orco/20hr-wingless-orco-phase_2.csv'

# load tracklets
data = pd.read_csv(csv_file, header=None)
tracklets = get_tracklets(data, center, radius)
all_segments = []
for tracklet in tracklets:
    segments = get_segments(tracklet)
    all_segments.extend(segments)
# sort the segments by encounter distance
all_segments = sorted(all_segments, key=lambda x: x['encounter_distance'].iloc[0], reverse=True)

# crossings
crossings = [segment['number_of_crossings'].iloc[0] for segment in all_segments]
plt.hist(crossings, bins=range(1,10))
plt.xlabel('Number of crossings')
plt.ylabel('Frequency')
plt.show()

# create video of encounters
import skvideo.io
import os
from tqdm import tqdm

N = 5
# get background image from the video
cap = cv2.VideoCapture(video_file)
for i in range(N):
    segment = all_segments[i]
    x = segment['x'].values
    y = segment['y'].values
    f = segment['Frame'].values
    # set start frame
    start_frame = f[0]
    cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
    writer = skvideo.io.FFmpegWriter('false_encounter_{:d}.mp4'.format(i))
    for j in tqdm(range(len(x))):
        ret, frame = cap.read()
        if not ret:
            break
        # draw on image using opencv
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        cv2.circle(frame, (int(x[j]), int(y[j])), 5, (255,0,0), -1)
        for k in range(j):
            cv2.circle(frame, (int(x[k]), int(y[k])), 1, (0,0,255), -1)
        # draw the trail
        cv2.circle(frame, (int(center[0]), int(center[1])), int(radius+0.05*sf), (0,0,0), 1)
        cv2.circle(frame, (int(center[0]), int(center[1])), int(radius-0.05*sf), (0,0,0), 1)

        writer.writeFrame(frame)
    writer.close()
cap.release()

NameError: name 'get_tracklets' is not defined

In [None]:
times = []
distances = []
for segment in all_segments:
    times.append(segment['Frame'][segment['time_from_crossing'] == 0].values[0])
    distances.append(np.max(np.abs(segment['angular_distance_crossing']).values[segment['time_from_crossing'].values>0])/sf)
distances = np.array(distances)
times = np.array(times)
plt.figure()
plt.hist(times, bins=np.arange(0, np.max(times), 30*FPS), color='gray')
plt.hist(times[distances > 2], bins=np.arange(0, np.max(times), 30*FPS), color='red')
plt.xlabel('Frame number')
plt.ylabel('Frequency')
plt.show()

In [73]:
video_file = '20hr-wingless-orco/20hr-wingless-orco_phase_3.mp4'
csv_file = '20hr-wingless-orco/20hr-wingless-orco-phase_3.csv'

# load tracklets
data = pd.read_csv(csv_file, header=None)
tracklets = get_tracklets(data, center, radius)
all_segments = []
for tracklet in tracklets:
    segments = get_segments(tracklet)
    all_segments.extend(segments)
# sort the segments by encounter distance
all_segments = sorted(all_segments, key=lambda x: x['number_of_crossings'].iloc[0], reverse=True)

# create video of encounters
import skvideo.io
import os
from tqdm import tqdm

N = 5
# get background image from the video
cap = cv2.VideoCapture(video_file)
for i in range(N):
    segment = all_segments[i]
    x = segment['x'].values
    y = segment['y'].values
    f = segment['Frame'].values
    # set start frame
    start_frame = f[0]
    cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
    writer = skvideo.io.FFmpegWriter('close_encounter_{:d}.mp4'.format(i))
    for j in tqdm(range(len(x))):
        ret, frame = cap.read()
        if not ret:
            break
        # draw on image using opencv
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        cv2.circle(frame, (int(x[j]), int(y[j])), 5, (255,0,0), -1)
        for k in range(j):
            cv2.circle(frame, (int(x[k]), int(y[k])), 1, (0,0,255), -1)
        # draw the trail
        cv2.circle(frame, (int(center[0]), int(center[1])), int(radius+0.05*sf), (0,0,0), 1)
        cv2.circle(frame, (int(center[0]), int(center[1])), int(radius-0.05*sf), (0,0,0), 1)

        writer.writeFrame(frame)
    writer.close()
cap.release()

100%|██████████| 248/248 [00:03<00:00, 64.99it/s]
100%|██████████| 158/158 [00:02<00:00, 63.86it/s]
100%|██████████| 68/68 [00:00<00:00, 81.57it/s] 
100%|██████████| 215/215 [00:03<00:00, 69.26it/s]
100%|██████████| 171/171 [00:02<00:00, 61.90it/s]


In [None]:
video_file = '20hr-wingless-orco/20hr-wingless-orco_phase_1.mp4'
csv_file = '20hr-wingless-orco/20hr-wingless-orco-phase_1.csv'

# load tracklets
data = pd.read_csv(csv_file, header=None)
tracklets = get_tracklets(data, center, radius)
all_segments = []
for tracklet in tracklets:
    segments = get_segments(tracklet)
    all_segments.extend(segments)
# sort the segments by encounter distance
all_segments = sorted(all_segments, key=lambda x: x['encounter_distance'].iloc[0], reverse=True)

In [47]:
# get every 6th column of data
data = pd.read_csv(file, header=None)

n_flies = 4
n_window = 5 # events
max_dist = 50 # pixel
max_time = 10*60*10 # 10 minutes * 60 seconds * 10 frames/second

# get tracklets
ncols = data.shape[1]
assert ncols % 6 == 0, 'Error: ncols is not a multiple of 6, check Ctrax output'
ntracks =   ncols // 6
tracklets = []
for i in range(ntracks):
    track = data.iloc[:, i*6:i*6+6].copy().reset_index()
    track.columns = ['Frame','ID','x','y','body_length','body_width','heading']
    track.replace(-1, np.nan, inplace=True)
    track.dropna(inplace=True)
    tracklets.append(track)

IDs = data.iloc[:, 0::6].values + 1
IDs = np.vstack([np.zeros(IDs.shape[1]), IDs]) # add a row of zeros at the beginning
births = np.diff(IDs, axis=0) > 0
deaths = np.diff(IDs, axis=0) < 0
Xs = data.iloc[:, 1::6].values
Ys = data.iloc[:, 2::6].values


In [50]:
def euclidean_distance(x1, y1, x2, y2):
    return np.sqrt((x1 - x2)**2 + (y1 - y2)**2)

clean_tracklets = []

i = -1
current_size = len(tracklets)

while len(tracklets) > n_flies:
    i += 1 % len(tracklets)
    # if we have gone through all the tracklets, break
    if i == 0:
        if current_size == len(tracklets):
            break
        current_size = len(tracklets)
    track = tracklets[i]
    while True:
        last_frame = track['Frame'].values[-1]
        # get the next few births
        next_births = np.where(births.sum(axis=1) > 0)[0]
        next_births = next_births[np.logical_and(next_births > last_frame, next_births < last_frame + max_time)][:n_window]
        print('Next births:', next_births)
        # if there are no more births, break
        if len(next_births) == 0:
            break
        # get tracklets that are born in the next few frames
        birth_ids = np.where(births[next_births, :])[1]
        # if multiple birth ids per frame, stop
        if len(birth_ids) > len(next_births):
            break
        print('Birth IDs:', birth_ids)
        distances = []
        for i, birth_id in enumerate(birth_ids):
            x = Xs[next_births[i], birth_id]
            y = Ys[next_births[i], birth_id]
            distance = euclidean_distance(x, y, track['x'].values[-1], track['y'].values[-1])
            distances.append(distance)
        # filter out the ones that are too far
        distances = np.array(distances)
        birth_ids = birth_ids[distances < max_dist]
        # interpolate  with the first tracklet
        X = 0
        success = False
        for X in range(len(birth_ids)):
            track2 = tracklets[birth_ids[X]]
            if len(track2) > 0:
                success = True
                break
        if not success:
            break
        n_frames = next_births[0] - last_frame
        if n_frames != 0:
            print('Interpolating', n_frames, 'frames')
            x = np.linspace(track['x'].values[-1], track2['x'].values[0], n_frames)
            y = np.linspace(track['y'].values[-1], track2['y'].values[0], n_frames)
            body_length = np.linspace(track['body_length'].values[-1], track2['body_length'].values[0], n_frames)
            body_width = np.linspace(track['body_width'].values[-1], track2['body_width'].values[0], n_frames)
            heading = np.linspace(track['heading'].values[-1], track2['heading'].values[0], n_frames)
            frames = np.arange(last_frame+1, next_births[0]+1)
            temp = pd.DataFrame({'Frame':frames, 'x':x, 'y':y, 'body_length':body_length, 'body_width':body_width, 'heading':heading})
            # merge the two tracklets along with the interpolated frames
            track = pd.concat([track, temp, track2], ignore_index=True)
        else:
            track = pd.concat([track, track2], ignore_index=True)
        # remove the tracklet that was merged
        tracklets.pop(birth_ids[0])
        # update births and deaths
        births[next_births[0], birth_ids[0]] = False
        deaths[next_births[0], birth_ids[0]] = False
        # update IDs
        IDs[next_births[0], birth_ids[0]] = 0
        # update Xs and Ys
        Xs[next_births[0], birth_ids[0]] = np.nan
        Ys[next_births[0], birth_ids[0]] = np.nan
    clean_tracklets.append(track)
    # plot the merged tracklet
    plt.plot(track['x'], track['y'], '-',alpha=0.5,linewidth=0.5)
    

Next births: [5092 6082 7236 9683 9926]
Birth IDs: [ 7  8  9 10 11]
Interpolating 3 frames
Next births: [10040 11469 12995 13064 13446]
Birth IDs: [12 13 14 15 16]
Interpolating 3 frames
Next births: [15515 16367 16528 18301 19896]
Birth IDs: [17 18 19 20 21]
Interpolating 3 frames
Next births: [19896 20587 20774 21397 22282]
Birth IDs: [21 22 23 24 25]
Interpolating 2 frames
Next births: [22282 24179 25045 26008 27106]
Birth IDs: [25 26 27 28 29]
Interpolating 5 frames
Next births: [30724 31721 32102 32923 33415]
Birth IDs: [32 33 34 35 36]
Interpolating 2 frames
Next births: [34843 35002 35143 35958 35960]
Birth IDs: [38 39 40 58 66]
Interpolating 3 frames
Next births: []
Next births: [2478 6082 7236]
Birth IDs: [6 8 9]
Interpolating 3 frames
Next births: [ 9926 11469 12995 13064 13446]
Birth IDs: [11 13 14 15 16]
Interpolating 4 frames
Next births: [16367 16528 18301 20587 20774]
Birth IDs: [18 19 20 22 23]
Interpolating 2 frames
Next births: []
Next births: [6082 7236]
Birth IDs: [

IndexError: list index out of range

In [51]:
len(clean_tracklets)

82

In [40]:
plt.plot(np.sum(IDs>0,axis=1))

[<matplotlib.lines.Line2D at 0x31271ca90>]