## TODO
* Make "save_videos_dataset_as_frames" a DVC pipeline
* Extract patches from images
* Run classifiers on those patches
* Run mini-lstm on rgb-mean or something like that (should be enough to count the number of frames)
* Do augmentation
* Repeat the previous classifiers (they should fail)


In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path

import pandas as pd

from fight_classifier import DATASET_DIR, PROJECT_DIR
from fight_classifier.data.video_to_images import (
    save_videos_dataset_as_frames)

frames_dir = DATASET_DIR / 'raw_frames/'
videos_dir = DATASET_DIR / 'Peliculas/'

videos_df = pd.read_csv(videos_dir / 'videos.csv')
videos_df

In [None]:
# Save the videos to images
frames_df = save_videos_dataset_as_frames(
    videos_df=videos_df,
    videos_dir=videos_dir,
    frames_dir=frames_dir
)


## Evidence of clever-hans

* TODO: stop using mobilenet normalization for this

In [None]:
import pytorch_lightning as pl
from torchvision.models import mobilenet_v3_large, MobileNet_V3_Large_Weights

from fight_classifier.data.image_dataset import ImageDataModule
from fight_classifier.model.hans_model import SmallCnnImageClassifier
from fight_classifier.model.image_based_model import ImageClassifierModule

BATCH_SIZE = 10
SPLIT_COHERENCE_COL = 'fine_category'

frames_dir = DATASET_DIR / 'raw_frames/'
frames_df = pd.read_csv(str(frames_dir / 'frames.csv'))


base_model_weights = MobileNet_V3_Large_Weights.DEFAULT
base_model = mobilenet_v3_large(weights=base_model_weights)
base_model.eval()
preprocess = base_model_weights.transforms()
preprocess_kwargs = {
    'resize_size': preprocess.resize_size[0],
    'crop_size': preprocess.crop_size[0],
    'mean': preprocess.mean,
    'std': preprocess.std,
}

image_data_module = ImageDataModule(
    image_df=frames_df,
    batch_size=BATCH_SIZE,
    preprocess_kwargs=preprocess_kwargs,
    split_coherence_col=SPLIT_COHERENCE_COL)

classifier = SmallCnnImageClassifier(n_layers=1)

classif_module = ImageClassifierModule(classifier=classifier)

trainer = pl.Trainer(
    default_root_dir=str(PROJECT_DIR),
    val_check_interval=500,
    limit_val_batches=300,
)

trainer.fit(
    model=classif_module,
    datamodule=image_data_module)

In [None]:
import matplotlib.pyplot as plt

trained_weights = classif_module.classifier.state_dict()
weight1 = trained_weights['lay1.weight']

plt.Figure()
plt.imshow(weight1[0])
plt.show()

plt.Figure()
plt.imshow(weight1[1])
plt.show()

In [None]:
import seaborn as sns


palette = sns.color_palette("vlag", as_cmap=True)

In [None]:
import torch
classifier.eval()

for img, groundtruth in image_data_module.train_dataset:
    print(img.shape)
    rescaled_img = img - torch.amin(img, dim=(1,2), keepdim=True)
    rescaled_img = rescaled_img / torch.amax(rescaled_img, dim=(1,2), keepdim=True)
    print(rescaled_img.min(), rescaled_img.max())
    plt.Figure()
    plt.imshow(torch.permute(rescaled_img, (1,2,0)))
    plt.show()
    break

In [None]:
dir(preprocess)

In [None]:
preprocess.resize_size

In [None]:
import numpy as np
from PIL import Image

a = np.zeros((232, 232, 3), dtype=np.uint8)
a[3:-3, 3:-3] = 255
processed_a = preprocess(Image.fromarray(a))

renormalized_a = a * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]
plt.Figure()
plt.imshow(renormalized_a)
plt.show()