In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import cv2
from skg import nsphere_fit
import matplotlib.collections as mcoll
import matplotlib.path as mpath
import scipy.io as sio
import json
from tqdm import tqdm
import os


def colorline(
    x, y, z=None, cmap=plt.get_cmap('copper'), norm=plt.Normalize(0.0, 1.0),
        linewidth=3, alpha=1.0):
    """
    http://nbviewer.ipython.org/github/dpsanders/matplotlib-examples/blob/master/colorline.ipynb
    http://matplotlib.org/examples/pylab_examples/multicolored_line.html
    Plot a colored line with coordinates x and y
    Optionally specify colors in the array z
    Optionally specify a colormap, a norm function and a line width
    """

    # Default colors equally spaced on [0,1]:
    if z is None:
        z = np.linspace(0.0, 1.0, len(x))

    # Special case if a single number:
    if not hasattr(z, "__iter__"):  # to check for numerical input -- this is a hack
        z = np.array([z])

    z = np.asarray(z)

    segments = make_segments(x, y)
    lc = mcoll.LineCollection(segments, array=z, cmap=cmap, norm=norm,
                              linewidth=linewidth, alpha=alpha)

    ax = plt.gca()
    ax.add_collection(lc)

    return lc


def make_segments(x, y):
    """
    Create list of line segments from x and y coordinates, in the correct format
    for LineCollection: an array of the form numlines x (points per line) x 2 (x
    and y) array
    """

    points = np.array([x, y]).T.reshape(-1, 1, 2)
    segments = np.concatenate([points[:-1], points[1:]], axis=1)
    return segments

def generate_arc(center, radius, start_angle, end_angle, num_points):
    angles = np.linspace(start_angle, end_angle, num_points)
    x = center[0] + radius * np.cos(angles)
    y = center[1] + radius * np.sin(angles)
    return x, y

# create figure and video directory
if not os.path.exists('figures'):
    os.makedirs('figures')
if not os.path.exists('videos'):
    os.makedirs('videos')

%matplotlib qt

In [2]:
prefix = '20hr-wingless-orcoctrl-tt'
FPS = 10
N_FLIES = 4

In [3]:
# PHASE 1
# load background image from mat file
p1_bg = sio.loadmat(f'{prefix}_phase_1/{prefix}_phase_1-bg.mat')['bg'][0][0].item()[0]
# PHASE 2
# load background image from mat file
p2_bg = sio.loadmat(f'{prefix}_phase_2/{prefix}_phase_2-bg.mat')['bg'][0][0].item()[0]
# PHASE 3
# load background image from mat file
p3_bg = sio.loadmat(f'{prefix}_phase_3/{prefix}_phase_3-bg.mat')['bg'][0][0].item()[0]

In [4]:
recalculate = False
if os.path.exists('arena.json') and not recalculate:
    
    with open('arena.json', 'r') as f:
        arena = json.load(f)

    center = np.array(arena['center'])
    radius = arena['radius']
    sf = arena['sf']

else:
    # get arena margins
    plt.figure()
    plt.imshow(p1_bg, cmap='gray')
    plt.title('Select 5 points on the margin of the arena')
    pts = plt.ginput(5)
    pts = np.array(pts)
    plt.close()

    # TRUE RADIUS
    true_radius = 75 # mm

    radius, center = nsphere_fit(pts)
    print('Center:', center)
    print('Radius:', radius)
    sf = radius / true_radius

    # save the properties of the arena
    arena = {
        'center': center.tolist(),
        'radius': radius,
        'sf': sf
    }

    with open('arena.json', 'w') as f:
        json.dump(arena, f)



