In [1]:

%load_ext autoreload
%autoreload 2
import json, os, cv2
from pathlib import Path
import pandas as pd 
import matplotlib.pyplot as plt 
import numpy as np
import xarray as xr
import seaborn as sns
import h5py as hf
from tqdm import tqdm
from tqdm import tqdm
import itertools 
from scipy.interpolate import interp1d
from scipy import signal
from matplotlib.backends.backend_pdf import PdfPages
from sklearn.cluster import KMeans
import matplotlib.colors as mcolors
import os, fnmatch
from scipy.spatial.distance import cdist
import matplotlib.gridspec as gridspec
from random import sample
from scipy.ndimage import gaussian_filter1d
from matplotlib.collections import LineCollection
from datetime import datetime


import sys
sys.path.insert(0, 'C:/Users/nlab/Documents/GitHub/obstacle_avoidance')

from utils.base_functions import *
from src.utils.auxiliary import flatten_series
from src.utils.path import find
from src.base import BaseInput
from plots.plots import plot_oa


import warnings
warnings.filterwarnings('ignore')

        As PyTorch is not installed, unsupervised identity learning will not be available.
        


In [2]:
def format_frames_oa( vid_path):
        # open the .avi file
        vidread = cv2.VideoCapture(vid_path)
        # empty array that is the target shape
        # should be number of frames x downsampled height x downsampled width
        all_frames = np.empty([int(vidread.get(cv2.CAP_PROP_FRAME_COUNT)),
                            int(vidread.get(cv2.CAP_PROP_FRAME_HEIGHT)),
                            int(vidread.get(cv2.CAP_PROP_FRAME_WIDTH))], dtype=np.uint8)
        # iterate through each frame
        for frame_num in tqdm(range(0,int(vidread.get(cv2.CAP_PROP_FRAME_COUNT)))):
            # read the frame in and make sure it is read in correctly
            ret, frame = vidread.read()
            if not ret:
                break
            # convert to grayyscale
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
            # add the downsampled frame to all_frames as int8
            all_frames[frame_num,:,:] = frame.astype(np.int8)
        return all_frames

def get_row_for_timestamp( df, seek_timestamp):
    for ind, row in df.iterrows():
        if seek_timestamp in row['trial_timestamps']:
            return row
def plot_frame( vid_arr, timestamps, df, seek_frame, return_as_array=False):
    seek_timestamp = timestamps[seek_frame]
    row = get_row_for_timestamp(df, seek_timestamp)
    if row is None:
        if return_as_array:
            return np.zeros(np.shape(vid_arr[0]))
        elif not return_as_array:
            plt.figure()
            plt.imshow(np.zeros(np.shape(vid_arr[0])), cmap='gray')
            plt.show()
    row_time_index = np.where(row['trial_timestamps']==seek_timestamp)
    #current_ang = row['head_angle'][row_time_index][0]
    x1 = row['nose_x'][row_time_index]
    y1 = row['nose_y'][row_time_index]
    #x2 = x1+60 * np.cos(current_ang)
    #y2 = y1+60 * np.sin(current_ang)
    frame = vid_arr[seek_frame,:,:]
    fig = plt.figure()
    plt.imshow(frame, cmap='gray')
    plt.plot((x1), (y1), '-')
    row_time_index = row_time_index[0][0]
    plt.plot(row['nose_x'][:row_time_index], row['nose_y'][:row_time_index],'r.')
    plt.plot(row['leftear_x'][:row_time_index], row['leftear_y'][:row_time_index], 'g.')
    plt.plot(row['rightear_x'][:row_time_index], row['rightear_y'][:row_time_index], 'g.')
    plt.plot([row['obstacleTL_x'][:row_time_index], row['obstacleTR_x'][:row_time_index], row['obstacleBR_x'][:row_time_index], row['obstacleBL_x'][:row_time_index],row['obstacleTL_x'][:row_time_index]],
                            [row['obstacleTL_y'][:row_time_index], row['obstacleTR_y'][:row_time_index], row['obstacleBR_y'][:row_time_index], row['obstacleBL_y'][:row_time_index],row['obstacleTL_y'][:row_time_index]],color='blue')

    if not return_as_array:
        plt.show()
    elif return_as_array:
        fig.canvas.draw()
        frame_as_array = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
        frame_as_array = frame_as_array.reshape(fig.canvas.get_width_height()[::-1] + (3,))
        plt.close()
        return frame_as_array
