In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import torchvision.models as models
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import os
import random
import json
from collections import defaultdict
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torch.optim.lr_scheduler import ReduceLROnPlateau

## Creating ResNet network and modifying it

In [3]:
class KeypointHeatmapDataset(Dataset):
    def __init__(self, img_dir, hm_dir, img_ext=".png", hm_ext=".pt"):
        self.img_dir = img_dir
        self.hm_dir  = hm_dir
        self.img_ext = img_ext
        self.hm_ext = hm_ext

        all_names = sorted([
            os.path.splitext(f)[0]
            for f in os.listdir(self.img_dir)
            if f.endswith(self.img_ext)
        ])

        #splitting the data based on the first two numbers
        groups = defaultdict(list)
        for name in all_names:
            obj_id = name.split("_")[0]
            groups[obj_id].append(name)

        total     = len(all_names)
        ratio     = 1
        target_n  = int(total * ratio)

        floors, remainders = {}, {}
        for obj_id, lst in groups.items():
            raw        = len(lst) * ratio
            fl         = int(raw)
            floors[obj_id]    = fl
            remainders[obj_id] = raw - fl

        sum_floor = sum(floors.values())
        leftover  = target_n - sum_floor
        for obj_id in sorted(remainders, key=lambda x: remainders[x], reverse=True)[:leftover]:
            floors[obj_id] += 1

        chosen = []
        random.seed(42)
        for obj_id, lst in groups.items():
            k = floors[obj_id]
            chosen.extend(random.sample(lst, k))

        self.basenames = sorted(chosen)

      #   self.basenames = [
      #   os.path.splitext(f)[0]
      #   for f in os.listdir(self.img_dir)
      #   if f.endswith(self.img_ext)
      # ]

        self.transform = transforms.Compose([
            transforms.ToTensor(),
            # normalizing the pictures
            transforms.Normalize(
                mean=[0.485,0.456,0.406],
                std =[0.229,0.224,0.225]
            ),
        ])

        img_basenames = {
          os.path.splitext(f)[0]
          for f in os.listdir(self.img_dir)
          if f.endswith(self.img_ext)
        }
        hm_basenames = {
            os.path.splitext(f)[0]
            for f in os.listdir(self.hm_dir)
            if f.endswith(self.hm_ext)
        }
        missing_in_hm = img_basenames - hm_basenames
        missing_in_img = hm_basenames - img_basenames
        print("Missing heatmaps for images:", missing_in_hm)
        print("Missing images for heatmaps:", missing_in_img)


    def __len__(self):
        return len(self.basenames)

    def __getitem__(self, idx):
        name = self.basenames[idx]
        img_path = os.path.join(self.img_dir, name + self.img_ext)
        img = Image.open(img_path).convert("RGB")
        img = self.transform(img)        # [3,256,256] tensor/image with 3 colors and dimensions 256x256
        hm_path = os.path.join(self.hm_dir, name + self.hm_ext)
        heatmaps = torch.load(hm_path)   # [50,64,64], 50 heatmaps with dimensions 64x64

        return img, heatmaps

# this makes the inputs for resnet network in the right format


Definition of the DataLoader

In [36]:
img_folder = "/content/drive/MyDrive/MLDL/6D-Pose-Estimation/data/cropped_resized_data"
hm_folder  = "/content/drive/MyDrive/MLDL/6D-Pose-Estimation/data/point_sampling_data/heatmaps_sigma_2/heatmaps_fps"

dataset = KeypointHeatmapDataset(
    img_dir=img_folder,
    hm_dir=hm_folder,
    img_ext=".png",
    hm_ext=".pt"
)

loader = DataLoader(
  dataset,
  batch_size=16,
  shuffle=True,
  num_workers=2
)

Missing heatmaps for images: set()
Missing images for heatmaps: {'05_0724', '08_0717', '02_0680', '05_0299', '15_0029', '02_0647'}


