## Write a video containing outline, midline and area of all exported shapes
Given a data folder containing *posture_fishX.npz files, this will save a movie showing both the outlines (+midlines) of exported postures, as well as potential holes inside these shapes. Here, the example is a fish school.

In [None]:
# Import necessary libraries
import matplotlib.pyplot as plt
import numpy as np
import cv2 as cv
import io
from PIL import Image
import glob
import os
import matplotlib
matplotlib.style.use('dark_background')

# Define a function to load and preprocess data from npz files
def load_and_preprocess_data(folder_path):
    data = {}
    files = sorted(glob.glob(os.path.join(folder_path, "*_posture*.npz")))
    min_frame = np.inf
    max_frame = -np.inf
    screen = [np.inf, np.inf, -np.inf, -np.inf]
    #print(files)

    for f in files:
        #print("loading", f)
        data[f] = {}
        with np.load(f) as npz:
            midline = {}
            offset = npz["offset"]
            frames = npz["frames"]

            if min_frame > frames.min():
                min_frame = frames.min()
            if max_frame < frames.max():
                max_frame = frames.max()

            if offset.T[0].min() < screen[0]:
                screen[0] = offset.T[0].min()
            if offset.T[1].min() < screen[1]:
                screen[1] = offset.T[1].min()

            if offset.T[0].max() > screen[2]:
                screen[2] = offset.T[0].max()
            if offset.T[1].max() > screen[3]:
                screen[3] = offset.T[1].max()

            midline = {}
            if len(npz["midline_points"].shape) == 2:
                i = 0
                indices = []
                for l in npz["midline_lengths"][:-1]:
                    i += l
                    indices.append(int(i))
                points = np.split(npz["midline_points"], indices, axis=0)
                for frame, point, off in zip(frames, points, offset):
                    midline[frame] = point + off
            else:
                for mpt, off, frame in zip(npz["midline_points"], offset, frames):
                    midline[frame] = mpt + off

            i = 0
            indices = []
            for l in npz["outline_lengths"][:-1]:
                i += l
                indices.append(int(i))
            points = np.split(npz["outline_points"], indices, axis=0)
            outline = {}
            for frame, point, off in zip(frames, points, offset):
                outline[frame] = point + off

            # Handle holes
            hole_counts = npz["hole_counts"].astype(int)
            hole_points = npz["hole_points"]
            holes = {}
            count_index = 0
            point_index = 0

            for frame in frames:
                holes[frame] = []
                num_holes = hole_counts[count_index]
                count_index += 1
                
                for _ in range(num_holes):
                    num_points = hole_counts[count_index]
                    count_index += 1
                    
                    if point_index + num_points > len(hole_points):
                        raise Exception(f"Error: index {point_index + num_points} is out of bounds for hole_points with size {len(hole_points)}")
                        break
                    
                    #print(f"{frame}: {num_points} points at index {point_index}")
                    hole = hole_points[point_index:point_index + num_points]
                    holes[frame].append(hole)
                    point_index += num_points

            data[f]["holes"] = holes
            data[f]["midline"] = midline
            data[f]["outline"] = outline

    screen[0] -= 10
    screen[1] -= 10
    screen[2] *= 1.1
    screen[3] *= 1.1
    input_shape = (screen[2] - screen[0], screen[3] - screen[1])
    output_width = 1280
    output_shape = (output_width, int(output_width * input_shape[1] / input_shape[0]))  # Adjust output resolution to maintain aspect ratio
    fps = 40.0

    return data, screen, input_shape, output_shape, fps, min_frame, max_frame

import matplotlib.patches as patches
import matplotlib.cm as cm

def calculate_polygon_area(points):
    """Calculate the area of a polygon given its vertices using the shoelace formula."""
    x = points[:, 0]
    y = points[:, 1]
    return 0.5 * np.abs(np.dot(x, np.roll(y, 1)) - np.dot(y, np.roll(x, 1)))

