## TODO
* Make "save_videos_dataset_as_frames" a DVC pipeline
* Extract patches from images
* Run classifiers on those patches
* Run mini-lstm on rgb-mean or something like that (should be enough to count the number of frames)
* Do augmentation
* Repeat the previous classifiers (they should fail)


In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path

import pandas as pd

from fight_classifier import DATASET_DIR, PROJECT_DIR
from fight_classifier.data.video_to_images import (
    save_videos_dataset_as_frames)

frames_dir = DATASET_DIR / 'raw_frames/'
videos_dir = DATASET_DIR / 'Peliculas/'

videos_df = pd.read_csv(videos_dir / 'videos.csv')
videos_df

In [None]:
# Save the videos to images
frames_df = save_videos_dataset_as_frames(
    videos_df=videos_df,
    videos_dir=videos_dir,
    frames_dir=frames_dir
)


## Evidence of clever-hans

* TODO: stop using mobilenet normalization for this

In [None]:
import pytorch_lightning as pl
from torchvision.models import mobilenet_v3_large, MobileNet_V3_Large_Weights

from fight_classifier.data.image_dataset import ImageDataModule
from fight_classifier.model.hans_model import SmallCnnImageClassifier
from fight_classifier.model.image_based_model import ImageClassifierModule, ProjFromFeatures

BATCH_SIZE = 40
SPLIT_COHERENCE_COL = 'fine_category'

frames_dir = DATASET_DIR / 'raw_frames/'
frames_df = pd.read_csv(str(frames_dir / 'frames.csv'))


base_model_weights = MobileNet_V3_Large_Weights.DEFAULT
base_model = mobilenet_v3_large(weights=base_model_weights)
base_model.eval()
preprocess = base_model_weights.transforms()
preprocess_kwargs = {
    'resize_size': preprocess.resize_size[0],
    'crop_size': preprocess.crop_size[0],
    'mean': preprocess.mean,
    'std': preprocess.std,
}

image_data_module = ImageDataModule(
    image_df=frames_df,
    batch_size=BATCH_SIZE,
    preprocess_kwargs=preprocess_kwargs,
    split_coherence_col=SPLIT_COHERENCE_COL)

# classifier = SmallCnnImageClassifier(n_layers=2)
classifier = ProjFromFeatures()

classif_module = ImageClassifierModule(classifier=classifier)

trainer = pl.Trainer(
    default_root_dir=str(PROJECT_DIR),
    limit_val_batches=300,
)

trainer.fit(
    model=classif_module,
    datamodule=image_data_module)

In [None]:
import albumentations as A
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import torch

from fight_classifier.visualization.patch_classification import viz_patch_heatmap
from fight_classifier.visualization import imshow_chw

idx = 0

ds = image_data_module.train_dataset
print(ds.resize_size, ds.crop_size)

image_row = frames_df.iloc[idx]
image_path = image_row[ds.image_path_col]
image = Image.open(image_path)
image_np = np.asarray(image)


im1 = A.SmallestMaxSize(max_size=ds.resize_size)(image=image_np)['image']
im2 = A.RandomCrop(height=ds.crop_size, width=ds.crop_size)(image=im1)['image']

plt.Figure()
plt.imshow(image_np)
plt.show()

plt.Figure()
plt.imshow(im1)
plt.show()

plt.Figure()
plt.imshow(im2)
plt.show()
