In [1]:
import torch
import cv2
import numpy as np
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader

# --- This new __getitem__ is the key change ---
class PennFudanDatasetV2(Dataset):
    def __init__(self, root, transform=None):
        self.root = root
        self.transform = transform
        self.images = sorted([os.path.join(root, "PNGImages", f) for f in os.listdir(os.path.join(root, "PNGImages"))])
        self.masks = sorted([os.path.join(root, "PedMasks", f) for f in os.listdir(os.path.join(root, "PedMasks"))])

    def __getitem__(self, idx):
        # Load image and mask
        img_path = self.images[idx]
        mask_path = self.masks[idx]
        img = Image.open(img_path).convert("RGB")
        mask = np.array(Image.open(mask_path))

        # Instances are encoded as different colors
        obj_ids = np.unique(mask)
        # First id is the background, so remove it
        obj_ids = obj_ids[1:]

        # Split the color-encoded mask into a set of binary masks
        masks = mask == obj_ids[:, None, None]

        # Get bounding box coordinates for each mask
        num_objs = len(obj_ids)
        boxes = []
        for i in range(num_objs):
            pos = np.where(masks[i])
            xmin = np.min(pos[1])
            xmax = np.max(pos[1])
            ymin = np.min(pos[0])
            ymax = np.max(pos[0])
            boxes.append([xmin, ymin, xmax, ymax])

        # Convert everything into torch.Tensors
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        # There is only one class (person)
        labels = torch.ones((num_objs,), dtype=torch.int64)
        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        iscrowd = torch.zeros((num_objs,), dtype=torch.int64)

        # ðŸŽ¯ Create the target dictionary
        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd

        if self.transform:
            img = self.transform(img)

        return img, target

    def __len__(self):
        return len(self.images)

In [2]:
import torchvision.transforms as transforms

def collate_fn(batch):
    return tuple(zip(*batch))

# Define the image transformation
transform = transforms.Compose([
    transforms.ToTensor()
])

In [3]:
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

def get_detection_model(num_classes):
    # Load a Faster R-CNN model
    # We use a pre-defined architecture, but train it from scratch (weights=None)
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=None)

    # Get the number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    
    # Replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    
    return model

In [4]:
# --- Setup ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes = 2 # 1 class (person) + background
model = get_detection_model(num_classes).to(device)

# Optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)

# Dataset and DataLoader
dataset = PennFudanDatasetV2(root='PennFudanPed', transform=transform)
data_loader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)

# --- Training Loop ---
print("--- Starting Multi-Person Model Training ---")
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0
    for images, targets in data_loader:
        images = list(img.to(device) for img in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()
        
        epoch_loss += losses.item()
    
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss/len(data_loader):.4f}")

print("--- Training Finished ---")

# Save the trained model
torch.save(model.state_dict(), 'multi_person_detector.pth')
print("Model saved to multi_person_detector.pth")

--- Starting Multi-Person Model Training ---
Epoch 1/10, Loss: 0.4268
Epoch 2/10, Loss: 0.2359
Epoch 3/10, Loss: 0.2040
Epoch 4/10, Loss: 0.1883
Epoch 5/10, Loss: 0.1827
Epoch 6/10, Loss: 0.1626
Epoch 7/10, Loss: 0.1388
Epoch 8/10, Loss: 0.1367
Epoch 9/10, Loss: 0.1284
Epoch 10/10, Loss: 0.1145
--- Training Finished ---
Model saved to multi_person_detector.pth


In [5]:
# --- Video Processing ---
from collections import OrderedDict, defaultdict
# Re-use the CentroidTracker class from your previous code
# (No changes needed to the CentroidTracker itself)
class CentroidTracker:
    # ... (paste the CentroidTracker class code here) ...
    def __init__(self, max_disappeared=50):
        self.next_object_id = 0
        self.objects = OrderedDict()
        self.disappeared = OrderedDict()
        self.max_disappeared = max_disappeared

    def register(self, centroid):
        self.objects[self.next_object_id] = centroid
        self.disappeared[self.next_object_id] = 0
        self.next_object_id += 1

    def deregister(self, object_id):
        del self.objects[object_id]
        del self.disappeared[object_id]

    def update(self, rects):
        if len(rects) == 0:
            for object_id in list(self.disappeared.keys()):
                self.disappeared[object_id] += 1
                if self.disappeared[object_id] > self.max_disappeared:
                    self.deregister(object_id)
            return self.objects

        input_centroids = np.zeros((len(rects), 2), dtype=int)
        for (i, (x1, y1, x2, y2)) in enumerate(rects):
            cX = int((x1 + x2) / 2.0)
            cY = int((y1 + y2) / 2.0)
            input_centroids[i] = (cX, cY)

        if len(self.objects) == 0:
            for i in range(len(input_centroids)):
                self.register(input_centroids[i])
        else:
            object_ids = list(self.objects.keys())
            object_centroids = list(self.objects.values())
            D = np.linalg.norm(np.array(object_centroids)[:, None] - input_centroids[None, :], axis=2)
            rows = D.min(axis=1).argsort()
            cols = D.argmin(axis=1)[rows]
            used_rows, used_cols = set(), set()
            for (row, col) in zip(rows, cols):
                if row in used_rows or col in used_cols:
                    continue
                object_id = object_ids[row]
                self.objects[object_id] = input_centroids[col]
                self.disappeared[object_id] = 0
                used_rows.add(row)
                used_cols.add(col)
            unused_rows = set(range(D.shape[0])) - used_rows
            for row in unused_rows:
                object_id = object_ids[row]
                self.disappeared[object_id] += 1
                if self.disappeared[object_id] > self.max_disappeared:
                    self.deregister(object_id)
            unused_cols = set(range(D.shape[1])) - used_cols
            for col in unused_cols:
                self.register(input_centroids[col])
        return self.objects

# --- Load the trained multi-person model ---
model.load_state_dict(torch.load('multi_person_detector.pth'))
model.eval()

tracker = CentroidTracker(max_disappeared=50)
VIDEO_SOURCE = "16.avi" # Your video file
cap = cv2.VideoCapture(VIDEO_SOURCE)

while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        break
    
    # Prepare frame for model
    img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
    img_tensor = transform(img).to(device)
    
    with torch.no_grad():
        prediction = model([img_tensor])

    # Extract boxes and scores
    boxes = prediction[0]['boxes'].cpu().numpy()
    scores = prediction[0]['scores'].cpu().numpy()
    
    # Filter out weak detections
    rects = boxes[scores > 0.7].astype(int)
    
    # Update the tracker with ALL detected boxes
    objects = tracker.update(rects)
    
    # Draw boxes and IDs for each tracked object
    for (object_id, centroid) in objects.items():
        # Find the bounding box for this tracked object
        for x1, y1, x2, y2 in rects:
            if x1 < centroid[0] < x2 and y1 < centroid[1] < y2:
                cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
                text = f"ID {object_id}"
                cv2.putText(frame, text, (centroid[0] - 10, centroid[1] - 10),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
                break
                
    cv2.imshow('Multi-Person Tracking', frame)
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

cap.release()
cv2.destroyAllWindows()