We will now define the dataset class for our image dataset consisting of keypoint annotations inside images. When computing heatmaps(one for each keypoint present in an image), we will also make heatmaps for lines present in the annotation jsons for each image.

In [None]:
from torch.utils.data.dataset import Dataset
from typing import List, Tuple, Callable, Optional
from pathlib import Path

import torch
import numpy as np
import cv2
import os
import json

In [None]:
from omegaconf import OmegaConf
from omegaconf import DictConfig
cfg = OmegaConf.load('train_config.yaml')

In [None]:
class HRNetDataset(Dataset):
    def __init__(self, dataset_folder: str, transform: Optional[Callable] = None, num_keypoints: int = 30, img_size: Tuple[int, int] = (960, 540), margin: float = 0.0):
        super().__init__()
        self.dataset_folder,self.num_keypoints,self.transform,self.img_size,self.margin = dataset_folder,num_keypoints,transform,img_size,margin
        self.img_paths = [p for p in Path(dataset_folder).glob('*.jpg') if p.with_suffix('.json').exists()]        
    
    def __len__(self): return len(self.img_paths)
    
    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        image = cv2.imread(str(img_path), cv2.IMREAD_COLOR)
        sample = dict(image=image)
        if self.transform: sample = self.transform(sample)
        annot_path = img_path.with_suffix('.json')
        keypoints,mask,lines = self._annot2keypoints(annot_path)
        sample.update(dict(keypoints=keypoints, img_idx=idx, mask=mask, img_name=img_path.name, lines=lines))
        return sample
    
    def _annot2keypoints(self, annot_path):
        with open(annot_path) as f: data = json.loads(f.read())
        kpts_dict,lines = {},[]
        for shape in data['shapes']:
            if shape['shape_type'] == 'point': kpts_dict[int(shape['label'])] = shape['points'][0]
            elif shape['shape_type'] == 'linestrip': lines.append(dict(label=shape['label'], points=shape['points']))
        keypoints = np.ones(self.num_keypoints * 3, dtype=np.float32) * -1
        mask = np.ones(self.num_keypoints, dtype=int)
        for i in range(self.num_keypoints):
            if i in kpts_dict:
                keypoints[i*3:i*3+2] = kpts_dict[i]
                keypoints[i*3+2] = 1
                mask[i] = 0
            else: keypoints[i*3+2] = 0
        return keypoints,mask,lines

In [None]:
def create_heatmaps(keypoints: torch.Tensor, sigma: float,
                    pred_size: Tuple[int, int] = (540, 960)) -> torch.Tensor:
    """Create Gaussian distributions heatmaps for keypoints.

    Each heatmap is drawn on an individual channel.

    Args:
        keypoints (torch.Tensor): A batch (B) of N points, each point is (x, y).
            Expected shape: (B, N, 2).
        sigma (float): Standard deviation.
        pred_size (Tuple[int, int]): Size of the 2D Gaussian distribution canvas
            (H, W). Defaults to (68, 120).

    Returns:
        (torch.Tensor): Resulted Gaussian heatmaps: (B, N, H, W).

    """
    h, w = pred_size
    device = keypoints.device
    x = keypoints[:, :, 0]
    y = keypoints[:, :, 1]

    x_range = torch.arange(0, w, device=device, dtype=torch.float32)
    y_range = torch.arange(0, h, device=device, dtype=torch.float32)
    gauss_x: torch.Tensor = gaussian(x_range, x, sigma)
    gauss_y: torch.Tensor = gaussian(y_range, y, sigma)
    heatmaps = torch.einsum("BNW, BNH -> BNHW", gauss_x, gauss_y)

    visible_points = torch.any(keypoints == 1, dim=-1, keepdim=True)
    zero = torch.tensor(0.0, device=device, dtype=torch.float32)
    heatmaps = torch.where(visible_points.unsqueeze(-1), heatmaps, zero)
    return heatmaps

In [None]:
from torch.utils.data import DataLoader, ConcatDataset, custom_collate