def create_video(data, screen, output_shape, fps, min_frame, max_frame, frame_range=None):
    fourcc = cv.VideoWriter_fourcc(*'MJPG')
    filename = "output.avi"
    out = cv.VideoWriter(filename, fourcc, fps, output_shape, True)

    print("Writing video '" + filename + "' with frames", min_frame, "-", max_frame)

    dpi = 250 / 4
    cv.destroyAllWindows()

    frames = np.arange(min_frame, max_frame + 1)
    if frame_range:
        frames = frames[np.logical_and(frames >= frame_range[0], frames <= frame_range[1])]

    plt.style.use('dark_background')
    colormap = cm.get_cmap('cool')

    for chosen_frame in frames:
        fig, ax = plt.subplots(figsize=(output_shape[0] / dpi, output_shape[1] / dpi), dpi=dpi)
        fig.set_tight_layout(True)

        for key in data:
            if chosen_frame not in data[key]["outline"]:
                continue

            outline = np.array(data[key]["outline"][chosen_frame])
            ax.scatter(outline.T[0], outline.T[1], label="outline", s=1.5, color=colormap(0.6))

            midline = data[key]["midline"]
            m = midline[chosen_frame]
            ax.scatter(m.T[0], m.T[1], label="midline", s=1.5, color=colormap(0.9))

            # Draw holes
            hole_area = 0
            for hole in data[key]["holes"][chosen_frame]:
                ax.scatter(hole.T[0], hole.T[1], label="hole", s=1, color='grey')
                hole_area += calculate_polygon_area(hole)

            # Calculate outline area and net area
            outline_area = calculate_polygon_area(outline)
            net_area = outline_area - hole_area

            # Display the area description near the outline
            centroid_x, centroid_y = outline.mean(axis=0)
            ax.text(centroid_x, centroid_y, f"Area: {outline_area:.0f}px²\nNet Area: {net_area:.0f}px²",
                    fontsize=8, color='lightgrey', 
                    bbox=dict(facecolor='black', alpha=0.7, edgecolor='white', boxstyle='round,pad=0.5'))

        ax.set_xlim(screen[0], screen[2])
        ax.set_ylim(screen[3], screen[1])  # Invert y-axis

        buf = io.BytesIO()
        plt.savefig(buf, format='png')
        buf.seek(0)
        im = Image.open(buf)
        im = np.array(im).astype(np.uint8)
        buf.close()
        plt.close(fig)

        if int(chosen_frame) % int((frames.max() - frames.min()) * 0.1) == 0:
            print(chosen_frame, "/", frames.max())

        if (im.shape[1], im.shape[0]) != output_shape:
            print("different shape", output_shape, im.shape)

        im = im[:, :, 0:3]
        out.write(im)
        cv.imshow("movie", im)
        cv.waitKey(1)

    out.release()
    print("Video creation complete.")

# Example usage
folder_path = "/Users/tristan/Videos/maldives/data/"
frame_range = [1000, 2500]  # Example frame range, adjust as needed

data, screen, input_shape, output_shape, fps, min_frame, max_frame = load_and_preprocess_data(folder_path)
create_video(data, screen, output_shape, fps, min_frame, max_frame, frame_range)

## Add the individuals from a separate dataset
This section will take a different run of the program that tracked individuals and add them to the school tracking.

In [None]:
# Import necessary libraries
import matplotlib.pyplot as plt
import numpy as np
import cv2 as cv
import io
from PIL import Image
import glob
import os
import pandas as pd
import matplotlib
matplotlib.style.use('dark_background')

# Define a function to load and preprocess data from npz files
def load_and_preprocess_data(folder_path):
    data = {}
    files = sorted(glob.glob(os.path.join(folder_path, "*_posture*.npz")))
    min_frame = np.inf
    max_frame = -np.inf
    screen = [np.inf, np.inf, -np.inf, -np.inf]
    #print(files)

    for f in files:
        #print("loading", f)
        data[f] = {}
        with np.load(f) as npz:
            midline = {}
            offset = npz["offset"]
            frames = npz["frames"]

            if min_frame > frames.min():
                min_frame = frames.min()
            if max_frame < frames.max():
                max_frame = frames.max()

            if offset.T[0].min() < screen[0]:
                screen[0] = offset.T[0].min()
            if offset.T[1].min() < screen[1]:
                screen[1] = offset.T[1].min()

            if offset.T[0].max() > screen[2]:
                screen[2] = offset.T[0].max()
            if offset.T[1].max() > screen[3]:
                screen[3] = offset.T[1].max()

            midline = {}
            if len(npz["midline_points"].shape) == 2:
                i = 0
                indices = []
                for l in npz["midline_lengths"][:-1]:
                    i += l
                    indices.append(int(i))
                points = np.split(npz["midline_points"], indices, axis=0)
                for frame, point, off in zip(frames, points, offset):
                    midline[frame] = point + off
            else:
                for mpt, off, frame in zip(npz["midline_points"], offset, frames):
                    midline[frame] = mpt + off

            i = 0
            indices = []
            for l in npz["outline_lengths"][:-1]:
                i += l
                indices.append(int(i))
            points = np.split(npz["outline_points"], indices, axis=0)
            outline = {}
            for frame, point, off in zip(frames, points, offset):
                outline[frame] = point + off

            # Handle holes
            hole_counts = npz["hole_counts"].astype(int)
            hole_points = npz["hole_points"]
            holes = {}
            count_index = 0
            point_index = 0

            for frame in frames:
                holes[frame] = []
                num_holes = hole_counts[count_index]
                count_index += 1
                
                for _ in range(num_holes):
                    num_points = hole_counts[count_index]
                    count_index += 1
                    
                    if point_index + num_points > len(hole_points):
                        raise Exception(f"Error: index {point_index + num_points} is out of bounds for hole_points with size {len(hole_points)}")
                        break
                    
                    #print(f"{frame}: {num_points} points at index {point_index}")
                    hole = hole_points[point_index:point_index + num_points]
                    holes[frame].append(hole)
                    point_index += num_points

            data[f]["holes"] = holes
            data[f]["midline"] = midline
            data[f]["outline"] = outline

    screen[0] -= 10
    screen[1] -= 10
    screen[2] *= 1.1
    screen[3] *= 1.1
    input_shape = (screen[2] - screen[0], screen[3] - screen[1])
    output_width = 1280
    output_shape = (output_width, int(output_width * input_shape[1] / input_shape[0]))  # Adjust output resolution to maintain aspect ratio
    fps = 50.0

    return data, screen, input_shape, output_shape, fps, min_frame, max_frame

