In [2]:
#####################################################################################################################
# FUNCTION THAT CONVERTS SLEAP .SLP FILES TO:
#  - *.csv
#  - *_traces_inferences.npy - which contains the x, y location for all the features detected
#  - *_chain_ids.npy - which contains all the track ids of the sleap output for each group of features
# 
# Note 1: Previous function is sleap-track CLI
#  
# Note 2: (Optional) Following notebook is 9_train_cnn_classifier pytorch if a trained CNN is not available
#
# Note 3: Next is 9_predict_using_cnn_classifer
#
#
#


import matplotlib
#matplotlib.use('Agg')
%matplotlib tk
%autosave 180
import matplotlib.pyplot as plt
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

import matplotlib.cm as cm
from matplotlib import gridspec

import numpy as np
import os
import shutil
import cv2
from tqdm import trange

#import glob2

import h5py
#import hdf5storage
import csv



Autosaving every 180 seconds


In [19]:

def csv_to_traces(fname):

    ''' Convert csv to the basic 2 files needed to generate traces for all animals
        1. traces_reassembled.npz contains tracesx, tracesy which correspond to the
        n animals selected. tracesx has dimensions [n_features*n_animals, n_frames]
        - For sleap output this is the .csv data reshaped from [n_detections,n_features]
          to [ ..., n_frames]
          filling in the list of animals whenever there is an animal detected in a particular
          frame. 

        2. chain_id.npy which has shape [n_frames, n_animals] and contains the track_id
        for each of the animal features sets in reassembled.
    '''

    import csv
    from tqdm import tqdm
    import numpy as np

    # LOAD .CSV DATA FIRST
    data_tracks = []
    with open(fname) as csv_file:
        csv_reader = csv.reader(csv_file, delimiter=',')
        ctr=0
        for row in tqdm(csv_reader):
            data_tracks.append([])
            for item in row:
                if item=='':
                    item = 1E10
                elif 'track' in item:
                    item = item.replace('track_','')
                data_tracks[ctr].append(item)
            ctr+=1
    data_tracks = np.array(data_tracks[1:])
    #print (data_tracks[0])

    # CONVERT DATA TO FLOAT32 and replace missing data with Nan
    #print (data_tracks.shape)
    data_tracks = np.float32(data_tracks)
    #print (data_tracks)

    idx = np.where(data_tracks==1E10)
    data_tracks[idx]=np.nan
    #print (data_tracks[0])
    #print ("Final tracks data shape: ", data_tracks.shape)
    
    
    #########################################################################
    ###### CONVERT data_tracks to inference and chain.npy files #############
    #########################################################################
    
    n_frames = int(np.max(data_tracks[:,1]))+1
    #print ("example data-tracks: ", data_tracks[240000:240002])
    #print ("Final tracks data shape: ", data_tracks.shape, 
    #       "  # frames: ", n_frames)


    # initialize array to hold track ids: [n_frames, track_id]
    n_animals=4+4
    n_features=14
    traces_reassembled = np.zeros((n_frames,n_animals*n_features,2), 'float32')+np.nan
    #print ("traces reassembled: ", traces_reassembled.shape)

    # initizlied chain ids
    chain_ids = np.zeros((n_frames, n_animals))+np.nan
    #print ("chain_ids: ", chain_ids.shape)

    # placeholder for data
    sleap_locs = np.zeros((n_features, 2), 'float32')

    # loop over all tracks filling in the data in the 
    ctr = 0
    from tqdm import trange
    for k in trange(data_tracks.shape[0]):
        track_id = int(data_tracks[k,0])
        frame_id = int(data_tracks[k,1])
        #print ("index: ", k, " track_id: ", track_id,
        #       "  frame_id: ", frame_id, 
        #       " data: ", data_tracks[k])
        sleap_locs[:,0] = data_tracks[k,2:][::2]  # grab the x vals
        sleap_locs[:,1] = data_tracks[k,2:][1::2] # grab the y vals

        # find the next empty slot and enter the chain id info
        loc = 0
        while True:
            if np.isnan(chain_ids[frame_id,loc]):
                break
            loc+=1

        # insert the track_id in the chain
        chain_ids[frame_id,loc] = track_id

        # insert the feature locations 
        traces_reassembled[frame_id,
                           loc*n_features:(loc+1)*n_features] = sleap_locs

    # save data
    np.save(fname[:-4]+'_chain_ids.npy', chain_ids)
    np.save(fname[:-4]+'_traces_reassembled.npy', traces_reassembled)

    return chain_ids, traces_reassembled



