In [9]:
import cv2
import h5py
import numpy as np
from tqdm import tqdm

# data dimensions
TRAIN_COUNT  = 20400
TEST_COUNT   = 10768
FRAME_HEIGHT = 480
FRAME_WIDTH  = 640
CHANNELS     = 3
FRAME_SHAPE  = (FRAME_HEIGHT, FRAME_WIDTH, CHANNELS)

def save_frames(
    filename, vid_path, data_len, lab_path=None
):
    pbar = tqdm(total=data_len, position=0, leave=True)
    with h5py.File(filename, 'w') as hf:
        lab = np.loadtxt(lab_path) if lab_path is not None else None
        cap = cv2.VideoCapture(vid_path)
        while cap.isOpened():
            if (frame_id := int(cap.get(1))) >= data_len:
                break
            ret, frame = cap.read()

            # save frames
            hf.create_dataset(
                name='X'+str(frame_id),
                data=frame,
                shape=FRAME_SHAPE
            )

            # save labels if available
            if lab is not None:
                hf.create_dataset(
                    name='y'+str(frame_id),
                    data=lab[frame_id]
                )

            pbar.update()
        cap.release()
    hf.close()

# save train data
save_frames(
    './data/train.h5',
    './data/train.mp4',
    TRAIN_COUNT,
    './data/train.txt'
)

# save test data
save_frames(
    './data/test.h5',
    './data/test.mp4',
    TEST_COUNT
)

100%|█████████▉| 10758/10768 [00:50<00:00, 258.18it/s]

In [9]:
import cv2
import h5py
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

class SpeedDataset(Dataset):
    def __init__(self, filename, transform=None):
        super(SpeedDataset, self).__init__()

        self.file = h5py.File(filename, 'r')
        self.transform = transform

    def __len__(self):
        return len(self.file.keys())

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        # retrieve from h5py
        X = np.array(self.file['X'+str(idx)])
        y = np.array(self.file['y'+str(idx)])

        # apply transform
        if self.transform is not None:
            X = self.transform(X).permute(1, 2, 0)

        return X, y

dataset = SpeedDataset(
    './data/train.h5',
    transforms.Compose([
        transforms.ToTensor()
    ])
)

dataloader = DataLoader(
    dataset,
    batch_size=100,
    shuffle=False,
    num_workers=0
)

for images, labels in dataloader:
    for img in images:
        cv2.imshow('train', img.numpy())
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break
    cv2.destroyAllWindows()

KeyboardInterrupt: 