# Define a function to load and preprocess sharks data from npz files
def load_sharks_data(folder_path):
    cm_per_pixel = 0.02
    sharks_data = []
    files = sorted(glob.glob(os.path.join(folder_path, "*_sharks*.npz")))

    for f in files:
        with np.load(f) as npz:
            frames = npz["frame"]
            X = npz["X#pcentroid"] / cm_per_pixel
            Y = npz["Y#pcentroid"] / cm_per_pixel
            poseX = {key: npz[key] for key in npz if key.startswith("poseX")}
            poseY = {key: npz[key] for key in npz if key.startswith("poseY")}
            
            for i, frame in enumerate(frames):
                row = {"frame": frame, "X": X[i], "Y": Y[i], "ID": int(f.split("fish")[-1].split(".npz")[0])}
                for k, v in poseX.items():
                    row[k] = v[i]
                for k, v in poseY.items():
                    row[k] = v[i]
                sharks_data.append(row)

    sharks_df = pd.DataFrame(sharks_data)
    print("Sharks DataFrame Head:\n", sharks_df.head())
    return sharks_df

import matplotlib.patches as patches
import matplotlib.cm as cm

def calculate_polygon_area(points):
    """Calculate the area of a polygon given its vertices using the shoelace formula."""
    x = points[:, 0]
    y = points[:, 1]
    return 0.5 * np.abs(np.dot(x, np.roll(y, 1)) - np.dot(y, np.roll(x, 1)))

import matplotlib.lines as mlines

def create_video(data, sharks_df, screen, output_shape, fps, min_frame, max_frame, skeleton, frame_range=None):
    fourcc = cv.VideoWriter_fourcc(*'MJPG')
    filename = "output.avi"
    out = cv.VideoWriter(filename, fourcc, fps, output_shape, True)

    print("Writing video '" + filename + "' with frames", min_frame, "-", max_frame)

    dpi = 300 / 4
    cv.destroyAllWindows()

    frames = np.arange(min_frame, max_frame + 1)
    if frame_range:
        frames = frames[np.logical_and(frames >= frame_range[0], frames <= frame_range[1])]

    plt.style.use('dark_background')
    colormap = cm.get_cmap('cool')
    max_ID = sharks_df['ID'].max()

    for chosen_frame in frames:
        fig, ax = plt.subplots(figsize=(output_shape[0] / dpi, output_shape[1] / dpi), dpi=dpi)
        fig.set_tight_layout(True)

        for key in data:
            if chosen_frame not in data[key]["outline"]:
                continue

            outline = np.array(data[key]["outline"][chosen_frame])
            ax.scatter(outline.T[0], outline.T[1], label="outline", s=0.05, color=colormap(0.6))

            midline = data[key]["midline"]
            m = midline[chosen_frame]
            #ax.scatter(m.T[0], m.T[1], label="midline", s=1, color=colormap(0.9))

            # Draw holes
            hole_area = 0
            for hole in data[key]["holes"][chosen_frame]:
                ax.scatter(hole.T[0], hole.T[1], label="hole", s=0.05, color='grey')
                hole_area += calculate_polygon_area(hole)

            # Calculate outline area and net area
            outline_area = calculate_polygon_area(outline)
            net_area = outline_area - hole_area

            # Display the area description near the outline
            centroid_x, centroid_y = outline.mean(axis=0)
            ax.text(centroid_x, centroid_y, f"Area: {outline_area:.0f}px²\nNet Area: {net_area:.0f}px²",
                    fontsize=8, color='lightgrey', 
                    bbox=dict(facecolor='black', alpha=0.7, edgecolor='white', boxstyle='round,pad=0.5'),
                    font='monospace', clip_on=True)

        # Draw sharks data
        sharks_frame_data = sharks_df[sharks_df['frame'] == chosen_frame]
        if not sharks_frame_data.empty:
            for idx, row in sharks_frame_data.iterrows():
                if np.isinf(row['X']) or np.isinf(row['Y']):
                    continue

                color = colormap(float(row.ID) / max_ID)

                ax.scatter(row['X'], row['Y'], color=color, s=10, label='shark position')
                for start, end in skeleton[1]:
                    poseX_start = row[f"poseX{start}"]
                    poseY_start = row[f"poseY{start}"]
                    poseX_end = row[f"poseX{end}"]
                    poseY_end = row[f"poseY{end}"]

                    if np.isinf(poseX_start) or np.isinf(poseY_start) or np.isinf(poseX_end) or np.isinf(poseY_end):
                        continue

                    line = mlines.Line2D([poseX_start, poseX_end], [poseY_start, poseY_end], color=color, linewidth=1)
                    ax.add_line(line)
                    ax.scatter([poseX_start, poseX_end], [poseY_start, poseY_end], s=5, color=color, alpha=0.6)

                # Display the ID near the centroid of the skeleton
                valid_pose_points = [(row[f"poseX{start}"], row[f"poseY{start}"]) for start, end in skeleton[1]
                                     if not (np.isinf(row[f"poseX{start}"]) or np.isinf(row[f"poseY{start}"]))]
                if valid_pose_points:
                    centroid_x, centroid_y = np.mean(valid_pose_points, axis=0)
                    ax.text(row['X'] - 25, row['Y'] - 65, f"ID: {int(row['ID'])}",
                            fontsize=8, color='lightgrey', 
                            bbox=dict(facecolor='black', alpha=0.7, edgecolor='white', boxstyle='round,pad=0.5'),
                            font='monospace', clip_on=True)
                    ax.plot([row['X'] - 25, centroid_x], [row['Y'] - 65, centroid_y], color='white', linewidth=1)

        ax.text(0.5, 0.95, f"Frame: {chosen_frame}", fontsize=8, color='lightgrey', 
                bbox=dict(facecolor='black', alpha=0.7, edgecolor='white', boxstyle='round,pad=0.5'),
                font='monospace', clip_on=True, transform=ax.transAxes)

        ax.set_xlim(screen[0], screen[2])
        ax.set_ylim(screen[3], screen[1])  # Invert y-axis
        #ax.axis('off')
        plt.tight_layout()

        buf = io.BytesIO()
        plt.savefig(buf, format='png')
        buf.seek(0)
        im = Image.open(buf)
        im = np.array(im).astype(np.uint8)
        buf.close()
        plt.close(fig)

        if (im.shape[1], im.shape[0]) != output_shape:
            print("different shape", output_shape, im.shape)
        else:
            im = im[:, :, 0:3]
            out.write(im)

            if int(chosen_frame) % int((frames.max() - frames.min()) * 0.05) == 0:
                print(chosen_frame, "/", frames.max())
            
            if int(chosen_frame) % int((frames.max() - frames.min()) * 0.005) == 0:
                cv.imshow("movie", im)
                cv.waitKey(1)

    out.release()
    print("Video creation complete.")