def get_loader(dataset_paths: List[str], data_params: DictConfig,

               transform: Optional[Callable] = None, shuffle: bool = True)\
        -> DataLoader:
    datasets = []
    for dataset_path in dataset_paths:
        datasets.append(HRNetDataset(dataset_path, transform=transform,
                                     num_keypoints=data_params.num_keypoints,
                                     margin=data_params.margin))
    dataset = ConcatDataset(datasets)
    factor = 1 if shuffle else 2
    loader = DataLoader(
        dataset, batch_size=data_params.batch_size * factor,
        num_workers=data_params.num_workers,
        pin_memory=data_params.pin_memory,
        shuffle=shuffle,
        collate_fn=custom_collate)
    return loader

ImportError: cannot import name 'custom_collate' from 'torch.utils.data' (/usr/local/lib/python3.12/site-packages/torch/utils/data/__init__.py)

In [None]:
dataset = HRNetDataset(dataset_folder='./clean_keypoints_dataset/')
sample = dataset[0]
train_loader = get_loader(cfg.data.train, cfg.data_params, None, True)
dl = iter(train_loader)

batch = next(dl)
img, keypoints, mask = batch['image'][0], batch['keypoints'][0].reshape(-1, cfg.data_params.num_keypoints, 3), batch['mask'][0]
print(img.shape, keypoints.shape, mask.shape)

for batch in dl:
    for idx in range(cfg.data_params.batch_size):
        img, keypoints, mask = batch['image'][idx], batch['keypoints'][idx].reshape(-1, cfg.data_params.num_keypoints, 3), batch['mask'][idx]
        # print(img.shape, keypoints.shape, mask.shape)
        heatmaps = create_heatmaps(keypoints, 1.0)
        heatmaps = torch.cat(
                [heatmaps, (1.0 - torch.max(heatmaps, dim=1, keepdim=True)[0])], 1)
        maps = torch.sum(heatmaps[0][:-1], 0)
        # plot_heatmap(maps)

        plot_img_keypoints(img.detach().cpu().numpy(), maps)


NameError: name 'custom_collate' is not defined

6 self.img_paths = filter(sorted([p for p in Path(dataset_folder).glob('*.jpg')]), lambda x: os.path.exists(x.replace('.jpg', '.json')))

TypeError: 'function' object is not iterable

##### ðŸ¤–ReplyðŸ¤–<!-- SOLVEIT_SEPARATOR_7f3a9b2c -->

The issue is that `filter` returns a filter object, not a list, and the arguments are in the wrong order. Here's the fix:

```python
class HRNetDataset(Dataset):
    def __init__(self, dataset_folder: str, transform: Optional[Callable] = None, num_keypoints: int = 30, img_size: Tuple[int, int] = (960, 540), margin: float = 0.0):
        super().__init__()
        self.dataset_folder,self.num_keypoints,self.transform,self.img_size,self.margin = dataset_folder,num_keypoints,transform,img_size,margin
        all_imgs = sorted([p for p in Path(dataset_folder).glob('*.jpg')])
        self.img_paths = [p for p in all_imgs if p.with_suffix('.json').exists()]
        
    def __len__(self): return len(self.img_paths)
    
    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        image = cv2.imread(str(img_path), cv2.IMREAD_COLOR)
        sample = dict(image=image)
        if self.transform: sample = self.transform(sample)
        annot_path = img_path.with_suffix('.json')
        keypoints,mask,lines = self._annot2keypoints(annot_path)
        sample.update(dict(keypoints=keypoints, img_idx=idx, mask=mask, img_name=img_path.name, lines=lines))
        return sample
    
    def _annot2keypoints(self, annot_path):
        with open(annot_path) as f: data = json.loads(f.read())
        kpts_dict,lines = {},[]
        for shape in data['shapes']:
            if shape['shape_type'] == 'point': kpts_dict[int(shape['label'])] = shape['points'][0]
            elif shape['shape_type'] == 'linestrip': lines.append(dict(label=shape['label'], points=shape['points']))
        keypoints = np.ones(self.num_keypoints * 3, dtype=np.float32) * -1
        mask = np.ones(self.num_keypoints, dtype=int)
        for i in range(self.num_keypoints):
            if i in kpts_dict:
                keypoints[i*3:i*3+2] = kpts_dict[i]
                keypoints[i*3+2] = 1
                mask[i] = 0
            else: keypoints[i*3+2] = 0
        return keypoints,mask,lines
```