In [6]:
import time
import cv2
import numpy as np
import torch
import torch.nn.functional as F
import onnxruntime as ort
from collections import deque
import concurrent.futures
from rtmlib import Wholebody, draw_skeleton, Body  

def extract_skeleton_from_batch(
    frames_batch,
    skeleton_detector,
    num_joints: int = 17,
    normalize: bool = True
) -> np.ndarray:
    """
    Извлекает ключевые точки для батча кадров и нормализует их.
    Возвращает np.ndarray формы (T, num_joints, 2).
    """
    skeletons = []
    for frame in frames_batch:
        try:
            keypoints, scores = skeleton_detector(frame)
        except Exception as e:
            kp = np.zeros((num_joints, 2), dtype=np.float32)
        else:
            kp = np.array(keypoints, dtype=np.float32)
            if kp.ndim == 3:
                kp = kp[0]
            # оставляем только xy
            if kp.shape[1] > 2:
                kp = kp[:, :2]
            # приведение числа точек
            if kp.shape[0] != num_joints:
                if kp.shape[0] < num_joints:
                    pad = np.zeros((num_joints - kp.shape[0], 2), dtype=np.float32)
                    kp = np.concatenate([kp, pad], axis=0)
                else:
                    kp = kp[:num_joints]
        skeletons.append(kp)

    seq = np.stack(skeletons, axis=0)  # (T, V, 2)
    if normalize:
        mins = seq.min(axis=(0,1))
        maxs = seq.max(axis=(0,1))
        diff = np.where(maxs - mins == 0, 1.0, maxs - mins)
        seq = (seq - mins) / diff
    return seq


def extract_skeletons_from_frames(
    frames,
    skeleton_detector,
    num_joints=17,
    batch_size=16,
    normalize=True,
    frame_skip=2,
    segment_length=8
) -> np.ndarray:
    start = time.time()
    skeletons = []
    n = len(frames)
    for i in range(0, n, batch_size):
        batch = frames[i: i + batch_size]
        sk = extract_skeleton_from_batch(batch, skeleton_detector, num_joints, normalize)
        skeletons.extend(sk)
    seq = skeletons[::frame_skip]
    # pad or trim
    if len(seq) < segment_length:
        seq += [seq[-1]] * (segment_length - len(seq))
    else:
        seq = seq[:segment_length]
    print(f"extract_skeletons: {time.time() - start:.4f}s")
    return np.array(seq, dtype=np.float32)

# ===================== ONNX Runtime init =====================
onnx_path = "hybrid_gcn_timesformer_09.onnx"
sess_opts = ort.SessionOptions()
sess_opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
providers = ['CPUExecutionProvider']   
ort_session = ort.InferenceSession(onnx_path, sess_options=sess_opts, providers=providers)

# ===================== Предсказание сегмента через ONNX =====================
def predict_segment_onnx(
    segment_frames,
    skeleton_detector,
    ort_session,
    segment_length=8,
    num_joints=17,
    normalize=True,
    frame_skip=2,
    batch_size=16
):
    start = time.time()
    seq = extract_skeletons_from_frames(
        segment_frames, skeleton_detector,
        num_joints, batch_size,
        normalize, frame_skip,
        segment_length
    )
    x_input = seq[np.newaxis, ...].astype(np.float32)  # (1, T, V, 2)
    logits_q, logits_c = ort_session.run(None, {"x": x_input})
    probs_q = F.softmax(torch.from_numpy(logits_q), dim=1).numpy()
    probs_c = F.softmax(torch.from_numpy(logits_c), dim=1).numpy()
    label_q = int(np.argmax(probs_q, axis=1)[0])
    label_c = int(np.argmax(probs_c, axis=1)[0])
    print(f"predict_segment_onnx: {time.time() - start:.4f}s")
    return label_q, probs_q, label_c, probs_c

# ===================== Утилита для внешнего вызова =====================
def process_frame_segment_onnx(
    segment_frames,
    skeleton_detector,
    ort_session,
    **kwargs
):
    q, _, c, _ = predict_segment_onnx(
        segment_frames, skeleton_detector, ort_session, **kwargs
    )
    return q, c

