In [None]:
'''
author: Felix Hol
date: 2020 Oct 20
content: code to track mosquitoes, several filtering parameters will need tweaking depending on imaging parameters.
Output is:
1) a pickle containing all detected centroids per frame
2) a pickle containing centroids tracked over time (/frame)
'''

%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import os
import itertools as it
import pandas as pd
import pims
import skimage
from skimage import data, io, util
from skimage.measure import label, regionprops
from skimage.morphology import binary_dilation, erosion, dilation, opening, binary_closing, closing, white_tophat, remove_small_objects, disk, black_tophat, skeletonize, convex_hull_image
import scipy
import trackpy as tp
import pylab
import math
from joblib import Parallel, delayed
import multiprocessing
from datetime import datetime
from tqdm import tnrange, tqdm
import pickle
import glob
import cv2 as cv

In [None]:
#### set directories where to get images and where to store output, and specifics of experiment

# dataDir = '/run/user/1001/gvfs/smb-share:server=gaia.pasteur.fr,share=%40ivi/BITES_BLOOD_BEHAVIOR/P3/200923_KPPTN/ctrl02/'
dataDir = '/mnt/DATA/biteData/P3/200930_KPPTN/denv04/'
saveDir = '/home/felix/biteData/P3/200930_processed/'
species = 'aeg'
mosAge = 21      #### mosquito age in days
mosDataName = 'aeg_denv04_0930'
frames = pims.ImageSequence(dataDir+'/*.png')

In [None]:
borderToExcludeXl = 250           #### excludes the outer n pixels of the left edge of the frame 
borderToExcludeXr = 140           #### excludes the outer n pixels of the right edge of the frame 
borderToExcludeYt = 100           #### excludes the outer n pixels of the top edge of the frame 
borderToExcludeYb = 220           #### excludes the outer n pixels of the bottom edge of the frame 
startFrame = 500              #### frame to start processing
stopFrame = len(frames)             #### frame to stop processing
startFrameBG = startFrame    #### frame to start background image
stopFrameBG = stopFrame     #### frame to stop background image (usually last frame unless immobile mosquitoes)
numBGframes = 20            #### number of frames to use for background image
mThreshold = 70              #### threshold used to create binary image of mosquitoes after BG subtraction
searchRadius = 450            #### used for tracking, maximum movement allowed between frames

In [None]:
os.chdir(saveDir)
len(frames)

In [None]:
#### create background image (BG image can be updated periodically - usually not necessary)

def getBG(start, stop, step):
    numBGframes = int(np.ceil((stop - start) /step) + 1)
    frameSize = frames[1].shape
    BG = np.zeros([frameSize[0], frameSize[1], numBGframes])

    j = 1
    for i in range(start, stop, step):
        j += 1
        BG[:, :, j - 1] = np.invert(frames[i])    
    BG = np.median(BG, axis=2)
    return BG

In [None]:
### compute background image using 30 frames (uniformly distributed over start - stop background)
### check resulting image to make sure non-moving mosquitoes did not become part of BG image
plt.figure(figsize=(18,12))
BG = getBG(startFrameBG, stopFrameBG, int(stopFrameBG / numBGframes)) 
plt.imshow(BG)

In [None]:
''' IF (and only if) there are immobile mosquitoes (=mosquitoes that do not move for the majority of frames) 
that cannot be removed by selecting a different range of the dataset, these can be blurred in the background 
image such that they are included in the tracking dataset (this may not be desirable for dead mosquitoes).

'''

blurX = [1450] #### X coordinate of all mosquitoes to be blurred in BG
blurY = [1150] #### Y coordinate of all mosquitoes to be blurred in BG
blurHW = 100
blurEdge = 10
numBlur = 5
sigma = 20

BGblur = getBG(startFrameBG, stopFrameBG, int(stopFrameBG / numBGframes))

for j in range(0,len(blurX)):
    for i in range(1,numBlur):
        BGblur[blurY[j] - blurHW : blurY[j] + blurHW, blurX[j] - blurHW : blurX[j] + blurHW] = \
        skimage.filters.gaussian(BGblur[blurY[j] - blurHW : blurY[j] + blurHW, blurX[j] - blurHW : blurX[j] + blurHW], sigma)

    BGblur[blurY[j] - (blurHW + blurEdge) : blurY[j] + (blurHW + blurEdge), blurX[j] - (blurHW + blurEdge) : blurX[j] + (blurHW + blurEdge)] = \
    skimage.filters.gaussian(BGblur[blurY[j] - (blurHW + blurEdge) : blurY[j] + (blurHW + blurEdge), blurX[j] - (blurHW + blurEdge) : blurX[j] + (blurHW + blurEdge)], sigma / 2)

    BGblur[blurY[j] - (blurHW + blurEdge*2) : blurY[j] + (blurHW + blurEdge*2), blurX[j] - (blurHW + blurEdge*2) : blurX[j] + (blurHW + blurEdge*2)] = \
    skimage.filters.gaussian(BGblur[blurY[j] - (blurHW + blurEdge*2) : blurY[j] + (blurHW + blurEdge*2), blurX[j] - (blurHW + blurEdge*2) : blurX[j] + (blurHW + blurEdge*2)], sigma / 4)

