In [None]:
import cv2
import numpy as np
import tempfile
import os
from ultralytics import YOLO
from deep_sort_realtime.deepsort_tracker import DeepSort
from mmaction.apis import init_recognizer, inference_recognizer
from mmengine.dataset import Compose

# Cargar modelos
object_detector = YOLO("../object_detection/runs/detect/bod_v1/weights/best.pt")
#config_file = "./models/tsm/5-4-third/tsm_multisubjects.py"
config_file = "./models/x3d/x3d_s_multisubjects.py"
checkpoint_file = "./models/x3d/best_45.pth"
model = init_recognizer(config_file, checkpoint_file, device='cuda:0')

# Inicializar el tracker
tracker = DeepSort(max_age=50, n_init=5, nms_max_overlap=1.0)

# Parámetros
clip_len = 8  # Frames por clip
MAX_PLAYERS = 10
MAX_FRAMES_WITHOUT_DETECTION = 3

# Mapeo de IDs del tracker a IDs fijos
id_mapping = {}
lost_ids = set(range(1, MAX_PLAYERS + 1))
active_ids = set()

# Diccionario de buffers de frames por jugador
player_buffers = {}

# Lista para almacenar posiciones de tiros
shot_positions = []

# Video
video_path = "../clips/ClipLF1.mp4"
output_path = "../output/ClipLF1_action_output.mp4"

cap = cv2.VideoCapture(video_path)
fps = int(cap.get(cv2.CAP_PROP_FPS))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

fourcc = cv2.VideoWriter_fourcc(*"mp4v")
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))

frame_idx = 0  # Contador de frames

def create_temp_video(frames, fps=30):
    """Crea un video temporal a partir de una lista de frames."""
    temp_dir = tempfile.mkdtemp()
    temp_video_path = os.path.join(temp_dir, 'temp_clip.mp4')
    
    # Configurar el escritor de video
    h, w = frames[0].shape[:2]
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    writer = cv2.VideoWriter(temp_video_path, fourcc, fps, (w, h))
    
    for frame in frames:
        writer.write(frame)
    writer.release()
    
    return temp_video_path, temp_dir

