From 4d058429bdb4a25a48e99de452f265480a6e5206 Mon Sep 17 00:00:00 2001 From: Sebastian Tuura Date: Tue, 25 Nov 2025 15:01:24 +0100 Subject: [PATCH] removed sam2 -> faster inference performance is similar --- team_assigner/team_assigner.py | 31 +++++------------ tracking/track_players.py | 62 +++------------------------------- 2 files changed, 14 insertions(+), 79 deletions(-) diff --git a/team_assigner/team_assigner.py b/team_assigner/team_assigner.py index 82e226d..126e7e6 100644 --- a/team_assigner/team_assigner.py +++ b/team_assigner/team_assigner.py @@ -17,7 +17,8 @@ def __init__(self, team_B= "DARK-BLUE shirt", history_len = 50, crop_factor = 0.375, - save_imgs = False + save_imgs = False, + crop = False ): self.team_colors = {} self.history_len = history_len @@ -25,10 +26,10 @@ def __init__(self, self.crop_factor = crop_factor self.save_imgs = save_imgs + self.crop = crop self.team_A = team_A self.team_B = team_B - self.sam2 = SAM("sam2.1_b.pt") def load_model(self): self.model = CLIPModel.from_pretrained("patrickjohncyh/fashion-clip") @@ -47,30 +48,20 @@ def crop_img(self, pil_image): def get_player_color(self,frame,bbox): image = frame[int(bbox[1]):int(bbox[3]), int(bbox[0]):int(bbox[2])] - results_list = self.sam2(image, bboxes=[0, 0, int(bbox[2])-int(bbox[0]), int(bbox[3])-int(bbox[1])]) - masks_obj = results_list[0].masks # Masks object + pil_image = Image.fromarray(image) # pytorch expects this format - if masks_obj is not None and len(masks_obj) > 0: - mask_tensor = masks_obj.data # torch.Tensor of shape (1, H, W) - mask_numpy = mask_tensor[0].cpu().numpy().astype(np.uint8) - - blurred = cv2.GaussianBlur(image, (21,21), 0) - - masked_img = np.where(mask_numpy[..., None] == 1, image, blurred) - rgb_image = cv2.cvtColor(masked_img, cv2.COLOR_BGR2RGB) # blurred img - pil_image = Image.fromarray(rgb_image) # pytorch expects this format - - # Crop the img around the center. - cropped_pil_image = self.crop_img(pil_image) + if self.crop: + # Crop the img around the center. + pil_image = self.crop_img(pil_image) if self.save_imgs: r = random.randint(1, 1000000) filename = f"masked_{r}.png" - cropped_pil_image.save(os.path.join("imgs/masked", filename)) + pil_image.save(os.path.join("imgs/masked", filename)) team_classes = [self.team_A, self.team_B] - inputs = self.processor(text=team_classes, images=cropped_pil_image, return_tensors="pt", padding=True) + inputs = self.processor(text=team_classes, images=pil_image, return_tensors="pt", padding=True) outputs = self.model(**inputs) logits_per_image = outputs.logits_per_image @@ -93,10 +84,6 @@ def get_team_from_history(self, player_id): def get_player_team(self,frame,player_bbox,player_id): - history = list(self.player_team_cache_history[player_id]) - # if len(history) > self.history_len: - # return self.get_team_from_history(player_id) - player_color = self.get_player_color(frame,player_bbox) self.player_team_cache_history[player_id].append(player_color) team_id = self.get_team_from_history(player_id) diff --git a/tracking/track_players.py b/tracking/track_players.py index 5cfd1dd..84e16a1 100644 --- a/tracking/track_players.py +++ b/tracking/track_players.py @@ -1,17 +1,12 @@ from ultralytics import YOLO import supervision as sv import sys -from ultralytics import SAM -import numpy as np -import random - sys.path.append("../") class PlayerTracker(): def __init__(self, model_path): self.model = YOLO(model_path) self.tracker = sv.ByteTrack() - self.sam2 = SAM("sam2.1_b.pt") def detect_frames(self, vid_frames, batch_size=20, min_conf=0.5): detections = [] @@ -21,61 +16,17 @@ def detect_frames(self, vid_frames, batch_size=20, min_conf=0.5): detections += batch_detections return detections - def mask_to_bbox(self, mask): - ys, xs = np.where(mask > 0) - if len(xs) == 0 or len(ys) == 0: - return [0, 0, mask.shape[1], mask.shape[0]] - - x_min = int(xs.min()) - x_max = int(xs.max()) - y_min = int(ys.min()) - y_max = int(ys.max()) - - return [x_min, y_min, x_max, y_max] - def get_object_tracks(self, vid_frames): - + detections = self.detect_frames(vid_frames) tracks = [] - total = len(vid_frames) - for frame_id, frame in enumerate(vid_frames): - print(f"frame {frame_id}/{total}") - detection = detections[frame_id] + for frame_id, detection in enumerate(detections): class_names = detection.names class_names_inv = {val:key for key,val in class_names.items()} detection_sv = sv.Detections.from_ultralytics(detection) - - refined_bboxes = [] - for det in detection_sv: - bbox_array = det[0] - x1, y1, x2, y2 = map(int, bbox_array) - conf = float(det[2]) - class_id = int(det[3]) - - if class_names_inv['Player'] != class_id: - continue # skip non-player detections - - player_crop = frame[y1:y2, x1:x2] - - # SAM2 mask - results = self.sam2(player_crop, bboxes=[[0,0,x2-x1, y2-y1]]) - mask = results[0].masks.data[0].cpu().numpy().astype(np.uint8) - - x_min_rel, y_min_rel, x_max_rel, y_max_rel = self.mask_to_bbox(mask) - - x1_ref = x1 + x_min_rel - y1_ref = y1 + y_min_rel - x2_ref = x1 + x_max_rel - y2_ref = y1 + y_max_rel - - refined_bboxes.append([x1_ref, y1_ref, x2_ref, y2_ref, conf, class_id]) - - xyxy = np.array([b[:4] for b in refined_bboxes]) - confidences = np.array([b[4] for b in refined_bboxes]) - class_ids = np.array([b[5] for b in refined_bboxes]) - detections_sv = sv.Detections(xyxy=xyxy, confidence=confidences, class_id=class_ids) - detection_with_tracks = self.tracker.update_with_detections(detections_sv) + detection_with_tracks = self.tracker.update_with_detections(detection_sv) + tracks.append({}) for frame_detection in detection_with_tracks: @@ -85,8 +36,5 @@ def get_object_tracks(self, vid_frames): if class_id == class_names_inv["Player"]: tracks[frame_id][track_id] = {"bbox": bbox} - - - return tracks - + return tracks \ No newline at end of file