In [1]:
import torch
from torch import nn
import cv2
from ultralytics import YOLO
import pandas as pd
import numpy as np

In [2]:
# InputShape([window_size, 17, 2]) 
# 批量：1，帧数：window_size，输入维度：17*2
class Pnet(nn.Module):
    def __init__(self, window_size, num_hiddens, num_layers):
        super().__init__()
        self.window_size = window_size
        self.encoder = nn.LSTM(17 * 2, num_hiddens, num_layers=num_layers, bidirectional=False)
        self.decoder = nn.Linear(num_hiddens, 2)
    def forward(self, x):
        _, state = self.encoder(x)
        res = self.decoder(state[0][-1])
        return res

In [3]:
net = Pnet(24, 128, 2)

In [4]:
window_size = 24

In [5]:
x = torch.randn(24, 17, 2).reshape(window_size, -1).unsqueeze(1)
res = net(x)

In [6]:
res

tensor([[-0.0938, -0.0751]], grad_fn=<AddmmBackward0>)

In [7]:
x.shape

torch.Size([24, 1, 34])

In [8]:
x2 = torch.randn(24, 64, 34)
net(x2).shape

torch.Size([64, 2])

In [9]:
import logging
from ultralytics.utils import LOGGER
LOGGER.setLevel(logging.WARNING)

In [10]:
detector = YOLO("../models/yolo11n-pose.pt")

In [11]:
import os
import sys
sys.path.append(os.path.abspath(os.path.join('..')))  # 添加上一层目录（即包含 src 的目录）

from src.utils.window import Window

In [12]:
window = Window(torch.device("cuda:0"), window_size, (17, 2))

In [13]:
def gen_frames():
    df = pd.DataFrame(np.random.rand(0, 17*2))
    cap = cv2.VideoCapture('../preprocess/test.mp4')
    while True:
        success, frame = cap.read()
        if not success:
            print('ok')
            break
        res = detector(frame)[0]
        kp = res.keypoints.xyn
        if len(res.boxes.cls) > 1:
            # 如果监测出两个人及以上，取置信度最大的
            idx = res.boxes.conf.argmax(-1).item()
            kp = res.keypoints.xy[idx].unsqueeze(0)
        elif len(res.boxes.cls) == 0:
            kp = None
            continue
        ready = window.add(kp)
        if ready:
            # print(window.data.shape)
            # torch.Size([window_size, 17, 2])
            data = pd.DataFrame(window.data.reshape(window_size, -1).cpu().numpy())
            window.clear()
            df = pd.concat([df, data], ignore_index=True)
    return df

In [14]:
import time

In [65]:
class Timer:
    def __init__(self):
        self._duration = 0
        self.s = None

    def start(self):
        self.s = time.time()

    def stop(self):
        assert self.s is not None
        self._duration += time.time() - self.s
        self.s = None

    def clear(self):
        self._duration = 0

    @property
    def t(self):
        return round(self._duration, 2)

    def __enter__(self):
        self.start()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.stop()


In [66]:
with Timer() as timer:
    df = gen_frames()
print(timer.t)

ok
55.22


In [67]:
df

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,24,25,26,27,28,29,30,31,32,33
0,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.537353,0.407614,0.000000,0.000000,...,0.564197,0.519054,0.543527,0.580287,0.563457,0.578616,0.543717,0.639587,0.562427,0.635656
1,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.537425,0.407214,0.000000,0.000000,...,0.563542,0.519085,0.543057,0.579465,0.562760,0.578019,0.543460,0.638636,0.561850,0.635171
2,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.538859,0.407787,0.000000,0.000000,...,0.564851,0.520129,0.542886,0.580259,0.564643,0.579493,0.543680,0.640475,0.564419,0.637672
3,0.517318,0.246722,0.523907,0.236667,0.512631,0.234307,0.533367,0.239292,0.503148,0.233254,...,0.479385,0.447478,0.512144,0.569434,0.461650,0.559254,0.505771,0.685830,0.457112,0.675888
4,0.517317,0.246737,0.523908,0.236682,0.512630,0.234321,0.533374,0.239305,0.503154,0.233266,...,0.479369,0.447484,0.512135,0.569428,0.461645,0.559260,0.505794,0.685802,0.457119,0.675868
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2203,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.544524,0.402096,0.560152,0.399657,...,0.566836,0.521924,0.544077,0.586494,0.566503,0.584986,0.543449,0.646406,0.567786,0.643108
2204,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.542877,0.403972,0.558324,0.401244,...,0.567277,0.523888,0.544689,0.588046,0.565562,0.586897,0.544696,0.648866,0.565362,0.645697
2205,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.541024,0.406075,0.556139,0.402841,...,0.566390,0.525086,0.544390,0.589949,0.564441,0.588446,0.545366,0.650879,0.564767,0.647031
2206,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.539029,0.407792,0.554516,0.403907,...,0.566486,0.523598,0.544152,0.589230,0.564853,0.587150,0.544838,0.648122,0.565586,0.643244