while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        break

    frame_idx += 1

    # Detectar jugadores en el frame actual
    detections = []
    object_detection_results = object_detector(frame)

    for result in object_detection_results[0].boxes.data.tolist():
        x1, y1, x2, y2, conf, cls = result
        cls = int(cls)

        if object_detector.names[cls] == "player" and conf > 0.5:
            detections.append(([x1, y1, x2 - x1, y2 - y1], conf, cls))

    # Actualizar tracker
    tracking_results = tracker.update_tracks(detections, frame=frame)
    confirmed_tracks = [track for track in tracking_results 
                       if track.is_confirmed() and track.time_since_update <= MAX_FRAMES_WITHOUT_DETECTION]

    for track in confirmed_tracks:
        original_id = track.track_id

        if original_id in id_mapping:
            fixed_id = id_mapping[original_id]
        elif lost_ids:
            fixed_id = lost_ids.pop()
            id_mapping[original_id] = fixed_id
            active_ids.add(fixed_id)
        else:
            continue  # No hay IDs disponibles

        x1, y1, x2, y2 = track.to_ltrb()
        position = (int((x1 + x2) / 2), y2)

        # Extraer y almacenar el recorte del jugador
        player_crop = frame[int(y1):int(y2), int(x1):int(x2)]
        player_crop = cv2.resize(player_crop, (224, 224))  # Redimensionar a 224x224

        if fixed_id not in player_buffers:
            player_buffers[fixed_id] = []
        player_buffers[fixed_id].append(player_crop)

        # Si el buffer tiene suficientes frames, procesar el clip
        if len(player_buffers[fixed_id]) == clip_len:
            # Crear video temporal
            temp_video_path, temp_dir = create_temp_video(player_buffers[fixed_id], fps)
            
            try:

                result = inference_recognizer(model, temp_video_path)
                
                predicted_class = result.pred_score.argmax().item()

                # Añade un umbral de confianza 
                confidence_threshold = 0.80
                scores = result.pred_score.tolist()
                max_score = max(scores)

                if max_score > confidence_threshold:
                    predicted_class = scores.index(max_score)
                else:
                    predicted_class = -1  # Clase "no acción"

                print(f"Predicción para ID {fixed_id} en frame {frame_idx}: {predicted_class}")

                if predicted_class in {1, 2}:  # Tiro o Bandeja detectado
                    action_name = "Tiro" if predicted_class == 1 else "Bandeja" if predicted_class == 2 else "No acción"
                    print(f"🔴 {action_name} detectado en frame {frame_idx}, ID {fixed_id}")

                    # Guardar la posición del tiro
                    shot_positions.append((frame_idx, position[0], position[1]))
                
            finally:
                # Limpiar archivos temporales
                if os.path.exists(temp_video_path):
                    os.remove(temp_video_path)
                if os.path.exists(temp_dir):
                    os.rmdir(temp_dir)
            
            # Limpiar buffer (mantener los últimos frames para solapamiento si es necesario)
            player_buffers[fixed_id] = player_buffers[fixed_id][-clip_len//2:]  # Solapamiento del 50%

        # Dibujar bounding box y posición del jugador
        cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (255, 0, 0), 2)
        cv2.putText(frame, f"ID: {fixed_id}", (int(x1), int(y1) - 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)

    # Actualizar lista de IDs activos
    active_now = {id_mapping[track.track_id] for track in confirmed_tracks if track.track_id in id_mapping}
    lost_ids.update(active_ids - active_now)
    active_ids = active_now

    # Guardar frame procesado
    out.write(frame)
    cv2.imshow('Video', frame)

    key = cv2.waitKey(1) & 0xFF
    if key == ord('q'):
        break

# Guardar posiciones de tiros detectados
np.savetxt("../output/shot_positions.csv", np.array(shot_positions), delimiter=",", fmt="%d")

# Liberar recursos
cap.release()
out.release()
cv2.destroyAllWindows()
print(f"Posiciones de tiros guardadas en shot_positions.csv")
print(f"Video procesado guardado en {output_path}")

APROXIMACIÓN CON YOLOv11

Comprobación de PyTorch y uso de la GPU

In [1]:
import torch

print(f"Versión de PyTorch: {torch.__version__}")
print(f"Versión de CUDA en PyTorch: {torch.version.cuda}")
print(f"¿CUDA está disponible?: {torch.cuda.is_available()}")
print(torch.cuda.get_device_name(0))

Versión de PyTorch: 2.5.1+cu121
Versión de CUDA en PyTorch: 12.1
¿CUDA está disponible?: True
NVIDIA GeForce RTX 3060 Laptop GPU


Entrenamiento de las redes

In [None]:
from ultralytics import YOLO

model = YOLO("yolo11n.pt")

results = model.train(data="./datasets/bsr_v1/data.yaml", 
                      device=0,
                      batch=8, 
                      epochs=100,
                      imgsz=1024, 
                      optimizer='SGD', 
                      lr0=0.01, 
                      lrf=0.1, 
                      weight_decay=0.0005, 
                      task="detect",
                      patience=25)

In [2]:
def getColor(class_name):
    switch = {
        'player': (255, 0, 0),       # Azul
        'basketball': (0, 165, 255), # Naranja
        'rim': (0, 0, 255),          # Rojo
        'made-shot': (0, 255, 0)     # Verde
    }
    return switch.get(class_name, (0, 0, 0)) 

def drawBBox(frame, x1, y1, x2, y2, label, class_name):
    color = getColor(class_name)
    cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)
    cv2.putText(frame, label, (int(x1), int(y1) - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)

def drawPosition(frame, position, position_label):
    cv2.ellipse(frame, (int(position[0]), int(position[1])), (9, 3), 0, 0, 360, (0, 0, 255), -1)
    cv2.putText(frame, position_label, (int(position[0]) - 50, int(position[1]) + 25), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)

In [3]:
import cv2
from ultralytics import YOLO

shoot_detector = YOLO("./runs/detect/bsr_v1/weights/best.pt")
object_detector = YOLO("../object_detection/runs/detect/bod_v1/weights/best.pt")
pose_model = YOLO("yolo11n-pose.pt")

video_path = "../clips/ClipLF1.mp4"
output_path = "../output/ClipLF1_action_output.mp4"

cap = cv2.VideoCapture(video_path)
fourcc = cv2.VideoWriter_fourcc(*"avc1")
fps = int(cap.get(cv2.CAP_PROP_FPS))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))

keypoint_names = [
    'Nariz', 'Ojo izquierdo', 'Ojo derecho', 'Oreja izquierda', 'Oído derecho',
    'Hombro izquierdo', 'Hombro derecho', 'Codo izquierdo', 'Codo derecho',
    'Muñeca izquierda', 'Muñeca derecha', 'Cadera izquierda', 'Cadera derecha',
    'Rodilla izquierda', 'Rodilla derecha', 'Tobillo izquierdo', 'Tobillo derecho'
]


