In [1]:
import os
import itertools
import torch
import lightning.pytorch as pl
from lightning.pytorch.callbacks.progress import TQDMProgressBar
import torchmetrics
import torch.nn.functional as F
import pytorchvideo.data

from pytorchvideo.transforms import (
    ApplyTransformToKey,
    Normalize,
    RandomShortSideScale,
    RemoveKey,
    ShortSideScale,
    UniformTemporalSubsample,
)

from torchvision.transforms import (
    CenterCrop,
    Compose,
    Lambda,
    RandomCrop,
    RandomHorizontalFlip,
    Resize,
)


from transformers import VideoMAEImageProcessor, VideoMAEForVideoClassification

  from .autonotebook import tqdm as notebook_tqdm


#### Limit Dataset

 To ensure a constant number of samples are retrieved from the dataset

In [2]:
class LimitDataset(torch.utils.data.Dataset):

    def __init__(self, dataset):
        super().__init__()
        self.dataset = dataset
        self.dataset_iter = itertools.chain.from_iterable(
            itertools.repeat(iter(dataset), 2)
        )

    def __getitem__(self, index):
        return next(self.dataset_iter)

    def __len__(self):
        return self.dataset.num_videos

#### Fly Data Module

In [3]:
class FlyDataModule(pl.LightningDataModule):

    def __init__(self, args):
        self.args = args
        super().__init__()

    def _make_transforms(self, mode: str):
        return Compose([self._video_transform(mode)])

    def _video_transform(self, mode: str):
        return ApplyTransformToKey(
            key="video",
            transform=Compose(
                [
                    UniformTemporalSubsample(self.args["num_frames_to_sample"]),
                    Lambda(lambda x: x / 255.0),
                    Normalize(self.args["video_means"], self.args["video_stds"]),
                ]
                + (
                    [
                        RandomShortSideScale(
                            min_size=self.args["video_min_short_side_scale"],
                            max_size=self.args["video_max_short_side_scale"],
                        ),
                        RandomCrop(self.args["crop_size"]),
                    ]
                    if mode == "train"
                    else [
                        ShortSideScale(self.args["video_min_short_side_scale"]),
                        CenterCrop(self.args["crop_size"]),
                    ]
                )
            ),
        )

    def train_dataloader(self):

        train_transform = self._make_transforms(mode="train")

        self.train_dataset = LimitDataset(
            pytorchvideo.data.labeled_video_dataset(
                data_path=self.args["train_data_path"],
                clip_sampler=pytorchvideo.data.make_clip_sampler('uniform', self.args["clip_duration"]), # Experiment olarak random da denenebilir
                transform=train_transform,
                video_path_prefix=self.args["video_path_prefix"], # could be '' I think
                decode_audio=False
            )
        )

        return torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=self.args["batch_size"]
        )

    def val_dataloader(self):
        
        val_transform = self._make_transforms(mode="val")

        self.val_dataset = LimitDataset(
            pytorchvideo.data.labeled_video_dataset(
                data_path=self.args["val_data_path"],
                clip_sampler=pytorchvideo.data.make_clip_sampler('uniform', self.args["clip_duration"]), # Experiment olarak random da denenebilir
                transform=val_transform,
                video_path_prefix=self.args["video_path_prefix"], # could be '' I think
                decode_audio=False
            )
        )

        return torch.utils.data.DataLoader(
            self.val_dataset,
            batch_size=self.args["batch_size"]
        )

#### Lightning Module

In [6]:
class VideoClassificationLightningModule(pl.LightningModule):
    def __init__(self, model, args):

        self.args = args
        super().__init__()

        self.model = model


    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        return self._common_step(batch, batch_idx, stage='train')

    def validation_step(self, batch, batch_idx):
        self._common_step(batch, batch_idx, stage='val')
    
    def _common_step(self, batch, batch_idx, stage='train'):
        X, y = batch['video'], batch['label']

        output = self.model(X.permute(0, 2, 1, 3, 4)) # (8, 3, 16, 224, 224) -> (8, 16, 3, 224, 224)

        loss = F.cross_entropy(output.logits, y)
        acc = torchmetrics.functional.accuracy(output.logits, y, task="multiclass", num_classes=3)

        self.log(f"{stage}_loss", loss.item(), on_step=True, on_epoch=True, prog_bar=True)
        self.log(
            f"{stage}_acc", acc, on_step=True, on_epoch=True, prog_bar=True
        )
        if stage == 'train':
            return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            self.parameters(),
            lr=self.args["lr"],
            weight_decay=self.args["weight_decay"],
        )
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, self.args["max_epochs"], last_epoch=-1
        )
        return [optimizer], [scheduler]

#### Config

In [7]:
def create_preprocessor_config(model, image_processor, sample_rate=8, fps=30):

    mean = image_processor.image_mean
    std = image_processor.image_std

    if "shortest_edge" in image_processor.size:
        height = width = image_processor.size["shortest_edge"]
    else:
        height = image_processor.size["height"]
        width = image_processor.size["width"]

    crop_size = (height, width)

    num_frames_to_sample = model.config.num_frames # 16 for VideoMAE
    clip_duration = num_frames_to_sample * sample_rate / fps
    print('Clip Duration:', clip_duration, 'seconds')

    return {
        "image_mean" : mean,
        "image_std" : std,
        "crop_size" : crop_size,
        "num_frames_to_sample" : num_frames_to_sample,
        "clip_duration": clip_duration,
        "sample_rate" : sample_rate
    }