In [26]:
#checking
imgs, hms = next(iter(loader))
print(imgs.shape) # torch.Size([16, 3, 256, 256])
print(hms.shape) # torch.Size([16, 50, 64, 64])

torch.Size([16, 3, 256, 256])
torch.Size([16, 50, 64, 64])


Taking pretrained ResNet-101, remove the final pooling and fully-connected layers, add new head of layers and give output that has format [batch_size, num_keypoints, 64, 64]

In [37]:
class HeatmapHead(nn.Module):
    def __init__(self, num_keypoints, in_channels):
        super().__init__()
        # three convolutional layers: 8×8 → 16×16 → 32×32 → 64×64
        #resnet initially makes the image smaller by 32 so in our case the picture becomes 8x8
        self.deconv = nn.Sequential(
            # 1st deconv: 2048→256 channels, doubles spatial size (8→16)
            nn.ConvTranspose2d(in_channels, 256, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),

            # 2nd deconv: 256→256 channels, doubles spatial size (16→32)
            nn.ConvTranspose2d(256, 256, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),

            # 3rd deconv: 256→256 channels, doubles spatial size (32→64)
            nn.ConvTranspose2d(256, 256, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
        )
        # Final 1×1 conv: 256→num_keypoints channels, keeps spatial size 64×64
        self.final = nn.Conv2d(256, num_keypoints, kernel_size=1)

    def forward(self, x):
        x = self.deconv(x)  #[B,256,64,64]
        x = self.final(x)   #[B,num_keypoints,64,64]
        return x

In [38]:
class KeypointHeatmapNet(nn.Module):
    def __init__(self, num_keypoints=50):
        super().__init__()
        #loading the pretrained ResNet-50 and striping off last pool+fc
        backbone = models.resnet18(pretrained=True)
        self.backbone = nn.Sequential(*list(backbone.children())[:-2])

        #attach the layers we want to add
        self.head = HeatmapHead(num_keypoints, in_channels=512)

    def forward(self, x):
        # x: [B, 3, 256, 256], B is batch size
        feat = self.backbone(x)     # feat: [B, 2048, 8, 8]
        heatmaps = self.head(feat)  # heatmaps: [B, 50, 64, 64]
        return heatmaps

## Defining the model, loss function and optimizer

In [39]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [40]:
#to check if we are using cuda
print("Using device:", device)
if device.type == "cuda":
    print("GPU name:", torch.cuda.get_device_name(0))

Using device: cuda
GPU name: Tesla T4


Spliting the data

In [41]:
n_total = len(dataset)
n_train = int(0.8 * n_total)
n_val = n_total - n_train

train_ds, val_ds = random_split(
  dataset,
  [n_train, n_val],
  generator=torch.Generator().manual_seed(8)
)

Instance the DataLoaders, model, loss function and optimizers

In [42]:
#definition of a loss function
import torch.nn.functional as F

def focal_heatmap_loss(pred, gt, alpha=2.0, gamma=4.0, eps=1e-6):
    """
    pred: raw network output, shape [B,K,H,W]
    gt:   ground-truth heatmaps in [0..1], same shape [B,K,H,W]
    """
    # 1) pretvori output u “confidence” 0..1
    p = torch.sigmoid(pred)

    # 2) positive term: fokus na gt==1 region
    pos = - alpha * (1 - p)**gamma * gt * torch.log(p + eps)

    # 3) negative term: background
    neg = - (1 - gt) * torch.log(1 - p + eps)

    return (pos + neg).mean()

In [35]:
import os
os.cpu_count()

2

In [43]:
train_loader = DataLoader(train_ds, batch_size=16, shuffle=True,  num_workers=4, pin_memory=False)
val_loader = DataLoader(val_ds,   batch_size=16, shuffle=False, num_workers=4, pin_memory=False)

model = KeypointHeatmapNet(num_keypoints=50).to(device)
criterion = focal_heatmap_loss
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                mode='min', factor=0.5, patience=5, verbose=True)

