In [None]:
import os
import torch
import pytorchvideo.data

from pytorchvideo.transforms import (
    ApplyTransformToKey,
    Normalize,
    RandomShortSideScale,
    RemoveKey,
    ShortSideScale,
    UniformTemporalSubsample,
)

from torchvision.transforms import (
    CenterCrop,
    Compose,
    Lambda,
    RandomCrop,
    RandomHorizontalFlip,
    Resize,
)

In [None]:
# PATH INFO
PROJ_DIR = '/Users/mpekey/Desktop/FlyVideo'
TRAIN_DATA_PATH = os.path.join(PROJ_DIR, 'FlyTrainingData', 'Train')
VAL_DATA_PATH = os.path.join(PROJ_DIR, 'FlyTrainingData', 'Validation')

# MODEL INFO
MODEL_CHECKPOINT = "MCG-NJU/videomae-base"
BATCH_SIZE = 8

# DATASET INFO
class_labels = ['Feeding', 'Grooming', 'Pumping']
label2id = {label: i for i, label in enumerate(class_labels)}
id2label = {i: label for label, i in label2id.items()}

### Creating Model

In [None]:
from transformers import VideoMAEImageProcessor, VideoMAEForVideoClassification


image_processor = VideoMAEImageProcessor.from_pretrained(MODEL_CHECKPOINT)
model = VideoMAEForVideoClassification.from_pretrained(
    MODEL_CHECKPOINT,
    label2id=label2id,
    id2label=id2label,
    ignore_mismatched_sizes=True,  # provide this in case you're planning to fine-tune an already fine-tuned checkpoint
    num_frames = 16 # Default is 16
)

#### Model Configurations

In [None]:
# Image Preprocessing

mean = image_processor.image_mean
std = image_processor.image_std

if "shortest_edge" in image_processor.size:
    height = width = image_processor.size["shortest_edge"]
else:
    height = image_processor.size["height"]
    width = image_processor.size["width"]

crop_size = (height, width)


num_frames_to_sample = model.config.num_frames # 16 for VideoMAE
sample_rate = 8
fps = 30
clip_duration = num_frames_to_sample * sample_rate / fps
print('Clip Duration:', clip_duration, 'seconds')

#### Augmentations

Train ve Val ayri olacak

In [None]:
basic_transforms = Compose(
    [
        ApplyTransformToKey(
            key="video",
            transform=Compose(
                [
                    UniformTemporalSubsample(num_frames_to_sample),
                    Lambda(lambda x: x / 255.0),
                    Normalize(mean, std),
                    RandomShortSideScale(min_size=256, max_size=320),
                    RandomCrop(crop_size),
                    #RandomHorizontalFlip(p=0.5)
                ]
            ),
        ),
    ]
)

#### Creating Dataset

Train ve Val ayri olacak

In [None]:
fly_dataset = pytorchvideo.data.labeled_video_dataset(data_path=DATA_PATH,
                                                      clip_sampler=pytorchvideo.data.make_clip_sampler('uniform', clip_duration),
                                                      transform=basic_transforms,
                                                      video_path_prefix='',
                                                      decode_audio=False)

#### Limit Dataset

 To ensure a constant number of samples are retrieved from the dataset

In [None]:
class LimitDataset(torch.utils.data.Dataset):

    def __init__(self, dataset):
        super().__init__()
        self.dataset = dataset
        self.dataset_iter = itertools.chain.from_iterable(
            itertools.repeat(iter(dataset), 2)
        )

    def __getitem__(self, index):
        return next(self.dataset_iter)

    def __len__(self):
        return self.dataset.num_videos

#### Fly Data Module

In [None]:
# Import Lightning

In [None]:
class FlyDataModule(pytorch_lightning.LightningDataModule):

    def __init__(self, args):
        self.args = args
        super().__init__()

    def _make_transforms(self, mode: str):
        return Compose(self._video_transform(mode))

    def _video_transform(self, mode: str):
        return ApplyTransformToKey(
            key="video",
            transform=Compose(
                [
                    UniformTemporalSubsample(self.args.video_num_subsampled),
                    Lambda(lambda x: x / 255.0),
                    Normalize(self.args.video_means, self.args.video_stds),
                ]
                + (
                    [
                        RandomShortSideScale(
                            min_size=self.args.video_min_short_side_scale,
                            max_size=self.args.video_max_short_side_scale,
                        ),
                        RandomCrop(self.args.video_crop_size),
                    ]
                    if mode == "train"
                    else [
                        ShortSideScale(self.args.video_min_short_side_scale),
                        CenterCrop(self.args.video_crop_size),
                    ]
                )
            ),
        )

    def train_dataloader(self):

        train_transform = self._make_transforms(mode="train")

        self.train_dataset = LimitDataset(
            pytorchvideo.data.labeled_video_dataset(
                data_path=TRAIN_DATA_PATH,
                clip_sampler=pytorchvideo.data.make_clip_sampler('uniform', self.args.clip_duration), # Experiment olarak random da denenebilir
                transform=train_transform,
                video_path_prefix=self.args.video_path_prefix, # could be '' I think
                decode_audio=False
            )
        )

        return torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=self.args.batch_size
        )

    def val_dataloader(self):
        
        val_transform = self._make_transforms(mode="val")

        self.val_dataset = LimitDataset(
            pytorchvideo.data.labeled_video_dataset(
                data_path=VAL_DATA_PATH,
                clip_sampler=pytorchvideo.data.make_clip_sampler('uniform', self.args.clip_duration), # Experiment olarak random da denenebilir
                transform=val_transform,
                video_path_prefix=self.args.video_path_prefix, # could be '' I think
                decode_audio=False
            )
        )

        return torch.utils.data.DataLoader(
            self.val_dataset,
            batch_size=self.args.batch_size
        )

#### Lightning Module

In [None]:
class VideoClassificationLightningModule(pytorch_lightning.LightningModule):
    def __init__(self, args):

        self.args = args
        super().__init__()

        # Model
        self.model = pytorchvideo.models.resnet.create_resnet(
            input_channel=3,
            model_num_class=400,
        )

        # Metrics
        self.train_accuracy = pytorch_lightning.metrics.Accuracy()
        self.val_accuracy = pytorch_lightning.metrics.Accuracy()


    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        return self._common_step(self, batch, batch_idx, stage='train')

    def validation_step(self, batch, batch_idx):
        # Should return or not in validation ???
        return self._common_step(self, batch, batch_idx, stage='train')
    
    def _common_step(self, batch, batch_idx, stage='train'):
        X, y = batch['video'], batch['label']

        y_pred = self.model(X)
        loss = F.cross_entropy(y_pred, y)
        acc = self.val_accuracy(F.softmax(y_pred, dim=-1), y)

        self.log(f"{stage}_loss", loss)
        self.log(
            f"{stage}_acc", acc, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True
        )
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            self.parameters(),
            lr=self.args.lr,
            momentum=self.args.momentum,
            weight_decay=self.args.weight_decay,
        )
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, self.args.max_epochs, last_epoch=-1
        )
        return [optimizer], [scheduler]

#### Training

In [None]:
trainer = pytorch_lightning.Trainer.from_argparse_args(args)
classification_module = VideoClassificationLightningModule(args)
data_module = FlyDataModule(args)
trainer.fit(classification_module, data_module)