In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# Classify every image in fight/no-fight, take average
## On imagenet logits
## On imagenet features
## With finetuning
#

In [None]:
from pathlib import Path
import pandas as pd

from fight_classifier import PROJECT_DIR, DATASET_DIR
frames_dir = DATASET_DIR / 'raw_frames/'

frames_df = pd.read_csv(str(frames_dir / 'frames.csv'))
frames_df

In [None]:
from torchvision.models import mobilenet_v3_large, MobileNet_V3_Large_Weights
from torchvision.models.feature_extraction import get_graph_node_names, create_feature_extractor

base_model_weights = MobileNet_V3_Large_Weights.DEFAULT
base_model = mobilenet_v3_large(weights=base_model_weights)
base_model.eval()
preprocess = base_model_weights.transforms()

In [None]:
from fight_classifier.data.image_dataset import (
    ImageDataset, ImageDataModule)
from torch.utils.data import DataLoader, Dataset


image_dataset = ImageDataset(
    image_df=frames_df,
    image_path_col='frame_path',
    groundtruth_col='is_fight',
    preprocess=preprocess)

image_dataloader = DataLoader(image_dataset, shuffle=True, batch_size=10)
image_data_module = ImageDataModule(
    image_df=frames_df, batch_size=10, preprocess=preprocess, split_coherence_col='fine_category')

In [None]:
import pytorch_lightning as pl
from fight_classifier.model.image_based_model import (
    ProjFromFeatures, ImageClassifierModule)
classifier = ProjFromFeatures()

classif_module = ImageClassifierModule(classifier=classifier)

trainer = pl.Trainer(
    default_root_dir=str(PROJECT_DIR),
    val_check_interval=500,
)

trainer.fit(
    model=classif_module,
    datamodule=image_data_module)

In [None]:
import matplotlib.pyplot as plt
import torch

for i, example in enumerate(image_dataset):
    if i >= 4000:
        break
    if i % 200 != 0:
        continue
    batch = preprocess(example['image']).unsqueeze(0)
    print(example['image'].size, '-->', batch.shape)
    # plt.Figure()
    # plt.imshow(example['image'])
    # plt.show()
    
    prediction = base_model(batch).squeeze(0).softmax(0)
    class_id = prediction.argmax().item()
    score = prediction[class_id].item()
    category_name = base_model_weights.meta["categories"][class_id]
    print(f"{category_name}: {100 * score:.1f}%")

In [None]:
from torchvision.models.feature_extraction import get_graph_node_names, create_feature_extractor

nodes, _ = get_graph_node_names(base_model)
print(nodes)


# Confused about the node specification here?
# We are allowed to provide truncated node names, and `create_feature_extractor`
# will choose the last node with that prefix.
feature_extractor = create_feature_extractor(
    model, return_nodes=['features.16', 'flatten', 'classifier.0', 'classifier.1', 'classifier.2', 'classifier.3'])
# `out` will be a dict of Tensors, each representing a feature map
out = feature_extractor(torch.zeros(1, 3, 32, 32))

In [None]:
frames_df

In [None]:
from torch import Tensor

In [None]:
from fight_classifier import PROJECT_DIR
print(PROJECT_DIR)

In [None]:
type(preprocess)