In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import cv2
from skg import nsphere_fit
import matplotlib.collections as mcoll
import matplotlib.path as mpath
import scipy.io as sio
import json
from tqdm import tqdm


def colorline(
    x, y, z=None, cmap=plt.get_cmap('copper'), norm=plt.Normalize(0.0, 1.0),
        linewidth=3, alpha=1.0):
    """
    http://nbviewer.ipython.org/github/dpsanders/matplotlib-examples/blob/master/colorline.ipynb
    http://matplotlib.org/examples/pylab_examples/multicolored_line.html
    Plot a colored line with coordinates x and y
    Optionally specify colors in the array z
    Optionally specify a colormap, a norm function and a line width
    """

    # Default colors equally spaced on [0,1]:
    if z is None:
        z = np.linspace(0.0, 1.0, len(x))

    # Special case if a single number:
    if not hasattr(z, "__iter__"):  # to check for numerical input -- this is a hack
        z = np.array([z])

    z = np.asarray(z)

    segments = make_segments(x, y)
    lc = mcoll.LineCollection(segments, array=z, cmap=cmap, norm=norm,
                              linewidth=linewidth, alpha=alpha)

    ax = plt.gca()
    ax.add_collection(lc)

    return lc


def make_segments(x, y):
    """
    Create list of line segments from x and y coordinates, in the correct format
    for LineCollection: an array of the form numlines x (points per line) x 2 (x
    and y) array
    """

    points = np.array([x, y]).T.reshape(-1, 1, 2)
    segments = np.concatenate([points[:-1], points[1:]], axis=1)
    return segments

def generate_arc(center, radius, start_angle, end_angle, num_points):
    angles = np.linspace(start_angle, end_angle, num_points)
    x = center[0] + radius * np.cos(angles)
    y = center[1] + radius * np.sin(angles)
    return x, y

%matplotlib qt

In [2]:
# PHASE 1
# load background image from mat file
p1_bg = sio.loadmat('20hr-wingless-orcoctrl-tt_phase_1/20hr-wingless-orcoctrl-tt_phase_1-bg.mat')['bg'][0][0].item()[0]
# PHASE 2
# load background image from mat file
p2_bg = sio.loadmat('20hr-wingless-orcoctrl-tt_phase_2/20hr-wingless-orcoctrl-tt_phase_2-bg.mat')['bg'][0][0].item()[0]
# PHASE 3
# load background image from mat file
p3_bg = sio.loadmat('20hr-wingless-orcoctrl-tt_phase_3/20hr-wingless-orcoctrl-tt_phase_3-bg.mat')['bg'][0][0].item()[0]

In [3]:
FPS = 32

In [4]:
# get arena margins
plt.figure()
plt.imshow(p1_bg, cmap='gray')
plt.title('Select 5 points on the margin of the arena')
pts = plt.ginput(5)
pts = np.array(pts)
plt.close()

# TRUE RADIUS
true_radius = 75 # mm

radius, center = nsphere_fit(pts)
print('Center:', center)
print('Radius:', radius)
sf = radius / true_radius

Center: [701.88646304 694.64395999]
Radius: 700.6732079836888


In [5]:
def get_ring(ring_pts):
    outer_ring = ring_pts[:5]
    inner_ring = ring_pts[5:]

    # fit a circle to the points
    outer_radius, outer_center = nsphere_fit(outer_ring)
    inner_radius, inner_center = nsphere_fit(inner_ring)

    # get the average center
    center = (outer_center + inner_center) / 2

    # get the average radius
    radius = (outer_radius + inner_radius) / 2

    # get the average trail width
    trail_width = outer_radius - inner_radius

    # ring properties
    ring_props = {
        'center': center,
        'radius': radius,
        'trail_width': trail_width,
    }
    return ring_props

def draw_ring(ring_props, bg=p2_bg, axis=None, color='r'):
    # draw the estimated circles
    if axis is None:
        fig, ax = plt.subplots()
    else:
        ax = axis

    ax.imshow(bg, cmap='gray')

    circ_inner = plt.Circle(ring_props['center'], ring_props['radius']-ring_props['trail_width']/2, color=color, fill=False)
    circ_outer = plt.Circle(ring_props['center'], ring_props['radius']+ring_props['trail_width']/2, color=color, fill=False)
    ax.add_artist(circ_inner)
    ax.add_artist(circ_outer)

    if axis is None:
        ax.set_aspect('equal')
        plt.title('Estimated ring')
        plt.show()

def ring_coordinates(pos,y_props):
    # get the angle of the point
    position = np.arctan2(pos[1]-y_props['center'][1], pos[0]-y_props['center'][0])
    position = position % (2*np.pi)
    # get the distance from the center
    dist = np.sqrt((pos[0]-y_props['center'][0])**2 + (pos[1]-y_props['center'][1])**2) - y_props['radius']
    return position, dist


In [6]:
recalculate = False

# check if previous ring props are available
try:
    if recalculate:
        raise FileNotFoundError
    
    with open('big_ring_props.json', 'r') as f:
        big_ring_props = json.load(f)
        big_ring_props = {k: np.array(v) if isinstance(v, list) else v for k, v in big_ring_props.items()}

except FileNotFoundError:
    # get ring points
    plt.figure()
    plt.imshow(p1_bg, cmap='gray')
    plt.title('Pick 5 points on the outer edge of the big circle and 5 points on the inner edge of the big circle')
    big_ring_pts = plt.ginput(10)
    plt.close()

    big_ring_props = get_ring(big_ring_pts)

    # save ring props as json
    with open('big_ring_props.json', 'w') as f:
        # convert numpy arrays to lists
        big_ring_props = {k: v.tolist() if isinstance(v, np.ndarray) else v for k, v in big_ring_props.items()}
        json.dump(big_ring_props, f)

    draw_ring(big_ring_props)

# check if previous ring props are available
try:
    if recalculate:
        raise FileNotFoundError
    
    with open('small_ring_props.json', 'r') as f:
        small_ring_props = json.load(f)
        small_ring_props = {k: np.array(v) if isinstance(v, list) else v for k, v in small_ring_props.items()}
except FileNotFoundError:
    # get ring points
    plt.figure()
    plt.imshow(p3_bg, cmap='gray')
    plt.title('Pick 5 points on the outer edge of the small circle and 5 points on the inner edge of the small circle')
    small_ring_pts = plt.ginput(10)
    plt.close()

    small_ring_props = get_ring(small_ring_pts)

    # save ring props as json
    with open('small_ring_props.json', 'w') as f:
        # convert numpy arrays to lists
        small_ring_props = {k: v.tolist() if isinstance(v, np.ndarray) else v for k, v in small_ring_props.items()}
        json.dump(small_ring_props, f)

    draw_ring(small_ring_props)

In [7]:
# plot every point within 20 pixels of the rings
fig, ax = plt.subplots()
plt.imshow(p2_bg, cmap='gray')

# create a meshgrid
res = 50
x = np.linspace(0, p2_bg.shape[1], res)
y = np.linspace(0, p2_bg.shape[0], res)
X, Y = np.meshgrid(x, y)
X = X.flatten()
Y = Y.flatten()
points = np.vstack([X, Y]).T