skeleton = ["shark", [[0, 2], [1, 2], [3, 2], [2, 4], [4, 5], [5, 6], [6, 7], [7, 8]]]
folder_path = "/Users/tristan/Videos/maldives/data/"
frame_range = [0, 7376]  # Example frame range, adjust as needed

data, screen, input_shape, output_shape, fps, min_frame, max_frame = load_and_preprocess_data(folder_path)
sharks_df = load_sharks_data(folder_path)
create_video(data, sharks_df, screen, output_shape, fps, min_frame, max_frame, skeleton=skeleton, frame_range=frame_range)

In [None]:

sharks_df

In [None]:
data

In [None]:
import pandas as pd
from shapely.geometry import Polygon, LineString, Point
import numpy as np
from scipy.ndimage import gaussian_filter1d
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from tqdm import tqdm
import matplotlib.animation as animation

class FishSchoolData:
    def __init__(self):
        self.data = {}
        self.predators = {}
        self.skeleton = ["shark", [[0, 2], [1, 2], [3, 2], [2, 4], [4, 5], [5, 6], [6, 7], [7, 8]]]

    def add_frame_data(self, file_id, frame, outline, holes):
        #if frame > 2600 or frame < 1800:
        #    return

        if frame not in self.data:
            self.data[frame] = []
        self.data[frame].append({"file_id": file_id, "outline": outline, "holes": holes})

    def load_school_data(self, data):
        for file_id, file_data in data.items():
            outline = file_data['outline']
            holes = file_data['holes']
            for frame, frame_data in outline.items():
                self.add_frame_data(file_id, frame, frame_data, holes[frame])

    def load_predator_data(self, df):
        self.predators = df.copy()
        #self.predators = self.predators[(self.predators.frame <= 2600) & (self.predators.frame >= 1800)]

    def calculate_curvature(self, points):
        x = np.array([p[0] for p in points])
        y = np.array([p[1] for p in points])
        dx = np.gradient(x)
        dy = np.gradient(y)
        ddx = np.gradient(dx)
        ddy = np.gradient(dy)
        curvature = (ddx * dy - dx * ddy) / np.power(dx ** 2 + dy ** 2, 1.5)
        magnitude = (dy - dx)
        return curvature, magnitude

    def detect_dents(self, smoothing_sigma=10, min_dent_length=15, dent_magnitude=1.5, visualize=False):
        dents = {}
        print("Detecting dents...")
        for frame, schools in tqdm(self.data.items(), desc="Frames"):
            dents[frame] = []
            for school in schools:
                outline = school['outline']
                curvature = self.calculate_curvature(outline)
                smoothed_curvature, magnitude = gaussian_filter1d(curvature, sigma=smoothing_sigma)

                if visualize and frame % 1000 == 0:
                    plt.hist(magnitude, bins=100)
                    plt.show()
                    plt.hist(smoothed_curvature, bins=100)
                    plt.show()

                    self.visualize_curvature(frame, outline, curvature, smoothed_curvature)

                dent_segments = []
                current_segment = []
                current_curvature = []
                inside_dent = False

                for i, curv in enumerate(smoothed_curvature):
                    if not inside_dent:
                        if curv > 0.002:
                            current_segment.append(outline[i])
                            current_curvature.append(magnitude[i])
                            inside_dent = True
                    else:
                        current_segment.append(outline[i])
                        current_curvature.append(magnitude[i])
                        if curv < -0.002:
                            if len(current_segment) >= 0:
                                diffs = np.max(np.abs(current_curvature)) - np.min(np.abs(current_curvature))
                                if len(current_segment) >= min_dent_length and diffs > dent_magnitude:
                                    dent_segments.append(LineString(current_segment))
                                    if visualize and frame % 1000 == 0:
                                        print("maximum curvature: ", diffs)
                            current_segment = []
                            current_curvature = []
                            inside_dent = False

                if len(current_segment) > 0:
                    diffs = np.max(np.abs(current_curvature)) - np.min(np.abs(current_curvature))
                    if len(current_segment) >= min_dent_length and diffs > dent_magnitude:
                        dent_segments.append(LineString(current_segment))
                        if visualize and frame % 1000 == 0:
                            print("maximum curvature: ", diffs)

                dents[frame].extend(dent_segments)

                if visualize and frame % 1000 == 0:
                    self.visualize_dents(frame, outline, dent_segments)

        return dents

    def filter_important_dents(self, dents, threshold_distance=50, visualize=False):
        important_dents = {}
        print("Filtering important dents...")
        
        for frame, dent_lines in tqdm(dents.items(), desc="Frames"):
            sub = self.predators[self.predators['frame'] == frame].copy()
            important_dents[frame] = []
            for dent in dent_lines:
                dent_coords = np.array(dent.coords).reshape(-1, 2)
                distances = []

                for _, predator in sub.iterrows():
                    predator_coords = predator[['X', 'Y']].values.astype(float)
                    predator_coords = predator_coords[None, :]
                    assert predator_coords.shape[-1] == 2 and len(predator_coords.shape) == 2

                    v = dent_coords - predator_coords

                    if v.shape[1] != 2:
                        raise ValueError("Unexpected shape for `v`: {}".format(v.shape))

                    distances.append(np.linalg.norm(v, axis=1).min())
                sub["distance"] = distances

                for _, predator in sub[sub["distance"] <= threshold_distance].iterrows():
                    important_dents[frame].append((frame, predator.ID, dent))

            for _, predator in sub.iterrows():
                for school in self.data[frame]:
                    for hole in school['holes']:
                        if Point(predator[['X', 'Y']]).within(Polygon(hole)):
                            #important_dents[frame].append((frame, predator.ID, hole))
                            break

            if visualize and frame % 1000 == 0:
                self.visualize_dents_and_predators(frame, sub, dent_lines, important_dents[frame], threshold_distance)
        
        return important_dents

    def draw_predator(self, row, ax, max_ID, colormap = cm.get_cmap('cool')):
        if np.isinf(row['X']) or np.isinf(row['Y']):
            return

        color = colormap(float(row.ID) / max_ID)

        ax.scatter(row['X'], row['Y'], color=color, s=10, label=None, clip_on=True)
        for start, end in self.skeleton[1]:
            poseX_start = row[f"poseX{start}"]
            poseY_start = row[f"poseY{start}"]
            poseX_end = row[f"poseX{end}"]
            poseY_end = row[f"poseY{end}"]

            if np.isinf(poseX_start) or np.isinf(poseY_start) or np.isinf(poseX_end) or np.isinf(poseY_end):
                continue

            line = mlines.Line2D([poseX_start, poseX_end], [poseY_start, poseY_end], color=color, linewidth=1)
            ax.add_line(line)
            ax.scatter([poseX_start, poseX_end], [poseY_start, poseY_end], s=1, color=color, alpha=0.6, clip_on=True)

        # Display the ID near the centroid of the skeleton
        valid_pose_points = [(row[f"poseX{start}"], row[f"poseY{start}"]) for start, end in self.skeleton[1]
                                if not (np.isinf(row[f"poseX{start}"]) or np.isinf(row[f"poseY{start}"]))]
        if valid_pose_points:
            centroid_x, centroid_y = np.mean(valid_pose_points, axis=0)
            ax.text(row['X'] - 25, row['Y'] - 155, f"ID: {int(row['ID'])}",
                    fontsize=6, color='lightgrey', 
                    bbox=dict(facecolor='black', alpha=0.7, edgecolor='white', boxstyle='round,pad=0.5'),
                    font='monospace', clip_on=True)
            ax.plot([row['X'] - 25, centroid_x], [row['Y'] - 155, centroid_y], color='white', linewidth=1, clip_on=True)

    def visualize_dents_and_predators(self, frame, sub, dent_lines, important_dents, save_to_buffer=False, output_shape=None, limits=None, dpi=150):
        plt.style.use('dark_background')

        fig = ax = None
        if not output_shape:
            fig, ax = plt.subplots(figsize=(10, 6))
        else:
            fig, ax = plt.subplots(figsize=(output_shape[0] / dpi, output_shape[1] / dpi), dpi=dpi)
            fig.set_tight_layout(True)

        for d in self.data[frame]:
            outline_array = np.array(d['outline'])
            ax.plot(outline_array[:, 0], outline_array[:, 1], label="Outline" if 'Outline' not in [line.get_label() for line in ax.lines] else "", color='white', linewidth=0.5)

        IDS = []
        all_dents = []
        for _, ID, dent in important_dents:
            dent_coords = np.array(dent.coords).reshape(-1, 2)
            IDS.append(ID)
            ax.plot(dent_coords[:, 0], dent_coords[:, 1], c='red', linewidth=0.5, label='Important Dent' if 'Important Dent' not in [line.get_label() for line in ax.lines] else "")
            all_dents.append(dent)
        
        for dent in dent_lines:
            #print(dent)
            dent_coords = np.array(dent.coords).reshape(-1, 2)
            if dent not in all_dents:
                ax.plot(dent_coords[:, 0], dent_coords[:, 1], c='orange', alpha=0.5, linewidth=0.5, label='Dent' if 'Dent' not in [line.get_label() for line in ax.lines] else "")

        for _, predator in self.predators[self.predators['frame'] == frame].iterrows():
            if predator['ID'] not in IDS:
                for school in self.data[frame]:
                    for hole in school['holes']:
                        if Point(predator[['X', 'Y']]).within(Polygon(hole)):
                            ax.plot(hole[:, 0], hole[:, 1], label='Hole' if 'Hole' not in [line.get_label() for line in ax.lines] else "", color='green', linewidth=0.35)
                            break

        for _, predator in self.predators[self.predators['frame'] == frame].iterrows():
            self.draw_predator(predator, ax, self.predators['ID'].max())
            '''if predator['ID'] in IDS:
                ax.circle(predator['X'], predator['Y'], label='Predator' if 'Predator' not in [line.get_label() for line in ax.lines] else "", s=5, color='red')
                #ax.scatter(predator['X'], predator['Y'], label='Predator' if 'Predator' not in [line.get_label() for line in ax.lines] else "", s=5, color='red')
            else:
                ax.circle(predator['X'], predator['Y'], label='Predator' if 'Predator' not in [line.get_label() for line in ax.lines] else "", s=3, color='blue')
                #ax.scatter(predator['X'], predator['Y'], label='Predator' if 'Predator' not in [line.get_label() for line in ax.lines] else "", s=3, color='blue')'''

        #ax.set_title(f'Frame {frame} - Dents and Predators')
        ax.text(0.5, 0.95, f"Frame: {frame}", fontsize=8, color='lightgrey', 
                bbox=dict(facecolor='black', alpha=0.7, edgecolor='white', boxstyle='round,pad=0.5'),
                font='monospace', clip_on=True, transform=ax.transAxes)
        #ax.set_xlabel('X Coordinate')
        #ax.set_ylabel('Y Coordinate')
        
        ax.legend( loc='lower right', fontsize=6)
        
        if limits:
            x_min, x_max, y_min, y_max = limits
            ax.set_xlim(x_min, x_max)
            ax.set_ylim(y_max, y_min)

        if save_to_buffer:
            buf = io.BytesIO()
            plt.savefig(buf, format='png')
            buf.seek(0)
            im = Image.open(buf)
            im = np.array(im).astype(np.uint8)
            buf.close()
            plt.close(fig)
            return im
        else:
            plt.show()

    def create_video(self, important_dents, video_filename="output.avi", threshold_distance=50, fps=5):
        # Determine the output video size
        all_predators = self.predators
        x_min, x_max = all_predators['X'].min(), all_predators['X'].max()
        y_min, y_max = all_predators['Y'].min(), all_predators['Y'].max()
        print(f"Predator X min: {x_min}, X max: {x_max}, Y min: {y_min}, Y max: {y_max}")
        output_shape = (int(x_max - x_min), int(y_max - y_min))

        # Setup the video writer
        fourcc = cv.VideoWriter_fourcc(*'MJPG')
        out = cv.VideoWriter(video_filename, fourcc, fps, output_shape, True)

        frames = sorted(important_dents.keys())
        
        try:
            for frame in tqdm(frames, desc="Creating Video"):
                sub = self.predators[self.predators['frame'] == frame].copy()
                important_frame_dents = important_dents[frame]
                im = self.visualize_dents_and_predators(frame, sub, self.data[frame], important_frame_dents, threshold_distance, save_to_buffer=True)

                # Convert RGB to BGR for OpenCV
                im = cv.cvtColor(im, cv.COLOR_RGB2BGR)
                out.write(im)

                if frame % 1000 == 0:
                    cv.imshow("movie", im)
                    cv.waitKey(1)

        except KeyboardInterrupt:
            print("Video creation interrupted.")

        out.release()
        print(f"Video saved as {video_filename}")

    def visualize_curvature(self, frame, outline, curvature, smoothed_curvature):
        plt.style.use('dark_background')
        fig, ax = plt.subplots(figsize=(10, 6))

        outline_array = np.array(outline)
        #ax.plot(outline_array[:, 0], curvature, label="Curvature", color='blue', linewidth=0.5)
        ax.plot(outline_array[:, 0], smoothed_curvature, label="Smoothed Curvature", color='red', linewidth=0.5)

        ax.set_title(f"Frame {frame} - Curvature")
        ax.set_aspect('auto')
        plt.legend()
        plt.show()

    def visualize_dents(self, frame, outline, dent_segments):
        plt.style.use('dark_background')
        fig, ax = plt.subplots(figsize=(10, 6))

        outline_array = np.array(outline)
        ax.plot(outline_array[:, 0], outline_array[:, 1], label="Outline", color='white', linewidth=0.5)

        for dent in dent_segments:
            dent_array = np.array(dent.coords)
            ax.plot(dent_array[:, 0], dent_array[:, 1], label="Dent", color='red', linewidth=1.5)

        ax.set_title(f"Frame {frame} - Detected Dents")
        ax.set_aspect('equal', 'box')
        plt.show()

