In [None]:
import os
import time
import threading
from collections import deque
from enum import Enum
import winsound

import cv2
import numpy as np
import mediapipe as mp
from ultralytics import YOLO
import serial
import serial.tools.list_ports

from local_landmark import LocalLandmark
from realsense_camera import RealSenseCamera
from yolo_object import YoloObject
from hand_helper import HandHelper
from object_tracker import ObjectTracker

In [None]:
SAMPLING_FPS = 10
SAMPLING_SECONDS = 1
sequence_length = int(SAMPLING_SECONDS*SAMPLING_FPS)
frame_interval = 1.0/SAMPLING_FPS

DATA_FOLDER = "sample_data"

In [None]:
IMAGE_WIDTH = 640
IMAGE_HEIGHT = 480

HAND_CONNECTIONS = ((0, 1), (0, 5), (9, 13), (13, 17), (5, 9), (0, 17), (1, 2), (2, 3), (3, 4), (5, 6), (6, 7), (7, 8),
                    (9, 10), (10, 11), (11, 12), (13, 14), (14, 15), (15, 16), (17, 18), (18, 19), (19, 20))

OBJECT_NAMES = {
    0: "small_screw",
    1: "big_screw",
    2: "small_wrench",
    3: "big_wrench",
    4: "cap",
    5: "barrel",
    6: "piston",
    7: "support",
    8: "air_connector",
    9: "nut"
}

label_colors = [np.random.random(3)*255 for _ in range(len(OBJECT_NAMES))]

In [None]:
mp_hands = mp.solutions.hands

yolo_model_path = os.path.join("weights", "yolov9c_fine_tuned.pt")
yolo_model = YOLO(yolo_model_path)

In [None]:
class Action(Enum):
    IDLE = 0
    PICK = 1
    PLACE = 2
    SCREW_WRENCH = 3

In [None]:
def process_landmarks(frame, model):
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    frame.flags.writeable = False
    results = model.process(frame)
    return results

def extract_hand_landmarks(hand_index, multi_hand_landmarks, multi_handedness, selected_landmarks_indices, extract_depth=False, camera=None):
    """hand_index: 0 = right, 1 = left"""
    hand_landmarks_list = []
    hand_index_in_multi_landmarks = -1
    if multi_hand_landmarks and multi_handedness:
        for possible_hand_index in range(len(multi_hand_landmarks)):
            handedness_classification = multi_handedness[possible_hand_index].classification[0]
            handedness_index = handedness_classification.index
            handedness_score = handedness_classification.score
            if handedness_index == hand_index: # and handedness_score > 0.7
                hand_index_in_multi_landmarks = possible_hand_index
                break

    for i in selected_landmarks_indices:
        if hand_index_in_multi_landmarks >= 0:
            lm = LocalLandmark.from_mediapipe_hand_landmark(multi_hand_landmarks[hand_index_in_multi_landmarks].landmark[i])
            if extract_depth:
                lm.set_depth(camera.get_depth(int(lm.x*IMAGE_WIDTH), int(lm.y*IMAGE_HEIGHT)))
            hand_landmarks_list.append(lm)
        else:
            hand_landmarks_list.append(LocalLandmark(0, 0, 0, 0))
                
    return hand_landmarks_list

def draw_landmarks(frame, landmarks, side, connections):
    """side: 0 = right, 1 = left"""

    if all([lm.is_empty() for lm in landmarks]):
        return

    left_landmarks_color = (255, 0, 0)
    right_landmarks_color = (0, 0, 255)

    for lm in landmarks:
        if side == 0:
            landmarks_color = right_landmarks_color
        elif side == 1:
            landmarks_color = left_landmarks_color
        else:
            return
        cv2.circle(frame, (int(lm.x*IMAGE_WIDTH), int(lm.y*IMAGE_HEIGHT)), 2, landmarks_color, 2)

    for connection in connections:
        if connection[1] < len(landmarks):
            start_point = (int(landmarks[connection[0]].x*IMAGE_WIDTH), int(landmarks[connection[0]].y*IMAGE_HEIGHT))
            end_point = (int(landmarks[connection[1]].x*IMAGE_WIDTH), int(landmarks[connection[1]].y*IMAGE_HEIGHT))
            cv2.line(frame, start_point, end_point, (255, 255, 255), 2)