# check if point is within 5 pixels of the big circle
for point in tqdm(points):
    pos, dist = ring_coordinates(point, big_ring_props)
    if abs(dist) < 20:
        plt.scatter(point[0], point[1], color=plt.cm.cool(pos/(2*np.pi)))


# check if point is within 5 pixels of the small circle
for point in tqdm(points):
    pos, dist = ring_coordinates(point, small_ring_props)
    if abs(dist) < 20:
        plt.scatter(point[0], point[1], color=plt.cm.cool(pos/(2*np.pi)))

plt.title('Points within 20 pixels of the rings')

draw_ring(big_ring_props, axis=ax, color='k')
draw_ring(small_ring_props, axis=ax, color='k')

# add colorbar
sm = plt.cm.ScalarMappable(cmap=plt.cm.cool, norm=plt.Normalize(0, 2*np.pi))
sm.set_array([])
plt.colorbar(sm, ticks=[0, np.pi, 2*np.pi], label='Angle', orientation='horizontal', ax=ax)
plt.show()



100%|██████████| 2500/2500 [00:00<00:00, 12580.95it/s]
100%|██████████| 2500/2500 [00:00<00:00, 22615.29it/s]


In [8]:
max_displacement = 50 # mm/s 
max_displacement = max_displacement * sf / FPS

#### JUMP TEST CODE

In [9]:
# get x and y 
for i in range(7):
    df = pd.read_csv(f'20hr-wingless-orcoctrl-tt_phase_1/20hr-wingless-orcoctrl-tt_phase_1-trackfeat.csv/fly{i+1}.csv')
    
    x = df['pos x'].values
    y = df['pos y'].values
    # interpolate missing values
    x = np.interp(np.arange(len(x)), np.arange(len(x))[~np.isnan(x)], x[~np.isnan(x)])
    y = np.interp(np.arange(len(y)), np.arange(len(y))[~np.isnan(y)], y[~np.isnan(y)])

    # get difference between consecutive frames
    dx = np.diff(x)
    dy = np.diff(y)
    d = np.sqrt(dx**2 + dy**2)

    plt.plot(x, y, 'k-', lw=0.2)
    # mark every point where d > max_displacement
    points = np.where(d > max_displacement)[0]
    # draw a line between consecutive points
    for i in points:
        plt.plot(x[i:i+2], y[i:i+2], 'r-', lw=1)
plt.gca().set_aspect('equal')
plt.show()

### TRACKLET CONVERSION

In [10]:
# convert to tracklets
for phase in [1,2,3]:
    tracklets = []
    for i in range(7):
        df = pd.read_csv(f'20hr-wingless-orcoctrl-tt_phase_{phase}/20hr-wingless-orcoctrl-tt_phase_{phase}-trackfeat.csv/fly{i+1}.csv')
        x = df['pos x'].values
        y = df['pos y'].values
        ori = df['ori'].values

        # interpolate missing values
        x = np.interp(np.arange(len(x)), np.arange(len(x))[~np.isnan(x)], x[~np.isnan(x)])
        y = np.interp(np.arange(len(y)), np.arange(len(y))[~np.isnan(y)], y[~np.isnan(y)])
        ori = np.interp(np.arange(len(ori)), np.arange(len(ori))[~np.isnan(ori)], ori[~np.isnan(ori)])
        
        # get difference between consecutive frames
        d = np.sqrt(np.diff(x)**2 + np.diff(y)**2)
        # get points where d > max_displacement
        points = np.where(d > max_displacement)[0]
        start = 0
        for i in points:
            t = pd.DataFrame({'pos x': x[start:i+1], 'pos y': y[start:i+1], 'ori': ori[start:i+1]})
            tracklets.append(t)
            start = i+1
        t = pd.DataFrame({'pos x': x[start:], 'pos y': y[start:], 'ori': ori[start:]})
        tracklets.append(t)

    # filter out tracklets with less than 5*FPS frames
    tracklets = [t for t in tracklets if len(t) >= 5*FPS]

    # add ring and yang coordinates to each tracklet
    i = 0
    for t in tqdm(tracklets):
        t['big ring pos'], t['big ring dist'] = zip(*[ring_coordinates([x, y], big_ring_props) for x, y in zip(t['pos x'], t['pos y'])])
        t['small ring pos'], t['small ring dist'] = zip(*[ring_coordinates([x, y], small_ring_props) for x, y in zip(t['pos x'], t['pos y'])])
        t['track id'] = i
        i += 1

    # combine and save tracklets
    tracklets = pd.concat(tracklets)
    tracklets.to_csv(f'phase_{phase}_tracklets.csv', index=False)



100%|██████████| 328/328 [00:06<00:00, 54.47it/s]
100%|██████████| 76/76 [00:03<00:00, 24.75it/s]
100%|██████████| 530/530 [00:07<00:00, 75.54it/s] 


In [11]:
# plot all tracklets in phases
for phase in [1,2,3]:
    tracklets = pd.read_csv(f'phase_{phase}_tracklets.csv')
    fig, ax = plt.subplots()
    for i in range(tracklets['track id'].max()):
        t = tracklets[tracklets['track id'] == i]
        plt.plot(t['pos x'], t['pos y'], 'k-', lw=0.2)
    if phase == 1:
        draw_ring(big_ring_props, axis=ax, color='r', bg=p1_bg)
    elif phase == 2:
        draw_ring(big_ring_props, axis=ax, color='k', bg=p2_bg)
    else:
        draw_ring(small_ring_props, axis=ax, color='r', bg=p3_bg)
    plt.gca().set_aspect('equal')
    plt.title(f'Phase {phase}')
    plt.show()

In [12]:
encounter_distance = 2*2 # mm body width
encounter_distance = encounter_distance * sf # scale to pixels
encounter_distance = encounter_distance / 2 # convert distance
print(f'Encounter distance: {encounter_distance} pixels')

trail_width = big_ring_props['trail_width']/2

tracklets = pd.read_csv('phase_2_tracklets.csv')
tracklets = [group for _, group in tracklets.groupby('track id')]


Encounter distance: 18.684618879565036 pixels


In [13]:
tracklet = tracklets[0]
ring = 'big'
plt.figure()
plt.plot(tracklet['big ring dist'].values, tracklet['big ring pos'].values, 'k-', lw=0.2)
plt.show()

val = np.abs(np.concatenate([[np.inf], tracklet[ring+' ring dist'].values])) < encounter_distance + trail_width
count = np.cumsum(np.concatenate([np.array([np.nan]),np.diff(val.astype(int))])>0)
mes = val*count
plt.figure()
plt.plot(tracklet['pos x'], tracklet['pos y'], 'k-', lw=0.2)
draw_ring(big_ring_props, color='k', axis=plt.gca())
# draw encounters with scatter
plt.scatter(tracklet['pos x'][mes[1:]>0], tracklet['pos y'][mes[1:]>0], c=mes[mes>0], cmap='viridis')
plt.gca().set_aspect('equal')
plt.show()

In [98]:
big_ring_props['trail_width']

15.972276246425736

In [107]:
encounter_distance = 2*2 # mm body width
encounter_distance = encounter_distance * sf # scale to pixels
encounter_distance = encounter_distance / 2 # convert distance
print(f'Encounter distance: {encounter_distance} pixels')

Encounter distance: 18.651685202142755 pixels


