In [2]:
import os
import itertools
import torch
import lightning.pytorch as pl
from lightning.pytorch.callbacks.progress import TQDMProgressBar
import torchmetrics
import torch.nn.functional as F
import pytorchvideo.data

import matplotlib.pyplot as plt
#import wandb

from pytorchvideo.transforms import (
    ApplyTransformToKey,
    Normalize,
    RandomShortSideScale,
    RemoveKey,
    ShortSideScale,
    UniformTemporalSubsample,
)

from torchvision.transforms import (
    CenterCrop,
    Compose,
    Lambda,
    RandomCrop,
    RandomHorizontalFlip,
    Resize,
)


from transformers import AutoImageProcessor, TimesformerForVideoClassification



#### Limit Dataset

 To ensure a constant number of samples are retrieved from the dataset

In [3]:
class LimitDataset(torch.utils.data.Dataset):

    def __init__(self, dataset):
        super().__init__()
        self.dataset = dataset
        self.dataset_iter = itertools.chain.from_iterable(
            itertools.repeat(iter(dataset), 2)
        )

    def __getitem__(self, index):
        return next(self.dataset_iter)

    def __len__(self):
        return self.dataset.num_videos

#### Fly Data Module

In [4]:
class FlyDataModule(pl.LightningDataModule):

    def __init__(self, args):
        self.args = args
        super().__init__()

    def _make_transforms(self, mode: str):
        return Compose([self._video_transform(mode)])

    def _video_transform(self, mode: str):
        return ApplyTransformToKey(
            key="video",
            transform=Compose(
                [
                    UniformTemporalSubsample(self.args["num_frames_to_sample"]),
                    Lambda(lambda x: x / 255.0),
                    Normalize(self.args["video_means"], self.args["video_stds"]),
                ]
                + (
                    [
                        RandomShortSideScale(
                            min_size=self.args["video_min_short_side_scale"],
                            max_size=self.args["video_max_short_side_scale"],
                        ),
                        RandomCrop(self.args["crop_size"]),
                    ]
                    if mode == "train"
                    else [
                        ShortSideScale(self.args["video_min_short_side_scale"]),
                        CenterCrop(self.args["crop_size"]),
                    ]
                )
            ),
        )

    def train_dataloader(self):

        train_transform = self._make_transforms(mode="train")

        self.train_dataset = LimitDataset(
            pytorchvideo.data.labeled_video_dataset(
                data_path=self.args["train_data_path"],
                clip_sampler=pytorchvideo.data.make_clip_sampler('uniform', self.args["clip_duration"]), # Experiment olarak random da denenebilir
                transform=train_transform,
                video_path_prefix=self.args["video_path_prefix"], # could be '' I think
                decode_audio=False
            )
        )

        return torch.utils.data.DataLoader(
            self.train_dataset,
            batch_size=self.args["batch_size"]
        )

    def val_dataloader(self):
        
        val_transform = self._make_transforms(mode="val")

        self.val_dataset = LimitDataset(
            pytorchvideo.data.labeled_video_dataset(
                data_path=self.args["val_data_path"],
                clip_sampler=pytorchvideo.data.make_clip_sampler('uniform', self.args["clip_duration"]), # Experiment olarak random da denenebilir
                transform=val_transform,
                video_path_prefix=self.args["video_path_prefix"], # could be '' I think
                decode_audio=False
            )
        )

        return torch.utils.data.DataLoader(
            self.val_dataset,
            batch_size=self.args["batch_size"]
        )

#### Lightning Module