In [None]:
def get_landmarks_from_flattened_array(flattened_landmarks):
   N_RIGHT_HAND_LANDMARKS = 21
   N_LEFT_HAND_LANDMARKS = 21

   right_hand_landmarks = []
   left_hand_landmarks = []

   cursor = 0
   
   for _ in range(N_RIGHT_HAND_LANDMARKS):
      cursor_end_position = cursor + 4
      lm_sub_array = flattened_landmarks[cursor:cursor_end_position]
      right_hand_landmarks.append(LocalLandmark.from_np_array(lm_sub_array))
      cursor = cursor_end_position

   for _ in range(N_LEFT_HAND_LANDMARKS):
      cursor_end_position = cursor + 4
      lm_sub_array = flattened_landmarks[cursor:cursor_end_position]
      left_hand_landmarks.append(LocalLandmark.from_np_array(lm_sub_array))
      cursor = cursor_end_position

   return right_hand_landmarks, left_hand_landmarks 

In [None]:
def play_video_from_sequence_data(data_folder_path):
    frame_interval = 1.0/SAMPLING_FPS
    
    for frame_index in range(sequence_length):
        frame_display_time = time.time()
        frame = np.zeros(shape=(IMAGE_HEIGHT, IMAGE_WIDTH, 3), dtype=np.uint8)

        frame_landmarks_data = np.load(os.path.join(data_folder_path, "landmarks", f"frame_{frame_index}.npy"))
        right_hand_landmarks, left_hand_landmarks = get_landmarks_from_flattened_array(frame_landmarks_data)
        draw_landmarks(frame, right_hand_landmarks, side=0, connections=HAND_CONNECTIONS)
        draw_landmarks(frame, left_hand_landmarks, side=1, connections=HAND_CONNECTIONS)

        frame_objects_data = np.load(os.path.join(data_folder_path, "objects", f"frame_{frame_index}.npy"))
        for frame_object_data in frame_objects_data:
            label_id, in_hand, x1, y1, x2, y2 = map(int, frame_object_data)
            hand_color = (0, 0, 255) if in_hand == 0 else (255, 0, 0)            
            cv2.rectangle(frame, (x1, y1), (x2, y2), hand_color, 2)
            cv2.putText(frame, f"{OBJECT_NAMES[label_id]}, in_hand: {in_hand}", (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, hand_color, 1)

        cv2.putText(frame, str(frame_index), (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)

        cv2.imshow("Video from data", frame)

        time_taken = time.time() - frame_display_time
        time_to_wait = frame_interval - time_taken
        if time_to_wait > 0:
            time.sleep(time_to_wait)

        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    cv2.destroyAllWindows()

In [None]:
def save_acquired_data(data_folder_path, frames_landmarks, frames_objects):
    landmarks_folder_path = os.path.join(data_folder_path, "landmarks")
    objects_folder_path = os.path.join(data_folder_path, "objects")

    if not os.path.exists(landmarks_folder_path):
        os.makedirs(landmarks_folder_path)
    if not os.path.exists(objects_folder_path):
        os.makedirs(objects_folder_path)

    print(f"Saving sequence data in {data_folder_path}")
    for i in range(len(frames_landmarks)):
        file_path = os.path.join(landmarks_folder_path, f"frame_{i}.npy")
        np.save(file_path, frames_landmarks[i])
    for i in range(len(frames_objects)):
        file_path = os.path.join(objects_folder_path, f"frame_{i}.npy")
        np.save(file_path, frames_objects[i])
    print("Sequence data successfully saved", end="\n\n")
    
    time.sleep(0.5)
    play_video_from_sequence_data(data_folder_path)

In [None]:
def arduino_connection_thread(recorder):
    ports = serial.tools.list_ports.comports()
    for port, desc, _ in ports:
        if "Genuino Uno" in desc:
            com_port = port

    conn = serial.Serial(port=com_port, baudrate=9600, timeout=0.1)

    while True:
        line = conn.readline()
        if line == b"white\r\n":
            winsound.Beep(500, 500)
            print(f"Start acquiring data for sequence {recorder.sequence_number} in one second...")
            time.sleep(1)
            winsound.Beep(700, 200)
            recorder.start_recording()
        elif line == b"red\r\n":
            recorder.sequence_number = max(0, recorder.sequence_number - 1)
            print(f"Back to sequence number {recorder.sequence_number}")
        elif line == b"pedal_low\r\n":
            break
        
    conn.close()
    recorder.terminate_recording()
    print("Stopping acquisition.")

In [None]:
class Recorder:
    def __init__(self, action_folder, sequence_number):
        self.action_folder = action_folder
        self.sequence_number = sequence_number
        self.acquiring = False
        self.terminate = False
        self.first_frame_acquisition_time = 0
        self.frame_index = 0
        self.frames_landmarks = deque(maxlen=sequence_length)
        self.frames_objects = deque(maxlen=sequence_length)
    
    def start_recording(self):
        self.acquiring = True
        self.first_frame_acquisition_time = 0
        self.frame_index = 0
        self.frames_landmarks.clear()
        self.frames_objects.clear()

    def stop_recording(self):
        self.acquiring = False

    def terminate_recording(self):
        self.terminate = True

In [None]:
ACQUIRE = False
ACTION = Action.SCREW_WRENCH.name.lower()

if ACQUIRE:
    sequence_number = 0
    action_folder = os.path.join(DATA_FOLDER, ACTION)
    if os.path.exists(action_folder):
        sequence_names = os.listdir(action_folder)
        if sequence_names:
            sequence_numbers = []
            for sequence_name in sequence_names:
                sequence_numbers.append(int(sequence_name.split('_')[1]))
            sequence_number = sorted(sequence_numbers, reverse=True)[0] + 1

    recorder = Recorder(action_folder, sequence_number)
    threading.Thread(target=arduino_connection_thread, args=(recorder,)).start()

camera = RealSenseCamera(image_width=IMAGE_WIDTH, image_height=IMAGE_HEIGHT)
camera.connect()

prev_frame_time = 0
new_frame_time = 0

hand_helper = HandHelper(image_width=IMAGE_WIDTH, image_height=IMAGE_HEIGHT)
object_tracker = ObjectTracker(image_width=IMAGE_WIDTH, image_height=IMAGE_HEIGHT)

try:
    with mp_hands.Hands(max_num_hands=2, min_detection_confidence=0.5, min_tracking_confidence=0.5) as hands:
        while True:
            if ACQUIRE and recorder.terminate:
                break
            ret = camera.acquire_frame()
            if not ret:
                break

            color_image = camera.color_image

            hands_results = process_landmarks(color_image, hands)
            object_results = yolo_model(color_image, verbose=False)[0]

            right_hand_landmarks = extract_hand_landmarks(0, hands_results.multi_hand_landmarks, hands_results.multi_handedness, list(range(21)), extract_depth=True, camera=camera)
            left_hand_landmarks = extract_hand_landmarks(1, hands_results.multi_hand_landmarks, hands_results.multi_handedness, list(range(21)), extract_depth=True, camera=camera)
            hand_helper.register_hands_landmarks(right_hand_landmarks, left_hand_landmarks)
            tips_midpoints = hand_helper.get_tips_midpoints()
            
            seen_yolo_objects = []
            for object_result in object_results.boxes:
                if object_result.conf.item() > 0.75:
                    seen_yolo_objects.append(YoloObject.from_yolo_box_result(object_result))

            object_tracker.register_seen_objects(seen_yolo_objects, tips_midpoints)
            object_tracker.increment_frame_index()
            
            if not ACQUIRE:
                new_frame_time = time.time()
                fps = 1.0/(new_frame_time - prev_frame_time)
                prev_frame_time = new_frame_time
                cv2.putText(color_image, f"{fps:.2f}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)

                draw_landmarks(color_image, right_hand_landmarks, side=0, connections=HAND_CONNECTIONS)
                draw_landmarks(color_image, left_hand_landmarks, side=1, connections=HAND_CONNECTIONS)
                    
                for i, in_hand_tracked_object in enumerate((object_tracker.right_hand_tracked_object, object_tracker.left_hand_tracked_object)):
                    if in_hand_tracked_object is None:
                        continue
                    hand_color = (0, 0, 255) if i == 0 else (255, 0, 0)
                    yolo_object = in_hand_tracked_object.yolo_object
                    cv2.putText(color_image, f"{OBJECT_NAMES[yolo_object.label_id]}, visible: {in_hand_tracked_object.is_visible}", (10, 30 + i*30), cv2.FONT_HERSHEY_SIMPLEX, 0.5, hand_color, 1)
                    if in_hand_tracked_object.is_visible:
                        cv2.rectangle(color_image, (yolo_object.x1, yolo_object.y1), (yolo_object.x2, yolo_object.y2), hand_color, 2)
                        cv2.putText(color_image, f"id: {in_hand_tracked_object.tracker_id}, {OBJECT_NAMES[yolo_object.label_id]}, conf: {yolo_object.conf:.2f}",
                                    (yolo_object.x1, yolo_object.y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, hand_color, 1)

                cv2.imshow("RealSense", color_image)
                if cv2.waitKey(1) & 0xFF == ord('q'):
                    break
            
            # Frame data acquisition
            if ACQUIRE and recorder.acquiring:
                elapsed_time = time.time() - recorder.first_frame_acquisition_time - recorder.frame_index*frame_interval
                if not (recorder.first_frame_acquisition_time > 0 and elapsed_time > 2*frame_interval):
                    if elapsed_time > frame_interval:
                        hands_landmarks = right_hand_landmarks + left_hand_landmarks
                        flattened_landmarks = np.concatenate([lm.get_np_array() for lm in hands_landmarks])
                        recorder.frames_landmarks.append(flattened_landmarks)
                        
                        frame_objects = []
                        for i, in_hand_tracked_object in enumerate((object_tracker.right_hand_tracked_object, object_tracker.left_hand_tracked_object)):
                            if in_hand_tracked_object is None:
                                continue
                            yolo_object = in_hand_tracked_object.yolo_object
                            frame_objects.append(np.array((yolo_object.label_id, i, yolo_object.x1, yolo_object.y1, yolo_object.x2, yolo_object.y2)))
                        recorder.frames_objects.append(np.array(frame_objects))
                        
                        if recorder.frame_index == 0:
                            recorder.first_frame_acquisition_time = time.time()
                            
                        recorder.frame_index += 1
                        if recorder.frame_index >= sequence_length:
                            data_folder_path = os.path.join(recorder.action_folder, f"sequence_{recorder.sequence_number}")
                            threading.Thread(target=save_acquired_data, args=(data_folder_path, list(recorder.frames_landmarks), list(recorder.frames_objects))).start()
                            recorder.stop_recording()
                            recorder.sequence_number += 1
                            
                else:
                    print("Video discarted for too many frames skipped")
                    recorder.stop_recording()

finally:
    cv2.destroyAllWindows()
    camera.disconnect()


In [None]:
data_folder_path = os.path.join(DATA_FOLDER, Action.PICK.name.lower(), "sequence_30") # _mirrored

play_video_from_sequence_data(data_folder_path)

In [None]:
cv2.destroyAllWindows()