Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 9 additions & 22 deletions team_assigner/team_assigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,19 @@ def __init__(self,
team_B= "DARK-BLUE shirt",
history_len = 50,
crop_factor = 0.375,
save_imgs = False
save_imgs = False,
crop = False
):
self.team_colors = {}
self.history_len = history_len
self.player_team_cache_history = defaultdict(lambda: deque(maxlen=history_len))

self.crop_factor = crop_factor
self.save_imgs = save_imgs
self.crop = crop

self.team_A = team_A
self.team_B = team_B
self.sam2 = SAM("sam2.1_b.pt")

def load_model(self):
self.model = CLIPModel.from_pretrained("patrickjohncyh/fashion-clip")
Expand All @@ -47,30 +48,20 @@ def crop_img(self, pil_image):
def get_player_color(self,frame,bbox):
image = frame[int(bbox[1]):int(bbox[3]), int(bbox[0]):int(bbox[2])]

results_list = self.sam2(image, bboxes=[0, 0, int(bbox[2])-int(bbox[0]), int(bbox[3])-int(bbox[1])])
masks_obj = results_list[0].masks # Masks object
pil_image = Image.fromarray(image) # pytorch expects this format

if masks_obj is not None and len(masks_obj) > 0:
mask_tensor = masks_obj.data # torch.Tensor of shape (1, H, W)
mask_numpy = mask_tensor[0].cpu().numpy().astype(np.uint8)

blurred = cv2.GaussianBlur(image, (21,21), 0)

masked_img = np.where(mask_numpy[..., None] == 1, image, blurred)
rgb_image = cv2.cvtColor(masked_img, cv2.COLOR_BGR2RGB) # blurred img
pil_image = Image.fromarray(rgb_image) # pytorch expects this format

# Crop the img around the center.
cropped_pil_image = self.crop_img(pil_image)
if self.crop:
# Crop the img around the center.
pil_image = self.crop_img(pil_image)

if self.save_imgs:
r = random.randint(1, 1000000)
filename = f"masked_{r}.png"
cropped_pil_image.save(os.path.join("imgs/masked", filename))
pil_image.save(os.path.join("imgs/masked", filename))

team_classes = [self.team_A, self.team_B]

inputs = self.processor(text=team_classes, images=cropped_pil_image, return_tensors="pt", padding=True)
inputs = self.processor(text=team_classes, images=pil_image, return_tensors="pt", padding=True)

outputs = self.model(**inputs)
logits_per_image = outputs.logits_per_image
Expand All @@ -93,10 +84,6 @@ def get_team_from_history(self, player_id):

def get_player_team(self,frame,player_bbox,player_id):

history = list(self.player_team_cache_history[player_id])
# if len(history) > self.history_len:
# return self.get_team_from_history(player_id)

player_color = self.get_player_color(frame,player_bbox)
self.player_team_cache_history[player_id].append(player_color)
team_id = self.get_team_from_history(player_id)
Expand Down
62 changes: 5 additions & 57 deletions tracking/track_players.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,12 @@
from ultralytics import YOLO
import supervision as sv
import sys
from ultralytics import SAM
import numpy as np
import random

sys.path.append("../")

class PlayerTracker():
def __init__(self, model_path):
self.model = YOLO(model_path)
self.tracker = sv.ByteTrack()
self.sam2 = SAM("sam2.1_b.pt")

def detect_frames(self, vid_frames, batch_size=20, min_conf=0.5):
detections = []
Expand All @@ -21,61 +16,17 @@ def detect_frames(self, vid_frames, batch_size=20, min_conf=0.5):
detections += batch_detections
return detections

def mask_to_bbox(self, mask):
ys, xs = np.where(mask > 0)
if len(xs) == 0 or len(ys) == 0:
return [0, 0, mask.shape[1], mask.shape[0]]

x_min = int(xs.min())
x_max = int(xs.max())
y_min = int(ys.min())
y_max = int(ys.max())

return [x_min, y_min, x_max, y_max]

def get_object_tracks(self, vid_frames):

detections = self.detect_frames(vid_frames)

tracks = []
total = len(vid_frames)
for frame_id, frame in enumerate(vid_frames):
print(f"frame {frame_id}/{total}")
detection = detections[frame_id]
for frame_id, detection in enumerate(detections):
class_names = detection.names
class_names_inv = {val:key for key,val in class_names.items()}
detection_sv = sv.Detections.from_ultralytics(detection)

refined_bboxes = []
for det in detection_sv:
bbox_array = det[0]
x1, y1, x2, y2 = map(int, bbox_array)
conf = float(det[2])
class_id = int(det[3])

if class_names_inv['Player'] != class_id:
continue # skip non-player detections

player_crop = frame[y1:y2, x1:x2]

# SAM2 mask
results = self.sam2(player_crop, bboxes=[[0,0,x2-x1, y2-y1]])
mask = results[0].masks.data[0].cpu().numpy().astype(np.uint8)

x_min_rel, y_min_rel, x_max_rel, y_max_rel = self.mask_to_bbox(mask)

x1_ref = x1 + x_min_rel
y1_ref = y1 + y_min_rel
x2_ref = x1 + x_max_rel
y2_ref = y1 + y_max_rel

refined_bboxes.append([x1_ref, y1_ref, x2_ref, y2_ref, conf, class_id])

xyxy = np.array([b[:4] for b in refined_bboxes])
confidences = np.array([b[4] for b in refined_bboxes])
class_ids = np.array([b[5] for b in refined_bboxes])
detections_sv = sv.Detections(xyxy=xyxy, confidence=confidences, class_id=class_ids)
detection_with_tracks = self.tracker.update_with_detections(detections_sv)
detection_with_tracks = self.tracker.update_with_detections(detection_sv)

tracks.append({})

for frame_detection in detection_with_tracks:
Expand All @@ -85,8 +36,5 @@ def get_object_tracks(self, vid_frames):

if class_id == class_names_inv["Player"]:
tracks[frame_id][track_id] = {"bbox": bbox}


return tracks


return tracks