In [12]:
import flash
import torch
from flash.core.data.utils import download_data
from flash.video import VideoClassificationData, VideoClassifier

In [13]:
from torch import Tensor
import kornia.augmentation as K
def normalize(x: Tensor) -> Tensor:
    return x / 255.0

In [14]:
from pytorchvideo.transforms import (
    ApplyTransformToKey,
    ShortSideScale,
    UniformTemporalSubsample,
    UniformCropVideo,
)
from torchvision.transforms import Compose, CenterCrop
from torchvision.transforms import RandomCrop
from flash.core.data.io.input import DataKeys
from flash.core.data.io.input_transform import InputTransform
from flash.core.data.transforms import ApplyToKeys
from typing import Callable
from flash.core.utilities.imports import (
    _KORNIA_AVAILABLE,
    _PYTORCHVIDEO_AVAILABLE,
    requires,
)
class TransformDataModule(InputTransform):
    image_size: int = 256
    temporal_sub_sample: int = 16  # This is the only change in our custom transform
    mean: Tensor = torch.tensor([0.45, 0.45, 0.45])
    std: Tensor = torch.tensor([0.225, 0.225, 0.225])
    data_format: str = "BCTHW"
    same_on_frame: bool = False

    def per_sample_transform(self) -> Callable:
        per_sample_transform = [CenterCrop(self.image_size)]

        return Compose(
            [
                ApplyToKeys(
                    DataKeys.INPUT,
                    Compose(
                        [UniformTemporalSubsample(self.temporal_sub_sample), normalize]
                        + per_sample_transform
                    ),
                ),
                ApplyToKeys(DataKeys.TARGET, torch.as_tensor),
            ]
        )

    def train_per_sample_transform(self) -> Callable:
        per_sample_transform = [RandomCrop(self.image_size, pad_if_needed=True)]

        return Compose(
            [
                ApplyToKeys(
                    DataKeys.INPUT,
                    Compose(
                        [UniformTemporalSubsample(self.temporal_sub_sample), normalize]
                        + per_sample_transform
                    ),
                ),
                ApplyToKeys(DataKeys.TARGET, torch.as_tensor),
            ]
        )

    def per_batch_transform_on_device(self) -> Callable:
        return ApplyToKeys(
            DataKeys.INPUT,
            K.VideoSequential(
                K.Normalize(self.mean, self.std),
                data_format=self.data_format,
                same_on_frame=self.same_on_frame,
            ),
        )

In [15]:
datamodule = VideoClassificationData.from_folders(
    train_folder="pen_dataset/train",
    val_folder="pen_dataset/val",
    clip_sampler="uniform",
    clip_duration=3,
    decode_audio=False,
    batch_size=1,
    transform=TransformDataModule(),  # The custom transform is given to the datamodule's transform argument

)

  rank_zero_deprecation(


In [21]:
model = VideoClassifier(backbone="x3d_xs", labels=datamodule.labels, pretrained=True)


  rank_zero_warn(
Using 'x3d_xs' provided by Facebook Research/PyTorchVideo (https://github.com/facebookresearch/pytorchvideo).


In [22]:
trainer = flash.Trainer(max_epochs=3,accelerator="gpu", devices=1,)
trainer.finetune(model, datamodule=datamodule, strategy="freeze")

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type       | Params
---------------------------------------------
0 | train_metrics | ModuleDict | 0     
1 | val_metrics   | ModuleDict | 0     
2 | test_metrics  | ModuleDict | 0     
3 | backbone      | Net        | 3.8 M 
4 | head          | Sequential | 802   
---------------------------------------------
32.2 K    Trainable params
3.8 M     Non-trainable params
3.8 M     Total params
15.180    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=3` reached.


In [25]:
datamodule = VideoClassificationData.from_folders(predict_folder="predict", batch_size=1)
predictions = trainer.predict(model, datamodule=datamodule, output="labels")
print(predictions)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 22it [00:00, ?it/s]

[['unsrcew_back'], ['unsrcew_back']]


In [23]:
trainer.save_checkpoint("video_classification.pt")