# ===================== Отображение видео со скользящим окном =====================
def display_video_sliding_window_onnx(
    video_source,
    skeleton_detector,
    ort_session,
    segment_length=8,
    num_joints=17,
    normalize=True,
    frame_skip=2,
    batch_size=16
):
    """
    Обрабатывает видео скользящим окном длины segment_length.
    Отрисовывает скелет, аннотации, подсчитывает отжимания и показывает FPS.
    """
    cap = cv2.VideoCapture(video_source)
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fourcc = cv2.VideoWriter_fourcc(*'XVID')
    out = cv2.VideoWriter('output_onnx.avi', fourcc, 30.0, (width, height))

    window = deque(maxlen=segment_length)
    last_segment_class = None
    pushup_count = 0
    prev_q = None
    prev_c = None
    frame_idx = 0
    prev_time = time.time()
    window_name = "ONNX Live"
    cv2.namedWindow(window_name, cv2.WINDOW_NORMAL)

    while True:
        ret, frame = cap.read()
        if not ret:
            break
        window.append(frame)
        frame_idx += 1

        # Предсказание по сегменту
        if len(window) == segment_length and (frame_idx - segment_length) % frame_skip == 0:
            q, c = process_frame_segment_onnx(
                list(window),
                skeleton_detector,
                ort_session,
                segment_length=segment_length,
                num_joints=num_joints,
                normalize=normalize,
                frame_skip=frame_skip,
                batch_size=batch_size
            )
            # print(q, c)
            prev_q, prev_c = q, c

            # Логика подсчета отжиманий:
            # Предполагаем классы count: 1 = верхняя позиция, 2 = нижняя позиция
            if q == 2 and last_segment_class == 0 and c == 1:
                pushup_count += 1
            last_segment_class = c

        # Расчет FPS
        curr_time = time.time()
        fps = 1.0 / (curr_time - prev_time) if curr_time != prev_time else 0.0
        prev_time = curr_time

        # Детекция и отрисовка скелета
        keypoints, scores = skeleton_detector(frame)
        frame_result = draw_skeleton(frame, keypoints, scores, kpt_thr=0.5)

        # Аннотации на кадре
        # if prev_q is not None:
        #     # cv2.putText(frame_result, f"Quality: {prev_q}", (10, 30),
        #     #             cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
        #     # cv2.putText(frame_result, f"Count: {prev_c}", (10, 60),
            #             cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
        cv2.putText(frame_result, f"Push-ups: {pushup_count}", (10, 30),
                    cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
        cv2.putText(frame_result, f"FPS: {fps:.2f}", (10, 60),
                    cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)

        # Показ и запись
        cv2.imshow(window_name, frame_result)
        out.write(frame_result)

        # Выход по 'q'
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    cap.release()
    out.release()
    cv2.destroyAllWindows()

video_path =     r"C:\Users\jet\Desktop\видосы\norm_v\val\true\true_0005.mp4" 
skeleton_detector = Body(mode='lightweight', backend='onnxruntime', device='cuda')

display_video_sliding_window_onnx(
    video_path, ## можно передать 0 для захвата видео с веб-камеры.
    skeleton_detector,
    ort_session,
    segment_length=4,
    num_joints=17,
    normalize=True,
    frame_skip=4,
    batch_size=8
)


load C:\Users\jet\.cache\rtmlib\hub\checkpoints\yolox_tiny_8xb8-300e_humanart-6f3252f9.onnx with onnxruntime backend
load C:\Users\jet\.cache\rtmlib\hub\checkpoints\rtmpose-s_simcc-body7_pt-body7_420e-256x192-acd4a1ef_20230504.onnx with onnxruntime backend
extract_skeletons: 0.0849s
predict_segment_onnx: 0.0859s
extract_skeletons: 0.0843s
predict_segment_onnx: 0.0853s
extract_skeletons: 0.0813s
predict_segment_onnx: 0.0839s
extract_skeletons: 0.0801s
predict_segment_onnx: 0.0811s
extract_skeletons: 0.0796s
predict_segment_onnx: 0.0821s
extract_skeletons: 0.0808s
predict_segment_onnx: 0.0828s
extract_skeletons: 0.1008s
predict_segment_onnx: 0.1018s
extract_skeletons: 0.0940s
predict_segment_onnx: 0.0960s
extract_skeletons: 0.0897s
predict_segment_onnx: 0.0907s
extract_skeletons: 0.0818s
predict_segment_onnx: 0.0838s
extract_skeletons: 0.1117s
predict_segment_onnx: 0.1137s
extract_skeletons: 0.0840s
predict_segment_onnx: 0.0862s
extract_skeletons: 0.0923s
predict_segment_onnx: 0.0943s
ex