In [19]:
class VideoClassificationLightningModule(pl.LightningModule):
    def __init__(self, model, args):
        super().__init__()

        self.args = args
        self.model = model
        self.dataloader_length = 0
        self.classes = ['Feeding', 'Grooming', 'Pumping']

        self.save_hyperparameters("args")
        
        # For logging outputs
        self.epoch_logits = []
        self.epoch_incorrect_samples = None

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        return self._common_step(batch, batch_idx, stage='train')

    def validation_step(self, batch, batch_idx):
        self._common_step(batch, batch_idx, stage='val')
    
    def _common_step(self, batch, batch_idx, stage='train'):
        X, y = batch['video'], batch['label']

        output = self(X.permute(0, 2, 1, 3, 4)) # (8, 3, 16, 224, 224) -> (8, 16, 3, 224, 224)

        loss = F.cross_entropy(output.logits, y)
        acc = torchmetrics.functional.accuracy(output.logits, y, task="multiclass", num_classes=3)

        self.log(
            f"{stage}_loss", loss.item(), batch_size=self.args["batch_size"], on_step=False, on_epoch=True, prog_bar=True
        )
        self.log(
            f"{stage}_acc", acc, batch_size=self.args["batch_size"], on_step=False, on_epoch=True, prog_bar=True
        )
        if stage == 'train':
            return loss
        elif stage == 'val':
            _, predictions = torch.max(output.logits, dim=1)
            incorrect_samples = X.permute(0, 2, 1, 3, 4)[predictions != y]  # Get incorrect samples
            self.epoch_logits.extend(output.logits)
            if self.epoch_incorrect_samples is None:
                self.epoch_incorrect_samples = incorrect_samples[0]
                print(len(incorrect_samples))

    def on_validation_epoch_end(self):
        #dummy_input = torch.zeros((1, 8, 3, 224, 224), device=self.device)
        #model_filename = "model_ckpt.onnx"
        #torch.onnx.export(self, dummy_input, model_filename, opset_version=11)
        #artifact = wandb.Artifact(name="model.ckpt", type="model")
        #artifact.add_file(model_filename)
        #self.logger.experiment.log_artifact(artifact)

        flattened_logits = torch.flatten(torch.cat(self.epoch_logits))
        incorrect_samples = self.epoch_incorrect_samples

        print('false_predictions:', incorrect_samples.shape)
        print('logits:', flattened_logits)
        print('global_step:', self.global_step)
        
        self.epoch_logits.clear()
        self.epoch_incorrect_samples=None

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            self.parameters(),
            lr=self.args["lr"],
        )
        return [optimizer]

#### Config

In [6]:
def create_preprocessor_config(model, image_processor, sample_rate=8, fps=30):

    mean = image_processor.image_mean
    std = image_processor.image_std

    if "shortest_edge" in image_processor.size:
        height = width = image_processor.size["shortest_edge"]
    else:
        height = image_processor.size["height"]
        width = image_processor.size["width"]

    crop_size = (height, width)

    num_frames_to_sample = model.config.num_frames # 16 for VideoMAE
    clip_duration = num_frames_to_sample * sample_rate / fps
    print('Clip Duration:', clip_duration, 'seconds')

    return {
        "image_mean" : mean,
        "image_std" : std,
        "crop_size" : crop_size,
        "num_frames_to_sample" : num_frames_to_sample,
        "clip_duration": clip_duration
    }

In [7]:
def get_timesformer_model(ckpt, label2id, id2label, num_frames):
    return TimesformerForVideoClassification.from_pretrained(
        ckpt,
        label2id=label2id,
        id2label=id2label,
        ignore_mismatched_sizes=True,  # provide this in case you're planning to fine-tune an already fine-tuned checkpoint
        num_frames = num_frames
    )

In [9]:
# PATH INFO
PROJ_DIR = '/cta/users/mpekey/FlyVideo'
TRAIN_DATA_PATH = os.path.join(PROJ_DIR, 'FlyTrainingData', 'Train')
VAL_DATA_PATH = os.path.join(PROJ_DIR, 'FlyTrainingData', 'Validation')


# DATASET INFO
class_labels = ['Feeding', 'Grooming', 'Pumping']
label2id = {label: i for i, label in enumerate(class_labels)}
id2label = {i: label for label, i in label2id.items()}

image_processor = AutoImageProcessor.from_pretrained("MCG-NJU/videomae-base")
model = get_timesformer_model(ckpt="facebook/timesformer-base-finetuned-k400",
                                  label2id=label2id,
                                  id2label=id2label,
                                  num_frames=8)

# Freeze the model    
for param in model.timesformer.parameters():
    param.requires_grad = False


Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.


