In [1]:
import ast
import torch
import scienceplots
import numpy as np
import pandas as pd
import seaborn as sns
from tqdm import tqdm
import torch.nn.functional as F
import matplotlib.pyplot as plt

# Slowfast imports
from slowfast.models import build_model
from slowfast.utils.parser import load_config, alt_parse_args
from slowfast.datasets.loader import construct_loader

from temporal_masking_utils import (
    get_feature_map,
    get_weighted_features,
    extract_framewise_features,
    calculate_masking_results,
)
from torchmetrics.functional.classification import multilabel_f1_score
from torchmetrics.functional.classification import multilabel_average_precision
from torchmetrics.functional.classification import multilabel_recall
from torchmetrics.functional.classification import multilabel_precision


plt.style.use("science")
plt.rcParams.update({"font.family": "Times New Roman"})



In [2]:
train_path = "/home/dl18206/Desktop/phd/code/personal/facebook/slowfast/dataset/annotations/standard/fg_only/standard/train.csv"
val_path = "/home/dl18206/Desktop/phd/code/personal/facebook/slowfast/dataset/annotations/standard/fg_only/standard/val.csv"
metadata_path = "/home/dl18206/Desktop/phd/code/personal/facebook/slowfast/dataset/metadata/with_negative_pairing/new_metadata.csv"

train_df = pd.read_csv(train_path)
val_df = pd.read_csv(val_path)

train_df.columns = ["subject_id", "label"]
val_df.columns = ["subject_id", "label"]

metadata = pd.read_csv(metadata_path)

with open(
    "/home/dl18206/Desktop/phd/code/personal/facebook/slowfast/dataset/metadata/behaviours.txt",
    "rb",
) as f:
    behaviours = [beh.decode("utf-8").strip() for beh in f.readlines()]

with open(
    "/home/dl18206/Desktop/phd/code/personal/facebook/slowfast/dataset/metadata/segments.txt",
    "rb",
) as f:
    segments = [seg.decode("utf-8").strip() for seg in f.readlines()]

train_df = train_df.merge(
    metadata[["subject_id_fg", "value"]], right_on="subject_id_fg", left_on="subject_id"
)

val_df = val_df.merge(
    metadata[["subject_id_fg", "value"]], right_on="subject_id_fg", left_on="subject_id"
)

In [3]:
# Load the model
path_to_config = "/home/dl18206/Desktop/phd/code/personal/facebook/slowfast/configs/SLOW_8x8_R50_Local_TEST.yaml"
path_to_ckpt = "/home/dl18206/Desktop/phd/code/personal/facebook/slowfast/checkpoint_epoch_00100.pyth"

args = alt_parse_args()[:-1]
cfg = load_config(
    args[0],
    path_to_config=path_to_config,
)
checkpoint = torch.load(path_to_ckpt, map_location="cpu")

model = build_model(cfg)
model.load_state_dict(checkpoint["model_state"])
model.eval()
model.cpu()

classifier = model.head.projection

In [4]:
# Load the data
loader = construct_loader(cfg, "test")  # dataset = build_dataset("tap", cfg, "train")

SAMPLING.BALANCED: False; BALANCE_TYPE: None


In [None]:
names, feat_maps, labels = [], [], []
for i, (inputs, label, index, time, meta) in tqdm(enumerate(loader)):
    # Get the feature map
    feature_map = get_feature_map(model, inputs[0].squeeze())
    names.append(meta["video_name"])
    feat_maps.append(feature_map.detach().cpu())
    labels.append(label)

    # Clear torch cache
    torch.cuda.empty_cache()

In [24]:
names = [name[0] for name in names]
feat_maps = torch.cat(feat_maps, dim=0)
labels = torch.cat(labels, dim=0)

In [25]:
# Save the feature maps to pickle
with open(
    "/home/dl18206/Desktop/phd/code/personal/facebook/slowfast/dataset/results/r50_e100_fg_few_shot_train.pkl",
    "wb",
) as f:
    torch.save(dict(names=names, feat_maps=feat_maps, labels=labels), f)

In [None]:
weight_thresh = 0.5
mask_thresh = 0.5

fg_names, bg_names, preds, weighted_preds, labels, masks = [], [], [], [], [], []