In [125]:
# get encounters in each phase
for phase in [1,2,3]:
    tracklets = pd.read_csv(f'phase_{phase}_tracklets.csv')
    tracklets = [group for _, group in tracklets.groupby('track id')]

    if phase == 1:
        props = big_ring_props
        ring = 'big'
        trail_width = big_ring_props['trail_width']/2
    elif phase == 2:
        props = big_ring_props
        ring = 'big'
        trail_width = big_ring_props['trail_width']/2
    else:
        props = small_ring_props
        ring = 'small'
        trail_width = small_ring_props['trail_width']/2
        
    # get encounters
    encounters = []
    for tracklet in tracklets:
        val = np.abs(np.concatenate([[np.inf], tracklet[ring+' ring dist'].values])) < encounter_distance + trail_width
        count = np.cumsum(np.concatenate([np.array([np.nan]),np.diff(val.astype(int))])>0)
        mes = val*count
        
        # get every continuous segment where the fly is within encounter_distance of the ring circle
        encounters += [group for m, group in tracklet.groupby(mes[1:]) if m != 0]
    # filter for encounters with atleast 1 sign change in the ring position
    encounters = [encounter for encounter in encounters if np.abs(np.diff(np.sign(encounter[ring+' ring dist']))).sum() > 3]
    
    print(f'Phase {phase}: {len(encounters)} encounters')
    
    # plot all encounters
    fig, ax = plt.subplots()
    for encounter in encounters:
        plt.plot(encounter['pos x'], encounter['pos y'], '-', lw=0.5, color=plt.cm.cool(np.random.rand()))
    if phase == 1:
        draw_ring(big_ring_props, axis=ax, color='r', bg=p1_bg)
    elif phase == 2:
        draw_ring(big_ring_props, axis=ax, color='k', bg=p2_bg)
    else:
        draw_ring(small_ring_props, axis=ax, color='r', bg=p3_bg)
    plt.gca().set_aspect('equal')
    plt.title(f'Phase {phase}')
    plt.show()

    # plot in ring space
    fig, ax = plt.subplots()

    max_distances = []
    for encounter in encounters:
        position = (np.unwrap(encounter[ring+' ring pos'])*props['radius']/sf)
        # find point of first crossing
        crossing = np.argmax(np.abs(np.diff(np.sign(encounter[ring+' ring dist'].values))))
        position = position - position[crossing]
        dists = [np.abs(np.max(position[crossing:])),np.abs(np.min(position[crossing:]))]
        if np.argmax(dists) == 1:
            position = -position
        distance = encounter[ring+' ring dist']/sf
        max_distances.append(np.max(dists))
        ax.plot(distance, position, 'r-', lw=1, alpha=0.5)

    ax.set_xlabel('Distance from ring center (mm)')
    ax.set_ylabel('Position on ring (mm)')
    ax.set_title(f'Phase {phase}')
    ax.set_aspect('equal')
    plt.show()

    # sort by max distance
    encounters = [encounters[i].reset_index(drop=True) for i in np.argsort(max_distances)]
    max_distances = np.sort(max_distances)

    # save ring encounters after adding an id
    for i, encounter in enumerate(encounters):
        encounter['id'] = i
        encounter['max_distance'] = max_distances[i]
    ring_encounters_CSV = pd.concat(encounters)
    ring_encounters_CSV.to_csv(f'encounters_phase_{phase}.csv', index=False)
        

Phase 1: 22 encounters
Phase 2: 6 encounters
Phase 3: 9 encounters


In [480]:
# create a video of top encounters
import skvideo.io
import os
from tqdm import tqdm

phase = 2
N = 15
# load background image in cv2
bg = cv2.cvtColor((eval(f"p{phase}_bg")*255).astype(np.uint8), cv2.COLOR_GRAY2RGB)

# create a video for each ring encounter
for i in range(N):
    # get x and y
    x = eval(f'ring_encounters_phase_{phase}[len(ring_encounters_phase_{phase})-{i}-1]["pos x"].values')
    y = eval(f'ring_encounters_phase_{phase}[len(ring_encounters_phase_{phase})-{i}-1]["pos y"].values')
    ori = eval(f'ring_encounters_phase_{phase}[len(ring_encounters_phase_{phase})-{i}-1]["ori"].values')
    writer = skvideo.io.FFmpegWriter('videos/ring_phase_{}_encounter_{}.mp4'.format(phase, i))
    for j in tqdm(range(len(x))):

        frame = bg.copy()
        cv2.circle(frame, (int(x[j]), int(y[j])), 5, (255,0,0), -1)
        # draw heading
        x2 = x[j] + 30*np.cos(ori[j])
        y2 = y[j] + 30*np.sin(ori[j])
        cv2.line(frame, (int(x[j]), int(y[j])), (int(x2), int(y2)), (0,255,0), 1)
        for k in range(j):
            cv2.circle(frame, (int(x[k]), int(y[k])), 1, (0,0,255), -1)

        writer.writeFrame(frame)
    writer.close()

for i in range(N):
    # get x and y
    x = eval(f'yang_encounters_phase_{phase}[len(yang_encounters_phase_{phase})-{i}-1]["pos x"].values')
    y = eval(f'yang_encounters_phase_{phase}[len(yang_encounters_phase_{phase})-{i}-1]["pos y"].values')
    ori = eval(f'yang_encounters_phase_{phase}[len(yang_encounters_phase_{phase})-{i}-1]["ori"].values')
    writer = skvideo.io.FFmpegWriter('videos/yang_phase_{}_encounter_{}.mp4'.format(phase, i))
    for j in tqdm(range(len(x))):

        frame = bg.copy()
        cv2.circle(frame, (int(x[j]), int(y[j])), 5, (255,0,0), -1)
        # draw heading
        x2 = x[j] + 30*np.cos(ori[j])
        y2 = y[j] + 30*np.sin(ori[j])
        cv2.line(frame, (int(x[j]), int(y[j])), (int(x2), int(y2)), (0,255,0), 1)
        for k in range(j):
            cv2.circle(frame, (int(x[k]), int(y[k])), 1, (0,0,255), -1)

        writer.writeFrame(frame)
    writer.close()

100%|██████████| 52/52 [00:00<00:00, 115.62it/s]
100%|██████████| 50/50 [00:00<00:00, 173.49it/s]
100%|██████████| 71/71 [00:00<00:00, 73.44it/s] 
100%|██████████| 70/70 [00:00<00:00, 111.55it/s]
100%|██████████| 35/35 [00:00<00:00, 147.66it/s]
100%|██████████| 37/37 [00:00<00:00, 150.03it/s]
100%|██████████| 95/95 [00:00<00:00, 112.46it/s]
100%|██████████| 33/33 [00:00<00:00, 150.31it/s]
100%|██████████| 34/34 [00:00<00:00, 156.49it/s]
100%|██████████| 47/47 [00:00<00:00, 156.10it/s]
100%|██████████| 31/31 [00:00<00:00, 155.15it/s]
100%|██████████| 46/46 [00:00<00:00, 150.49it/s]
100%|██████████| 44/44 [00:00<00:00, 129.10it/s]
100%|██████████| 45/45 [00:00<00:00, 169.14it/s]
100%|██████████| 60/60 [00:00<00:00, 105.68it/s]
100%|██████████| 102/102 [00:00<00:00, 120.27it/s]
100%|██████████| 74/74 [00:00<00:00, 119.51it/s]
100%|██████████| 284/284 [00:02<00:00, 102.78it/s]
100%|██████████| 99/99 [00:00<00:00, 117.67it/s]
100%|██████████| 69/69 [00:00<00:00, 120.83it/s]
100%|██████████|

