In [None]:
# Import necessary libraries
import numpy as np
import matplotlib.pyplot as plt
import cv2
from tqdm import tqdm

# Load the data from the NPZ file
data = np.load('/Users/tristan/Videos/data/group_1_visual_field_fish0.npz')
kinematics = np.load('/Users/tristan/Videos/data/group_1_fish0.npz')

my_id = kinematics['id'][0]
cm_per_pixel = kinematics['cm_per_pixel'][0]
headX, headY, headFrame = kinematics['X'] / cm_per_pixel, kinematics['Y'] / cm_per_pixel, kinematics['frame']
print("kinematics: ", kinematics.files, " my_id=", my_id, " cm_per_pixel=", cm_per_pixel)

# Extract the data arrays
depth = data['depth']          # Shape: (len_frames, 2, layers, rays_per_layer)
eye_pos = data['eye_pos']      # Shape: (len_frames, 2, 2)
eye_angle = data['eye_angle']  # Shape: (len_frames, 2)
fish_pos = data['fish_pos']    # Shape: (len_frames, 2)
fish_angle = data['fish_angle']  # Shape: (len_frames,)
frames = data['frames']        # Shape: (len_frames,)
fov_range = data['fov_range']  # Shape: (2,)
ids = data['ids']              # Shape: (len_frames, 2, rays_per_layer)
visible_points = data['visible_points']  # Shape: (len_frames, 2, layers, rays_per_layer, 2)

# Constants
invalid_value = np.finfo(np.float32).max
fov_start, fov_end = fov_range
len_frames, num_eyes, num_layers, rays_per_layer = depth.shape
vres = num_layers * rays_per_layer  # Total number of rays per eye

# Define the frame range to process
frame_start = 0      # Start frame number
frame_end = 200      # End frame number

# Find the indices of the frames in the data
frame_indices = np.where((frames >= frame_start) & (frames <= frame_end))[0]

# Video parameters
video_width = 3700    # Video width in pixels
video_height = 3700   # Video height in pixels
fps = 30              # Frames per second

# Initialize VideoWriter
fourcc = cv2.VideoWriter_fourcc(*'MJPG')  # Codec
out = cv2.VideoWriter('visual_field.mp4', fourcc, fps, (video_width, video_height))

# Create a black background image
background = np.zeros((video_height, video_width, 3), dtype=np.uint8)

# Iterate over the frames and plot the hit points for each eye
for idx in tqdm(frame_indices):
    # Create a copy of the background for this frame
    frame_img = background.copy()
    frame_ids = []

    # Process each eye (0: left eye, 1: right eye)
    for eye_idx in range(2):
        # Get eye position and angle for the current frame and eye
        eye_position = eye_pos[idx, eye_idx]      # Shape: (2,)
        eye_orientation = eye_angle[idx, eye_idx]  # Scalar

        # Generate angles for the field of view
        angles = np.linspace(fov_start, fov_end, rays_per_layer)

        # Compute direction vectors for each ray
        directions = np.stack([np.cos(angles), np.sin(angles)], axis=1)  # Shape: (rays_per_layer, 2)

        # Get depth values for the current frame and eye
        depth_values = depth[idx, eye_idx, 0, :]  # Shape: (rays_per_layer,)

        # Filter out invalid depth values
        valid_mask = depth_values < invalid_value

        # Extract valid depth values and directions
        depth_valid = np.sqrt(depth_values[valid_mask])  # Assuming depth_values are squared distances
        directions_valid = directions[valid_mask]

        # Compute hit points using eye position, orientation, direction, and depth
        # Rotate directions by eye_orientation
        rotation_matrix = np.array([[np.cos(eye_orientation), -np.sin(eye_orientation)],
                                    [np.sin(eye_orientation),  np.cos(eye_orientation)]])
        rotated_directions = directions_valid @ rotation_matrix.T

        # Compute hit points
        hit_points = eye_position + rotated_directions * depth_valid[:, np.newaxis]  # Shape: (N, 2)

        # Convert positions to integer pixel coordinates
        points = hit_points.astype(np.int32)

        # Draw lines and points on the frame
        color = (255, 0, 0) if eye_idx == 0 else (0, 0, 255)  # Color for each eye
        eye_pos_int = eye_position.astype(np.int32)
        for point in points:
            if 0 <= point[0] < video_width and 0 <= point[1] < video_height:
                # Draw a line from the eye position to the hit point
                cv2.line(frame_img, tuple(eye_pos_int), tuple(point), color, 1)
                # Draw the hit point
                cv2.circle(frame_img, tuple(point), 1, color, -1)

        # Second method: Using visible_points data directly
        # Get visible points for the current frame and eye
        hit_points = visible_points[idx, eye_idx, 0].reshape(-1, 2)

        # Filter out invalid points
        valid_points = hit_points[hit_points[:, 0] != invalid_value]

        # Convert positions to integer pixel coordinates
        points = valid_points.astype(np.int32)

        valid_ids = ids[idx, eye_idx, 0, valid_mask]
        frame_ids.append(valid_ids)

        # Draw the points on the frame
        color = (255, 255, 0) if eye_idx == 0 else (255, 0, 255)  # Different color
        for point in points:
            if 0 <= point[0] < video_width and 0 <= point[1] < video_height:
                cv2.circle(frame_img, tuple(point), 1, color, -1)

    frame_ids = np.unique(np.concatenate(frame_ids))
    if my_id in frame_ids:
        frame_ids = np.delete(frame_ids, np.where(frame_ids == my_id))

    # Draw eye positions
    for eye_idx in range(2):
        eye_position = eye_pos[idx, eye_idx].astype(np.int32)
        if 0 <= eye_position[0] < video_width and 0 <= eye_position[1] < video_height:
            # Draw the eye position
            cv2.circle(frame_img, tuple(eye_position), 5, (0, 255, 0), -1)  # Green for eye positions

    # Draw fish head position
    head_position = np.array((headX[idx], headY[idx]), dtype=np.int32)
    cv2.circle(frame_img, tuple(head_position), 5, (255, 255, 255), -1)  # White for fish head

    cv2.putText(frame_img, f'Fish ID: {my_id}', (50, 50),
                cv2.FONT_HERSHEY_DUPLEX, 1, (255, 255, 255), 2)
    # Add frame number text
    cv2.putText(frame_img, f'Frame {frames[idx]}', (50, 80),
                cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
    # Add visible IDs text
    cv2.putText(frame_img, f'Visible IDs: {frame_ids}', (50, 110),
                cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)

    # Write the frame to the video
    out.write(frame_img)

    # Optionally, display the frame every 100 frames
    #if idx % 100 == 0:
    #    print(f'Processed frame {idx}/{len(frame_indices)}')
    #    plt.imshow(cv2.cvtColor(frame_img, cv2.COLOR_BGR2RGB))
    #    plt.axis('off')
    #    plt.show()

# Release the VideoWriter
out.release()

print('Video saved as visual_field.mp4')

In [None]:
!ffmpeg -i visual_field.mp4 -c:v h264_videotoolbox -profile:v high -crf 20 -pix_fmt yuv420p visual_field_compressed.mp4 -y

In [None]:

from IPython.display import Video
Video('visual_field_compressed.mp4', embed=True)

In [None]:
import numpy as np
import glob
print(glob.glob("/Users/tristan/Videos/group_1_*_history.npz"))
with np.load("/Users/tristan/Videos/group_1_weights_-2_history.npz") as npz:
    print(list(npz.keys()))
    print(npz["uniquenesses"])