In [1]:
import torch
import webdataset as wds
import pytorch_lightning as pl
from torchvision import transforms

ModuleNotFoundError: No module named 'torch'

### Usage with PyTorch

Below we provide a simple DataModule class which can be used with PyTorch-Lightning for loading the sharded TARs. This setup will work with DDP modes.

In [None]:
class ataModule(pl.LightningDataModule):
    def __init__(self, cfg, train_augs=[transforms.Resize(800, 640)])
    )], val_augs=[]):
        super().__init__()
        self.total_batch_size = cfg.optim.batch_size
        self.world_size = cfg.compute.world_size
        self.cfg = cfg
        self.train_augmentation = train_augs
        self.valid_augmentation = val_augs
        assert self.total_batch_size % cfg.compute.world_size == 0
        self.local_batch_size = self.total_batch_size // cfg.compute.world_size

    @staticmethod
    def convert_anns(anns: List[Any]):
        # convert the pose anns to d2 ann
        data = [{
            "bbox": a["bbox_2d"],
            "bbox_mode": BoxMode.XYXY_ABS,
            "keypoints": [],
            "category_id": a["category_id"] if "category_id" in a else 0,
            "keypoints": a["bbox_3d"],
            "translation": a["cam_t_m2c"],
            "rotation": a["cam_R_m2c"],
            "bbox_3d": a["bbox_3d"]
        } for a in anns]
        return data

    def train_dataloader(self):
        batch_sampler = torch.utils.data.sampler.BatchSampler(
            self.train_sampler, self.local_batch_size, drop_last=True
        )  # drop_last so the batch always have the same size
        return DataLoader(
            self.train_dataset,
            num_workers=self.cfg.dataloader.num_workers,
            batch_sampler=batch_sampler,
            collate_fn=trivial_batch_collator,
            worker_init_fn=worker_init_reset_seed,
        )

    def val_dataloader(self):
        batch_sampler = torch.utils.data.sampler.BatchSampler(
            self.valid_sampler, self.local_batch_size, drop_last=False
        )  # drop_last so the batch always have the same size
        return DataLoader(
            self.valid_dataset,
            num_workers=self.cfg.dataloader.num_workers,
            batch_sampler=batch_sampler,
            collate_fn=trivial_batch_collator,
            worker_init_fn=worker_init_reset_seed,
        )