In [5]:
def get_ring(ring_pts):
    outer_ring = ring_pts[:5]
    inner_ring = ring_pts[5:]

    # fit a circle to the points
    outer_radius, outer_center = nsphere_fit(outer_ring)
    inner_radius, inner_center = nsphere_fit(inner_ring)

    # get the average center
    center = (outer_center + inner_center) / 2

    # get the average radius
    radius = (outer_radius + inner_radius) / 2

    # get the average trail width
    trail_width = outer_radius - inner_radius

    # ring properties
    ring_props = {
        'center': center,
        'radius': radius,
        'trail_width': trail_width,
    }
    return ring_props

def draw_ring(ring_props, bg=p2_bg, axis=None, color='r'):
    # draw the estimated circles
    if axis is None:
        fig, ax = plt.subplots()
    else:
        ax = axis

    ax.imshow(bg, cmap='gray')

    circ_inner = plt.Circle(ring_props['center'], ring_props['radius']-ring_props['trail_width']/2, color=color, fill=False)
    circ_outer = plt.Circle(ring_props['center'], ring_props['radius']+ring_props['trail_width']/2, color=color, fill=False)
    ax.add_artist(circ_inner)
    ax.add_artist(circ_outer)

    if axis is None:
        ax.set_aspect('equal')
        plt.title('Estimated ring')
        plt.show()

def ring_coordinates(pos,y_props):
    # get the angle of the point
    position = np.arctan2(pos[1]-y_props['center'][1], pos[0]-y_props['center'][0])
    position = position % (2*np.pi)
    # get the distance from the center
    dist = np.sqrt((pos[0]-y_props['center'][0])**2 + (pos[1]-y_props['center'][1])**2) - y_props['radius']
    return position, dist


In [6]:
recalculate = False

# check if previous ring props are available
try:
    if recalculate:
        raise FileNotFoundError
    
    with open('big_ring_props.json', 'r') as f:
        big_ring_props = json.load(f)
        big_ring_props = {k: np.array(v) if isinstance(v, list) else v for k, v in big_ring_props.items()}

except FileNotFoundError:
    # get ring points
    plt.figure()
    plt.imshow(p1_bg, cmap='gray')
    plt.title('Pick 5 points on the outer edge of the big circle and 5 points on the inner edge of the big circle')
    big_ring_pts = plt.ginput(10)
    plt.close()

    big_ring_props = get_ring(big_ring_pts)

    # save ring props as json
    with open('big_ring_props.json', 'w') as f:
        # convert numpy arrays to lists
        big_ring_props = {k: v.tolist() if isinstance(v, np.ndarray) else v for k, v in big_ring_props.items()}
        json.dump(big_ring_props, f)

    draw_ring(big_ring_props)

# check if previous ring props are available
try:
    if recalculate:
        raise FileNotFoundError
    
    with open('small_ring_props.json', 'r') as f:
        small_ring_props = json.load(f)
        small_ring_props = {k: np.array(v) if isinstance(v, list) else v for k, v in small_ring_props.items()}
except FileNotFoundError:
    # get ring points
    plt.figure()
    plt.imshow(p3_bg, cmap='gray')
    plt.title('Pick 5 points on the outer edge of the small circle and 5 points on the inner edge of the small circle')
    small_ring_pts = plt.ginput(10)
    plt.close()

    small_ring_props = get_ring(small_ring_pts)

    # save ring props as json
    with open('small_ring_props.json', 'w') as f:
        # convert numpy arrays to lists
        small_ring_props = {k: v.tolist() if isinstance(v, np.ndarray) else v for k, v in small_ring_props.items()}
        json.dump(small_ring_props, f)

    draw_ring(small_ring_props)

In [7]:
max_displacement = 50 # mm/s 
max_displacement = max_displacement * sf / FPS

#### JUMP TEST CODE

In [8]:
# get x and y 
for i in range(N_FLIES):
    df = pd.read_csv(f'{prefix}_phase_1/{prefix}_phase_1-trackfeat.csv/fly{i+1}.csv')
    
    x = df['pos x'].values
    y = df['pos y'].values
    # interpolate missing values
    x = np.interp(np.arange(len(x)), np.arange(len(x))[~np.isnan(x)], x[~np.isnan(x)])
    y = np.interp(np.arange(len(y)), np.arange(len(y))[~np.isnan(y)], y[~np.isnan(y)])

    # get difference between consecutive frames
    dx = np.diff(x)
    dy = np.diff(y)
    d = np.sqrt(dx**2 + dy**2)

    plt.plot(x, y, 'k-', lw=0.2)
    # mark every point where d > max_displacement
    points = np.where(d > max_displacement)[0]
    # draw a line between consecutive points
    for i in points:
        plt.plot(x[i:i+2], y[i:i+2], 'r-', lw=1)
plt.gca().set_aspect('equal')
plt.show()

### TRACKLET CONVERSION

In [9]:
# convert to tracklets
for phase in [1,2,3]:
    tracklets = []
    for i in range(N_FLIES):
        df = pd.read_csv(f'{prefix}_phase_{phase}/{prefix}_phase_{phase}-trackfeat.csv/fly{i+1}.csv')
        x = df['pos x'].values
        y = df['pos y'].values
        ori = df['ori'].values
        frame = df.index.values

        # interpolate missing values
        x = np.interp(np.arange(len(x)), np.arange(len(x))[~np.isnan(x)], x[~np.isnan(x)])
        y = np.interp(np.arange(len(y)), np.arange(len(y))[~np.isnan(y)], y[~np.isnan(y)])
        ori = np.interp(np.arange(len(ori)), np.arange(len(ori))[~np.isnan(ori)], ori[~np.isnan(ori)])
        
        # get difference between consecutive frames
        d = np.sqrt(np.diff(x)**2 + np.diff(y)**2)
        # get points where d > max_displacement
        points = np.where(d > max_displacement)[0]
        start = 0
        for i in points:
            t = pd.DataFrame({'pos x': x[start:i+1], 'pos y': y[start:i+1], 'ori': ori[start:i+1], 'frame': frame[start:i+1]})
            tracklets.append(t)
            start = i+1
        t = pd.DataFrame({'pos x': x[start:], 'pos y': y[start:], 'ori': ori[start:], 'frame': frame[start:]})
        tracklets.append(t)

    # filter out tracklets with less than 5*FPS frames
    tracklets = [t for t in tracklets if len(t) >= 5*FPS]

    # add ring and yang coordinates to each tracklet
    i = 0
    for t in tqdm(tracklets):
        t['big ring pos'], t['big ring dist'] = zip(*[ring_coordinates([x, y], big_ring_props) for x, y in zip(t['pos x'], t['pos y'])])
        t['small ring pos'], t['small ring dist'] = zip(*[ring_coordinates([x, y], small_ring_props) for x, y in zip(t['pos x'], t['pos y'])])
        t['track id'] = i
        i += 1

    # combine and save tracklets
    tracklets = pd.concat(tracklets)
    tracklets.to_csv(f'phase_{phase}_tracklets.csv', index=False)



100%|██████████| 187/187 [00:03<00:00, 55.61it/s]
100%|██████████| 31/31 [00:01<00:00, 17.43it/s]
100%|██████████| 408/408 [00:03<00:00, 117.31it/s]


In [10]:
# plot all tracklets in phases
for phase in [1,2,3]:
    tracklets = pd.read_csv(f'phase_{phase}_tracklets.csv')
    fig, ax = plt.subplots()
    for i in range(tracklets['track id'].max()):
        t = tracklets[tracklets['track id'] == i]
        plt.plot(t['pos x'], t['pos y'], 'w-', lw=0.2)
    if phase == 1:
        draw_ring(big_ring_props, axis=ax, color='r', bg=p1_bg)
    elif phase == 2:
        draw_ring(big_ring_props, axis=ax, color='k', bg=p2_bg)
    else:
        draw_ring(small_ring_props, axis=ax, color='r', bg=p3_bg)
    plt.gca().set_aspect('equal')
    plt.title(f'Phase {phase}')
    # save figure
    plt.savefig(f'figures/phase_{phase}_tracklets.png')
    plt.show()

In [11]:
# plot all start and end points of tracklets
for phase in [1,2,3]:
    tracklets = pd.read_csv(f'phase_{phase}_tracklets.csv')
    fig, ax = plt.subplots(1, 2)
    # draw rings
    if phase == 1:
        draw_ring(big_ring_props, axis=ax[0], color='r', bg=p1_bg)
        draw_ring(big_ring_props, axis=ax[1], color='r', bg=p1_bg)
    elif phase == 2:
        draw_ring(big_ring_props, axis=ax[0], color='k', bg=p2_bg)
        draw_ring(big_ring_props, axis=ax[1], color='k', bg=p2_bg)
    else:
        draw_ring(small_ring_props, axis=ax[0], color='r', bg=p3_bg)
        draw_ring(small_ring_props, axis=ax[1], color='r', bg=p3_bg)
    for i in range(tracklets['track id'].max()):
        t = tracklets[tracklets['track id'] == i]
        ax[0].scatter(t.iloc[0]['pos x'], t.iloc[0]['pos y'], color=plt.cm.rainbow(i/tracklets['track id'].max()), s=1)
        ax[1].scatter(t.iloc[-1]['pos x'], t.iloc[-1]['pos y'], color=plt.cm.rainbow(i/tracklets['track id'].max()), s=1)
    ax[0].set_title('Start points')
    ax[1].set_title('End points')
    ax[0].set_aspect('equal')
    ax[1].set_aspect('equal')
    plt.suptitle(f'Phase {phase}')
    plt.tight_layout()
    # save figure
    plt.savefig(f'figures/phase_{phase}_start_end.png')
    plt.show()


In [12]:
encounter_distance = 5*2 # mm body width
encounter_distance = encounter_distance * sf # scale to pixels
encounter_distance = encounter_distance / 2 # convert distance
print(f'Encounter distance: {encounter_distance} pixels = {encounter_distance/sf:.2f} mm = {encounter_distance/sf*0.0393701:.2f} inches')

trail_width = big_ring_props['trail_width']/2
print(f'BIG Trail width: {trail_width} pixels = {trail_width/sf:.2f} mm = {trail_width/sf*0.0393701:.2f} inches')

trail_width = small_ring_props['trail_width']/2
print(f'SMALL Trail width: {trail_width} pixels = {trail_width/sf:.2f} mm = {trail_width/sf*0.0393701:.2f} inches')

Encounter distance: 46.950959124984905 pixels = 5.00 mm = 0.20 inches
BIG Trail width: 7.986138123212868 pixels = 0.85 mm = 0.03 inches
SMALL Trail width: 6.067263341391566 pixels = 0.65 mm = 0.03 inches


In [13]:
# plot every point within the target distance of the rings
fig, ax = plt.subplots()
plt.imshow(p2_bg, cmap='gray')

# create a meshgrid
res = 50
x = np.linspace(0, p2_bg.shape[1], res)
y = np.linspace(0, p2_bg.shape[0], res)
X, Y = np.meshgrid(x, y)
X = X.flatten()
Y = Y.flatten()
points = np.vstack([X, Y]).T

for point in tqdm(points):
    pos, dist = ring_coordinates(point, big_ring_props)
    if abs(dist) < encounter_distance + trail_width:
        plt.scatter(point[0], point[1], color=plt.cm.cool(pos/(2*np.pi)), s=1)


for point in tqdm(points):
    pos, dist = ring_coordinates(point, small_ring_props)
    if abs(dist) < 20:
        plt.scatter(point[0], point[1], color=plt.cm.cool(pos/(2*np.pi)), s=1)

plt.title('Points within the target distance of the rings')

draw_ring(big_ring_props, axis=ax, color='k')
draw_ring(small_ring_props, axis=ax, color='k')

# add colorbar
sm = plt.cm.ScalarMappable(cmap=plt.cm.cool, norm=plt.Normalize(0, 2*np.pi))
sm.set_array([])
plt.colorbar(sm, ticks=[0, np.pi, 2*np.pi], label='Angle', orientation='horizontal', ax=ax)
# save figure
plt.savefig('figures/targeted_points.png')
plt.show()



100%|██████████| 2500/2500 [00:00<00:00, 5134.83it/s]
100%|██████████| 2500/2500 [00:00<00:00, 18441.46it/s]


#### TEST TRACKLET

In [14]:
tracklets = pd.read_csv('phase_2_tracklets.csv')
tracklets = [group for _, group in tracklets.groupby('track id')]

tracklet = tracklets[0]
ring = 'big'
plt.figure()
plt.plot(tracklet['big ring dist'].values, np.unwrap(tracklet['big ring pos'].values), 'k-', lw=0.2)
plt.show()

val = np.abs(np.concatenate([[np.inf], tracklet[ring+' ring dist'].values])) < encounter_distance + trail_width
count = np.cumsum(np.concatenate([np.array([np.nan]),np.diff(val.astype(int))])>0)
mes = val*count
plt.figure()
plt.plot(tracklet['pos x'], tracklet['pos y'], 'k-', lw=0.2)
draw_ring(big_ring_props, color='k', axis=plt.gca())
# draw encounters with scatter
plt.scatter(tracklet['pos x'][mes[1:]>0], tracklet['pos y'][mes[1:]>0], c=mes[mes>0], cmap='viridis')
plt.gca().set_aspect('equal')
plt.show()

In [23]:
def get_crossings(encounter, ring, min_deviation):
    crossings = np.abs(np.diff(np.sign(encounter[ring+' ring dist'])))
    # distance traveled between crossings must be atleast the trail width
    encounters_between_crossings = []
    crossing_points = np.where(crossings > 0)[0]
    if len(crossing_points) == 0:
        return 0, []
    encounters_between_crossings.append(encounter[:crossing_points[0]])
    for i in range(len(crossing_points)-1):
        encounters_between_crossings.append(encounter[crossing_points[i]:crossing_points[i+1]])
    # get the maximum distance travelled away from the center
    furthest_distance = np.array([np.max(np.abs(encounter[ring+' ring dist'])) for encounter in encounters_between_crossings])
    # get the crossing points where the distance travelled is atleast the trail width
    crossing_points = np.array(crossing_points)[np.where(furthest_distance > min_deviation)[0]]
    n_crossings = len(crossing_points)
    return n_crossings, crossing_points


N_CROSSINGS_FOR_ENCOUNTER = 1
# create a subfolder for N_CROSSINGS_FOR_ENCOUNTER
if not os.path.exists(f'figures/Crossings_{N_CROSSINGS_FOR_ENCOUNTER}'):
    os.makedirs(f'figures/Crossings_{N_CROSSINGS_FOR_ENCOUNTER}')
if not os.path.exists(f'videos/Crossings_{N_CROSSINGS_FOR_ENCOUNTER}'):
    os.makedirs(f'videos/Crossings_{N_CROSSINGS_FOR_ENCOUNTER}')
if not os.path.exists(f'Crossings_{N_CROSSINGS_FOR_ENCOUNTER}'):
    os.makedirs(f'Crossings_{N_CROSSINGS_FOR_ENCOUNTER}')

In [24]:

# get encounters in each phase
for phase in [1,2,3]:
    tracklets = pd.read_csv(f'phase_{phase}_tracklets.csv')
    tracklets = [group for _, group in tracklets.groupby('track id')]

    if phase == 1:
        props = big_ring_props
        ring = 'big'
        trail_width = big_ring_props['trail_width']/2
    elif phase == 2:
        props = big_ring_props
        ring = 'big'
        trail_width = big_ring_props['trail_width']/2
    else:
        props = small_ring_props
        ring = 'small'
        trail_width = small_ring_props['trail_width']/2
    
    print(f'Phase {phase}: Encounter distance: {encounter_distance} pixels, Trail width: {trail_width} pixels')
    # get encounters
    encounters = []
    for tracklet in tracklets:
        val = np.abs(np.concatenate([[np.inf], tracklet[ring+' ring dist'].values])) < encounter_distance + trail_width
        count = np.cumsum(np.concatenate([np.array([np.nan]),np.diff(val.astype(int))])>0)
        mes = val*count
        
        # get every continuous segment where the fly is within encounter_distance of the ring circle
        encounters += [group for m, group in tracklet.groupby(mes[1:]) if m != 0]
    # filter for encounters with atleast 1 sign change in the ring position
    n_crossings= [get_crossings(encounter, ring, trail_width)[0] for encounter in encounters]
    mask = [n_crossings[i] >= N_CROSSINGS_FOR_ENCOUNTER for i in range(len(encounters))]
    encounters = [encounter for i,encounter in enumerate(encounters) if mask[i]]
    n_crossings = [n_crossings[i] for i in range(len(n_crossings)) if mask[i]]

    print(f'Phase {phase}: {len(encounters)} encounters')
    
    # plot all encounters
    fig, ax = plt.subplots()
    for encounter in encounters:
        plt.plot(encounter['pos x'], encounter['pos y'], '-', lw=0.5, color=plt.cm.cool(np.random.rand()))
    if phase == 1:
        draw_ring(big_ring_props, axis=ax, color='r', bg=p1_bg)
    elif phase == 2:
        draw_ring(big_ring_props, axis=ax, color='k', bg=p2_bg)
    else:
        draw_ring(small_ring_props, axis=ax, color='r', bg=p3_bg)
    plt.gca().set_aspect('equal')
    plt.title(f'Phase {phase}')
    # save figure
    plt.savefig(f'figures/Crossings_{N_CROSSINGS_FOR_ENCOUNTER}/phase_{phase}_encounters.png')
    plt.show()

    # plot in ring space
    fig, ax = plt.subplots()

    max_distances = []
    for encounter in encounters:
        position = (np.unwrap(encounter[ring+' ring pos'])*props['radius']/sf)
        # find point of first crossing
        crossing = np.argmax(np.abs(np.diff(np.sign(encounter[ring+' ring dist'].values))))
        position = position - position[crossing]
        dists = [np.abs(np.max(position[crossing:])),np.abs(np.min(position[crossing:]))]
        if np.argmax(dists) == 1:
            position = -position
        distance = encounter[ring+' ring dist']/sf
        max_distances.append(np.max(dists))
        ax.plot(distance, position, 'r-', lw=1, alpha=0.5)
        # plot the trail width
    ax.axvline(-trail_width/sf, color='k', alpha=0.5, ls='--')
    ax.axvline(trail_width/sf, color='k', alpha=0.5, ls='--')


    ax.set_xlabel('Distance from ring center (mm)')
    ax.set_ylabel('Position on ring (mm)')
    ax.set_title(f'Phase {phase}')
    ax.set_aspect('equal')
    # save figure
    plt.savefig(f'figures/Crossings_{N_CROSSINGS_FOR_ENCOUNTER}/phase_{phase}_encounters_unwrapped.png')
    plt.show()

    # sort by max distance
    encounters = [encounters[i].reset_index() for i in np.argsort(max_distances)]
    n_crossings = [n_crossings[i] for i in np.argsort(max_distances)]
    max_distances = np.sort(max_distances)

    # # sort by number of crossings
    # encounters = [encounters[i] for i in np.argsort(n_crossings)]
    # max_distances = [max_distances[i] for i in np.argsort(n_crossings)]
    # n_crossings = np.sort(n_crossings)

    # save ring encounters after adding an id
    for i, encounter in enumerate(encounters):
        encounter['id'] = i
        encounter['max_distance'] = max_distances[i]
        encounter['n_crossings'] = n_crossings[i]
    ring_encounters_CSV = pd.concat(encounters) if len(encounters) > 0 else pd.DataFrame(columns=['id', 'max_distance', 'n_crossings'])
    ring_encounters_CSV.to_csv(f'Crossings_{N_CROSSINGS_FOR_ENCOUNTER}/encounters_phase_{phase}.csv', index=False)

    # plot the distribution of max distances
    plt.figure()
    plt.hist(max_distances, bins=20, density=True)
    plt.xlabel('Max distance (mm)')
    plt.ylabel('Frequency')
    plt.title(f'Phase {phase}')
    # save figure
    plt.savefig(f'figures/Crossings_{N_CROSSINGS_FOR_ENCOUNTER}/phase_{phase}_max_distances.png')
    plt.show()
        