In [468]:
plt.imshow(bg)

<matplotlib.image.AxesImage at 0x32f05e0e0>

In [212]:
# # draw the tracklet in original space
# plt.figure()
# plt.imshow(p2_bg, cmap='gray')

# for tracklet in tracklets[:100]:
#     # project the tracklet on the ring circle
#     ring_pos = []
#     ring_dist = []
#     for i in range(len(tracklet)):
#         pos = np.array([tracklet['pos x'].iloc[i], tracklet['pos y'].iloc[i]])
#         pos, dist = y_coordinates(pos, ring_props)
#         ring_pos.append(pos)
#         ring_dist.append(abs(dist))
#     # color the tracklet based on the position on the ring circle
#     colorline(tracklet['pos x'], tracklet['pos y'], ring_dist, cmap=plt.cm.cool, norm=plt.Normalize(0, 350), linewidth=0.5)

# plt.gca().set_aspect('equal')
# plt.title('Original tracklet')
# plt.show()


In [None]:
encounter_distance = (2 * 3 + 3/2) * sf # 2 * 3mm body widths + 3mm trail converted to pixels

# get all ring encounters
ring_encounters = []
for tracklet in tracklets:
    # project the tracklet on the ring circle using vectorized function
    ring_pos = []
    ring_dist = []
    for i in range(len(tracklet)):
        pos = np.array([tracklet['pos x'].iloc[i], tracklet['pos y'].iloc[i]])
        pos, dist = y_coordinates(pos, ring_props)
        ring_pos.append(pos)
        ring_dist.append(abs(dist))
    
    # check if the tracklet encounters the ring circle
    if np.any(np.array(ring_dist) < encounter_distance):
        ring_encounters.append(tracklet)


In [91]:
# plot all tracklets
plt.figure()
for t in tracklets:
    plt.plot(t[0], t[1], 'k-', lw=0.2)
plt.gca().set_aspect('equal')
plt.title('Tracklets')
plt.show()

In [93]:
plt.figure()
plt.plot(df['vel'])
x = df['pos x']
y = df['pos y']
x = np.interp(np.arange(len(x)), np.where(~np.isnan(x))[0], x[~np.isnan(x)])
y = np.interp(np.arange(len(y)), np.where(~np.isnan(y))[0], y[~np.isnan(y)])
v = np.sqrt(np.diff(x)**2 + np.diff(y)**2)
plt.plot(v)
plt.show()

In [136]:
# find all nans
nan_idx = np.where(np.isnan(df['pos x']))[0]
# find every start and end of a nan sequence
nan_idx = nan_idx[np.where(np.diff(nan_idx) != 0)[0]]
# plot everything outside the nans
plt.figure()
while len(nan_idx) > 0:
    plt.plot(df['pos x'].iloc[:nan_idx[0]], df['pos y'].iloc[:nan_idx[0]], '-', lw=0.2)
    df = df.iloc[nan_idx[0]+1:]
    nan_idx = nan_idx[1:]
plt.gca().set_aspect('equal')



In [144]:
x = np.array(df['pos x'])
y = np.array(df['pos y'])
nan_idx = np.where(np.isnan(x))[0]
# interpolate the nan values
x = np.interp(np.arange(len(x)), np.where(~np.isnan(x))[0], x[~np.isnan(x)])
y = np.interp(np.arange(len(y)), np.where(~np.isnan(y))[0], y[~np.isnan(y)])
# convert to mm
x = x/sf
y = y/sf
# get the velocity
vx = np.gradient(x, 1/FPS)
vy = np.gradient(y, 1/FPS)
v = np.sqrt(vx**2 + vy**2)
plt.figure()
plt.plot(x)
plt.plot(y)
plt.plot(v)
plt.show()

In [162]:
x = np.array(df['pos x'])
y = np.array(df['pos y'])
nan_idx = np.where(np.isnan(x))[0]
# interpolate the nan values
# x = np.interp(np.arange(len(x)), np.where(~np.isnan(x))[0], x[~np.isnan(x)])
# y = np.interp(np.arange(len(y)), np.where(~np.isnan(y))[0], y[~np.isnan(y)])
# convert to mm
x = x#/sf
y = y#/sf
# get the velocity
vx = np.gradient(x, 1/FPS)
vy = np.gradient(y, 1/FPS)
v = np.sqrt(vx**2 + vy**2)

mask = v < 0.5
plt.plot(x, y, 'r-', lw=1, alpha=0.5)
plt.scatter(df['pos x'], df['pos y'], c=df['vel'], cmap='viridis', s=1,zorder=10)
# plt.plot(x[mask], y[mask], '-', lw=0.2)
# plt.scatter(df['pos x'][mask], df['pos y'][mask], c=df['vel'][mask], cmap='viridis')
# plt.colorbar()
plt.gca().set_aspect('equal')


In [112]:
plt.plot(v)

[<matplotlib.lines.Line2D at 0x38b1220b0>]

In [81]:
plt.plot(df['vel'])
# df['vel'] = df['vel']/sf

[<matplotlib.lines.Line2D at 0x37cd2fbb0>]

In [82]:
mask = df['vel'] <10
plt.plot(df['pos x'][mask], df['pos y'][mask], 'r-', lw=0.2)
# plt.scatter(df['pos x'][mask], df['pos y'][mask], c=df['vel'][mask], cmap='viridis')
# plt.colorbar()

[<matplotlib.lines.Line2D at 0x37cdbaa40>]

In [7]:
# get tracklets
def get_tracklets(data, center, radius):
    n_flies = data.shape[1]//6
    tracklets = []
    for i in range(n_flies):
        # get track
        track = data.iloc[:, i*6:i*6+6].copy().reset_index()
        track.columns = ['Frame','ID','x','y','body_crossinglength','body_width','heading']

        # remove nan values
        track.replace(-1, np.nan, inplace=True)
        track.dropna(inplace=True)

        # skip if the track is too short (less than 10 frames)
        if track.shape[0] < 10:
            continue

        # Calculate other variables
        x_ = np.concatenate(([track['x'].iloc[0]], track['x'].values))
        y_ = np.concatenate(([track['y'].iloc[0]], track['y'].values))

        # kinematics
        track['velocity'] = np.sqrt(np.diff(x_)**2 + np.diff(y_)**2) * FPS
        track['motion_direction'] = np.arctan2(np.diff(y_), np.diff(x_))
        track['acceleration'] = np.diff(np.concatenate(([0], track['velocity'].values))) * FPS
        track['angular_velocity'] = np.diff(np.concatenate(([track['heading'].iloc[0]], track['heading'].values))) * FPS
        track['angular_acceleration'] = np.diff(np.concatenate(([0], track['angular_velocity'].values))) * FPS

        # odor trail related variables
        track['distance'] = np.sqrt((track['x'] - center[0])**2 + (track['y'] - center[1])**2) - radius
        track['angle'] = np.rad2deg(np.unwrap(np.arctan2(track['y'] - center[1], track['x'] - center[0])))
        track['angular_distance'] = np.deg2rad(track['angle']) * radius
        track['trail_heading'] = track['heading'] - np.arctan2(track['y'] - center[1], track['x'] - center[0])
        tracklets.append(track)
    return tracklets

