# 1. Imports

In [None]:
import os, json, cv2, numpy as np, matplotlib.pyplot as plt

import torch
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision.models.detection.rpn import AnchorGenerator
from torchvision.transforms import functional as F

import albumentations as A # Library for augmentations

In [None]:
import transforms, utils, engine, train
from utils import collate_fn
from engine import train_one_epoch, evaluate

# 2. Augmentations

In [None]:
def train_transform():
    return A.Compose([
        A.Sequential([
            A.RandomRotate90(p=1), # Random rotation of an image by 90 degrees zero or more times
            A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, brightness_by_max=True, always_apply=False, p=1), # Random change of brightness & contrast
        ], p=1)
    ],
    keypoint_params=A.KeypointParams(format='xy'), # More about keypoint formats used in albumentations library read at https://albumentations.ai/docs/getting_started/keypoints_augmentation/
    bbox_params=A.BboxParams(format='pascal_voc', label_fields=['bboxes_labels']) # Bboxes should have labels, read more here https://albumentations.ai/docs/getting_started/bounding_boxes_augmentation/
    )

# 3. Dataset class

In [None]:
class ClassDataset(Dataset):
    def __init__(self, root, transform=None, demo=False):
        self.root = root
        self.transform = transform
        self.demo = demo

        self.img_folder = os.path.join(root, "images")
        self.annot_folder = os.path.join(root, "annotations")

        # Obtain the file names of all images and JSON files (excluding suffixes)
        # Filter out files starting with ". "
        all_img_names = {os.path.splitext(f)[0] for f in os.listdir(self.img_folder) if not f.startswith('.')}
        all_ann_names = {os.path.splitext(f)[0] for f in os.listdir(self.annot_folder) if not f.startswith('.')}

        # 2. Take the intersection: Only retain the file names that have both images and annotations
        self.valid_names = sorted(list(all_img_names & all_ann_names))

        # Print out the actual quantity loaded for easy debugging
        print(f"ðŸ“‚ loading {root} ...")
        print(f"   - Original image: {len(all_img_names)}, Original annotation: {len(all_ann_names)}")
        print(f"   - âœ… Effective pairing: {len(self.valid_names)}")

        # 3. Establish a mapping to facilitate the subsequent retrieval of complete file names with suffixes based on the names
        self.filename_map = {}
        for f in os.listdir(self.img_folder):
            name, ext = os.path.splitext(f)
            if name in self.valid_names:
                self.filename_map[name] = f

    def __getitem__(self, idx):
        # 1. Obtain the "base file name" corresponding to the current index (without a suffix)
        base_name = self.valid_names[idx]

        # 2. Construct an absolute path
        img_filename = self.filename_map[base_name]
        img_path = os.path.join(self.img_folder, img_filename)
        annotations_path = os.path.join(self.annot_folder, base_name + ".json")

        img_original = cv2.imread(img_path)
        img_original = cv2.cvtColor(img_original, cv2.COLOR_BGR2RGB)

        with open(annotations_path) as f:
            data = json.load(f)
            bboxes_original = data['bboxes']
            keypoints_original = data['keypoints']

            # Category name
            bboxes_labels_original = ['Tower' for _ in bboxes_original]

        if self.transform:
            keypoints_original_flattened = [el[0:2] for kp in keypoints_original for el in kp]

            transformed = self.transform(image=img_original, bboxes=bboxes_original, bboxes_labels=bboxes_labels_original, keypoints=keypoints_original_flattened)
            img = transformed['image']
            bboxes = transformed['bboxes']

            # Automatically adapt the number of key points
            num_kps = len(keypoints_original[0])
            keypoints_transformed_unflattened = np.reshape(np.array(transformed['keypoints']), (-1, num_kps, 2)).tolist()

            keypoints = []
            for o_idx, obj in enumerate(keypoints_transformed_unflattened):
                obj_keypoints = []
                for k_idx, kp in enumerate(obj):
                    # Protective measures: Prevent index out-of-bounds
                    if k_idx < len(keypoints_original[o_idx]):
                        visibility = keypoints_original[o_idx][k_idx][2]
                        obj_keypoints.append(kp + [visibility])
                    else:
                        obj_keypoints.append(kp + [1])
                keypoints.append(obj_keypoints)

        else:
            img, bboxes, keypoints = img_original, bboxes_original, keypoints_original

        bboxes = torch.as_tensor(bboxes, dtype=torch.float32)
        target = {}
        target["boxes"] = bboxes
        target["labels"] = torch.as_tensor([1 for _ in bboxes], dtype=torch.int64)
        target["image_id"] = torch.tensor([idx])
        target["area"] = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0])
        target["iscrowd"] = torch.zeros(len(bboxes), dtype=torch.int64)
        target["keypoints"] = torch.as_tensor(keypoints, dtype=torch.float32)
        img = F.to_tensor(img)

        bboxes_original = torch.as_tensor(bboxes_original, dtype=torch.float32)
        target_original = {}
        target_original["boxes"] = bboxes_original
        target_original["labels"] = torch.as_tensor([1 for _ in bboxes_original], dtype=torch.int64)
        target_original["image_id"] = torch.tensor([idx])
        target_original["area"] = (bboxes_original[:, 3] - bboxes_original[:, 1]) * (bboxes_original[:, 2] - bboxes_original[:, 0])
        target_original["iscrowd"] = torch.zeros(len(bboxes_original), dtype=torch.int64)
        target_original["keypoints"] = torch.as_tensor(keypoints_original, dtype=torch.float32)
        img_original = F.to_tensor(img_original)

        if self.demo:
            return img, target, img_original, target_original
        else:
            return img, target

    def __len__(self):
        # The length here is the length of the intersection and will never cross the boundary
        return len(self.valid_names)