Phase 1: Encounter distance: 46.950959124984905 pixels, Trail width: 7.986138123212868 pixels
Phase 1: 108 encounters
Phase 2: Encounter distance: 46.950959124984905 pixels, Trail width: 7.986138123212868 pixels
Phase 2: 33 encounters
Phase 3: Encounter distance: 46.950959124984905 pixels, Trail width: 6.067263341391566 pixels
Phase 3: 74 encounters


In [25]:
for phase in [1,2,3]:
    if phase == 1:
        props = big_ring_props
        ring = 'big'
        trail_width = big_ring_props['trail_width']/2
    elif phase == 2:
        props = big_ring_props
        ring = 'big'
        trail_width = big_ring_props['trail_width']/2
    else:
        props = small_ring_props
        ring = 'small'
        trail_width = small_ring_props['trail_width']/2

    # get ring encounters for phase 1
    encounters = pd.read_csv(f'Crossings_{N_CROSSINGS_FOR_ENCOUNTER}/encounters_phase_{phase}.csv')
    if len(encounters) == 0:
        print(f'No encounters in phase {phase}')
        continue
    encounters = [group.reset_index(drop=True) for _, group in encounters.groupby('id')]
    max_distances = [encounter['max_distance'].values[0] for encounter in encounters]
    n_crossings = [encounter['n_crossings'].values[0] for encounter in encounters]

    # plot the top 10 in different plots (ring space)
    fig, ax = plt.subplots(2, 5, sharex=True, sharey=True)
    for i, encounter in enumerate(encounters[-10:][::-1]):
        position = (np.unwrap(encounter[ring+' ring pos'])*props['radius']/sf)
        # find point of first crossing
        _, crossings = get_crossings(encounter, ring, trail_width)
        crossing = crossings[0]
        position = position - position[crossing]
        dists = [np.abs(np.max(position[crossing:])),np.abs(np.min(position[crossing:]))]
        if np.argmax(dists) == 1:
            position = -position
        distance = encounter[ring+' ring dist']/sf
        ax.flatten()[i].plot(distance, position, 'r-', lw=1, alpha=0.5)
        ax.flatten()[i].set_title(f'Encounter {i+1}')
        ax.flatten()[i].set_aspect('equal')
        # draw the trail width
        ax.flatten()[i].axvline(-trail_width/sf, color='k', ls='--', lw=0.5)
        ax.flatten()[i].axvline(trail_width/sf, color='k', ls='--', lw=0.5)
    # hide axis with no plot
    for i in range(10-len(encounters)):
        ax.flatten()[-i-1].axis('off')
    plt.suptitle(f'Top 10 encounters in phase {phase}')
    plt.tight_layout()
    plt.savefig(f'figures/Crossings_{N_CROSSINGS_FOR_ENCOUNTER}/top_10_encounters_phase_{phase}_unwrapped.png')
    plt.show()

    # plot the top 10 in different plots (xy space)
    fig, ax = plt.subplots(2, 5, sharex=True, sharey=True)
    for i, encounter in enumerate(encounters[-10:][::-1]):
        ax.flatten()[i].plot(encounter['pos x'], encounter['pos y'], 'w-', lw=0.5)
        ax.flatten()[i].set_title(f'Encounter {i+1}')
        ax.flatten()[i].set_aspect('equal')
        draw_ring(props, color='k', axis=ax.flatten()[i])
    # hide axis with no plot
    for i in range(10-len(encounters)):
        ax.flatten()[-i-1].axis('off')
    plt.suptitle(f'Top 10 encounters in phase {phase}')
    plt.tight_layout()
    plt.savefig(f'figures/Crossings_{N_CROSSINGS_FOR_ENCOUNTER}/top_10_encounters_phase_{phase}.png')
    plt.show()