early_stop_patience = 10

In [44]:
print(len(train_loader))
len(val_loader)

711


178

Smoke test for one batch (contains forward-pass, loss calculation and error check)

In [45]:
model.train()
preds = model(imgs.to(device))  #forward
loss = criterion(preds, hms.to(device))
print("Smoke test - loss:", loss.item())

Smoke test - loss: 0.7046369910240173


In [46]:
from tqdm import tqdm

# for visualizations
train_losses = []
val_losses   = []

num_epochs = 25
best_val_loss = float("inf")

for epoch in range(1, num_epochs+1):
    model.train()
    running_train_loss = 0.0

    # tqdm for training
    for imgs, gt_maps in tqdm(train_loader, desc=f"Epoch {epoch} - Training", leave=False):
        imgs, gt_maps = imgs.to(device), gt_maps.to(device)

        # forward
        preds = model(imgs)
        loss  = criterion(preds, gt_maps)

        # backward + step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_train_loss += loss.item() * imgs.size(0)

    epoch_train_loss = running_train_loss / len(train_loader.dataset)
    train_losses.append(epoch_train_loss)

    model.eval()
    running_val_loss = 0.0
    with torch.no_grad():
        # tqdm for validation
        for imgs, gt_maps in tqdm(val_loader, desc=f"Epoch {epoch} - Validation", leave=False):
            imgs, gt_maps = imgs.to(device), gt_maps.to(device)
            preds = model(imgs)
            running_val_loss += criterion(preds, gt_maps).item() * imgs.size(0)

    epoch_val_loss = running_val_loss / len(val_loader.dataset)
    val_losses.append(epoch_val_loss)

    # early stop
    scheduler.step(epoch_val_loss)

    if epoch_val_loss < best_val_loss:
        best_val_loss = epoch_val_loss
        no_improve = 0
        torch.save(model.state_dict(), "best_fps_model_50_keypoints_focal_hm.pth")
    else:
        no_improve += 1

    if no_improve >= early_stop_patience:
        print(f"Early stopping at epoch {epoch} (no improvement for {early_stop_patience} epochs)")
        break

    print(f"Epoch {epoch:02d}/{num_epochs}  "
          f"Train Loss: {epoch_train_loss:.4f}  "
          f"Val Loss:   {epoch_val_loss:.4f}")

    if epoch_val_loss < best_val_loss:
        best_val_loss = epoch_val_loss
        torch.save(model.state_dict(), "best_fps_model_50_keypoints_focal_hm.pth")

    print(f"Finished Epoch {epoch}")

print("Training finished!")



Epoch 01/25  Train Loss: 0.1479  Val Loss:   0.0609
Finished Epoch 1




Epoch 02/25  Train Loss: 0.0415  Val Loss:   0.0280
Finished Epoch 2




Epoch 03/25  Train Loss: 0.0227  Val Loss:   0.0204
Finished Epoch 3




Epoch 04/25  Train Loss: 0.0175  Val Loss:   0.0175
Finished Epoch 4




Epoch 05/25  Train Loss: 0.0150  Val Loss:   0.0157
Finished Epoch 5




Epoch 06/25  Train Loss: 0.0135  Val Loss:   0.0146
Finished Epoch 6




Epoch 07/25  Train Loss: 0.0125  Val Loss:   0.0138
Finished Epoch 7




Epoch 08/25  Train Loss: 0.0118  Val Loss:   0.0133
Finished Epoch 8




Epoch 09/25  Train Loss: 0.0113  Val Loss:   0.0132
Finished Epoch 9




Epoch 10/25  Train Loss: 0.0110  Val Loss:   0.0127
Finished Epoch 10




Epoch 11/25  Train Loss: 0.0108  Val Loss:   0.0126
Finished Epoch 11




Epoch 12/25  Train Loss: 0.0107  Val Loss:   0.0125
Finished Epoch 12




Epoch 13/25  Train Loss: 0.0106  Val Loss:   0.0123
Finished Epoch 13