def plot_all_trials( vid_arr, timestamps, df, vid_savepath):
    fourcc = cv2.VideoWriter_fourcc(*'XVID')
    out_vid = cv2.VideoWriter(vid_savepath, fourcc, 60.0, (640, 480))
    start = 3600*5
    for seek_frame in tqdm(range(start,start+3600)):
        frame = plot_frame(vid_arr, timestamps, df, seek_frame, return_as_array=True)
        out_vid.write(frame.astype('uint8'))
    out_vid.release()
def make_videos(self):
    vid_savepath = os.path.join(trial_path, (data['animal'].iloc[0]+'_'+str(data['date'].iloc[0])+'_'+str(data['task'].iloc[0])+'plot.avi'))
    vid_path = find('*'+str(s['date'])+'*'+s['animal']+'*'+str(s['task'])+'*.avi', trial_path)[0]
    timestamp_path = find('*'+str(s['date'])+'*'+s['animal']+'*'+str(s['task'])+'*_top1_BonsaiTS.csv', trial_path)[0]
    print('formating video frames as array')
    vid_arr = format_frames_oa(vid_path)
    print('plotting video of traces')
    plot_all_trials(vid_arr, read_timestamp_file(), data, vid_savepath)

def read_timestamp_series(s):
        """ Read timestamps as a pd.Series and format time.

        Parameters
        --------
        s : pd.Series
            Timestamps as a Series.
            Expected to be formated as hours : minutes : seconds . microsecond

        Returns
        --------
        output_time : np.array
            Returned as the number of seconds that have passed since the
            previous midnight, with microescond precision, e.g. 700.000000
        """
        output_time = []
        fmt = '%H:%M:%S.%f'
        if s.dtype != np.float64:
            for current_time in s:
                str_time = str(current_time).strip()
                try:
                    t = datetime.strptime(str_time, fmt)
                except ValueError as v:
                    ulr = len(v.args[0].partition('unconverted data remains: ')[2])
                    if ulr:
                        str_time = str_time[:-ulr]
                try:
                    output_time.append((datetime.strptime(str_time, '%H:%M:%S.%f') - datetime.strptime('00:00:00.000000', '%H:%M:%S.%f')).total_seconds())
                except ValueError:
                    output_time.append(np.nan)
            output_time = np.array(output_time)
        else:
            output_time = s.values
        return output_time
def read_timestamp_file(timestamp_path, position_data_length=None, force_timestamp_shift=False):
        """ Read timestamps from a .csv file.

        Parameters
        --------
        position_data_length : None or int
            Number of timesteps in data from deeplabcut. This is used to
            determine whether or not the number of timestamps is too short
            for the number of video frames.
            Eyecam and Worldcam will have half the number of timestamps as
            the number of frames, since they are aquired as an interlaced
            video and deinterlaced in analysis. To fix this, timestamps need
            to be interpolated.
        """
        # read data and set up format
        s = pd.read_csv(timestamp_path, encoding='utf-8', engine='c', header=None).squeeze()
        if s[0] == 0:
            s = s[1:]
        camT = read_timestamp_series(s)
        # auto check if vids were deinterlaced
        if position_data_length is not None:
            if position_data_length > len(camT):
                camT = interp_timestamps(camT, use_medstep=False)
        # force the times to be shifted if the user is sure it should be done
        if force_timestamp_shift is True:
            camT = interp_timestamps(camT, use_medstep=False)
        return camT


In [3]:
vid_path = r"D:\obstacle_avoidance\recordings\042723\G8CKLN\oa\042723_G8CKLN_control_Rig2_oa_top1.avi"

all_frames = format_frames_oa(vid_path)

100%|██████████| 38476/38476 [00:58<00:00, 655.97it/s]


In [4]:
df =pd.read_hdf(r"D:\obstacle_avoidance\recordings\042623\G8CKRT\oa\non_obstacleG8CKRT_042623_oa.h5")

In [73]:
df