def get_segments(tracklet):
    # segment the tracklet into trail and non-trail based on distance from the trail and atleast one crossing
    trail = (np.abs(tracklet['distance']) < sf*0.5).values # trail is within 0.5 inches of the trail
    # find every time the trail is crossed
    crossing = (tracklet['distance']<trail_width*sf).astype(int).values
    crossing = np.abs(np.diff(np.concatenate(([crossing[0]], crossing))))>0
    # find every continuous segment of trail where there is atleast one crossing
    segments = []
    start = 0
    end = 0
    for i in range(1, len(trail)):
        if trail[i] == trail[i-1]:
            continue
        if not trail[i-1]:
            start = i
        else:
            end = i
            if crossing[start:end].sum() > 0:
                segment = tracklet.iloc[start:end].copy().reset_index(drop=True)
                # find the position of the first crossing
                crossing_idx = np.argmax(crossing[start:end])
                # get the interpolated position of the crossing
                crossing_x = np.interp(0, [segment['distance'].iloc[crossing_idx-1], segment['distance'].iloc[crossing_idx]], [segment['x'].iloc[crossing_idx-1], segment['x'].iloc[crossing_idx]])
                crossing_y = np.interp(0, [segment['distance'].iloc[crossing_idx-1], segment['distance'].iloc[crossing_idx]], [segment['y'].iloc[crossing_idx-1], segment['y'].iloc[crossing_idx]])
                crossing_a = np.interp(0, [segment['distance'].iloc[crossing_idx-1], segment['distance'].iloc[crossing_idx]], [segment['angle'].iloc[crossing_idx-1], segment['angle'].iloc[crossing_idx]])
                crossing_ad = np.interp(0, [segment['distance'].iloc[crossing_idx-1], segment['distance'].iloc[crossing_idx]], [segment['angular_distance'].iloc[crossing_idx-1], segment['angular_distance'].iloc[crossing_idx]])
                segment['x_crossing'] = segment['x'] - crossing_x
                segment['y_crossing'] = segment['y'] - crossing_y
                segment['angle_crossing'] = segment['angle'] - crossing_a
                segment['angular_distance_crossing'] = segment['angular_distance'] - crossing_ad
                segment['time_from_crossing'] = (segment['Frame'] - segment['Frame'].iloc[crossing_idx])/FPS
                segment['number_of_crossings'] = crossing[start:end].sum()
                segment['encounter_distance'] = np.max(np.abs(segment['angular_distance_crossing'].values[segment['time_from_crossing'].values>0]))/sf
                segment['time'] = segment['Frame'].iloc[crossing_idx]/FPS
                segments.append(segment)
    return segments


In [5]:
csv_file = '20hr-wingless-orco/20hr-wingless-orco-phase_1.csv'
data = pd.read_csv(csv_file, header=None)
tracklets = get_tracklets(data, center, radius)

In [6]:
fig, ax = plt.subplots()
# plot the inverted image
ax.imshow(1-frame,cmap='gray_r')
# plot the tracklets
for track in tracklets:
    ax.plot(track['x'], track['y'], '-',alpha=0.5, color='black',linewidth=0.1)
# plot the odor trail
circle = plt.Circle(center, radius, color='r', fill=False, linewidth=2)
ax.add_artist(circle)
plt.show()
# calculate a 2d histogram of the x and y coordinates
n_bins = 100
# combine all x and y coordinates
x = np.concatenate([track['x'].values for track in tracklets])
y = np.concatenate([track['y'].values for track in tracklets])
H, xedges, yedges = np.histogram2d(x, y, bins=n_bins)
H = H.T
# log scale
H = np.log(H+1)
extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
fig, ax = plt.subplots()
im = ax.imshow(H, extent=extent, origin='lower', cmap='hot')
cbar = plt.colorbar(im, ax=ax)
cbar.set_ticks(np.log([1, 10, 100, 1000]))
cbar.set_ticklabels([1, 10, 100, 1000])
# plot the odor trail
circle = plt.Circle(center, radius, color='k', fill=False, linewidth=2,zorder=10)
ax.add_artist(circle)
# plot the tracklets
for track in tracklets:
    ax.plot(track['x'], track['y'], '-',alpha=0.5, color='black',linewidth=0.1)
plt.show()

In [7]:
# project the tracklets into the space of the odor trail
tracklet = tracklets[1]

# plot the tracklet (color by speed)
fig, ax = plt.subplots()
ax.scatter(tracklet['distance'], tracklet['angular_distance'], c=tracklet['distance'], cmap='RdBu', s=1)
ax.plot(tracklet['distance'], tracklet['angular_distance'], '-', color='black', alpha=1, linewidth=0.1)
# plot a line at 0
ax.axvline(0, color='r', linestyle='--')
ax.set_xlabel('Distance from the odor trail (pixels)')
ax.set_ylabel('Angle from the odor trail (degrees)')
# ax.set_aspect('equal')
plt.show()
# also plot in normal space
fig, ax = plt.subplots()
ax.scatter(tracklet['x'], tracklet['y'], c=tracklet['distance'], cmap='RdBu', s=1)
ax.plot(tracklet['x'], tracklet['y'], '-', color='black', alpha=1, linewidth=0.1)
circle = plt.Circle(center, radius, color='r', fill=False)
ax.add_artist(circle)
ax.set_xlabel('x (pixels)')
ax.set_ylabel('y (pixels)')
ax.set_aspect('equal')
plt.show()

In [11]:
# plot the segments
show_heading = False
filter_length = 2 # inches
fig, ax = plt.subplots()
all_segments = []
for tracklet in tracklets:
    segments = get_segments(tracklet)
    for segment in segments:

        if filter_length is not None and segment['encounter_distance'].iloc[0] < filter_length:
            continue
        mod = 1 if np.mean(segment['angular_distance_crossing'].values[segment['time_from_crossing'].values < 0]) < 0 else -1
        ax.scatter(segment['distance']/sf, mod*segment['angular_distance_crossing']/sf, c=segment['time_from_crossing'], cmap='viridis', s=5, zorder=1)
        ax.plot(segment['distance']/sf, mod*segment['angular_distance_crossing']/sf, '-', color='black', alpha=1, linewidth=1, zorder=0)
        # plot the trail heading
        for i in range(len(segment)):
            x = segment['distance'].iloc[i]/sf
            y = mod*segment['angular_distance_crossing'].iloc[i]/sf
            a = mod*segment['trail_heading'].iloc[i]
            if show_heading:
                ax.plot([x, x + 10*np.cos(a)/sf], [y, y + 10*np.sin(a)/sf], '-', color='k', linewidth=0.5)
    all_segments.extend(segments)
ax.axvline(0, color='r', linestyle='--')
ax.set_xlabel('Distance from the odor trail (inches)')
ax.set_ylabel('Arc Distance on the odor trail (inches)')
ax.set_aspect('equal')
plt.show()

In [12]:
def corr(x,y):
    r = np.corrcoef(x,y)[0,1]
    n = len(x)
    r_z = np.arctanh(r)
    se = 1/np.sqrt(n-3)
    z = 1.96
    lo_z, hi_z = r_z-z*se, r_z+z*se
    lo, hi = np.tanh((lo_z, hi_z))
    return r, lo, hi