# Utility function to calculate polygon area
def calculate_polygon_area(points):
    polygon = Polygon(points)
    return polygon.area

# Example usage
fish_school_data = FishSchoolData()

# Load school data from the data dictionary
fish_school_data.load_school_data(data)

# Load predator data from the df DataFrame
fish_school_data.load_predator_data(sharks_df)

# Detect dents with visualization
dents = fish_school_data.detect_dents(visualize=False)


In [None]:

# Filter important dents
important_dents = fish_school_data.filter_important_dents(dents, visualize=False, threshold_distance=100)
#fish_school_data.create_video(important_dents, video_filename="important_dents.mp4")

In [None]:
def create_video(self, dents, important_dents, video_filename="output.avi", threshold_distance=100, fps=50):
    # Determine the output video size
    all_predators = self.predators[['X', 'Y']].copy()
    all_predators = all_predators[(all_predators['X'] != np.inf) & (all_predators['Y'] != np.inf)]
    x_min, x_max = all_predators['X'].min(), int(all_predators['X'].max() * 1)
    y_min, y_max = all_predators['Y'].min(), int(all_predators['Y'].max() * 1)

    y_min = x_min = 0
    print(f"Predator X min: {x_min}, X max: {x_max}, Y min: {y_min}, Y max: {y_max}")

    input_shape = (x_max, y_max)

    output_width = 1280
    output_shape = (output_width, int(output_width * input_shape[1] / input_shape[0]))  # Adjust output resolution to maintain aspect ratio
    
    # Setup the video writer
    fourcc = cv.VideoWriter_fourcc(*'MJPG')
    out = cv.VideoWriter(video_filename, fourcc, fps, output_shape, True)

    frames = sorted(important_dents.keys())
    
    try:
        for frame in tqdm(frames, desc="Creating Video"):
            sub = self.predators[self.predators['frame'] == frame].copy()
            important_frame_dents = important_dents[frame]
            im = self.visualize_dents_and_predators(frame, sub, dents[frame], important_frame_dents, save_to_buffer=True, limits=(0, x_max, 0, y_max), output_shape=output_shape, dpi=100)

            im = im[:, :, 0:3]
            if (im.shape[1], im.shape[0]) != output_shape:
                print("different shape", output_shape, im.shape)
            else:
                # Convert RGB to BGR for OpenCV
                im = cv.cvtColor(im, cv.COLOR_RGB2BGR)
                out.write(im)

            if frame % 10 == 0:
                cv.imshow("movie", im)
                cv.waitKey(1)

    except KeyboardInterrupt:
        print("Video creation interrupted.")

    out.release()
    print(f"Video saved as {video_filename}")