Epoch 14/25  Train Loss: 0.0105  Val Loss:   0.0123
Finished Epoch 14




Epoch 15/25  Train Loss: 0.0104  Val Loss:   0.0123
Finished Epoch 15




Epoch 16/25  Train Loss: 0.0104  Val Loss:   0.0123
Finished Epoch 16




Epoch 17/25  Train Loss: 0.0103  Val Loss:   0.0123
Finished Epoch 17




Epoch 18/25  Train Loss: 0.0102  Val Loss:   0.0121
Finished Epoch 18




Epoch 19/25  Train Loss: 0.0102  Val Loss:   0.0121
Finished Epoch 19




Epoch 20/25  Train Loss: 0.0102  Val Loss:   0.0123
Finished Epoch 20




Epoch 21/25  Train Loss: 0.0101  Val Loss:   0.0121
Finished Epoch 21




Epoch 22/25  Train Loss: 0.0101  Val Loss:   0.0123
Finished Epoch 22




Epoch 23/25  Train Loss: 0.0101  Val Loss:   0.0120
Finished Epoch 23




Epoch 24/25  Train Loss: 0.0100  Val Loss:   0.0121
Finished Epoch 24




Epoch 25/25  Train Loss: 0.0100  Val Loss:   0.0120
Finished Epoch 25
Training finished!


Save the model

In [47]:
import torch
SAVE_PATH = "/content/drive/MyDrive/MLDL/6D-Pose-Estimation/models/resnet/resnet18_fps_model_50_keypoints_focalhm.pth"
torch.save(model.state_dict(), SAVE_PATH)

In [None]:
# #for visualizations
# train_losses = []
# val_losses   = []

# num_epochs = 25
# best_val_loss = float("inf")

# for epoch in range(1, num_epochs+1):
#     model.train()
#     running_train_loss = 0.0

#     for imgs, gt_maps in train_loader:
#         imgs, gt_maps = imgs.to(device), gt_maps.to(device)

#         # forward
#         preds = model(imgs)
#         loss  = criterion(preds, gt_maps)

#         # backward + step
#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()

#         running_train_loss += loss.item() * imgs.size(0)

#     epoch_train_loss = running_train_loss / len(train_loader.dataset)
#     train_losses.append(epoch_train_loss)

#     model.eval()
#     running_val_loss = 0.0
#     with torch.no_grad():
#         for imgs, gt_maps in val_loader:
#             imgs, gt_maps = imgs.to(device), gt_maps.to(device)
#             preds = model(imgs)
#             running_val_loss += criterion(preds, gt_maps).item() * imgs.size(0)

#     epoch_val_loss = running_val_loss / len(val_loader.dataset)
#     val_losses.append(epoch_val_loss)

#     #early stop
#     scheduler.step(epoch_val_loss)

#     if epoch_val_loss < best_val_loss:
#         best_val_loss = epoch_val_loss
#         no_improve = 0
#         torch.save(model.state_dict(), "best_cps_model_50_keypoints_focal_hm.pth")
#     else:
#         no_improve += 1

#     if no_improve >= early_stop_patience:
#         print(f"Early stopping at epoch {epoch} (no improvement for {early_stop_patience} epochs)")
#         break

#     print(f"Epoch {epoch:02d}/{num_epochs}  "
#           f"Train Loss: {epoch_train_loss:.4f}  "
#           f"Val Loss:   {epoch_val_loss:.4f}")

#     if epoch_val_loss < best_val_loss:
#         best_val_loss = epoch_val_loss
#         torch.save(model.state_dict(), "best_cps_model_50_keypoints_focal_hm.pth")

#     print(f"Finished Epoch {epoch}")

# print("Training finished!")

In [None]:
# # --- 1) Plot train i val loss-a ---
# epochs = range(1, len(train_losses) + 1)