In [13]:
# plot velocity during encounters
from scipy.stats import sem
from tqdm import tqdm
max_frames_before_encounter = np.max([np.sum(segment['time_from_crossing'].values < 0) for segment in all_segments])
max_frames_after_encounter = np.max([np.sum(segment['time_from_crossing'].values > 0) for segment in all_segments])
fig, ax = plt.subplots()
m,l,h = [],[],[]
for i in tqdm(range(-max_frames_before_encounter, max_frames_after_encounter)):
    vels = []
    # get the velocity of each fly at this time
    for segment in all_segments:
        vels.append(np.log(segment['velocity'].values[np.abs(segment['time_from_crossing'].values - i/FPS) < 1/FPS]/sf+1e-4))
    vels = np.concatenate(vels)
    # get velocity and 95% confidence interval
    m.append(np.mean(vels))
    l.append(np.mean(vels) - 1.96*sem(vels))
    h.append(np.mean(vels) + 1.96*sem(vels))
m,l,h = np.exp(m), np.exp(l), np.exp(h)
ax.plot(np.arange(-max_frames_before_encounter, max_frames_after_encounter)/FPS, m, color='k')
ax.fill_between(np.arange(-max_frames_before_encounter, max_frames_after_encounter)/FPS, l, h, color='gray', alpha=0.5)
ax.axvline(0, color='r', linestyle='--')
ax.set_xlabel('Time from encounter (s)')
ax.set_ylabel('Velocity (inches/s)')
plt.show()

  ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
  ret = ret.dtype.type(ret / rcount)
 25%|██▍       | 267/1075 [00:04<00:12, 65.31it/s]


KeyboardInterrupt: 

In [14]:
# encounter duration
x,y = [],[]
for segment in all_segments:
    encounter_distance = segment['encounter_distance'].iloc[0]
    color = 'lightgray' if encounter_distance < 2 else 'k'
    encounter_duration = np.log(segment['time_from_crossing'].max())
    prior_velocity = np.mean(segment['velocity'].values[segment['time_from_crossing'].values < 0])/sf
    x.append(prior_velocity)
    y.append(encounter_duration)
    plt.scatter(prior_velocity, encounter_duration, c=color, s=5)
# calculate the correlation along with the 95% confidence interval
r, lo, hi = corr(x,y)
plt.axhline(0, color='r', linestyle='--')
plt.title('r = {:.2f} ({:.2f}, {:.2f})'.format(r, lo, hi))
plt.xlabel('Prior velocity (inches/s)')
plt.ylabel('Log(Encounter duration)')
plt.show()
# encounter distance
x,y = [],[]
plt.figure()
for segment in all_segments:
    encounter_distance = segment['encounter_distance'].iloc[0]
    color = 'lightgray' if encounter_distance < 2 else 'k'
    prior_velocity = np.mean(segment['velocity'].values[segment['time_from_crossing'].values < 0])/sf
    x.append(prior_velocity)
    y.append(encounter_distance)
    plt.scatter(prior_velocity, encounter_distance, c=color, s=5)
# calculate the correlation along with the 95% confidence interval
r, lo, hi = corr(x,y)
plt.axhline(2, color='r', linestyle='--')
plt.title('r = {:.2f} ({:.2f}, {:.2f})'.format(r, lo, hi))
plt.xlabel('Prior velocity (inches/s)')
plt.ylabel('Encounter distance (inches)')
plt.show()

In [15]:
# plot a histogram of the encounter distance
plt.figure()
x = [segment['encounter_distance'].iloc[0] for segment in all_segments]
vals, bins = np.histogram(x, bins=50)
vals = vals/np.sum(vals)
plt.bar(bins[:-1], vals, width=np.diff(bins)[0])
plt.xlabel('Encounter distance (inches)')
plt.ylabel('Frequency')
# expected exponential distribution
from scipy.optimize import curve_fit
def exp(x, a, b):
    return a*np.exp(-b*x)
popt, pcov = curve_fit(exp, bins[:-1], vals)
plt.plot(bins, exp(bins, *popt), 'r--')
plt.show()

# plot encounter distance vs average velocity before crossing
plt.figure()
x,y = [],[]
for segment in all_segments:
    encounter_distance = segment['encounter_distance'].iloc[0]
    prior_velocity = np.mean(segment['velocity'].values[segment['time_from_crossing'].values < 0])/sf
    x.append(prior_velocity)
    y.append(encounter_distance)
plt.scatter(x, y, c='k', s=5)
# calculate the correlation along with the 95% confidence interval
r, lo, hi = corr(x,y)
plt.title('r = {:.2f} ({:.2f}, {:.2f})'.format(r, lo, hi))
plt.xlabel('Prior velocity (inches/s)')
plt.ylabel('Encounter distance (inches)')
plt.show()

# plot encounter distance vs average trail heading before crossing
def tortuosity(x, y):
    distance = np.sum(np.sqrt(np.diff(x)**2 + np.diff(y)**2))
    straight_distance = np.sqrt((x[0]-x[-1])**2 + (y[0]-y[-1])**2)
    return distance/straight_distance

plt.figure()
x,y,s = [],[],[]
for segment in all_segments:
    encounter_distance = segment['encounter_distance'].iloc[0]
    prior_heading = np.mean(segment['trail_heading'].values[segment['time_from_crossing'].values < 0])
    prior_heading = prior_heading % (2*np.pi)
    x.append(prior_heading)
    y.append(encounter_distance)
plt.scatter(x, y, c='k', cmap='viridis', s=1)
plt.xlim([0, 2*np.pi])
plt.xticks([0, np.pi/2, np.pi, 3*np.pi/2, 2*np.pi], ['0',r'$\frac{\pi}{2}$', r'$\pi$', r'$\frac{3\pi}{2}$', r'$2\pi$'])
plt.xlabel('Prior trail heading')
plt.ylabel('Encounter distance (inches)')
plt.show()


  plt.scatter(x, y, c='k', cmap='viridis', s=1)


In [16]:
# plot points in x,y where the encounter occurs
plt.figure()
plt.imshow(1-frame,cmap='gray_r')
for segment in all_segments:
    encounter_distance = np.max(np.abs(segment['angular_distance_crossing'].values[segment['time_from_crossing'].values>0]))/sf
    if encounter_distance < 2:
        color = 'gray'
    else:
        color = 'r'
    plt.scatter(segment['x'][segment['time_from_crossing'] == 0], segment['y'][segment['time_from_crossing'] == 0], c=color, s=5)
plt.show()

In [38]:
# plot points in x,y where the encounter occurs
plt.figure()
plt.imshow(1-frame,cmap='gray_r')
dists = []
xs = []
ys = []
for segment in all_segments:
    encounter_distance = np.max(np.abs(segment['angular_distance_crossing'].values[segment['time_from_crossing'].values>0]))/sf
    if encounter_distance < 2:
        continue
    xs.append(segment['x'][segment['time_from_crossing'] >= 0].values)
    ys.append(segment['y'][segment['time_from_crossing'] >= 0].values)
    dists.append(encounter_distance)
    plt.plot(segment['x'][segment['time_from_crossing'] >= 0], segment['y'][segment['time_from_crossing'] >= 0], '-', color=plt.cm.Reds(encounter_distance/8), linewidth=0.5)
    # mark the start of the encounter
    plt.scatter(segment['x'][segment['time_from_crossing'] == 0], segment['y'][segment['time_from_crossing'] == 0], c=plt.cm.Reds(encounter_distance/8), s=15)
