In [70]:
"""
Analysis toolbox for Schorscher-Petcu et al
© Browne Lab 2020
https://github.com/browne-lab/throwinglight
Authors: 
Ara Schorscher-Petcu (https://github.com/Ara-SP)
Liam E. Browne (https://github.com/lebrowne)

"""

import os, glob, cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy import signal

def ms_to_frame(ms, fps):
    frame = int((ms / 1000) * fps)
    return frame

def prepare_dirs(path, analysis_type, video_type='.avi', video_prefix='Wide'): 
    
    analysis_type_dir = os.path.join(os.path.dirname(path), 'analysis', analysis_type)
    subdirectories = []
    if analysis_type_dir in [x[0] for x in os.walk(os.path.dirname(path))]: 
        print('Stopped as', analysis_type_dir, 'folder already present. Please remove and try again if required.')
        success = False
    else:   
        subdirectories = [x[0] for x in os.walk(path) if len([f for f in os.listdir(x[0]) if f.endswith(video_type)]) > 0]
        if len(subdirectories) == 0: 
            print('No',video_type,'files found.')
            success = False
        else:
            os.makedirs(analysis_type_dir) 
            success = True
                
    return success, subdirectories, analysis_type_dir


def global_motion_energy(path, x0, x1, y0, y1, stimulation_time_ms, fps, expected_frames, mod_file, threshold=5, binary=True, video_type='.avi', video_prefix='Global',show_image=True):

    analysis_type = 'global_motion_energy'
    success, subdirectories, analysis_type_dir = prepare_dirs(path, analysis_type, video_type, video_prefix)
      
    if success:
        for subdirectory in subdirectories:

            video_list = [vid for vid in os.listdir(subdirectory) if vid.endswith(video_type) and vid.startswith(video_prefix)]
            data = pd.DataFrame()
            for vid in video_list:
                
                cap = cv2.VideoCapture(os.path.join(subdirectory, vid))
                    
                #check correct frame number
                frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
                if not frame_count == expected_frames:
                    print('Expected frame number not in \''+ vid +'\'')
                    continue
                
                stimulation_frame = ms_to_frame(stimulation_time_ms, fps)
                freq = 0.02 * fps 
                sos = signal.butter(9, freq, 'high', fs=fps, output='sos')
                chambers = []

                for (a,b,c,d) in zip(x0,x1,y0,y1):
                    stim_window = np.arange(stimulation_frame-10,stimulation_frame+5,dtype=int)

                    pixels = []
                    for i in stim_window:
                        cap.set(1,i)
                        ret, frame = cap.read()
                        grey_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)[c:d,a:b]
                        pixels.append(grey_frame)
                    cap_array = (np.array(pixels)).astype(np.int16)
                    cap_diff = np.abs(np.diff(cap_array, axis=0))
                    cap_diff_mean = np.mean(cap_diff, axis=(1,2)) #mean used now since rois may be different sizes
                    f_diff = signal.sosfilt(sos, cap_diff_mean)
                    chambers.append(f_diff)
                    
                if len(chambers) > 0:
                    chambers = np.array(chambers)
                    stim_chamber, measured_stimulation_frame = np.where(chambers == np.max(chambers[:,7:12]))
                    measured_stimulation_frame = measured_stimulation_frame[0] + stimulation_frame-10 + 1
                    stim_chamber = stim_chamber[0]
        
                    height = y1[stim_chamber]-y0[stim_chamber]
                    width = x1[stim_chamber]-x0[stim_chamber]
                    video = cv2.VideoWriter(os.path.join(analysis_type_dir, vid[:-4]+mod_file+'.avi'), 0, fps, (width, height), 0)

                    cap = cv2.VideoCapture(os.path.join(subdirectory, vid))
                    ret, frame = cap.read()
                    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
                    prev_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)[y0[stim_chamber]:y1[stim_chamber], x0[stim_chamber]:x1[stim_chamber]]
                    count = 1
                    while count < frame_count:
                        ret, frame = cap.read()
                        current_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)[y0[stim_chamber]:y1[stim_chamber], x0[stim_chamber]:x1[stim_chamber]]
                        diff = cv2.absdiff(prev_frame, current_frame)
                        
                        if binary:
                            motion_e = cv2.threshold(diff, threshold, 255, cv2.THRESH_BINARY)[1]
                        else: 
                            motion_e = diff*(diff>threshold)
                        video.write(motion_e)
                        motion_energy = motion_e.sum()
                        data.loc[count,os.path.basename(vid)] = motion_energy
                        prev_frame = current_frame
                        count += 1
                        
                    video.release()
                
                    if show_image:
                        cap.set(1,measured_stimulation_frame)
                        ret, frame = cap.read()
                        grey_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)[y0[stim_chamber]:y1[stim_chamber],x0[stim_chamber]:x1[stim_chamber]]
                        if ret:
                            plt.imshow(grey_frame)

                else: 
                    print('Error determining stimulation chamber')
                    stim_chamber = []

            data.to_csv(os.path.join(analysis_type_dir, 'motion_energy.csv'), index=True)
    print('Analysis complete')               