In [26]:
# create a joint video of all top 10 encounters
import skvideo.io
import os
from tqdm import tqdm

for phase in [1,2,3]:

    if phase == 1:
        props = big_ring_props
        ring = 'big'
        trail_width = big_ring_props['trail_width']/2
        bg = p1_bg
    elif phase == 2:
        props = big_ring_props
        ring = 'big'
        trail_width = big_ring_props['trail_width']/2
        bg = p2_bg
    else:
        props = small_ring_props
        ring = 'small'
        trail_width = small_ring_props['trail_width']/2
        bg = p3_bg

    encounters = pd.read_csv(f'Crossings_{N_CROSSINGS_FOR_ENCOUNTER}/encounters_phase_{phase}.csv')
    if len(encounters) == 0:
        print(f'No encounters in phase {phase}')
        continue
    encounters = [group.reset_index(drop=True) for _, group in encounters.groupby('id')]
    # reverse the order
    encounters = encounters[::-1]

    xs = [encounter['pos x'].values for encounter in encounters]
    ys = [encounter['pos y'].values for encounter in encounters]
    on_trail = [encounter[ring+' ring dist'].values<trail_width for encounter in encounters]

    # get background image
    bg = cv2.cvtColor((bg*255).astype(np.uint8), cv2.COLOR_GRAY2RGB)
    width, height = bg.shape[1], bg.shape[0]


    writer = skvideo.io.FFmpegWriter(f'videos/Crossings_{N_CROSSINGS_FOR_ENCOUNTER}/top_10_encounters_phase_{phase}.mp4', outputdict={'-r': str(FPS)})
    for j in tqdm(range(10)):
        if j >= len(encounters):
            break
        for i in range(len(encounters[j])):
            frame = bg.copy()
            if on_trail[j][i]:
                color = (255, 0, 0)
            else:
                color = (255, 255, 255)
            # draw a trail
            for k in range(i):
                if on_trail[j][k]:
                    cv2.line(frame, (int(xs[j][k]), int(ys[j][k])), (int(xs[j][k+1]), int(ys[j][k+1])), (255, 0, 0), 1)
                else:
                    cv2.line(frame, (int(xs[j][k]), int(ys[j][k])), (int(xs[j][k+1]), int(ys[j][k+1])), (255, 255, 255), 1)
            cv2.circle(frame, (int(xs[j][i]), int(ys[j][i])), 5, color, -1)
            # add text
            cv2.putText(frame, f'Encounter {j+1}', (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 1)
            cv2.putText(frame, f'Time: {i/FPS:.2f} s', (10, 50), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
            writer.writeFrame(frame)
        # wait for 2 second at the end of each encounter
        for _ in range(2*FPS):
            writer.writeFrame(frame)
    writer.close()



100%|██████████| 10/10 [00:22<00:00,  2.25s/it]
100%|██████████| 10/10 [00:49<00:00,  4.93s/it]
100%|██████████| 10/10 [00:08<00:00,  1.15it/s]