for i, (inputs, label, index, time, meta) in tqdm(enumerate(loader)):
    # Track names
    fg_name = meta["fg_video_name"][0]
    bg_name = meta["bg_video_name"][0]
    fg_names.append(fg_name)
    bg_names.append(bg_name)

    # Extract feature map
    fg_map = get_feature_map(model, inputs["fg_frames"][0][0])
    bg_map = get_feature_map(model, inputs["bg_frames"][0][0])

    # Get frame-wise features
    fg_framewise_features = extract_framewise_features(fg_map)

    # Get frame-wise logits
    fg_framewise_logits = []
    for feat in fg_framewise_features:
        fg_framewise_logits.append(torch.sigmoid(classifier(feat).detach()).numpy())
    fg_framewise_logits = np.array(fg_framewise_logits)

    # Extract weighted features
    weighted_features, non_weighted_features, mask = get_weighted_features(
        fg_map,
        bg_map,
        classifier,
        weight_thresh=weight_thresh,
        mask_thresh=mask_thresh,
        weight_features_by_mask=True,
        return_mask=True,
    )
    # Store the mask
    masks.append(mask)

    # Apply the classifier
    pred = torch.sigmoid(classifier(non_weighted_features).squeeze(0))
    weighted_pred = torch.sigmoid(classifier(weighted_features).squeeze(0))

    # Store the results
    preds.append(pred.cpu().detach())
    weighted_preds.append(weighted_pred.cpu().detach())
    labels.append(label.cpu().detach())

    if i % 100 == 0:
        print(f"Processed {i} samples")

        # Convert to numpy
        fg_names = np.array(fg_names)
        bg_names = np.array(bg_names)

        masks = np.array(masks)

        preds = torch.stack(preds)
        weighted_preds = torch.stack(weighted_preds)

        labels = torch.stack(labels)

        # Save all in a single dictionary
        results = {
            "fg_names": fg_names,
            "bg_names": bg_names,
            "preds": preds,
            "weighted_preds": weighted_preds,
            "labels": labels,
            "masks": masks,
        }

        # Save the results using torch
        torch.save(
            results,
            f"weighted_features_w-mask_weight_t={weight_thresh}_mask_thresh={mask_thresh}_iter={i}_split=val.pt",
        )

        # Reset the lists
        fg_names, bg_names, preds, weighted_preds, labels, masks = (
            [],
            [],
            [],
            [],
            [],
            [],
        )

        # Free up memory
        torch.cuda.empty_cache()

In [11]:
fg_names, bg_names, preds, weighted_preds, labels, masks = [], [], [], [], [], []

iters = np.arange(0, 1100, 100)

for iter in iters:
    results = torch.load(
        f"weighted_features_w-mask_weight_t={weight_thresh}_mask_thresh={mask_thresh}_iter={iter}_split=val.pt"
    )

    fg_names.append(results["fg_names"])
    bg_names.append(results["bg_names"])
    preds.append(results["preds"])
    weighted_preds.append(results["weighted_preds"])
    labels.append(results["labels"])
    masks.append(results["masks"])

# Concatenate the results
fg_names = np.concatenate(fg_names)
bg_names = np.concatenate(bg_names)
masks = np.concatenate(masks)
preds = torch.cat(preds)
weighted_preds = torch.cat(weighted_preds)
labels = torch.cat(labels)

# Save the results
results = {
    "fg_names": fg_names,
    "bg_names": bg_names,
    "preds": preds,
    "weighted_preds": weighted_preds,
    "labels": labels,
    "masks": masks,
}

torch.save(
    results,
    f"weighted_features_w-mask_weight_t={weight_thresh}_mask_thresh={mask_thresh}_split=val.pt",
)

**Weighted results**

In [4]:
# Weighted results
weighted_result = torch.load(
    "weighted_features_w-mask_weight_t=0.5_mask_thresh=0.5_split=val.pt"
)

recall = multilabel_recall(
    weighted_result["preds"],
    weighted_result["labels"].squeeze(1),
    average="none",
    num_labels=14,
)

weighted_recall = multilabel_recall(
    weighted_result["weighted_preds"],
    weighted_result["labels"].squeeze(1),
    average="none",
    num_labels=14,
)

f1 = multilabel_f1_score(
    weighted_result["preds"],
    weighted_result["labels"].squeeze(1).long(),
    average="none",
    num_labels=14,
)

weighted_f1 = multilabel_f1_score(
    weighted_result["weighted_preds"],
    weighted_result["labels"].squeeze(1).long(),
    average="none",
    num_labels=14,
)

average_precision = multilabel_precision(
    weighted_result["preds"],
    weighted_result["labels"].squeeze(1).long(),
    average="none",
    num_labels=14,
)

weighted_average_precision = multilabel_precision(
    weighted_result["weighted_preds"],
    weighted_result["labels"].squeeze(1).long(),
    average="none",
    num_labels=14,
)


weighted_df = pd.DataFrame(
    {
        "behaviour": behaviours,
        "segment": segments,
        "f1": f1.cpu().numpy(),
        "weighted_f1": weighted_f1.cpu().numpy(),
        "recall": recall.cpu().numpy(),
        "weighted_recall": weighted_recall.cpu().numpy(),
        "precision": average_precision.cpu().numpy(),
        "weighted_precision": weighted_average_precision.cpu().numpy(),
    }
)

weighted_avg_df = pd.DataFrame(
    {
        "overall": weighted_df[weighted_df.columns[2:]].mean(),
        "head": weighted_df[weighted_df.segment == "head"][
            weighted_df.columns[2:]
        ].mean(),
        "tail": weighted_df[weighted_df.segment == "tail"][
            weighted_df.columns[2:]
        ].mean(),
        "fewshot": weighted_df[weighted_df.segment == "few_shot"][
            weighted_df.columns[2:]
        ].mean(),
    }
).T

In [None]:
weighted_df

In [None]:
weighted_avg_df