create_video(fish_school_data, dents, important_dents, video_filename="important_dents.mp4")

In [None]:
def visualize_dents_and_predators(self, frame, sub, dent_lines, important_dents, threshold_distance):
    plt.style.use('dark_background')
    fig, ax = plt.subplots(figsize=(10, 6))

    for d in self.data[frame]:
        outline_array = np.array(d['outline'])
        ax.plot(outline_array[:, 0], outline_array[:, 1], label="Outline" if 'Outline' not in [line.get_label() for line in ax.lines] else "", color='white', linewidth=0.5)

    IDS = []
    for _, ID, dent in important_dents:
        dent_coords = np.array(dent.coords).reshape(-1, 2)
        IDS.append(ID)
        ax.plot(dent_coords[:, 0], dent_coords[:, 1], c='red', linewidth=0.5, label='Important Dent' if 'Important Dent' not in [line.get_label() for line in ax.lines] else "")

    for _, predator in self.predators[self.predators['frame'] == frame].iterrows():
        if predator['ID'] in IDS:
            ax.scatter(predator['X'], predator['Y'], label='Predator' if 'Predator' not in [line.get_label() for line in ax.lines] else "", s=5, color='red')
        else:
            ax.scatter(predator['X'], predator['Y'], label='Predator' if 'Predator' not in [line.get_label() for line in ax.lines] else "", s=3, color='blue')
    ax.set_title(f'Frame {frame} - Dents and Predators')
    ax.set_xlabel('X Coordinate')
    ax.set_ylabel('Y Coordinate')
    ax.legend()
    plt.show()