plt.show()

  plt.scatter(segment['x'][segment['time_from_crossing'] == 0], segment['y'][segment['time_from_crossing'] == 0], c=plt.cm.Reds(encounter_distance/8), s=15)


In [40]:
# plot only top n encounters
N = 10
for i in np.argsort(dists)[-N:]:
    plt.figure()
    plt.imshow(1-frame,cmap='gray_r')
    x = xs[i]
    y = ys[i]
    plt.plot(x, y, '-', color='r', linewidth=0.5)
    # mark the start of the encounter
    plt.scatter(x[0], y[0], c='r', s=15)
    plt.title('D: {:.2f} T = {:d}:{:d}'.format(dists[i], int(all_segments[i]['time'].iloc[0]//60), int(all_segments[i]['time'].iloc[0]%60)))
    plt.show()

In [61]:
video_file = '20hr-wingless-orco/20hr-wingless-orco_phase_1.mp4'
csv_file = '20hr-wingless-orco/20hr-wingless-orco-phase_1.csv'

# load tracklets
data = pd.read_csv(csv_file, header=None)
tracklets = get_tracklets(data, center, radius)
all_segments = []
for tracklet in tracklets:
    segments = get_segments(tracklet)
    all_segments.extend(segments)
# sort the segments by encounter distance
all_segments = sorted(all_segments, key=lambda x: x['encounter_distance'].iloc[0], reverse=True)

# crossings
crossings = [segment['number_of_crossings'].iloc[0] for segment in all_segments]
plt.hist(crossings, bins=range(1,10))
plt.xlabel('Number of crossings')
plt.ylabel('Frequency')
plt.show()

# create video of encounters
import skvideo.io
import os
from tqdm import tqdm

N = 5
# get background image from the video
cap = cv2.VideoCapture(video_file)
for i in range(N):
    segment = all_segments[i]
    x = segment['x'].values
    y = segment['y'].values
    f = segment['Frame'].values
    # set start frame
    start_frame = f[0]
    cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
    writer = skvideo.io.FFmpegWriter('true_encounter_{:d}.mp4'.format(i))
    for j in tqdm(range(len(x))):
        ret, frame = cap.read()
        if not ret:
            break
        # draw on image using opencv
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        cv2.circle(frame, (int(x[j]), int(y[j])), 5, (255,0,0), -1)
        for k in range(j):
            cv2.circle(frame, (int(x[k]), int(y[k])), 1, (0,0,255), -1)
        # draw the trail
        cv2.circle(frame, (int(center[0]), int(center[1])), int(radius+0.05*sf), (0,0,0), 1)
        cv2.circle(frame, (int(center[0]), int(center[1])), int(radius-0.05*sf), (0,0,0), 1)

        writer.writeFrame(frame)
    writer.close()
cap.release()

100%|██████████| 113/113 [00:01<00:00, 75.88it/s]
100%|██████████| 99/99 [00:01<00:00, 67.91it/s]
100%|██████████| 181/181 [00:03<00:00, 60.17it/s]
100%|██████████| 70/70 [00:01<00:00, 66.71it/s]
100%|██████████| 212/212 [00:03<00:00, 64.44it/s]


In [3]:
video_file = '20hr-wingless-orco/20hr-wingless-orco_phase_2.mp4'
csv_file = '20hr-wingless-orco/20hr-wingless-orco-phase_2.csv'

# load tracklets
data = pd.read_csv(csv_file, header=None)
tracklets = get_tracklets(data, center, radius)
all_segments = []
for tracklet in tracklets:
    segments = get_segments(tracklet)
    all_segments.extend(segments)
# sort the segments by encounter distance
all_segments = sorted(all_segments, key=lambda x: x['encounter_distance'].iloc[0], reverse=True)

# crossings
crossings = [segment['number_of_crossings'].iloc[0] for segment in all_segments]
plt.hist(crossings, bins=range(1,10))
plt.xlabel('Number of crossings')
plt.ylabel('Frequency')
plt.show()

# create video of encounters
import skvideo.io
import os
from tqdm import tqdm

N = 5
# get background image from the video
cap = cv2.VideoCapture(video_file)
for i in range(N):
    segment = all_segments[i]
    x = segment['x'].values
    y = segment['y'].values
    f = segment['Frame'].values
    # set start frame
    start_frame = f[0]
    cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
    writer = skvideo.io.FFmpegWriter('false_encounter_{:d}.mp4'.format(i))
    for j in tqdm(range(len(x))):
        ret, frame = cap.read()
        if not ret:
            break
        # draw on image using opencv
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        cv2.circle(frame, (int(x[j]), int(y[j])), 5, (255,0,0), -1)
        for k in range(j):
            cv2.circle(frame, (int(x[k]), int(y[k])), 1, (0,0,255), -1)
        # draw the trail
        cv2.circle(frame, (int(center[0]), int(center[1])), int(radius+0.05*sf), (0,0,0), 1)
        cv2.circle(frame, (int(center[0]), int(center[1])), int(radius-0.05*sf), (0,0,0), 1)

        writer.writeFrame(frame)
    writer.close()
cap.release()

NameError: name 'get_tracklets' is not defined

In [None]:
times = []
distances = []
for segment in all_segments:
    times.append(segment['Frame'][segment['time_from_crossing'] == 0].values[0])
    distances.append(np.max(np.abs(segment['angular_distance_crossing']).values[segment['time_from_crossing'].values>0])/sf)
distances = np.array(distances)
times = np.array(times)
plt.figure()
plt.hist(times, bins=np.arange(0, np.max(times), 30*FPS), color='gray')
plt.hist(times[distances > 2], bins=np.arange(0, np.max(times), 30*FPS), color='red')
plt.xlabel('Frame number')
plt.ylabel('Frequency')
plt.show()

In [73]:
video_file = '20hr-wingless-orco/20hr-wingless-orco_phase_3.mp4'
csv_file = '20hr-wingless-orco/20hr-wingless-orco-phase_3.csv'

# load tracklets
data = pd.read_csv(csv_file, header=None)
tracklets = get_tracklets(data, center, radius)
all_segments = []
for tracklet in tracklets:
    segments = get_segments(tracklet)
    all_segments.extend(segments)
# sort the segments by encounter distance
all_segments = sorted(all_segments, key=lambda x: x['number_of_crossings'].iloc[0], reverse=True)

# create video of encounters
import skvideo.io
import os
from tqdm import tqdm

N = 5
# get background image from the video
cap = cv2.VideoCapture(video_file)
for i in range(N):
    segment = all_segments[i]
    x = segment['x'].values
    y = segment['y'].values
    f = segment['Frame'].values
    # set start frame
    start_frame = f[0]
    cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
    writer = skvideo.io.FFmpegWriter('close_encounter_{:d}.mp4'.format(i))
    for j in tqdm(range(len(x))):
        ret, frame = cap.read()
        if not ret:
            break
        # draw on image using opencv
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        cv2.circle(frame, (int(x[j]), int(y[j])), 5, (255,0,0), -1)
        for k in range(j):
            cv2.circle(frame, (int(x[k]), int(y[k])), 1, (0,0,255), -1)
        # draw the trail
        cv2.circle(frame, (int(center[0]), int(center[1])), int(radius+0.05*sf), (0,0,0), 1)
        cv2.circle(frame, (int(center[0]), int(center[1])), int(radius-0.05*sf), (0,0,0), 1)

        writer.writeFrame(frame)
    writer.close()
cap.release()

100%|██████████| 248/248 [00:03<00:00, 64.99it/s]
100%|██████████| 158/158 [00:02<00:00, 63.86it/s]
100%|██████████| 68/68 [00:00<00:00, 81.57it/s] 
100%|██████████| 215/215 [00:03<00:00, 69.26it/s]
100%|██████████| 171/171 [00:02<00:00, 61.90it/s]


In [None]:
video_file = '20hr-wingless-orco/20hr-wingless-orco_phase_1.mp4'
csv_file = '20hr-wingless-orco/20hr-wingless-orco-phase_1.csv'

# load tracklets
data = pd.read_csv(csv_file, header=None)
tracklets = get_tracklets(data, center, radius)
all_segments = []
for tracklet in tracklets:
    segments = get_segments(tracklet)
    all_segments.extend(segments)
# sort the segments by encounter distance
all_segments = sorted(all_segments, key=lambda x: x['encounter_distance'].iloc[0], reverse=True)

In [47]:
# get every 6th column of data
data = pd.read_csv(file, header=None)

n_flies = 4
n_window = 5 # events
max_dist = 50 # pixel
max_time = 10*60*10 # 10 minutes * 60 seconds * 10 frames/second

# get tracklets
ncols = data.shape[1]
assert ncols % 6 == 0, 'Error: ncols is not a multiple of 6, check Ctrax output'
ntracks =   ncols // 6
tracklets = []
for i in range(ntracks):
    track = data.iloc[:, i*6:i*6+6].copy().reset_index()
    track.columns = ['Frame','ID','x','y','body_length','body_width','heading']
    track.replace(-1, np.nan, inplace=True)
    track.dropna(inplace=True)
    tracklets.append(track)

IDs = data.iloc[:, 0::6].values + 1
IDs = np.vstack([np.zeros(IDs.shape[1]), IDs]) # add a row of zeros at the beginning
births = np.diff(IDs, axis=0) > 0
deaths = np.diff(IDs, axis=0) < 0
Xs = data.iloc[:, 1::6].values
Ys = data.iloc[:, 2::6].values


In [50]:
def euclidean_distance(x1, y1, x2, y2):
    return np.sqrt((x1 - x2)**2 + (y1 - y2)**2)

clean_tracklets = []

i = -1
current_size = len(tracklets)

while len(tracklets) > n_flies:
    i += 1 % len(tracklets)
    # if we have gone through all the tracklets, break
    if i == 0:
        if current_size == len(tracklets):
            break
        current_size = len(tracklets)
    track = tracklets[i]
    while True:
        last_frame = track['Frame'].values[-1]
        # get the next few births
        next_births = np.where(births.sum(axis=1) > 0)[0]
        next_births = next_births[np.logical_and(next_births > last_frame, next_births < last_frame + max_time)][:n_window]
        print('Next births:', next_births)
        # if there are no more births, break
        if len(next_births) == 0:
            break
        # get tracklets that are born in the next few frames
        birth_ids = np.where(births[next_births, :])[1]
        # if multiple birth ids per frame, stop
        if len(birth_ids) > len(next_births):
            break
        print('Birth IDs:', birth_ids)
        distances = []
        for i, birth_id in enumerate(birth_ids):
            x = Xs[next_births[i], birth_id]
            y = Ys[next_births[i], birth_id]
            distance = euclidean_distance(x, y, track['x'].values[-1], track['y'].values[-1])
            distances.append(distance)
        # filter out the ones that are too far
        distances = np.array(distances)
        birth_ids = birth_ids[distances < max_dist]
        # interpolate  with the first tracklet
        X = 0
        success = False
        for X in range(len(birth_ids)):
            track2 = tracklets[birth_ids[X]]
            if len(track2) > 0:
                success = True
                break
        if not success:
            break
        n_frames = next_births[0] - last_frame
        if n_frames != 0:
            print('Interpolating', n_frames, 'frames')
            x = np.linspace(track['x'].values[-1], track2['x'].values[0], n_frames)
            y = np.linspace(track['y'].values[-1], track2['y'].values[0], n_frames)
            body_length = np.linspace(track['body_length'].values[-1], track2['body_length'].values[0], n_frames)
            body_width = np.linspace(track['body_width'].values[-1], track2['body_width'].values[0], n_frames)
            heading = np.linspace(track['heading'].values[-1], track2['heading'].values[0], n_frames)
            frames = np.arange(last_frame+1, next_births[0]+1)
            temp = pd.DataFrame({'Frame':frames, 'x':x, 'y':y, 'body_length':body_length, 'body_width':body_width, 'heading':heading})
            # merge the two tracklets along with the interpolated frames
            track = pd.concat([track, temp, track2], ignore_index=True)
        else:
            track = pd.concat([track, track2], ignore_index=True)
        # remove the tracklet that was merged
        tracklets.pop(birth_ids[0])
        # update births and deaths
        births[next_births[0], birth_ids[0]] = False
        deaths[next_births[0], birth_ids[0]] = False
        # update IDs
        IDs[next_births[0], birth_ids[0]] = 0
        # update Xs and Ys
        Xs[next_births[0], birth_ids[0]] = np.nan
        Ys[next_births[0], birth_ids[0]] = np.nan
    clean_tracklets.append(track)
    # plot the merged tracklet
    plt.plot(track['x'], track['y'], '-',alpha=0.5,linewidth=0.5)
    

Next births: [5092 6082 7236 9683 9926]
Birth IDs: [ 7  8  9 10 11]
Interpolating 3 frames
Next births: [10040 11469 12995 13064 13446]
Birth IDs: [12 13 14 15 16]
Interpolating 3 frames
Next births: [15515 16367 16528 18301 19896]
Birth IDs: [17 18 19 20 21]
Interpolating 3 frames
Next births: [19896 20587 20774 21397 22282]
Birth IDs: [21 22 23 24 25]
Interpolating 2 frames
Next births: [22282 24179 25045 26008 27106]
Birth IDs: [25 26 27 28 29]
Interpolating 5 frames
Next births: [30724 31721 32102 32923 33415]
Birth IDs: [32 33 34 35 36]
Interpolating 2 frames
Next births: [34843 35002 35143 35958 35960]
Birth IDs: [38 39 40 58 66]
Interpolating 3 frames
Next births: []
Next births: [2478 6082 7236]
Birth IDs: [6 8 9]
Interpolating 3 frames
Next births: [ 9926 11469 12995 13064 13446]
Birth IDs: [11 13 14 15 16]
Interpolating 4 frames
Next births: [16367 16528 18301 20587 20774]
Birth IDs: [18 19 20 22 23]
Interpolating 2 frames
Next births: []
Next births: [6082 7236]
Birth IDs: [

IndexError: list index out of range

In [51]:
len(clean_tracklets)

82

In [40]:
plt.plot(np.sum(IDs>0,axis=1))

[<matplotlib.lines.Line2D at 0x31271ca90>]