In [20]:
##################################################################
####### .CSV TO traces_inferences.npy and chain_id.npy files #####
##################################################################
# SELECT THE .CSV FILE
fname = '/media/cat/7e3d5af3-7d7b-424d-bdd5-eb995a4a0c62/dan/cohort1/march_9/2020-3-9_12_14_22_815059_compressed/video.mp4.analysis.csv'

chain_ids, traces_reassembled = csv_to_traces(fname)
print ("chain_ids: ", chain_ids.shape, chain_ids[:5])
print ("traces_reassembled: ", traces_reassembled.shape, traces_reassembled[:5])



289788it [00:04, 61364.38it/s]
100%|██████████| 289787/289787 [00:04<00:00, 60148.73it/s]


chain_ids:  (89987, 8) [[ 0.  1.  2.  3. nan nan nan nan]
 [ 0.  1.  2.  3. nan nan nan nan]
 [ 0.  1.  2.  3. nan nan nan nan]
 [ 0.  1.  2.  3. nan nan nan nan]
 [ 0.  2.  3. nan nan nan nan nan]]
traces_reassembled:  (89987, 112, 2) [[[nan nan]
  [nan nan]
  [nan nan]
  ...
  [nan nan]
  [nan nan]
  [nan nan]]

 [[nan nan]
  [nan nan]
  [nan nan]
  ...
  [nan nan]
  [nan nan]
  [nan nan]]

 [[nan nan]
  [nan nan]
  [nan nan]
  ...
  [nan nan]
  [nan nan]
  [nan nan]]

 [[nan nan]
  [nan nan]
  [nan nan]
  ...
  [nan nan]
  [nan nan]
  [nan nan]]

 [[nan nan]
  [nan nan]
  [nan nan]
  ...
  [nan nan]
  [nan nan]
  [nan nan]]]


In [18]:
print (traces_reassembled.shape)

(89987, 112, 2)


In [44]:
######################################################
######## LOOP OVER the CSV DATA FILE #################
######################################################




  2%|▏         | 5968/289787 [00:00<00:04, 59673.51it/s]

example data-tracks:  [[ 2395.      72536.              nan         nan         nan         nan
          nan         nan         nan         nan         nan         nan
    281.0457    164.48834   271.1698    172.6713    264.94714   192.30534
    259.1043    210.41785   243.06102   226.46506   280.88947   258.35266
          nan         nan         nan         nan   279.23987   292.46088]
 [ 2402.      72536.        823.093     261.11435   799.1196    283.18332
    838.86835   291.26627   789.022     313.1874    829.20447   324.9848
    808.9319    331.14444   797.3017    363.0731    792.94495   392.97885
          nan         nan         nan         nan         nan         nan
    895.0874    538.9115          nan         nan         nan         nan]]
Final tracks data shape:  (289787, 30)   # frames:  89987
traces reassembled:  (89987, 112, 2)
chain_ids:  (89987, 8)


100%|██████████| 289787/289787 [00:04<00:00, 59768.04it/s]


[[0.000e+00 1.000e+00 2.000e+00 ...       nan       nan       nan]
 [0.000e+00 1.000e+00 2.000e+00 ...       nan       nan       nan]
 [0.000e+00 1.000e+00 2.000e+00 ...       nan       nan       nan]
 ...
 [2.848e+03 2.862e+03       nan ...       nan       nan       nan]
 [2.848e+03 2.862e+03       nan ...       nan       nan       nan]
 [2.848e+03 2.862e+03       nan ...       nan       nan       nan]]