In [None]:
#### ONLY if using locally blurred BG image
plt.figure(figsize=(18,12))
plt.imshow(BGblur)

In [None]:
def trackMosq2(i, mThreshold, BG, borderToExcludeXl, borderToExcludeXr, borderToExcludeYt, borderToExcludeYb):
    frameSize = frames[1].shape
    # borderToExclude = 0
    selem1 = disk(6)
    selem2 = disk(1)
    A = np.zeros(frameSize)
    A = A + np.invert(frames[i])
    B = A - BG
    if B.min() > 0:
        Bm = B - B.min()
    else:
        Bm = B
    Bt = Bm > mThreshold
    Bts = remove_small_objects(Bt, min_size=300)
    Be = erosion(Bts, selem2)
    Bf = remove_small_objects(Be, min_size=200)
    Bc = binary_closing(Bf, selem1)
    C = B * Bc
    eroded = erosion(C, selem2)
    eroded = skimage.filters.gaussian(eroded, 4)
    eroded[eroded < 0] = 0
    erL = label(eroded>0)
    erR = regionprops(erL, C)#, coordinates='xy')
    l = 1
    for props in erR:   #### this filters out objects that are way too larger or small and excludes detections near the edge when desired
        if props.area > 40000:
            erL[erL==l] = 0
        if props.area < 800:
            erL[erL==l] = 0
        if props.major_axis_length > 300:
            erL[erL==l] = 0
        if props.centroid[0] < borderToExcludeYt or props.centroid[1] < borderToExcludeXl or props.centroid[0] > frameSize[0] - borderToExcludeYb or props.centroid[1] > frameSize[1] - borderToExcludeXr:
            erL[erL==l] = 0
        l = l +1
    erLf = label(erL>0)
    erodedF = eroded * (erLf > 0)
    erRf = regionprops(erLf, C)#, coordinates='xy')
    centroids = np.zeros([len(erRf), 2])
    numCent = 0
    for props in erRf:
        centroids[numCent] = props.centroid
        numCent += 1
    cenS= centroids.shape; numCen = cenS[0]
    frameNo = i
    frameNoCen = np.zeros((numCen,1), dtype=np.int)
    frameNoCen[:] = frameNo
    centroidsF = np.hstack((centroids,frameNoCen))
    numCents = centroidsF.shape[0]
    return centroidsF, numCents

In [None]:
#####testing...., use this cell to check a few frames and verify that all mosquitoes are detected (red dot)
i = 5880
centroidsF, numCents = trackMosq2(i, mThreshold, BGblur, borderToExcludeXl, borderToExcludeXr, borderToExcludeYt, borderToExcludeYb)
plt.figure(figsize=(18,12))
plt.imshow(frames[i], cmap = 'gray')
# plt.imshow(eroded)
# plt.plot(coordinatesF[:,1],coordinatesF[:,0],'r.')
plt.plot(centroidsF[:,1],centroidsF[:,0],'r.')

In [None]:
num_cores = multiprocessing.cpu_count()

print('detecting centriods of mosquitoes in frames ' + str(startFrame) + ' - ' + str(stopFrame) + ' using ' + str(num_cores) + ' cores')

results = Parallel(n_jobs=num_cores)(delayed(trackMosq2)(i, mThreshold, BG, borderToExcludeXl, borderToExcludeXr, borderToExcludeYt, borderToExcludeYb) for i in tqdm(range(startFrame, stopFrame)))

centroidsAllT = np.zeros((1,3))

for i in range(len(results)):
    centroidsAllT = np.vstack((centroidsAllT,results[i][0]))

centroidPickleName = saveDir + mosDataName + '_centroids.pkl'
    
with open(centroidPickleName, 'wb') as f:
    pickle.dump(centroidsAllT, f)
    
### convert centroids to dataframe

df_cenAllT = pd.DataFrame(centroidsAllT)
df_cenAllT = df_cenAllT.rename(columns={0: "y", 1: "x", 2: "frame"})

### track (link mosquitoes across frames). Adjust search radius depending on how crowded the images are

print('linking centroids through time using search radius: ' + str(searchRadius))

tFull = tp.link_df(df_cenAllT, searchRadius, memory=3)

tFull['species'] = species
tFull['age'] = mosAge

tFull = tFull.drop([0])

trackPickleName = saveDir + mosDataName + '_tracks.pkl'
tFull.to_pickle(trackPickleName)

print('output saved at ' + saveDir)