# 4. Visualizing a random item from dataset

In [None]:
KEYPOINTS_FOLDER_TRAIN = 'pgtt_keypoints_dataset_imgs/drum/train'

dataset = ClassDataset(KEYPOINTS_FOLDER_TRAIN, transform=train_transform(), demo=True)
data_loader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=collate_fn)

iterator = iter(data_loader)
batch = next(iterator)

print("Original targets:\n", batch[3], "\n\n")
print("Transformed targets:\n", batch[1])

In [None]:
def visualize(image, bboxes, keypoints, image_original=None, bboxes_original=None, keypoints_original=None):
    fontsize = 18
    color_blue = (0, 0, 255)
    color_outline = (255, 255, 255)

    font_type = cv2.FONT_HERSHEY_DUPLEX
    font_scale = 0.5
    font_thickness = 1

    # ================= Helper Function: Draw Paired Keypoints =================
    def draw_pairs(img, kps):
        # We assume there are 22 points in total, divided into 11 pairs
        # Iterate through each level (level 1 to 11)
        # range(0, 22, 2) means taking indices 0, 2, 4, ... 20
        for i in range(0, len(kps), 2):
            if i + 1 >= len(kps): break # Prevent index out of bounds

            # Extract the pair of points
            pt1 = tuple(kps[i])     # Could be left or right
            pt2 = tuple(kps[i+1])   # Could be left or right

            # Current level (1, 2, ... 11)
            level = (i // 2) + 1

            # Core logic: Compare X coordinates
            # pt[0] is x, pt[1] is y
            if pt1[0] < pt2[0]:
                # pt1 is on the left, pt2 is on the right
                left_pt, right_pt = pt1, pt2
            else:
                # pt1 is on the right, pt2 is on the left (swap)
                left_pt, right_pt = pt2, pt1

            # Define labels
            label_left = f"l{level}"
            label_right = f"r{level}"

            # === Draw Left Point ===
            img = cv2.circle(img, left_pt, 3, color_blue, -1)
            cv2.putText(img, label_left, left_pt, font_type, font_scale, color_outline, font_thickness + 2, cv2.LINE_AA)
            cv2.putText(img, label_left, left_pt, font_type, font_scale, color_blue, font_thickness, cv2.LINE_AA)

            # === Draw Right Point ===
            img = cv2.circle(img, right_pt, 3, color_blue, -1)
            cv2.putText(img, label_right, right_pt, font_type, font_scale, color_outline, font_thickness + 2, cv2.LINE_AA)
            cv2.putText(img, label_right, right_pt, font_type, font_scale, color_blue, font_thickness, cv2.LINE_AA)

        return img

    # ================= Main Drawing Logic =================

    # 1. Draw Bounding Boxes
    for bbox in bboxes:
        start_point = (bbox[0], bbox[1])
        end_point = (bbox[2], bbox[3])
        image = cv2.rectangle(image.copy(), start_point, end_point, (0,255,0), 2)

    # 2. Draw Keypoints (Using new paired drawing logic)
    for kps in keypoints:
        image = draw_pairs(image, kps)

    # 3. Display Logic
    if image_original is None and keypoints_original is None:
        plt.figure(figsize=(20,20))
        plt.imshow(image)
        plt.axis('off')
        plt.show()
    else:
        # Comparison display for training set
        for bbox in bboxes_original:
            start_point = (bbox[0], bbox[1])
            end_point = (bbox[2], bbox[3])
            image_original = cv2.rectangle(image_original.copy(), start_point, end_point, (0,255,0), 2)

        for kps in keypoints_original:
            image_original = draw_pairs(image_original, kps)

        f, ax = plt.subplots(1, 2, figsize=(40, 20))
        ax[0].imshow(image_original)
        ax[0].set_title('Original image', fontsize=fontsize)
        ax[1].imshow(image)
        ax[1].set_title('Transformed image', fontsize=fontsize)

image = (batch[0][0].permute(1,2,0).numpy() * 255).astype(np.uint8)
bboxes = batch[1][0]['boxes'].detach().cpu().numpy().astype(np.int32).tolist()

keypoints = []
for kps in batch[1][0]['keypoints'].detach().cpu().numpy().astype(np.int32).tolist():
    keypoints.append([kp[:2] for kp in kps])

image_original = (batch[2][0].permute(1,2,0).numpy() * 255).astype(np.uint8)
bboxes_original = batch[3][0]['boxes'].detach().cpu().numpy().astype(np.int32).tolist()

keypoints_original = []
for kps in batch[3][0]['keypoints'].detach().cpu().numpy().astype(np.int32).tolist():
    keypoints_original.append([kp[:2] for kp in kps])

visualize(image, bboxes, keypoints, image_original, bboxes_original, keypoints_original)

# 5. Training

In [None]:
def get_model(num_keypoints, weights_path=None):

    # 1. Define custom Anchor Generator to handle various object scales and ratios
    anchor_generator = AnchorGenerator(sizes=(32, 64, 128, 256, 512), aspect_ratios=(0.25, 0.5, 0.75, 1.0, 2.0, 3.0, 4.0))

    # 2. Build the model with the custom configuration
    # num_classes=2 (Background + Drum tower)
    # num_keypoints=22 (Specific to Drum tower structure)
    model = torchvision.models.detection.keypointrcnn_resnet50_fpn(pretrained=False,
                                                                   pretrained_backbone=True,
                                                                   num_keypoints=num_keypoints,
                                                                   num_classes = 2,
                                                                   rpn_anchor_generator=anchor_generator)

    if weights_path:
        state_dict = torch.load(weights_path)
        model.load_state_dict(state_dict)        
        
    return model

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

KEYPOINTS_FOLDER_TRAIN = 'pgtt_keypoints_dataset_imgs/drum/train'
KEYPOINTS_FOLDER_TEST = 'pgtt_keypoints_dataset_imgs/drum/test'

dataset_train = ClassDataset(KEYPOINTS_FOLDER_TRAIN, transform=train_transform(), demo=False)
dataset_test = ClassDataset(KEYPOINTS_FOLDER_TEST, transform=None, demo=False)

data_loader_train = DataLoader(dataset_train, batch_size=4, shuffle=True, collate_fn=collate_fn)
data_loader_test = DataLoader(dataset_test, batch_size=1, shuffle=False, collate_fn=collate_fn)

model = get_model(num_keypoints = 22)
model.to(device)

params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.002, momentum=0.9, weight_decay=0.0005)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=8, gamma=0.3)
num_epochs = 20

for epoch in range(num_epochs):
    train_one_epoch(model, optimizer, data_loader_train, device, epoch, print_freq=1000)
    lr_scheduler.step()
    evaluate(model, data_loader_test, device)
    
# Save model weights after training
torch.save(model.state_dict(), 'keypointsrcnn_weights.pth')

# 6. Visualizing model predictions

In [None]:
iterator = iter(data_loader_test)

In [None]:
images, targets = next(iterator)
images = list(image.to(device) for image in images)

with torch.no_grad():
    model.to(device)
    model.eval()
    output = model(images)

print("Predictions: \n", output)

In [None]:
image = (images[0].permute(1,2,0).detach().cpu().numpy() * 255).astype(np.uint8)
scores = output[0]['scores'].detach().cpu().numpy()

high_scores_idxs = np.where(scores > 0.7)[0].tolist() # Indexes of boxes with scores > 0.7
post_nms_idxs = torchvision.ops.nms(output[0]['boxes'][high_scores_idxs], output[0]['scores'][high_scores_idxs], 0.3).cpu().numpy() # Indexes of boxes left after applying NMS (iou_threshold=0.3)

keypoints = []
for kps in output[0]['keypoints'][high_scores_idxs][post_nms_idxs].detach().cpu().numpy():
    keypoints.append([list(map(int, kp[:2])) for kp in kps])

bboxes = []
for bbox in output[0]['boxes'][high_scores_idxs][post_nms_idxs].detach().cpu().numpy():
    bboxes.append(list(map(int, bbox.tolist())))
    
visualize(image, bboxes, keypoints)