Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stacking tensors without same size #56

Closed
aofrancani opened this issue May 26, 2021 · 2 comments
Closed

Stacking tensors without same size #56

aofrancani opened this issue May 26, 2021 · 2 comments

Comments

@aofrancani
Copy link

Hi, I'm following the tutorial Training a PyTorchVideo classification model and I believe I can't load the data correctly.

I'm using Google Colab and my Kinetics400 is in my Google Drive. I've preprocessed the Kinetics such that all the videos are rescaled to height=256 pixels.

My Dataloader is implemented in the same way as described in the tutorial:

class KineticsDataModule(pytorch_lightning.LightningDataModule):
    """
    This LightningDataModule implementation constructs a PyTorchVideo Kinetics dataset for both
    the train and val partitions. It defines each partition's augmentation and
    preprocessing transforms and configures the PyTorch DataLoaders.
    """

    # Dataset configuration
    _DATA_PATH = '/content/drive/MyDrive/Datasets/Kinetics400/'
    _CLIP_DURATION = 2  # Duration of sampled clip for each video
    _BATCH_SIZE = 8
    _NUM_WORKERS = 8  # Number of parallel processes fetching data


    def train_dataloader(self):
        """
        Create the Kinetics train partition from the list of video labels
        in {self._DATA_PATH}/train.csv. Add transform that subsamples and
        normalizes the video before applying the scale, crop and flip augmentations.
        """
        train_transform = Compose(
            [
            ApplyTransformToKey(
              key="video",
              transform=Compose(
                  [
                    UniformTemporalSubsample(8),
                    Normalize((0.45, 0.45, 0.45), (0.225, 0.225, 0.225)),
                    RandomShortSideScale(min_size=256, max_size=320),
                    RandomCrop(244),
                    RandomHorizontalFlip(p=0.5),
                  ]
                ),
              ),
            ]
        )
        train_dataset = pytorchvideo.data.Kinetics(
              data_path=os.path.join(self._DATA_PATH, "train.csv"),
              clip_sampler=pytorchvideo.data.make_clip_sampler("random", self._CLIP_DURATION),
              transform=train_transform
        )
        return torch.utils.data.DataLoader(
            train_dataset,
            batch_size=self._BATCH_SIZE,
            num_workers=self._NUM_WORKERS,
        )

    def val_dataloader(self):
        """
        Create the Kinetics val partition from the list of video labels
        in {self._DATA_PATH}/val.csv. Add transform that subsamples and
        normalizes the video before applying the scale.
        """
        val_transform = Compose(
            [
            ApplyTransformToKey(
              key="video",
              transform=Compose(
                  [
                    UniformTemporalSubsample(8),
                    Normalize((0.45, 0.45, 0.45), (0.225, 0.225, 0.225)),
                  ]
                ),
              ),
            ]
        )
        val_dataset = pytorchvideo.data.Kinetics(
            data_path=os.path.join(self._DATA_PATH, "val.csv"),
            clip_sampler=pytorchvideo.data.make_clip_sampler("uniform", self._CLIP_DURATION),
            transform=val_transform
        )

        return torch.utils.data.DataLoader(
            val_dataset,
            batch_size=self._BATCH_SIZE,
            num_workers=self._NUM_WORKERS,
        )

I built a default ResNet just like the tutorial. Following the tutorial until the training step, I'm running a cell in Google Colab with only train() to run the function def train().

Even though I'm randomly cropping to 224x224 in Transforms, I'm getting the following error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-7-2da0ffaf5447> in <module>()
----> 1 train()

13 frames
<ipython-input-6-cd4463cf3c91> in train()
      3   data_module = KineticsDataModule()
      4   trainer = pytorch_lightning.Trainer()
----> 5   trainer.fit(classification_module, data_module)

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in fit(self, model, train_dataloader, val_dataloaders, datamodule)
    456         )
    457 
--> 458         self._run(model)
    459 
    460         assert self.state.stopped

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in _run(self, model)
    754 
    755         # dispatch `start_training` or `start_evaluating` or `start_predicting`