while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        break

    shoot_detection_results = shoot_detector(frame)
    object_detection_results = object_detector(frame)
    pose_model_results = pose_model(frame)

    # Análisis de las detecciones de jugadores
    for result in shoot_detection_results[0].boxes.data.tolist():  # Obtener los resultados como lista
        x1, y1, x2, y2, conf, cls = result  # Coordenadas, confianza y clase
        cls = int(cls)

        # Filtrar solo por las clases deseadas
        if shoot_detector.names[cls] in ['shoot'] and conf > 0.7:

            # Calcular la posición como el punto medio del borde inferior de la bbox
            position = (int((x1 + x2) / 2), y2)
            position_label = f"x:{int(position[0])} y:{int(position[1])}"

            label = f"{shoot_detector.names[cls]} {conf:.2f}"

            drawBBox(frame, x1, y1, x2, y2, label, shoot_detector.names[cls])

            drawPosition(frame, position, position_label)

    for result in object_detection_results[0].boxes.data.tolist():
        x1, y1, x2, y2, conf, cls = result
        cls = int(cls)

        if object_detector.names[cls] in ['player', 'basketball', 'rim', 'made-shot'] and conf > 0.5:
            label = f"{object_detector.names[cls]} {conf:.2f}"
            drawBBox(frame, x1, y1, x2, y2, label, object_detector.names[cls])


    # Análisis de las detecciones de pose
    # Suponiendo que pose_model_results es una lista de resultados, y cada resultado tiene un atributo "keypoints"
    # que es un tensor o similar en formato [num_detections, 17, 3] (x, y, conf)
    pose_threshold = 0.3  # Umbral de confianza para visualizar keypoints

    # Se asume que pose_model_results[0].keypoints.data tiene la información; ajústalo según tu estructura.
    if len(pose_model_results) > 0 and hasattr(pose_model_results[0], 'keypoints'):
        # Itera sobre las detecciones de pose:
        for det in pose_model_results[0].keypoints.data.tolist():
            # 'det' es una lista de 17 keypoints, cada uno en formato [x, y, score]
            for idx, kp in enumerate(det):
                x, y, kp_conf = kp
                if kp_conf > pose_threshold:
                    # Dibuja un círculo pequeño en la posición del keypoint
                    cv2.circle(frame, (int(x), int(y)), radius=3, color=(0, 255, 0), thickness=-1)
                    # Opcional: Dibujar el nombre del keypoint
                    cv2.putText(frame, keypoint_names[idx], (int(x)+5, int(y)+5),
                                cv2.FONT_HERSHEY_SIMPLEX, 0.3, (0, 255, 0), 1)
    else:
        print('No se encontraron detecciones de pose o el atributo keypoints no está disponible.')


    # Mostrar el frame procesado en pantalla
    cv2.imshow('Resultados', frame)


    """ # Esperar por una tecla: espacio para avanzar, 'q' para salir
    key = cv2.waitKey(0) & 0xFF  # Espera indefinidamente hasta que se presione una tecla
    if key == ord('q'):
        break
    elif key == ord(' '):  # Espacio para continuar
        # Escribir el frame procesado en el video de salida
        out.write(frame) """
    
    key = cv2.waitKey(1) & 0xFF
    if key == ord('q'):
        break


# Liberar recursos
cap.release()
out.release()
cv2.destroyAllWindows()

print(f"Video procesado guardado en {output_path}")



0: 576x1024 (no detections), 33.9ms
Speed: 10.8ms preprocess, 33.9ms inference, 1.4ms postprocess per image at shape (1, 3, 576, 1024)

0: 576x1024 10 players, 1 rim, 36.4ms
Speed: 11.5ms preprocess, 36.4ms inference, 3.1ms postprocess per image at shape (1, 3, 576, 1024)

0: 384x640 (no detections), 48.4ms
Speed: 4.2ms preprocess, 48.4ms inference, 1.4ms postprocess per image at shape (1, 3, 384, 640)

0: 576x1024 (no detections), 34.8ms
Speed: 10.4ms preprocess, 34.8ms inference, 1.3ms postprocess per image at shape (1, 3, 576, 1024)

0: 576x1024 10 players, 1 rim, 41.7ms
Speed: 10.4ms preprocess, 41.7ms inference, 3.6ms postprocess per image at shape (1, 3, 576, 1024)

0: 384x640 (no detections), 44.5ms
Speed: 4.9ms preprocess, 44.5ms inference, 2.3ms postprocess per image at shape (1, 3, 384, 640)

0: 576x1024 (no detections), 33.0ms
Speed: 9.8ms preprocess, 33.0ms inference, 1.6ms postprocess per image at shape (1, 3, 576, 1024)

0: 576x1024 10 players, 1 rim, 30.5ms
Speed: 9.5ms