def filter_important_dents(self, dents, threshold_distance=50, visualize=False):
    important_dents = {}
    print("Filtering important dents...")
    
    for frame, dent_lines in tqdm(dents.items(), desc="Frames"):
        sub = self.predators[self.predators['frame'] == frame].copy()
        important_dents[frame] = []
        for dent in dent_lines:
            dent_coords = np.array(dent.coords).reshape(-1, 2)
            distances = []

            for _, predator in sub.iterrows():
                predator_coords = predator[['X', 'Y']].values.astype(float)
                predator_coords = predator_coords[None, :]
                assert predator_coords.shape[-1] == 2 and len(predator_coords.shape) == 2

                v = dent_coords - predator_coords

                if v.shape[1] != 2:
                    raise ValueError("Unexpected shape for `v`: {}".format(v.shape))

                distances.append(np.linalg.norm(v, axis=1).min())
            sub["distance"] = distances

            for _, predator in sub[sub["distance"] <= threshold_distance].iterrows():
                important_dents[frame].append((frame, predator.ID, dent))

        if visualize:
            visualize_dents_and_predators(self, frame, sub, dent_lines, important_dents[frame], threshold_distance)
    
    return important_dents

filter_important_dents(fish_school_data, dents, threshold_distance=50, visualize=True)