--> 756         self.dispatch()
    757 
    758         # plugin will finalized fitting (e.g. ddp_spawn will load trained model)

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in dispatch(self)
    795             self.accelerator.start_predicting(self)
    796         else:
--> 797             self.accelerator.start_training(self)
    798 
    799     def run_stage(self):

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/accelerators/accelerator.py in start_training(self, trainer)
     94 
     95     def start_training(self, trainer: 'pl.Trainer') -> None:
---> 96         self.training_type_plugin.start_training(trainer)
     97 
     98     def start_evaluating(self, trainer: 'pl.Trainer') -> None:

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in start_training(self, trainer)
    142     def start_training(self, trainer: 'pl.Trainer') -> None:
    143         # double dispatch to initiate the training loop
--> 144         self._results = trainer.run_stage()
    145 
    146     def start_evaluating(self, trainer: 'pl.Trainer') -> None:

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in run_stage(self)
    805         if self.predicting:
    806             return self.run_predict()
--> 807         return self.run_train()
    808 
    809     def _pre_training_routine(self):

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in run_train(self)
    840             self.progress_bar_callback.disable()
    841 
--> 842         self.run_sanity_check(self.lightning_module)
    843 
    844         self.checkpoint_connector.has_trained = False

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in run_sanity_check(self, ref_model)
   1105 
   1106             # run eval step
-> 1107             self.run_evaluation()
   1108 
   1109             self.on_sanity_check_end()

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in run_evaluation(self, on_epoch)
    947             dl_max_batches = self.evaluation_loop.max_batches[dataloader_idx]
    948 
--> 949             for batch_idx, batch in enumerate(dataloader):
    950                 if batch is None:
    951                     continue

/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py in __next__(self)
    515             if self._sampler_iter is None:
    516                 self._reset()
--> 517             data = self._next_data()
    518             self._num_yielded += 1
    519             if self._dataset_kind == _DatasetKind.Iterable and \

/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py in _next_data(self)
   1197             else:
   1198                 del self._task_info[idx]
-> 1199                 return self._process_data(data)
   1200 
   1201     def _try_put_index(self):

/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py in _process_data(self, data)
   1223         self._try_put_index()
   1224         if isinstance(data, ExceptionWrapper):
-> 1225             data.reraise()
   1226         return data
   1227 

/usr/local/lib/python3.7/dist-packages/torch/_utils.py in reraise(self)
    427             # have message field
    428             raise self.exc_type(message=msg)
--> 429         raise self.exc_type(msg)
    430 
    431 

RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/worker.py", line 202, in _worker_loop
    data = fetcher.fetch(index)
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/fetch.py", line 35, in fetch
    return self.collate_fn(data)
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/collate.py", line 73, in default_collate
    return {key: default_collate([d[key] for d in batch]) for key in elem}
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/collate.py", line 73, in <dictcomp>
    return {key: default_collate([d[key] for d in batch]) for key in elem}
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/collate.py", line 55, in default_collate
    return torch.stack(batch, 0, out=out)
RuntimeError: stack expects each tensor to be equal size, but got [3, 8, 256, 454] at entry 0 and [3, 8, 256, 144] at entry 5

I was expecting something like [3, 8, 244, 244] due to RandomCrop(244) in the DataLoader. What am I missing? Thanks in advance for your help!

@nicklaslund
Copy link

First of all, @aofrancani sorry for not answering your question. I have not looked into the problem, if I find some time for it or encounter the same problem I will let you know.

Sencodnly, ss the title of this issue is appropriate seems appropriate for my request I will file my request as a comment here.

I want to work with video data with clips of different lengt, i.e. different number of frames. As of now, in the default_collate(batch) in .../torch/utils/data/_utils/collate.py the elements of the batch are transformed into a tensor using torch.stack(batch, 0, out=out). Is there any plans for introducing nestedtensors https://github.com/pytorch/nestedtensor in the near future to be able to work video clips of varying length?

Thanks in advance and thank you for the library, it is a pleasure to work with.