Unnamed: 0,first_poke,second_poke,trial_timestamps,trial_vidframes,nose_x,nose_y,leftear_x,leftear_y,rightear_x,rightear_y,...,gt_obstacleTR_x_cm,gt_obstacleTR_y_cm,gt_obstacleBR_x_cm,gt_obstacleBR_y_cm,gt_obstacleBL_x_cm,gt_obstacleBL_y_cm,gt_obstacle_cen_x,gt_obstacle_cen_x_cm,gt_obstacle_cen_y,gt_obstacle_cen_y_cm
0,40698.778227,40702.272345,"[40698.785305, 40698.802355, 40698.819417, 406...","[128, 129, 130, 131, 132, 133, 134, 135, 136, ...","[119.03721618652344, 117.72010040283203, 117.4...","[253.92242431640625, 252.53662109375, 251.7605...","[127.06047821044922, 126.20441436767578, 125.6...","[267.3447265625, 266.4164733886719, 266.759521...","[144.4882049560547, 143.71347045898438, 143.39...","[263.9754638671875, 263.71661376953125, 263.48...",...,36.233807,15.091118,37.129423,30.622867,31.605736,30.394298,384.480251,34.063662,257.619523,22.824227
1,40702.272345,40708.994444,"[40702.287142, 40702.303296, 40702.32, 40702.3...","[[338, 339, 340, 341, 342, 343, 344, 345, 346,...","[688.3577880859375, 688.5303344726562, 688.782...","[264.744873046875, 264.13262939453125, 263.727...","[660.4657592773438, 660.3155517578125, 660.205...","[254.3142852783203, 254.03196716308594, 254.63...","[668.826416015625, 669.2583618164062, 669.1931...","[272.9864196777344, 272.223876953125, 272.6505...",...,36.140056,15.297044,37.110390,30.619161,31.598483,30.388308,384.068202,34.027156,258.590953,22.910292
2,40708.994444,40713.402188,"[40709.00544, 40709.022105, 40709.039603, 4070...","[741, 742, 743, 744, 745, 746, 747, 748, 749, ...","[116.86168670654297, 116.06183624267578, 116.7...","[252.99224853515625, 252.9968719482422, 253.15...","[127.65352630615234, 126.72145080566406, 127.4...","[263.6208190917969, 263.3009033203125, 263.743...","[142.7602081298828, 143.74696350097656, 142.63...","[261.71258544921875, 260.7822265625, 261.87390...",...,41.632134,15.262945,40.908704,30.319068,41.728244,30.205052,454.381502,40.256679,257.125518,22.780459
3,40713.402188,40716.769356,"[40713.40718, 40713.424473, 40713.440409, 4071...","[[1005, 1006, 1007, 1008, 1009, 1010, 1011, 10...","[687.7032470703125, 688.4337158203125, 687.266...","[261.4231262207031, 260.4730529785156, 258.301...","[658.1459350585938, 658.788330078125, 658.0382...","[256.355224609375, 256.5988464355469, 257.0325...","[667.3782958984375, 667.34130859375, 667.77917...","[271.54107666015625, 271.8910827636719, 272.48...",...,41.643506,15.082714,38.274890,30.317088,43.765537,29.971986,452.916021,40.126842,256.126865,22.691982
4,40716.769356,40720.513331,"[40716.774912, 40716.791564, 40716.809164, 407...","[1207, 1208, 1209, 1210, 1211, 1212, 1213, 121...","[114.26569366455078, 114.50706481933594, 114.2...","[258.09423828125, 258.5579528808594, 258.31008...","[143.2292938232422, 143.2358856201172, 143.447...","[263.84075927734375, 263.87969970703125, 263.8...","[138.84149169921875, 136.95347595214844, 136.9...","[246.88626098632812, 245.94223022460938, 245.6...",...,41.636270,15.260638,42.068115,30.308188,39.862307,30.336948,452.399584,40.081087,257.457911,22.809908
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
93,41176.075622,41179.786956,"[41176.083558, 41176.099942, 41176.116595, 411...","[[28756, 28757, 28758, 28759, 28760, 28761, 28...","[686.3865966796875, 685.968994140625, 685.8830...","[268.870849609375, 269.9952697753906, 269.8724...","[671.664306640625, 671.3069458007812, 671.5077...","[250.68890380859375, 251.02703857421875, 251.6...","[657.53564453125, 657.6210327148438, 657.82305...","[254.1944580078125, 254.33132934570312, 254.55...",...,27.084549,8.453716,29.404227,23.382812,23.763502,23.833395,289.587755,25.656505,182.403823,16.160368
94,41179.786956,41184.477209,"[41179.801177, 41179.817856, 41179.835596, 411...","[28979, 28980, 28981, 28982, 28983, 28984, 289...","[115.21121978759766, 114.4676513671875, 114.43...","[253.8466033935547, 253.18275451660156, 252.81...","[127.78558349609375, 126.84893035888672, 125.1...","[262.06793212890625, 262.48504638671875, 263.4...","[143.24462890625, 141.9678192138672, 141.86824...","[261.5128173828125, 260.94317626953125, 260.28...",...,43.666641,20.048019,44.189528,34.832997,38.953713,34.938653,467.743009,41.440463,308.896273,27.367175
97,41217.751744,41225.178496,"[41217.764134, 41217.781145, 41217.798387, 412...","[[31256, 31257, 31258, 31259, 31260, 31261, 31...","[687.0308837890625, 686.3541259765625, 686.781...","[268.8574523925781, 269.9048767089844, 269.739...","[672.8463745117188, 672.149169921875, 672.0062...","[250.84864807128906, 251.80108642578125, 252.3...","[656.5009765625, 655.5339965820312, 654.999511...","[257.91259765625, 257.5879211425781, 257.51635...",...,36.773050,20.302607,37.204447,35.312360,31.851704,35.147147,389.093487,34.472379,312.154863,27.655875
99,41242.657484,41247.974579,"[41242.672768, 41242.689548, 41242.706854, 412...","[[32750, 32751, 32752, 32753, 32754, 32755, 32...","[687.149658203125, 686.8919067382812, 687.1539...","[269.35308837890625, 269.54248046875, 269.6818...","[672.1484985351562, 672.6808471679688, 673.192...","[254.5059814453125, 255.00009155273438, 255.14...","[655.3718872070312, 656.4122314453125, 657.615...","[262.2472229003906, 262.4425354003906, 262.767...",...,36.722359,20.195884,37.187203,35.311093,31.850340,35.136131,388.886347,34.454027,311.449654,27.593396