In [8]:
# PATH INFO
PROJ_DIR = '/Users/mpekey/Desktop/FlyVideo'
TRAIN_DATA_PATH = os.path.join(PROJ_DIR, 'FlyTrainingData', 'Train')
VAL_DATA_PATH = os.path.join(PROJ_DIR, 'FlyTrainingData', 'Validation')

# MODEL INFO
MODEL_CHECKPOINT = "MCG-NJU/videomae-base"

# DATASET INFO
class_labels = ['Feeding', 'Grooming', 'Pumping']
label2id = {label: i for i, label in enumerate(class_labels)}
id2label = {i: label for label, i in label2id.items()}

image_processor = VideoMAEImageProcessor.from_pretrained(MODEL_CHECKPOINT)
model = VideoMAEForVideoClassification.from_pretrained(
    MODEL_CHECKPOINT,
    label2id=label2id,
    id2label=id2label,
    ignore_mismatched_sizes=True,  # provide this in case you're planning to fine-tune an already fine-tuned checkpoint
    num_frames = 16 # Default is 16
)

model_args = create_preprocessor_config(model, image_processor, sample_rate=8, fps=30)

args = {
    # Data
    "train_data_path" : TRAIN_DATA_PATH,
    "val_data_path" : VAL_DATA_PATH,
    "lr" : 0.1,
    "weight_decay" : 1e-4,
    "max_epochs" : 1,
    "batch_size" : 8,
    "video_path_prefix" : '',
    "video_min_short_side_scale" : 256,
    "video_max_short_side_scale" : 320,
    "clip_duration" : model_args["clip_duration"],
    "crop_size" : model_args["crop_size"],
    "num_frames_to_sample": model_args["num_frames_to_sample"],
    "video_means" : model_args["image_mean"],
    "video_stds" : model_args["image_std"]
}

# Freeze the model
for param in model.videomae.parameters():
    param.requires_grad = False

Some weights of the model checkpoint at MCG-NJU/videomae-base were not used when initializing VideoMAEForVideoClassification: ['decoder.decoder_layers.0.layernorm_before.weight', 'decoder.decoder_layers.2.intermediate.dense.weight', 'mask_token', 'decoder.decoder_layers.3.intermediate.dense.weight', 'decoder.decoder_layers.0.attention.attention.q_bias', 'decoder.decoder_layers.1.output.dense.weight', 'decoder.decoder_layers.3.attention.attention.q_bias', 'decoder.decoder_layers.2.layernorm_before.weight', 'decoder.decoder_layers.2.output.dense.bias', 'decoder.decoder_layers.2.attention.output.dense.bias', 'decoder.decoder_layers.1.attention.attention.value.weight', 'decoder.decoder_layers.0.attention.output.dense.weight', 'decoder.decoder_layers.0.output.dense.weight', 'decoder.decoder_layers.1.attention.attention.v_bias', 'decoder.decoder_layers.2.attention.output.dense.weight', 'decoder.decoder_layers.2.layernorm_before.bias', 'decoder.decoder_layers.2.attention.attention.q_bias', 'd

Clip Duration: 4.266666666666667 seconds


#### Training

In [9]:
trainer = pl.Trainer(
    max_epochs=args["max_epochs"],
    callbacks=[TQDMProgressBar(refresh_rate=8)],
    accelerator="auto",
    #devices=1 if torch.cuda.is_available() else None,
)
classification_module = VideoClassificationLightningModule(model, args)
data_module = FlyDataModule(args)
trainer.fit(classification_module, data_module)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type                           | Params
---------------------------------------------------------
0 | model | VideoMAEForVideoClassification | 86.2 M
---------------------------------------------------------
2.3 K     Trainable params
86.2 M    Non-trainable params
86.2 M    Total params
344.918   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]



                                                                           

  rank_zero_warn(


Epoch 0: 100%|██████████| 80/80 [1:49:03<00:00, 81.80s/it, v_num=0, train_loss_step=3.65e-5, train_acc_step=1.000]



Epoch 0: 100%|██████████| 80/80 [2:30:40<00:00, 113.01s/it, v_num=0, train_loss_step=3.65e-5, train_acc_step=1.000, val_loss_step=46.60, val_acc_step=0.000, val_loss_epoch=36.70, val_acc_epoch=0.260, train_loss_epoch=15.30, train_acc_epoch=0.454]

`Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████| 80/80 [2:30:41<00:00, 113.02s/it, v_num=0, train_loss_step=3.65e-5, train_acc_step=1.000, val_loss_step=46.60, val_acc_step=0.000, val_loss_epoch=36.70, val_acc_epoch=0.260, train_loss_epoch=15.30, train_acc_epoch=0.454]
