We will now define the dataset class for our image dataset consisting of keypoint annotations inside images. When computing heatmaps(one for each keypoint present in an image), we will also make heatmaps for lines present in the annotation jsons for each image.

In [None]:
from torch.utils.data.dataset import Dataset
from typing import List, Tuple, Callable, Optional
from pathlib import Path

import torch
import numpy as np
import cv2
import os
import json
import matplotlib.pyplot as plt
import hydra
import sys

In [None]:
from omegaconf import OmegaConf, DictConfig
cfg = OmegaConf.load('train_config.yaml')
hrnet_cfg = OmegaConf.load('model_config/hrnet_w48.yaml')
cfg.model.params.nn_module.hrnet_config = hrnet_cfg

In [None]:
if str(Path.cwd()) not in sys.path: sys.path.insert(0, str(Path.cwd()))
import metamodel

In [None]:
class HRNetDataset(Dataset):
    def __init__(self, dataset_folder: str, transform: Optional[Callable] = None, num_keypoints: int = 30, img_size: Tuple[int, int] = (960, 540), margin: float = 0.0):
        super().__init__()
        self.dataset_folder,self.num_keypoints,self.transform,self.img_size,self.margin = dataset_folder,num_keypoints,transform,img_size,margin
        self.img_paths = [p for p in Path(dataset_folder).glob('*.jpg') if p.with_suffix('.json').exists()][:1]        
    
    def __len__(self): return len(self.img_paths)
    
    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        image = cv2.imread(str(img_path), cv2.IMREAD_COLOR)
        if self.transform: sample = self.transform(sample)
        annot_path = img_path.with_suffix('.json')
        keypoints,mask,lines = self._annot2keypoints(annot_path)
        image,keypoints = self._resize_img_and_kpts(image, keypoints)
        sample = dict(image=image, keypoints=keypoints, img_idx=idx, mask=mask, img_name=img_path.name, lines=lines)
        return sample
    
    def _annot2keypoints(self, annot_path):
        with open(annot_path) as f: data = json.loads(f.read())
        kpts_dict,lines = {},[]
        for shape in data['shapes']:
            if shape['shape_type'] == 'point': kpts_dict[int(shape['label'])] = shape['points'][0]
            elif shape['shape_type'] == 'linestrip': lines.append(dict(label=shape['label'], points=shape['points']))
        keypoints = np.ones(self.num_keypoints * 3, dtype=np.float32) * -1
        mask = np.ones(self.num_keypoints, dtype=int)
        for i in range(self.num_keypoints):
            if i in kpts_dict:
                keypoints[i*3:i*3+2] = kpts_dict[i]
                keypoints[i*3+2] = 1
                mask[i] = 0
            else: keypoints[i*3+2] = 0
        return keypoints,mask,lines
    
    def _resize_img_and_kpts(self, image, keypoints):
        h,w = image.shape[:2]
        tw,th = self.img_size
        scale_h,scale_w = th/h,tw/w
        resized_img = cv2.resize(image, (tw,th))
        for i in range(self.num_keypoints):
            if keypoints[i*3+2] > 0: keypoints[i*3],keypoints[i*3+1] = keypoints[i*3]*scale_w,keypoints[i*3+1]*scale_h
        return resized_img,keypoints


In [None]:
def gaussian(x: torch.Tensor, mu: torch.Tensor, sigma: float) -> torch.Tensor:
    """1D Gaussian distribution. The distribution amplitude is 1.0.

    Args:
        x (torch.Tensor): 1D tensor of X values, (X,).
        mu (torch.Tensor): Mean values for gaussian (B, N).
        sigma (float): Standard deviation in scale of X axis.

    Returns:
        torch.Tensor: Resulted 1d gaussians: (B, N, X).
    """
    return torch.exp(-(torch.div(x - mu.unsqueeze(-1), sigma) ** 2) / 2.0)

def create_heatmaps(keypoints: torch.Tensor, sigma: float, pred_size: Tuple[int, int] = (540, 960)) -> torch.Tensor:
    h,w = pred_size
    device = keypoints.device
    x,y,vis = keypoints[:,:,0],keypoints[:,:,1],keypoints[:,:,2]
    x_range,y_range = torch.arange(0, w, device=device, dtype=torch.float32),torch.arange(0, h, device=device, dtype=torch.float32)
    gauss_x,gauss_y = gaussian(x_range, x, sigma),gaussian(y_range, y, sigma)
    heatmaps = torch.einsum("BNW, BNH -> BNHW", gauss_x, gauss_y)
    heatmaps = torch.where((vis > 0).unsqueeze(-1).unsqueeze(-1), heatmaps, torch.tensor(0.0, device=device))
    return heatmaps

Why do you need to unsqueeze vis two times? walk me throuugh the shapes step by step

##### ðŸ¤–ReplyðŸ¤–<!-- SOLVEIT_SEPARATOR_7f3a9b2c -->

Let's trace through the shapes step by step:

Starting shapes:
- `keypoints`: (B, N, 3) where B=batch size, N=num_keypoints
- `x, y, vis`: each is (B, N) after slicing
- `x_range`: (W,) and `y_range`: (H,)

After gaussian function:
- `gauss_x`: (B, N, W) - Gaussian values for each keypoint along width
- `gauss_y`: (B, N, H) - Gaussian values for each keypoint along height

After einsum:
- `heatmaps`: (B, N, H, W) - 2D Gaussian heatmap for each keypoint