In [None]:
import numpy as np
important_dents = {}
print("Filtering important dents...")
threshold_distance = 50
for frame, dent_lines in tqdm(dents.items(), desc="Frames"):
    sub = fish_school_data.predators[fish_school_data.predators['frame'] == frame].copy()
    important_dents[frame] = []
    for dent in dent_lines:
        #print(np.array(dent.coords.xy).reshape(2, -1).T)
        dent_coords = np.array(dent.coords.xy).reshape(2, -1).T.astype(float)

        distances = []

        for _, predator in sub.iterrows():
            predator_coords = predator[['X', 'Y']].values.astype(float)
            predator_coords = predator_coords[None, :]
            assert predator_coords.shape[-1] == 2 and len(predator_coords.shape) == 2

            v = dent_coords - predator_coords
            #print(type(v), v.shape, v.dtype, np.linalg.norm(v, axis=1))

            # Ensure the shape of `v` is appropriate for summing along axis 1
            if v.shape[1] != 2:
                raise ValueError("Unexpected shape for `v`: {}".format(v.shape))

            distances.append(np.linalg.norm(v, axis=1).min())
        sub["distance"] = distances
        #print(sub.distance)

        #sub["distance"] = np.linalg.norm(sub[['X', 'Y']].values - dent_coords, axis=1)
        for _, predator in sub[sub["distance"] <= threshold_distance].iterrows():
            o = (frame, predator.ID, dent)
            important_dents[frame].append(o)
            print(o)

#important_dents

In [None]:
import matplotlib.pyplot as plt
from tqdm import tqdm

def plot_dents_and_predators_combined(dents, fish_school_data, threshold_distance=50):
    important_dents = {}

    print("Filtering important dents...")
    for frame, dent_lines in tqdm(dents.items(), desc="Frames"):
        sub = fish_school_data.predators[fish_school_data.predators['frame'] == frame].copy()
        important_dents[frame] = []

        fig, ax = plt.subplots()
        
        # Plot all dents (outlines)
        for dent in dent_lines:
            dent_coords = np.array(dent.coords.xy).reshape(2, -1).T.astype(float)
            ax.plot(dent_coords[:, 0], dent_coords[:, 1], 'bo-', label='Dent Outline' if 'Dent Outline' not in [line.get_label() for line in ax.lines] else "")

            distances = []
            for _, predator in sub.iterrows():
                predator_coords = predator[['X', 'Y']].values.astype(float)
                predator_coords = predator_coords[None, :]
                assert predator_coords.shape[-1] == 2 and len(predator_coords.shape) == 2

                v = dent_coords - predator_coords

                if v.shape[1] != 2:
                    raise ValueError("Unexpected shape for `v`: {}".format(v.shape))

                distances.append(np.linalg.norm(v, axis=1).min())

            sub["distance"] = distances

        # Highlight important areas (predators within threshold distance)
        for _, predator in sub.iterrows():
            if predator["distance"] <= threshold_distance:
                ax.plot(predator['X'], predator['Y'], 'ro', label='Important Predator' if 'Important Predator' not in [line.get_label() for line in ax.lines] else "")
            else:
                ax.plot(predator['X'], predator['Y'], 'go', label='Predator' if 'Predator' not in [line.get_label() for line in ax.lines] else "")

        ax.set_title(f'Frame {frame}')
        ax.set_xlabel('X Coordinate')
        ax.set_ylabel('Y Coordinate')
        ax.legend()
        plt.show()

# Assuming `dents` and `fish_school_data` are defined elsewhere in your code
plot_dents_and_predators_combined(dents, fish_school_data, threshold_distance=50)