In [1]:
import os
from typing import Optional, Tuple

import torch
import torchvision
from easydict import EasyDict as edict
from PIL import Image
from torch.utils.data import DataLoader, Dataset

from hybrik.models import builder
from hybrik.utils.config import update_config
from hybrik.utils.presets import SimpleTransform3DSMPLCam
import numpy as np
import cv2


class hbryik_dataset(Dataset):
    """
    Takes in a folder and returns the images and their filenames
    within that folder. No subdirectories with images are returned
    """
    def __init__(self, dir: str, transform: SimpleTransform3DSMPLCam, dim: Tuple[int, int] = (256, 256)) -> None:
        """
        The transform expects an RGB image and a bounding box. The bounding box
        is supposed to be around the person in the image but our images will already
        be cropped. Therefore, we will not use a bounding box over the entire image.
        
        Dim: (width, height)
        
        """
        self.dir = dir
        self.transform = transform
        self.images = [file for file in os.listdir(dir) if file.endswith(('jpg', 'jpeg', 'png'))]
        self.dim = dim
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        """Returns the image, the filename, the bbox and the center of the image"""
        print(os.path.join(self.dir, self.images[idx]))
        image = cv2.cvtColor(cv2.imread(os.path.join(self.dir, self.images[idx])), cv2.COLOR_BGR2RGB)
        # resize the image
        image = cv2.resize(image, self.dim)
        bbox = [0, 0, image.shape[0], image.shape[1]]
        # come back to this if failing but using our own bounding box for now
        image, _bbox, img_center = self.transform.test_transform(image, bbox)
        return image, self.images[idx], np.array(bbox), np.array(img_center)
        


def create_dataloader(
    data_dir: str,
    batch_size: int = 8,
    shuffle: bool = False,
    num_workers: int = 4,
    transform: Optional[torchvision.transforms.Compose] = None,
) -> DataLoader:
    """
    Creates a PyTorch DataLoader for images stored in subdirectories.

    Parameters:
    - data_dir (str): Path to the main directory containing subdirectories of images.
    - batch_size (int): Number of samples per batch to load.
    - shuffle (bool): Whether to shuffle the dataset.
    - num_workers (int): How many subprocesses to use for data loading.
    - transform (torchvision.transforms.Compose): Transformations to apply to the images.

    Returns:
    - DataLoader: PyTorch DataLoader.
    """

    # Load the dataset from the directory with subdirectories
    dataset = hbryik_dataset(dir=data_dir, transform=transform)

    # Create the DataLoader
    dataloader = DataLoader(
        dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers
    )

    return dataloader


cfg_file = 'configs/256x192_adam_lr1e-3-hrw48_cam_2x_w_pw3d_3dhp.yaml'
CKPT = './pretrained_models/hybrik_hrnet.pth'
cfg = update_config(cfg_file)

bbox_3d_shape = getattr(cfg.MODEL, 'BBOX_3D_SHAPE', (2000, 2000, 2000))
bbox_3d_shape = [item * 1e-3 for item in bbox_3d_shape]
dummpy_set = edict({
    'joint_pairs_17': None,
    'joint_pairs_24': None,
    'joint_pairs_29': None,
    'bbox_3d_shape': bbox_3d_shape
})


transformation = SimpleTransform3DSMPLCam(
    dummpy_set, scale_factor=cfg.DATASET.SCALE_FACTOR,
    color_factor=cfg.DATASET.COLOR_FACTOR,
    occlusion=cfg.DATASET.OCCLUSION,
    input_size=cfg.MODEL.IMAGE_SIZE,
    output_size=cfg.MODEL.HEATMAP_SIZE,
    depth_dim=cfg.MODEL.EXTRA.DEPTH_DIM,
    bbox_3d_shape=bbox_3d_shape,
    rot=cfg.DATASET.ROT_FACTOR, sigma=cfg.MODEL.EXTRA.SIGMA,
    train=False, add_dpg=False,
    loss_type=cfg.LOSS['TYPE'])

data_dir = "examples"
dataloader = create_dataloader(data_dir, transform=transformation, num_workers=0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


hybrik_model = builder.build_sppe(cfg.MODEL)

print(f'Loading model from {CKPT}...')
save_dict = torch.load(CKPT, map_location='cpu')
if type(save_dict) == dict:
    model_dict = save_dict['model']
    hybrik_model.load_state_dict(model_dict)
else:
    hybrik_model.load_state_dict(save_dict)
    
    
hybrik_model.eval()

for i, (images, filename, bbox, img_center) in enumerate(dataloader):
    images = images.to(device)
    pose_output = hybrik_model(
        images, flip_test=True,
        bboxes=bbox.to(images.device).float(),
        img_center=img_center.to(images.device).float()
    )

  self.smpl_data = Struct(**pk.load(smpl_file, encoding='latin1'))
  save_dict = torch.load(CKPT, map_location='cpu')


Loading model from ./pretrained_models/hybrik_hrnet.pth...
examples/000000581328.jpg
examples/000000581667.jpg
examples/000000581357.jpg
examples/000000000431.jpg
examples/output_frame_001.png
examples/000000581091.jpg
examples/000000581056.jpg


In [24]:
img_center.shape

torch.Size([7, 2])

In [19]:
bbox

tensor([[  0,   0, 256, 256],
        [  0,   0, 256, 256],
        [  0,   0, 256, 256],
        [  0,   0, 256, 256],
        [  0,   0, 256, 256],
        [  0,   0, 256, 256],
        [  0,   0, 256, 256]])