[[      nan       nan]
 [      nan       nan]
 [      nan       nan]
 [      nan       nan]
 [      nan       nan]
 [908.91315 164.58955]
 [900.74615 148.29443]
 [880.7421  136.36438]
 [852.91156 128.35283]
 [      nan       nan]
 [      nan       nan]
 [      nan       nan]
 [      nan       nan]
 [      nan       nan]
 [711.9241  136.06058]
 [723.92346 147.80606]
 [      nan       nan]
 [746.0039  151.94868]
 [754.04016 122.08743]
 [755.91144 136.23624]
 [772.03925 141.93369]
 [791.9203  143.96324]
 [807.9962  142.30421]
 [824.0479  140.11755]
 [835.9752  130.20096]
 [      nan      

In [20]:
# CONVERT DATA TO FLOAT32 and replace missing data with Nan
print (data.shape)
data = np.float32(data)
print (data)

idx = np.where(data==1E10)
data[idx]=np.nan
print (data)


(289787, 30)
[[0.0000000e+00 0.0000000e+00 1.0000000e+10 ... 1.0000000e+10
  1.0000000e+10 1.0000000e+10]
 [1.0000000e+00 0.0000000e+00 7.1192407e+02 ... 1.0000000e+10
  1.0000000e+10 1.0000000e+10]
 [2.0000000e+00 0.0000000e+00 1.0000000e+10 ... 5.8466187e+02
  1.0000000e+10 1.0000000e+10]
 ...
 [2.8620000e+03 8.9985000e+04 1.0000000e+10 ... 1.0000000e+10
  1.0000000e+10 1.0000000e+10]
 [2.8480000e+03 8.9986000e+04 2.4971579e+02 ... 1.0000000e+10
  1.0000000e+10 1.0000000e+10]
 [2.8620000e+03 8.9986000e+04 1.0000000e+10 ... 1.0000000e+10
  1.0000000e+10 1.0000000e+10]]


[[0.0000000e+00 0.0000000e+00           nan ...           nan
            nan           nan]
 [1.0000000e+00 0.0000000e+00 7.1192407e+02 ...           nan
            nan           nan]
 [2.0000000e+00 0.0000000e+00           nan ... 5.8466187e+02
            nan           nan]
 ...
 [2.8620000e+03 8.9985000e+04           nan ...           nan
            nan           nan]
 [2.8480000e+03 8.9986000e+04 2.4971579e+02 ...           nan
            nan           nan]
 [2.8620000e+03 8.9986000e+04           nan ...           nan
            nan           nan]]


In [2]:
import csv

traces =[]
with open('/media/cat/7e3d5af3-7d7b-424d-bdd5-eb995a4a0c62/dan/cohort1/2020-3-9_12_14_22_815059_compressed/video.mp4.analysis.csv', mode='r') as infile:
    reader = csv.reader(infile)
    with open('coors_new.csv', mode='w') as outfile:
        writer = csv.writer(outfile)
        for row in reader:
            traces.append(row)
        
#print (len(mydict))

In [3]:
traces.pop(0)

print (traces[0])
print (len(valid_frame_idxs))
print (len(traces))

['track_0', '0', '', '', '', '', '', '', '', '', '', '', '908.9131469726562', '164.5895538330078', '900.7461547851562', '148.29443359375', '880.7421264648438', '136.3643798828125', '852.9115600585938', '128.3528289794922', '', '', '', '', '', '', '', '', '', '']


NameError: name 'valid_frame_idxs' is not defined

In [6]:
#########################################################################
######################### MAKE VIDEOS ###################################
#########################################################################
import cv2
import matplotlib

#          pup1     pup2    female  male
colors_4= ['orange','green', 'blue', 'red', 'cyan','white','yellow','pink']

video_name = '/media/cat/7e3d5af3-7d7b-424d-bdd5-eb995a4a0c62/dan/cohort1/2020-3-9_12_14_22_815059_compressed/2020-3-9_12_14_22_815059_compressed.avi'
original_vid = cv2.VideoCapture(video_name)


# SELECT VIDEO SIZE
size_vid = np.array([1280,1024])
scale = 1
dot_size = 8//scale

# SET START AND END TIMES
start = 0
end = 299+1
#end = len(valid_frame_idxs)
#end = 3000
print ("START: ", start, "   END: ", end)

# SELECT VIDEOS OUT
#out_dir = '/media/cat/4TBSSD/dan/march_2/madeline_dlc/2020-3-9_08_18_49_128168/'
fname_out = video_name[:-4]+"_corrected_"+str(start)+"_"+str(end)+".mp4"
fourcc = cv2.VideoWriter_fourcc('M','P','E','G')
video_out = cv2.VideoWriter(fname_out,fourcc, 25, (size_vid[0]//scale,size_vid[1]//scale), True)


original_vid.set(cv2.CAP_PROP_POS_FRAMES, start)

font = cv2.FONT_HERSHEY_PLAIN

ctr_sleap = 0
for n in trange(start,end, 1):
    ret, frame = original_vid.read()
    #print (n, frame.shape)
    cv2.putText(frame, str(n), (50, 100), font, 5, (255, 255, 0), 5)
    frame = frame[::scale, ::scale]
    
    # LOAD ALL LAbELS AT THIS TRACES
    while True:
        temp = traces[ctr_sleap]
        #print (temp)
        frame_id = int(temp[1])
        track_id = int(temp[0].replace('track_',''))
        
        # TRACK TEXT
        track_text = True
        #print (locs1)
        
        # EXIT WHEN HITTING NEXT FRAME INFO
        if frame_id>n:
            break
        
        # 
        locs1 = np.array(temp[2:])  # LOAD x,y locations

        y_array = locs1[::2]
        #print (x_array)
        x_array = locs1[1::2]
        
        for k in range(len(x_array)):
            x = x_array[k]
            y = y_array[k]
            
            if x=='' or y=='':
                continue
            else:
                x=int(float(x))//scale
                y=int(float(y))//scale

                if track_text:
                    cv2.putText(frame, str(track_id), (y,x), font, 5, (255, 255, 0), 5)
                    track_text=False
                
                frame[x-dot_size:x+dot_size,y-dot_size:y+dot_size]= (np.float32(
                    matplotlib.colors.to_rgb(colors_4[track_id%8]))*255.).astype('uint8')
                #print (colors_4[k])
                #frame[y-dot_size:y+dot_size,x-dot_size:x+dot_size]= (np.float32(
                #    matplotlib.colors.to_rgb(colors_4[z//14]))*255.).astype('uint8')
                
        ctr_sleap+=1

    #print ("")
    video_out.write(frame)

    #print ("")

video_out.release()
original_vid.release()
#cv2.destroyAllWindows()

  0%|          | 0/300 [00:00<?, ?it/s]

START:  0    END:  300


100%|██████████| 300/300 [00:07<00:00, 41.71it/s]


In [5]:
import cv2
import matplotlib
matplotlib.use('Agg')


#          pup1     pup2    female  male
colors_4= ['orange','green', 'blue', 'red', 'cyan','white','yellow','pink']

video_name = '/media/cat/4TBSSD/dan/march_2/sleap_talmo/2020-3-9_12_14_22_815059_compressed.avi'
original_vid = cv2.VideoCapture(video_name)


# SELECT VIDEO SIZE
size_vid = np.array([1280,1024])
scale = 1
dot_size = 8//scale

# SET START AND END TIMES
start = 0
end = 299+1
end = len(valid_frame_idxs)
end = 3000
start=225
end=start+100

print ("START: ", start, "   END: ", end)

# SELECT VIDEOS OUT
#out_dir = '/media/cat/4TBSSD/dan/march_2/madeline_dlc/2020-3-9_08_18_49_128168/'
fname_out = video_name[:-4]+"_corrected_"+str(start)+"_"+str(end)+".mp4"
fourcc = cv2.VideoWriter_fourcc('M','P','E','G')
video_out = cv2.VideoWriter(fname_out,fourcc, 25, (size_vid[0]//scale,size_vid[1]//scale), True)


original_vid.set(cv2.CAP_PROP_POS_FRAMES, start)

font = cv2.FONT_HERSHEY_PLAIN

ctr_sleap = 0
ctr_show = 0

fig=plt.figure(figsize=(20,20))

for n in trange(start,end, 1):
    ret, frame = original_vid.read()
    #print (n, frame.shape)
    cv2.putText(frame, str(n), (50, 100), font, 5, (255, 255, 0), 5)
    frame = frame[::scale, ::scale]
    
    # LOAD ALL LAbELS AT THIS TRACES
    while True:
        temp = traces[ctr_sleap]
        #print (temp)
        frame_id = int(temp[1])
        track_id = int(temp[0].replace('track_',''))
        
#         if frame_id !=n:
#             ctr_sleap+=1

#             continue
        
        # TRACK TEXT
        track_text = True
        #print (locs1)
        
        # EXIT WHEN HITTING NEXT FRAME INFO
        if frame_id>n:
            break
        
        if frame_id<n:
            ctr_sleap+=1
            continue
        
        # 
        locs1 = np.array(temp[2:])  # LOAD x,y locations

        y_array = locs1[::2]
        #print (x_array)
        x_array = locs1[1::2]
        
        for k in range(len(x_array)):
            x = x_array[k]
            y = y_array[k]
            
            if x=='' or y=='':
                continue
            else:
                x=int(float(x))//scale
                y=int(float(y))//scale

                if track_text:
                    cv2.putText(frame, str(track_id), (y,x), font, 5, (255, 255, 0), 5)
                    track_text=False
                
                frame[x-dot_size:x+dot_size,y-dot_size:y+dot_size]= (np.float32(
                    matplotlib.colors.to_rgb(colors_4[track_id%8]))*255.).astype('uint8')

        
        ctr_sleap+=1

    ax=plt.subplot(2,5,ctr_show+1)
    plt.imshow(frame)
    plt.title(str(n))
    
    ctr_show+=1
fig.savefig('/home/cat/fig2.png',dpi=300)   

NameError: name 'valid_frame_idxs' is not defined