Now for the masking with `vis`:
- `vis`: (B, N) - visibility flag for each keypoint
- Need to broadcast `vis` to match `heatmaps` shape (B, N, H, W)
- First `unsqueeze(-1)`: (B, N, 1) - adds dimension for H
- Second `unsqueeze(-1)`: (B, N, 1, 1) - adds dimension for W
- Now `(vis > 0).unsqueeze(-1).unsqueeze(-1)` broadcasts correctly to (B, N, H, W)

The two unsqueezes are needed because `heatmaps` has 4 dimensions and `vis` only has 2, so we need to add 2 singleton dimensions to enable broadcasting across both the H and W dimensions.

In [None]:
from torch.utils.data import DataLoader, ConcatDataset, default_collate
from typing import List, Optional, Callable
from transforms import *

collate_objs = ['keypoints', 'image', 'img_idx', 'mask']

def custom_collate(batch):
    default_collated = default_collate([{k: v for k, v in sample.items()
                                         if k in collate_objs}
                                        for sample in batch])
    custom_collated = {'img_name': [sample['img_name'] for sample in batch]}

    return {**default_collated, **custom_collated}

def train_transform(brightness: Tuple[float, float] = (0.8, 1.2),
                    color: Tuple[float, float] = (0.8, 1.2),
                    contrast: Tuple[float, float] = (0.8, 1.2),
                    gauss_noise_sigma: float = 30.0,
                    prob: float = 0.5):
    transforms = ComposeTransform([
        UseWithProb(ColorAugment(brightness=brightness,
                                 color=color,
                                 contrast=contrast), prob),
        UseWithProb(GaussNoise(gauss_noise_sigma), prob),
        # UseWithProb(Flip(), 0.5),
        ToTensor()
    ])
    return transforms
    
def get_loader(dataset_paths: List[str], data_params: DictConfig,

               transform: Optional[Callable] = None, shuffle: bool = True)\
        -> DataLoader:
    datasets = []
    for dataset_path in dataset_paths:
        datasets.append(HRNetDataset(dataset_path, transform=transform,
                                     num_keypoints=data_params.num_keypoints,
                                     margin=data_params.margin))
    dataset = ConcatDataset(datasets)
    factor = 1 if shuffle else 2
    loader = DataLoader(
        dataset, batch_size=data_params.batch_size * factor,
        num_workers=data_params.num_workers,
        pin_memory=data_params.pin_memory,
        shuffle=shuffle,
        collate_fn=custom_collate)
    return loader

In [None]:
def plot_heatmap_on_img(img_tensor, heatmap_tensor):
    img = img_tensor.detach().cpu().numpy()
    heatmap = heatmap_tensor.detach().cpu().numpy()
    heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))
    heatmap_colored = cv2.applyColorMap((heatmap * 255).astype(np.uint8), cv2.COLORMAP_JET)
    heatmap_colored = cv2.cvtColor(heatmap_colored, cv2.COLOR_BGR2RGB)
    overlay = cv2.addWeighted(img.astype(np.uint8), 0.6, heatmap_colored, 0.4, 0)
    plt.imshow(overlay)
    plt.show()

In [None]:
train_loader = get_loader(cfg.data.train, cfg.data_params, None, True)
dl = iter(train_loader)

for batch in dl:
    for idx in range(cfg.data_params.batch_size):
        img, keypoints, mask = batch['image'][idx], batch['keypoints'][idx].reshape(-1, cfg.data_params.num_keypoints, 3), batch['mask'][idx]
        heatmaps = create_heatmaps(keypoints, 2)
        heatmaps = torch.cat(
                [heatmaps, (1.0 - torch.max(heatmaps, dim=1, keepdim=True)[0])], 1)
        maps = torch.sum(heatmaps[0][:-1], 0)
        plot_heatmap_on_img(img, maps)


In [None]:
model = hydra.utils.instantiate(cfg.model)
aug_params = cfg.data_params.augmentations
train_trns = train_transform(
    brightness=aug_params.brightness,
    color=aug_params.color,
    contrast=aug_params.contrast,
    gauss_noise_sigma=aug_params.gauss_noise_sigma,
    prob=aug_params.prob
)
val_trns = test_transform()
train_loader = get_loader(cfg.data.train, cfg.data_params,
                            train_trns, True)
# val_loader = get_loader(cfg.data.val, cfg.data_params, val_trns, False)
# experiment_name = cfg.metadata.experiment_name
# run_name = cfg.metadata.run_name
# save_dir = f'./experiments/{experiment_name}_{run_name}'
# callbacks = [
#     Checkpoint(save_dir, max_saves=3, file_format='save-{epoch:03d}.pth',
#                 save_after_exception=True, optimizer_state=True, period=2),
#     LoggingToFile(os.path.join(save_dir, 'log.txt')),
# ]

# pretrain_path = cfg.model.params.pretrain
# if pretrain_path is not None:
#     if os.path.exists(pretrain_path):
#         model_pretrain = load_model(pretrain_path,
#                                     device=cfg.model.params.device)
#         if cfg.train_params.load_compatible:
#             model_pretrain = load_model(pretrain_path,
#                                         device=cfg.model.params.device)
#             model = load_compatible_weights(model_pretrain, model)
#         else:
#             model = load_model(pretrain_path,
#                                 device=cfg.model.params.device)
#         model.set_lr(cfg.model.params.optimizer.lr)
#     else:
#         raise ValueError(f'Pretrain {pretrain_path} does not exist')
# # Model may need tuning to find the optimal one for the particular model
# if cfg.train_params.use_compile:
#     model.nn_module = compile(model.nn_module)
# model.fit(train_loader, val_loader=val_loader, metrics_on_train=False,
#             num_epochs=cfg.train_params.max_epochs,
#             callbacks=callbacks)

  self.scaler = torch.cuda.amp.GradScaler() if self.amp else None
  super().__init__(