Some weights of TimesformerForVideoClassification were not initialized from the model checkpoint at facebook/timesformer-base-finetuned-k400 and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([400, 768]) in the checkpoint and torch.Size([3, 768]) in the model instantiated
- classifier.bias: found shape torch.Size([400]) in the checkpoint and torch.Size([3]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [13]:
PROJ_DIR = '/cta/users/mpekey/FlyVideo'
TRAIN_DATA_PATH = os.path.join(PROJ_DIR, 'FlyTrainingData', 'Train')
VAL_DATA_PATH = os.path.join(PROJ_DIR, 'FlyTrainingData', 'Validation')

In [14]:
# Create Arguments
model_args = create_preprocessor_config(model, 
                                        image_processor, 
                                        sample_rate=16, 
                                        fps=30)

args = {
    # Data
    "train_data_path" : TRAIN_DATA_PATH,
    "val_data_path" : VAL_DATA_PATH,
    "lr" : 0.001,
    "max_epochs" : 1,
    "batch_size" : 16,
    "video_path_prefix" : '',
    "video_min_short_side_scale" : 256,
    "video_max_short_side_scale" : 320,
    "clip_duration" : model_args["clip_duration"],
    "crop_size" : model_args["crop_size"],
    "num_frames_to_sample": model_args["num_frames_to_sample"],
    "video_means" : model_args["image_mean"],
    "video_stds" : model_args["image_std"],
    "sample_rate": 16,
    "fps":30,
    "num_frames":8
}

Clip Duration: 4.266666666666667 seconds


#### Training

In [20]:
trainer = pl.Trainer(
        max_epochs=args['max_epochs'],
        callbacks=[TQDMProgressBar(refresh_rate=args['batch_size'])],
        accelerator="gpu" if torch.cuda.is_available() else "auto",
        devices=1 if torch.cuda.is_available() else None,
        log_every_n_steps=40
    )
classification_module = VideoClassificationLightningModule(model, args)
data_module = FlyDataModule(args)
trainer.fit(classification_module, data_module)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name  | Type                              | Params
------------------------------------------------------------
0 | model | TimesformerForVideoClassification | 121 M 
------------------------------------------------------------
2.3 K     Trainable params
121 M     Non-trainable params
121 M     Total params
485.044   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]13
Sanity Checking DataLoader 0: 100%|██████████| 2/2 [00:36<00:00, 18.27s/it]false_predictions: torch.Size([8, 3, 224, 224])
logits: tensor([ 0.6479,  0.8090, -0.2530,  0.5467,  0.6009, -0.0518,  0.4856,  0.3367,
         0.1013,  0.4477,  0.1497,  0.1963,  0.3785,  0.3198,  0.6934,  0.5756,
         0.4119, -0.2066,  0.4649,  0.5531, -0.4300,  0.8923,  0.3700, -0.2273,
         0.8465,  0.3662, -0.0679,  0.7400,  0.2328,  0.1603,  0.5228,  0.4122,
        -0.0802,  0.4988,  0.4221, -0.0790,  0.5141,  0.3994, -0.1032,  0.5178,
         0.4667, -0.0360,  0.4729,  0.4105, -0.1231,  0.4547,  0.3870, -0.0892,
         0.6912,  0.4379, -0.6753,  0.8330,  0.1755, -0.3654,  0.8428,  0.0978,
        -0.3756, -0.2428,  0.3205, -0.3050, -0.2836,  0.3562, -0.3126, -0.2763,
         0.3484, -0.3535, -0.4278,  0.2953, -0.3705, -0.3642,  0.3027, -0.3779,
         0.2304,  0.7050,  0.0158,  0.1370,  0.4829,  0.1189,  0.3913,  0.3473,

`Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████| 40/40 [34:28<00:00, 51.71s/it, v_num=4, val_loss=0.685, val_acc=0.750, train_loss=0.852, train_acc=0.619]


In [21]:
classification_module.hparams

"args": {'train_data_path': '/cta/users/mpekey/FlyVideo/FlyTrainingData/Train', 'val_data_path': '/cta/users/mpekey/FlyVideo/FlyTrainingData/Validation', 'lr': 0.001, 'max_epochs': 1, 'batch_size': 16, 'video_path_prefix': '', 'video_min_short_side_scale': 256, 'video_max_short_side_scale': 320, 'clip_duration': 4.266666666666667, 'crop_size': (224, 224), 'num_frames_to_sample': 8, 'video_means': [0.485, 0.456, 0.406], 'video_stds': [0.229, 0.224, 0.225], 'sample_rate': 16, 'fps': 30, 'num_frames': 8}