# plt.figure(figsize=(8,5))
# plt.plot(epochs, train_losses, label='Train Loss')
# plt.plot(epochs, val_losses,   label='Val Loss')
# plt.xlabel('Epoch')
# plt.ylabel('Loss')
# plt.title('Training vs. Validation Loss')
# plt.legend()
# plt.grid(True)
# plt.show()

# model.eval()

# mean = torch.tensor([0.485, 0.456, 0.406]).view(3,1,1)
# std  = torch.tensor([0.229, 0.224, 0.225]).view(3,1,1)

# imgs, gt_maps = next(iter(val_loader))
# imgs, gt_maps = imgs[:3], gt_maps[:3]

# with torch.no_grad():
#     preds = model(imgs.to(device)).cpu()

# imgs = imgs.cpu()

import numpy as np

def get_keypoints_from_heatmaps(hm_np):
    """
    hm_np: numpy array with shape [K, H, W]
    Returns array of shape [K, 2] with (x64, y64).
    """
    coords = []
    K, H, W = hm_np.shape
    for i in range(K):
        flat_idx = np.argmax(hm_np[i])
        y, x = divmod(flat_idx, W)
        coords.append((x, y))
    return np.array(coords)

scale = 4  # 64 → 256

# За сваког од првих 3 примера
for i in range(3):
    # Denormalize и конвертуј у H×W×C за plt
    img = imgs[i] * std + mean
    img = img.permute(1,2,0).numpy()

    # Припреми GT и предикцију у numpy формату
    gt_hm_np   = gt_maps[i].numpy()      # [50, 64, 64]
    pred_hm_np = preds[i].numpy()        # [50, 64, 64]

    # Извези 64×64 координате
    true_pts_64 = get_keypoints_from_heatmaps(gt_hm_np)
    pred_pts_64 = get_keypoints_from_heatmaps(pred_hm_np)

    # Upscale на 256×256
    true_pts_256 = true_pts_64 * scale
    pred_pts_256 = pred_pts_64 * scale

    # Plot
    plt.figure(figsize=(4,4))
    plt.imshow(img)
    plt.scatter(true_pts_256[:, 0], true_pts_256[:, 1],
                c='lime', marker='o', label='GT')
    plt.scatter(pred_pts_256[:, 0], pred_pts_256[:, 1],
                c='red', marker='x',  label='Pred')
    plt.legend(loc='lower right')
    plt.axis('off')
    plt.show()


In [None]:
for hm in preds[0]:   # preds[0] has shape [50,64,64], iterating gives 50 arrays 64×64
    plt.imshow(hm.detach().numpy(), cmap='viridis')
    plt.show()

In [None]:
# preds.shape == (B, K, H, W)
batch_idx   = 1   # prva slika u batchu
keypoint_id = 17   # koju mapu gledaš

single_hm = preds[batch_idx, keypoint_id].detach().numpy()  # shape (64,64)
plt.figure(figsize=(4,4))
plt.imshow(single_hm, cmap='viridis')
plt.title(f"Image {batch_idx}, Keypoint {keypoint_id}")
plt.colorbar()
plt.show()

# Defining the JSON file with all keypoints

In [48]:
def get_all_keypoints(preds, hmap_size=64, img_size=256):
    B, K, H, W = preds.shape
    scale = img_size / hmap_size
    offset = scale / 2
    preds = preds.cpu()
    all_results = []

    for b in range(B):
        heatmaps = preds[b]  # [K, H, W]
        pts = []
        for idx in range(K):
            hm = heatmaps[idx]
            flat = hm.view(-1)
            pos = flat.argmax().item()
            y, x = divmod(pos, W)
            x = x * scale + offset
            y = y * scale + offset
            # conf = float(hm.view(-1).max())
            pts.append([x, y])
        all_results.append(pts)

    return all_results

subset = val_loader.dataset
orig_ds = subset.dataset
indices = subset.indices
batch_size = val_loader.batch_size

predictions = {}

for batch_i, (imgs, _) in enumerate(val_loader):
    with torch.no_grad():
        preds = model(imgs.to(device)).cpu()

    batch_keypoints = get_all_keypoints(preds)
    for i, pts in enumerate(batch_keypoints):
        orig_idx = indices[batch_i * batch_size + i]
        img_id = orig_ds.basenames[orig_idx]
        predictions[img_id] = pts

out_dir = "/content/drive/MyDrive/MLDL/6D-Pose-Estimation/data/predicted_key_points"
os.makedirs(out_dir, exist_ok=True)
out_path = os.path.join(out_dir, "2D_predicted_resnet18_keypoints50_fps_focalloss.json")

with open(out_path, "w") as f:
    json.dump(predictions, f, indent=2, sort_keys=True)

print(f"Saved {len(predictions)} entries to {out_path}")

Saved 2843 entries to /content/drive/MyDrive/MLDL/6D-Pose-Estimation/data/predicted_key_points/2D_predicted_resnet18_keypoints50_fps_focalloss.json


# Junk methods

## Choose the best 4 key points

In [None]:
def select_topk_keypoints_from_heatmaps(preds, topk=4, hmap_size=64, img_size=256):

    B, K, H, W = preds.shape
    assert H == W == hmap_size
    scale  = img_size / hmap_size
    offset = scale / 2

    preds_cpu = preds.cpu()
    results = []

    for i in range(B):
        hm = preds_cpu[i]  # Tensor[K, H, W]

        flat_confs = hm.view(K, -1)
        confs = flat_confs.max(dim=1).values

        topk_vals, topk_idx = torch.topk(confs, topk, largest=True)

        coords = []
        for idx in topk_idx:
            channel_map = hm[idx]
            flat_map = channel_map.view(-1)
            max_pos = flat_map.argmax().item()
            y, x = divmod(max_pos, W)

            x_img = x * scale + offset
            y_img = y * scale + offset
            coords.append((x_img, y_img))

        results.append({
            'indices': topk_idx.tolist(),
            'scores':  topk_vals.tolist(),
            'coords2d': coords
        })

    return results

In [None]:
import torch
import json

def select_topk_keypoints_from_heatmaps(preds, topk=4, hmap_size=64, img_size=256):
    B, K, H, W = preds.shape
    assert H == W == hmap_size
    scale  = img_size / hmap_size
    offset = scale / 2

    preds_cpu = preds.cpu()
    results = []

    for i in range(B):
        hm = preds_cpu[i]  # Tensor[K, H, W]

        flat_confs = hm.view(K, -1)
        confs = flat_confs.max(dim=1).values

        topk_vals, topk_idx = torch.topk(confs, topk, largest=True)

        coords = []
        for idx in topk_idx:
            channel_map = hm[idx]
            flat_map = channel_map.view(-1)
            max_pos = flat_map.argmax().item()
            y, x = divmod(max_pos, W)

            x_img = x * scale + offset
            y_img = y * scale + offset
            coords.append((x_img, y_img))

        results.append({
            'indices': topk_idx.tolist(),
            'scores':  topk_vals.tolist(),
            'coords2d': coords
        })

    return results

# ——————————————————————————————————————————————
#  Load your trained model
model.load_state_dict(torch.load("best_model.pth", map_location=device))
model.to(device)
model.eval()

#  Prepare for saving predictions
predicted = {}

#  val_loader.dataset is a Subset wrapping your KeypointHeatmapDataset
subset: torch.utils.data.Subset = val_loader.dataset
orig_ds       = subset.dataset        # the original KeypointHeatmapDataset
subset_indices = subset.indices       # maps subset-pos → original-pos
batch_size     = val_loader.batch_size

#  Run through validation set and collect top-4 keypoints
for batch_idx, (imgs, _) in enumerate(val_loader):
    imgs = imgs.to(device)
    with torch.no_grad():
        preds = model(imgs).cpu()

    best4 = select_topk_keypoints_from_heatmaps(preds, topk=4)

    for i, item in enumerate(best4):
        # extract the 2D coords
        coords2d = item["coords2d"]

        # map subset position → original index → basename
        subset_pos = batch_idx * batch_size + i
        orig_idx   = subset_indices[subset_pos]
        img_id     = orig_ds.basenames[orig_idx]

        predicted[img_id] = coords2d

#  Write out to JSON
out_path = "/content/drive/MyDrive/MLDL/6D-Pose-Estimation/data/predicted_key_points/2D_predicted_key_points_fps.json"
with open(out_path, "w") as f:
    json.dump(predicted, f, indent=2)

print(f"Saved {len(predicted)} entries to {out_path}")


In [None]:
def keypoints_from_heatmaps(preds, hmap_size=64, img_size=256):
    B, K, H, W = preds.shape
    assert H == W == hmap_size
    scale  = img_size / hmap_size
    offset = scale / 2

    preds_cpu = preds.cpu()
    results = []
    #for each batch
    for i in range(B):
        hm = preds_cpu[i]  # Tensor[K, H, W]

        coords = []
        confs  = []

        for idx in range(K):
          channel_map = hm[idx]
          flat_map = channel_map.view(-1)
          max_pos = flat_map.argmax().item()
          confs.append(flat_map[max_pos].item())

          y, x = divmod(max_pos, W)
          x_img = x * scale + offset
          y_img = y * scale + offset
          coords.append((x_img, y_img))

        results.append({
            'indices': list(range(K)),
            'scores':  confs,
            'coords2d': coords
        })

    return results


In [None]:
# ——————————————————————————————————————————————
#  Load your trained model
model = KeypointHeatmapNet(num_keypoints=50).to(device)
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
# model.to(device)
model.eval()

#  Prepare for saving predictions
predicted = {}

#  val_loader.dataset is a Subset wrapping your KeypointHeatmapDataset
subset: torch.utils.data.Subset = val_loader.dataset
orig_ds       = subset.dataset        # the original KeypointHeatmapDataset
subset_indices = subset.indices       # maps subset-pos → original-pos
batch_size     = val_loader.batch_size

#  Run through validation set and collect top-4 keypoints
for batch_idx, (imgs, _) in enumerate(val_loader):
    imgs = imgs.to(device)
    with torch.no_grad():
        preds = model(imgs).cpu()

    keypoints = keypoints_from_heatmaps(preds)

    for i, item in enumerate(keypoints):
        # extract the 2D coords
        coords_con2d = item["coords2d"] # DODAJ OVDE ZA CONFIDENCE

        # map subset position → original index → basename
        subset_pos = batch_idx * batch_size + i
        orig_idx   = subset_indices[subset_pos]
        img_id     = orig_ds.basenames[orig_idx]

        predicted[img_id] = coords_con2d

#  Write out to JSON
out_path = "/content/drive/MyDrive/MLDL/6D-Pose-Estimation/data/predicted_key_points/2D_predicted_key_points_fps_50.json"
with open(out_path, "w") as f:
    json.dump(predicted, f, indent=2)

print(f"Saved {len(predicted)} entries to {out_path}")


Saved 2843 entries to /content/drive/MyDrive/MLDL/6D-Pose-Estimation/data/predicted_key_points/2D_predicted_key_points_fps_50.json


In [None]:
SAVE_PATH = "/content/drive/MyDrive/MLDL/6D-Pose-Estimation/best_model.pth"
torch.save(model.state_dict(), SAVE_PATH)


# After training

In [None]:
checkpoint_path = "/content/drive/MyDrive/MLDL/6D-Pose-Estimation/models/resnet/resnet18_cps_model_50_keypoints_focalhm.pth"


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = KeypointHeatmapNet(num_keypoints=50).to(device)
checkpoint_path = "/content/drive/MyDrive/MLDL/6D-Pose-Estimation/models/resnet/resnet18_cps_model_50_keypoints_focalhm.pth"
state_dict = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(state_dict)
model.eval()