@aofrancani
Copy link
Author

aofrancani commented Jun 2, 2021

Hi everyone! I solved it...

  1. The first thing I did was to change my val_dataloader because the problem was in the validation step. Therefore, I added the ShortSideScale and CenterCrop Transforms to fix that problem:
        val_transform = Compose(
            [
            ApplyTransformToKey(
              key="video",
              transform=Compose(
                  [
                    UniformTemporalSubsample(8),
                    Normalize((0.45, 0.45, 0.45), (0.225, 0.225, 0.225)),
                    ShortSideScale(size=256),
                    CenterCrop(244),
                  ]
                ),
              ),
            ]
        )

After changing it, I had another issue:

  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/_utils/collate.py", line 55, in default_collate
    return torch.stack(batch, 0, out=out)
RuntimeError: stack expects each tensor to be equal size, but got [89088] at entry 0 and [88064] at entry 1
  1. Now the problem was in the decoding the audio. In Section Dataset of the Tutorial, they set decode_audio = False in pytorchvideo.data.Kinetics. However, in Section Transforms, this parameter is not shown, so I added it in my code.

Hence, my final code is the following:

class KineticsDataModule(pytorch_lightning.LightningDataModule):
    """
    This LightningDataModule implementation constructs a PyTorchVideo Kinetics dataset for both
    the train and val partitions. It defines each partition's augmentation and
    preprocessing transforms and configures the PyTorch DataLoaders.
    """

    # Dataset configuration
    _DATA_PATH = '/content/drive/MyDrive/Datasets/Kinetics400/'
    _CLIP_DURATION = 2  # Duration of sampled clip for each video
    _BATCH_SIZE = 4
    _NUM_WORKERS = 2  # Number of parallel processes fetching data


    def train_dataloader(self):
        """
        Create the Kinetics train partition from the list of video labels
        in {self._DATA_PATH}/train.csv. Add transform that subsamples and
        normalizes the video before applying the scale, crop and flip augmentations.
        """
        train_transform = Compose(
            [
            ApplyTransformToKey(
              key="video",
              transform=Compose(
                  [
                    UniformTemporalSubsample(8),
                    Normalize((0.45, 0.45, 0.45), (0.225, 0.225, 0.225)),
                    RandomShortSideScale(min_size=256, max_size=320),
                    RandomCrop(244),
                    RandomHorizontalFlip(p=0.5),
                  ]
                ),
              ),
            ]
        )
        train_dataset = pytorchvideo.data.Kinetics(
              data_path=os.path.join(self._DATA_PATH, "train.csv"),
              clip_sampler=pytorchvideo.data.make_clip_sampler("random", self._CLIP_DURATION),
              decode_audio=False,
              transform=train_transform
        )
        return torch.utils.data.DataLoader(
            train_dataset,
            batch_size=self._BATCH_SIZE,
            num_workers=self._NUM_WORKERS,
        )

    def val_dataloader(self):
        """
        Create the Kinetics val partition from the list of video labels
        in {self._DATA_PATH}/val.csv. Add transform that subsamples and
        normalizes the video before applying the scale.
        """
        val_transform = Compose(
            [
            ApplyTransformToKey(
              key="video",
              transform=Compose(
                  [
                    UniformTemporalSubsample(8),
                    Normalize((0.45, 0.45, 0.45), (0.225, 0.225, 0.225)),
                    ShortSideScale(size=256),
                    CenterCrop(244),
                  ]
                ),
              ),
            ]
        )
        val_dataset = pytorchvideo.data.Kinetics(
            data_path=os.path.join(self._DATA_PATH, "val.csv"),
            clip_sampler=pytorchvideo.data.make_clip_sampler("uniform", self._CLIP_DURATION),
            decode_audio=False,
            transform=val_transform
        )
        return torch.utils.data.DataLoader(
            val_dataset,
            batch_size=self._BATCH_SIZE,
            num_workers=self._NUM_WORKERS,
        )

Once I've solved it, I will close this issue! Thanks for this amazing library!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants