In [21]:
import torch
import torchvision
from torchvision.models.detection import keypointrcnn_resnet50_fpn
import torch
from torchvision.models.detection import keypointrcnn_resnet50_fpn
from torchvision.transforms import functional as F
from PIL import Image
import os
import json
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
import cv2

In [19]:
print(torch.__version__)
model = keypointrcnn_resnet50_fpn(pretrained=True)

2.1.0


In [None]:
input_dir = "heatmaps/dataset/train"
output_json = "heatmaps/dataset/annotations/keypoints_train.json"
visualize = False
MIN_SCORE = 0.8 


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model.eval()

os.makedirs(os.path.dirname(output_json), exist_ok=True)

#Dictionary with the annotations (keypoints) for each image
annotations = {}

for filename in os.listdir(input_dir):
    if filename.lower().endswith((".jpg", ".png")):
        path = os.path.join(input_dir, filename)
        img = Image.open(path).convert("RGB")
        img_tensor = F.to_tensor(img).unsqueeze(0).to(device)

        with torch.no_grad():
            output = model(img_tensor)[0]

        keypoints_all = []
        for i in range(len(output["keypoints"])):
            score = output["scores"][i].item()
            if score >= MIN_SCORE:
                kp = output["keypoints"][i][:, :2].cpu().numpy().tolist()
                keypoints_all.append([[int(x), int(y)] for x, y in kp])

        annotations[filename] = {"keypoints": keypoints_all}

        if visualize:
            fig, ax = plt.subplots(1)
            ax.imshow(img)
            colors = ['r', 'g', 'b', 'y', 'c', 'm']
            for idx, person in enumerate(keypoints_all):
                for x, y in person:
                    ax.add_patch(patches.Circle((x, y), 3, color=colors[idx % len(colors)]))
            ax.set_title(filename)
            plt.axis("off")
            plt.show()

#Save the keypoints
with open(output_json, "w") as f:
    json.dump(annotations, f, indent=2)


In [22]:
def generate_combined_heatmap(image_shape, keypoints, sigma=20):
    heatmap = np.zeros(image_shape, dtype=np.float32)
    for x, y in keypoints:
        if x < 0 or y < 0 or x >= image_shape[1] or y >= image_shape[0]:
            continue
        temp = np.zeros(image_shape, dtype=np.float32)
        temp[int(y), int(x)] = 1
        temp = cv2.GaussianBlur(temp, (0, 0), sigma)
        temp = temp / temp.max()
        heatmap += temp
    return np.clip(heatmap, 0, 1)

def overlay_and_save_heatmap(image_path, heatmap, output_path, alpha=0.5):
    img = Image.open(image_path).convert("RGB")
    img_np = np.array(img).astype(np.float32) / 255.0

    heatmap_resized = cv2.resize(heatmap, (img_np.shape[1], img_np.shape[0]))
    heatmap_color = cv2.applyColorMap(np.uint8(255 * heatmap_resized), cv2.COLORMAP_JET)
    heatmap_color = cv2.cvtColor(heatmap_color, cv2.COLOR_BGR2RGB) / 255.0

    overlay = (1 - alpha) * img_np + alpha * heatmap_color
    overlay = np.clip(overlay, 0, 1)

    plt.imsave(output_path, overlay)

In [25]:
def process_dataset(images_folder, annotation_path, output_folder, sigma=20):
    os.makedirs(output_folder, exist_ok=True)

    with open(annotation_path, 'r') as f:
        annotations = json.load(f)

    for image_name, data in annotations.items():
        keypoints_nested = data["keypoints"]
        keypoints = [kp for person in keypoints_nested for kp in person]

        image_path = os.path.join(images_folder, image_name)

        try:
            with Image.open(image_path) as img:
                img = img.convert("RGB")
                image_shape = img.size[::-1]  # (height, width)

                heatmap = generate_combined_heatmap(image_shape, keypoints, sigma)

                # Save raw .npy heatmap
                npy_path = os.path.join(output_folder, image_name.replace('.png', '_heatmap.npy'))
                np.save(npy_path, heatmap)

                # Save overlay as image
                overlay_path = os.path.join(output_folder, image_name.replace('.png', '_overlay.png'))
                overlay_and_save_heatmap(image_path, heatmap, overlay_path)

        except FileNotFoundError:
            print(f"Image not found: {image_path}")

    print(f"✔ All heatmaps saved in: {output_folder}")

In [27]:
process_dataset(
    images_folder="heatmaps/dataset/train/",
    annotation_path="heatmaps/dataset/annotations/keypoints_train.json",
    output_folder="heatmaps/dataset/heatmaps",
    sigma=20 
)