In [5]:
timestamps = read_timestamp_file(r"D:\obstacle_avoidance\recordings\042623\G8CKRT\oa\042623_G8CKRT_control_Rig2_oa_top1_BonsaiTS.csv")

In [6]:

plot_all_trials(all_frames, timestamps, df, r'D:\obstacle_avoidance\recordings\G8CK\042623_G8CKRTLTplot.avi')



 74%|███████▍  | 2674/3600 [32:02<15:51,  1.03s/it]  

In [22]:
df.obstacleTL_x

0      [317.44390869140625, 317.42626953125, 317.5698...
1      [317.295654296875, 317.2867736816406, 317.5597...
2      [317.5199890136719, 317.2187805175781, 317.514...
3      [411.0384826660156, 411.0067443847656, 410.857...
4      [410.9383544921875, 410.82421875, 410.94079589...
                             ...                        
224    [293.1347351074219, 293.0476379394531, 292.737...
226    [393.1718444824219, 393.0086364746094, 392.896...
228    [392.7950134277344, 393.14984130859375, 393.15...
229    [407.4093322753906, 407.34405517578125, 407.52...
230    [407.3516845703125, 407.5133972167969, 407.452...
Name: obstacleTL_x, Length: 195, dtype: object

In [29]:
import cv2
import random

def extract_frames_from_video(video_path, output_path, num_frames):
    # Open the video file
    video = cv2.VideoCapture(video_path)

    # Check if the video file was successfully opened
    if not video.isOpened():
        print("Error opening video file")
        return

    # Get total number of frames in the video
    total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))

    # Generate random frame indices
    indices = random.sample(range(total_frames), num_frames)

    # Initialize variables
    frame_count = 0
    frame_index = 0

    # Read frames until the video is completed or desired number of frames is reached
    while video.isOpened() and frame_count < num_frames:
        # Read the current frame
        ret, frame = video.read()

        if ret:
            if frame_index in indices:
                # Save the frame as an image
                frame_output_path = f"{output_path}/frame_{frame_count}.jpg"
                cv2.imwrite(frame_output_path, frame)

                # Increment the frame count
                frame_count += 1

            # Increment the frame index
            frame_index += 1
        else:
            # Break the loop if the video is completed
            break

    # Release the video object
    video.release()

    print(f"Extracted {frame_count} frames from the video")


Error opening video file


In [32]:

# Example usage
video_path = r"D:\obstacle_avoidance\recordings\041723\test\oa\041723_test_control_Rig2_oa_top1.avi"
output_path = r"D:\obstacle_avoidance\recordings\G8CK\vid_images"
num_frames = 5

extract_frames_from_video(video_path, output_path, num_frames)

Extracted 5 frames from the video
