In [39]:
import face_recognition

class FaceFinder:
    UNASSIGNED = -1
    def __init__(self, upsample=1, grouping_area_thresh=0.1, grouping_frame_thresh=12, verbose=True):
        self.upsample=upsample
        self.area_thresh=0.6
        self.frame_thresh=12
        self.verbose=verbose
        
    def find_faces(self, frames):
        face_data = self._find_faces_and_encodings(frames)
        return self._group_faces(face_data)
        
    def _find_faces_and_encodings(self, frames):
        face_data = []
        
        for i, f in enumerate(frames):
            if self.verbose and i%1000 == 0:
                print(f'Processing frame {i} of {len(frames)}')
            
            face_locations = face_recognition.face_locations(f, number_of_times_to_upsample=self.upsample, model="cnn")
            face_encodings = face_recognition.face_encodings(f, face_locations, num_jitters=1)
            face_ids = [self.UNASSIGNED] * len(face_encodings)
    
            face_data.append([i, [list(z) for z in zip(face_locations, face_encodings, face_ids)]])
        return face_data
    
    
    def _group_faces(self, face_data):
        face_count = 0
        for frame_index, frame_face_data in face_data:
            for i, (face_location, face_encoding, face_id) in enumerate(frame_face_data):
                if face_id != self.UNASSIGNED:
                    continue
                
                face_id = face_count
                face_count += 1
                frame_face_data[i][-1] = face_id
                self._propagate_face_id(face_id, frame_index, face_location, face_data, area_thresh=self.area_thresh, frame_thresh=self.frame_thresh)
        
        faces_by_frame = face_data
        faces_by_id = {}
        for frame_index, frame_face_data in face_data:
            for i, (face_location, face_encoding, face_id) in enumerate(frame_face_data):
                if face_id not in faces_by_id:
                    faces_by_id[face_id] = []
                
                faces_by_id[face_id].append((frame_index, (face_location, face_encoding, face_id)))
        
        return face_count, faces_by_frame, faces_by_id

    def _propagate_face_id(self, face_id, start_index, start_bounds, face_data, area_thresh=0.6, frame_thresh=12):
        i = start_index + 1
        missed_frames = 0
        prev_bounds = start_bounds
        while i < len(face_data) - 1 and missed_frames <= frame_thresh:
            faces_in_frame = [(f[0], f[-1]) for f in face_data[i][1]]
            overlapping_faces = self._find_overlapping_faces(prev_bounds, faces_in_frame, area_thresh)
            found_face = False
            for j, does_overlap in enumerate(overlapping_faces):
                if not does_overlap:
                    continue

                found_face = True
                prev_bounds, _ = faces_in_frame[j]
                face_data[i][1][j][-1] = face_id
                break

            if found_face:
                missed_frames = 0
            else:
                missed_frames += 1

            i += 1


    def _find_overlapping_faces(self, ref_bounds, candidates, area_thresh):
        overlaps = []
        ref_area = self._rect_area(ref_bounds)
        for candidate_bounds, candidate_face_id in candidates:
            if candidate_face_id != self.UNASSIGNED:
                overlaps.append(False)
                continue
            
            intersection_bounds = self._rect_intersection(ref_bounds, candidate_bounds)
            if intersection_bounds is not None and self._rect_area(intersection_bounds)/ref_area >= area_thresh:
                overlaps.append(True)
                continue

            overlaps.append(False)
        return overlaps
    
    def _rect_intersection(self, r1, r2):
        top1, right1, bottom1, left1 = r1
        top2, right2, bottom2, left2 = r2

        if left1 > right2 or right1 < left2 or top1 > bottom2 or bottom1 < top2:
            return None

        top, bottom = max(top1, top2), min(bottom1, bottom2)
        left, right = max(left1, left2), min(right1, right2)

        return top, right, bottom, left

    def _rect_area(self, r):
        top, right, bottom, left = r
        return (right - left) * (bottom - top)

32556
