In [8]:
import cv2
import ast
import mmcv
import math
import torch
import torch.nn.functional as F
import torchvision
import pickle as pkl
import pandas as pd
import numpy as np
import seaborn as sns
from einops import rearrange
import matplotlib.pyplot as plt
from torchvision.utils import make_grid

from torchmetrics.functional.classification import (
    multilabel_f1_score,
    multilabel_precision,
    multilabel_recall,
)

from data_utils import results2df
from sklearn.metrics.pairwise import cosine_similarity

# Slowfast imports
from slowfast.models import build_model
from slowfast.utils.parser import load_config, alt_parse_args
from slowfast.datasets.loader import construct_loader

In [None]:
# Set random seed for reproducibility
torch.manual_seed(0)

**Generating few-shot annotation file**

In [None]:
train_path = "/home/dl18206/Desktop/phd/code/personal/facebook/slowfast/dataset/annotations/standard/fg_only/standard/train.csv"
val_path = "/home/dl18206/Desktop/phd/code/personal/facebook/slowfast/dataset/annotations/standard/fg_only/standard/val.csv"
metadata_path = "/home/dl18206/Desktop/phd/code/personal/facebook/slowfast/dataset/metadata/with_negative_pairing/new_metadata.csv"

train_df = pd.read_csv(train_path)
val_df = pd.read_csv(val_path)

train_df.columns = ["subject_id", "label"]
val_df.columns = ["subject_id", "label"]

metadata = pd.read_csv(metadata_path)

with open(
    "/home/dl18206/Desktop/phd/code/personal/facebook/slowfast/dataset/metadata/behaviours.txt",
    "rb",
) as f:
    behaviours = [beh.decode("utf-8").strip() for beh in f.readlines()]

train_df = train_df.merge(
    metadata[["subject_id_fg", "value"]], right_on="subject_id_fg", left_on="subject_id"
)

val_df = val_df.merge(
    metadata[["subject_id_fg", "value"]], right_on="subject_id_fg", left_on="subject_id"
)


def is_fs(x):
    fs_behaviours = ["aggression"]
    for b in fs_behaviours:
        if b == x:
            return True
    for b in fs_behaviours:
        if b in x.split(","):
            return True
    return False


train_df[
    (train_df["value"].apply(is_fs)) & (train_df.subject_id.str.startswith("acp"))
][["subject_id", "label", "value"]]

val_df[(val_df["value"].apply(is_fs))][["subject_id", "label", "value"]]

# .to_csv(
#     "/home/dl18206/Desktop/phd/code/personal/facebook/slowfast/dataset/annotations/standard/fg_only/few_shot/aggression/val_aggression.csv",
#     index=False,
#     header=False,
# )

In [24]:
val_df_neg = pd.read_csv(
    "/home/dl18206/Desktop/phd/code/personal/facebook/slowfast/dataset/annotations/standard/with_negative_pairing/standard/val.csv"
)
val_df_neg.columns = ["fg", "bg", "label", "negative", "utm_code"]
val_df_neg = val_df_neg.merge(
    metadata[["subject_id_fg", "value"]], right_on="subject_id_fg", left_on="fg"
)[["fg", "bg", "label", "value"]]

fs_val_df_neg = val_df_neg[val_df_neg["value"].apply(is_fs)]

In [28]:
fs_labels = fs_val_df_neg.label

In [22]:
fs_val_df_neg[["bg", "label"]].to_csv(
    "/home/dl18206/Desktop/phd/code/personal/facebook/slowfast/dataset/annotations/standard/fg_only/few_shot/background/val.csv",
    index=False,
    header=False,
)

**Temporal activation maps**

In [None]:
# Load the model
path_to_config = "/home/dl18206/Desktop/phd/code/personal/facebook/slowfast/configs/SLOW_8x8_R50_Local.yaml"
path_to_ckpt = "/home/dl18206/Desktop/phd/code/personal/facebook/slowfast/checkpoint_epoch_00100.pyth"

args = alt_parse_args()[:-1]
cfg = load_config(
    args[0],
    path_to_config=path_to_config,
)
checkpoint = torch.load(path_to_ckpt, map_location="cpu")

model = build_model(cfg)
model.load_state_dict(checkpoint["model_state"])
model.eval()
model.cpu()

classifier = model.head.projection

# Load the data
loader = construct_loader(cfg, "test")  # dataset = build_dataset("tap", cfg, "train")
inputs, labels, index, time, meta = next(iter(loader))

**Funcs**

In [13]:
def get_feature_map(model, sample):
    with torch.no_grad():
        feature_map = model.s5(
            model.s4(model.s3(model.s2(model.s1([sample.unsqueeze(0)]))))
        )[
            0
        ]  # TODO: Investigate features at earlier layers
    return feature_map


def extract_frame_wise_features(feature_map, t):
    spatially_pooled = F.adaptive_avg_pool3d(feature_map, (t, 1, 1))
    frame_wise_features = torch.flatten(spatially_pooled, start_dim=2)
    return frame_wise_features

In [29]:
from tqdm import tqdm

In [None]:
name, feat_map = [], []
for i, (input, label, idx, time, meta) in tqdm(enumerate(loader)):
    feature_map = get_feature_map(model, input[0][0])
    name.append(meta["video_name"])
    feat_map.append(feature_map)

In [19]:
sample_index = 0
label = labels[sample_index]
sample = inputs[0][sample_index]
negative_sample = inputs[0][-1]

# sample = sample[
#     :,
#     :,
#     0:224,
# ] # TODO: Investigate timestamp

In [None]:
feature_map = get_feature_map(model, sample)
negative_feature_map = get_feature_map(model, negative_sample)
print(f"Feature map shape: {feature_map.shape}")
print(f"Video-level map shape: {F.adaptive_avg_pool3d(feature_map, (1, 1, 1)).shape}")

In [9]:
framewise_features = extract_frame_wise_features(feature_map, t=16)
negative_framewise_features = extract_frame_wise_features(negative_feature_map, t=16)

In [None]:
# Compute frame-wise cosine similarity
normalise = True
framewise_features = framewise_features.squeeze(0)
framewise_features = framewise_features.T
if normalise:
    framewise_features = F.normalize(framewise_features, p=2, dim=1)
print(framewise_features.shape, framewise_features.T.shape)
framewise_cosine_similarity = torch.mm(framewise_features, framewise_features.T)

In [12]:
# Compute frame-wise cosine similarity between the sample and the negative sample
negative_framewise_features = negative_framewise_features.squeeze(0)
negative_framewise_features = negative_framewise_features.T

if normalise:
    negative_framewise_features = F.normalize(negative_framewise_features, p=2, dim=1)

negative_framewise_cosine_similarity = torch.mm(
    framewise_features, negative_framewise_features
)

In [None]:
# Linear projection into lower dimensional space

linear_projection = torch.nn.Linear(
    in_features=framewise_features.shape[1], out_features=128
)

reduced_framewise_features = linear_projection(framewise_features)

if normalise:
    reduced_framewise_features = F.normalize(reduced_framewise_features, p=2, dim=1)

reduced_framewise_cosine_similarity = torch.mm(
    reduced_framewise_features, reduced_framewise_features.T
)

# Plot frame-wise cosine similarity
plt.figure(figsize=(10, 10))
sns.heatmap(
    reduced_framewise_cosine_similarity.detach().numpy(),
    xticklabels=False,
    yticklabels=False,
    cmap="viridis",
    annot=True,
)

In [13]:
frame_wise_logits = []
for feat in framewise_features:
    frame_wise_logits.append(torch.sigmoid(classifier(feat).detach()).numpy())
frame_wise_logits = np.array(frame_wise_logits)

In [14]:
# Video feats
video_level_features = F.adaptive_avg_pool3d(feature_map, (1, 1, 1))
video_level_features = torch.flatten(video_level_features, start_dim=1)

# Video logits
video_level_logits = torch.sigmoid(classifier(video_level_features)).detach().numpy()
video_level_logits = np.array(video_level_logits)

In [None]:
# Plot cosine similarity matrix
plt.figure(figsize=(12, 12))
sns.heatmap(framewise_cosine_similarity.detach().numpy(), cmap="viridis", annot=True)
plt.title("Frame-wise cosine similarity")
plt.show()

In [None]:
# Plot frame-wise logits as heatmap
plt.figure(figsize=(12, 8), dpi=100)
sns.heatmap(frame_wise_logits.T, cmap="viridis", annot=True, yticklabels=behaviours)
plt.title("Frame-wise logits")
plt.show()

In [None]:
# Plot framewise logits and video-level logits side by side
fig, ax = plt.subplots(
    1, 3, figsize=(20, 10), gridspec_kw={"width_ratios": [3, 0.75, 0.75]}
)
sns.heatmap(
    frame_wise_logits.T, cmap="viridis", annot=True, yticklabels=behaviours, ax=ax[0]
)
sns.heatmap(
    video_level_logits.T,
    cmap="viridis",
    annot=True,
    ax=ax[1],
    cbar=False,
)
sns.heatmap(
    label.unsqueeze(0).numpy().T,
    cmap="viridis",
    annot=True,
    ax=ax[2],
    cbar=False,
)

ax[0].title.set_text("Masked Frame-wise Logits")
ax[1].title.set_text("Video-level Logits")
ax[2].title.set_text("Label")

plt.show()

In [17]:
subframe_features = framewise_features[0:8:,]

# Subframe logits
subframe_logits = []
for feat in subframe_features:
    subframe_logits.append(torch.sigmoid(classifier(feat).detach()).numpy())

subframe_logits = np.array(subframe_logits)

# Adaptive pool over subframe features
subframe_features = subframe_features.unsqueeze(0)
subframe_features = F.adaptive_avg_pool2d(subframe_features, (1, 2048))[0]

# Apply classifier to subframe features
subvideo_logits = torch.sigmoid(classifier(subframe_features)).detach().numpy()

In [None]:
# Plot framewise logits and video-level logits side by side
fig, ax = plt.subplots(
    1, 4, figsize=(20, 10), gridspec_kw={"width_ratios": [3, 0.75, 0.75, 0.75]}
)
sns.heatmap(
    subframe_logits.T, cmap="viridis", annot=True, yticklabels=behaviours, ax=ax[0]
)
sns.heatmap(
    subvideo_logits.T,
    cmap="viridis",
    annot=True,
    ax=ax[1],
    cbar=False,
)

sns.heatmap(
    video_level_logits.T,
    cmap="viridis",
    annot=True,
    ax=ax[2],
    cbar=False,
)

sns.heatmap(
    label.unsqueeze(0).numpy().T,
    cmap="viridis",
    annot=True,
    ax=ax[3],
    cbar=False,
)

ax[0].title.set_text("Masked Frame-wise Logits")
ax[1].title.set_text("Sub-video Logits")
ax[2].title.set_text("Video-level Logits")
ax[3].title.set_text("Label")

plt.show()

In [17]:
sample = rearrange(sample, "c t w h -> t c w h")


def plot_video(clip, nrow=8):
    grid = make_grid(clip, nrow=nrow)
    img = torchvision.transforms.ToPILImage()(grid)
    img.show()

In [18]:
plot_video(sample)

In [20]:
multi_head_attention = torch.nn.MultiheadAttention(
    embed_dim=2048, num_heads=1, batch_first=True
)
q = framewise_features[
    14:15,
    :,
].unsqueeze(0)
kv = framewise_features.unsqueeze(0)
output, attention_weights = multi_head_attention(q, kv, kv)

In [None]:
q.shape, kv.shape

In [None]:
attention_weights

In [None]:
# Visualize attention weights
plt.figure(figsize=(6, 4))
sns.heatmap(
    attention_weights.squeeze(0).T.detach().numpy(),
    cmap="viridis",
    annot=True,
    fmt=".2f",
    cbar=True,
)
plt.title("Attention Weights Visualization")
plt.xlabel("Query")
plt.ylabel("Sequence Position")
plt.show()

In [None]:
negative_framewise_features.squeeze_(0).shape, framewise_features.squeeze_(0).shape

In [None]:
negative_framewise_features.shape, framewise_features.shape

In [23]:
negative_framewise_features = negative_framewise_features.T

In [None]:
# Cosine sim for temporal masking
background_embedding = framewise_features[
    12:13:, :
]  # negative_framewise_features[10:11:, :]

cosine_sim = cosine_similarity(
    background_embedding.detach().numpy(), framewise_features.detach().numpy()
)

# Normalize cosine similarity between 0 and 1
cosine_sim = (cosine_sim - np.min(cosine_sim)) / (
    np.max(cosine_sim) - np.min(cosine_sim)
)

plt.figure(figsize=(12, 2))
sns.heatmap(cosine_sim, cmap="viridis", annot=True, fmt=".2f")
plt.title("Cosine Similarity with Temporal Masking")
plt.xlabel("Query")
plt.ylabel("Sequence Position")
plt.show()

In [None]:
background_embedding.detach().numpy().shape, framewise_features.detach().numpy().shape

In [None]:
frame_wise_logits

In [14]:
fg_fewshot = pd.read_pickle(
    "/home/dl18206/Desktop/phd/code/personal/facebook/slowfast/dataset/results/r50_e100_fg_few_shot.pkl"
)
bg_fewshot = pd.read_pickle(
    "/home/dl18206/Desktop/phd/code/personal/facebook/slowfast/dataset/results/r50_e100_bg_few_shot.pkl"
)

In [15]:
fg_sample = fg_fewshot.feat_map.iloc[0]
bg_sample = bg_fewshot.feat_map.iloc[0]

# Compute cosine similarity between the two samples
fg_framewise_features = extract_frame_wise_features(fg_sample, t=16)
bg_framewise_features = extract_frame_wise_features(bg_sample, t=16)

# Squeeze and permute the dimensions
fg_framewise_features = fg_framewise_features.squeeze(0)
fg_framewise_features = fg_framewise_features.T

bg_framewise_features = bg_framewise_features.squeeze(0)
bg_framewise_features = bg_framewise_features.T

# Compute framewise logits
fg_frame_wise_logits = []
for feat in fg_framewise_features.squeeze(0):
    fg_frame_wise_logits.append(torch.sigmoid(classifier(feat).detach()).numpy())
fg_frame_wise_logits = np.array(fg_frame_wise_logits)

bg_frame_wise_logits = []
for feat in bg_framewise_features.squeeze(0):
    bg_frame_wise_logits.append(torch.sigmoid(classifier(feat).detach()).numpy())
bg_frame_wise_logits = np.array(bg_frame_wise_logits)

# Video feats
fg_video_level_features = F.adaptive_avg_pool3d(fg_sample, (1, 1, 1))
fg_video_level_features = torch.flatten(fg_video_level_features, start_dim=1)

fg_video_level_logits = (
    torch.sigmoid(classifier(fg_video_level_features)).detach().numpy()
)
fg_video_level_logits = np.array(fg_video_level_logits)

bg_video_level_features = F.adaptive_avg_pool3d(bg_sample, (1, 1, 1))
bg_video_level_features = torch.flatten(bg_video_level_features, start_dim=1)

bg_video_level_logits = (
    torch.sigmoid(classifier(bg_video_level_features)).detach().numpy()
)
bg_video_level_logits = np.array(bg_video_level_logits)

In [None]:
# Plot fg and bg frame-wise logits
fig, ax = plt.subplots(1, 2, figsize=(20, 10), gridspec_kw={"width_ratios": [1, 1]})
sns.heatmap(
    fg_frame_wise_logits.T, cmap="viridis", annot=True, yticklabels=behaviours, ax=ax[0]
)
sns.heatmap(
    bg_frame_wise_logits.T, cmap="viridis", annot=True, yticklabels=behaviours, ax=ax[1]
)

ax[0].title.set_text(f"FG Frame-wise Logits: {fg_fewshot.name.iloc[0][0]}")
ax[1].title.set_text(f"BG Frame-wise Logits: {bg_fewshot.name.iloc[0][0]}")

In [None]:
# Plot fg and bg video-level logits
fig, ax = plt.subplots(1, 2, figsize=(10, 10), gridspec_kw={"width_ratios": [1, 1]})

sns.heatmap(
    fg_video_level_logits.T,
    cmap="viridis",
    annot=True,
    ax=ax[0],
    cbar=False,
    yticklabels=behaviours,
)

sns.heatmap(
    bg_video_level_logits.T,
    cmap="viridis",
    annot=True,
    ax=ax[1],
    cbar=False,
)

ax[0].title.set_text(f"FG Video-level Logits: {fg_fewshot.name.iloc[0][0]}")
ax[1].title.set_text(f"BG Video-level Logits: {bg_fewshot.name.iloc[0][0]}")

In [18]:
def func(fg_sample, bg_sample, feats=False, logits=False):
    fg_framewise_features = extract_frame_wise_features(fg_sample, t=16)
    bg_framewise_features = extract_frame_wise_features(bg_sample, t=16)

    # Squeeze and permute the dimensions
    fg_framewise_features = fg_framewise_features.squeeze(0)
    fg_framewise_features = fg_framewise_features.T

    bg_framewise_features = bg_framewise_features.squeeze(0)
    bg_framewise_features = bg_framewise_features.T

    # Compute framewise logits
    fg_frame_wise_logits = []
    for feat in fg_framewise_features.squeeze(0):
        fg_frame_wise_logits.append(torch.sigmoid(classifier(feat).detach()).numpy())
    fg_frame_wise_logits = np.array(fg_frame_wise_logits)

    bg_frame_wise_logits = []
    for feat in bg_framewise_features.squeeze(0):
        bg_frame_wise_logits.append(torch.sigmoid(classifier(feat).detach()).numpy())
    bg_frame_wise_logits = np.array(bg_frame_wise_logits)

    # Video feats
    fg_video_level_features = F.adaptive_avg_pool3d(fg_sample, (1, 1, 1))
    fg_video_level_features = torch.flatten(fg_video_level_features, start_dim=1)

    fg_video_level_logits = (
        torch.sigmoid(classifier(fg_video_level_features)).detach().numpy()
    )
    fg_video_level_logits = np.array(fg_video_level_logits)

    bg_video_level_features = F.adaptive_avg_pool3d(bg_sample, (1, 1, 1))
    bg_video_level_features = torch.flatten(bg_video_level_features, start_dim=1)

    bg_video_level_logits = (
        torch.sigmoid(classifier(bg_video_level_features)).detach().numpy()
    )
    bg_video_level_logits = np.array(bg_video_level_logits)

    if feats:
        return fg_framewise_features, bg_framewise_features
    elif logits:
        return (
            fg_frame_wise_logits,
            bg_frame_wise_logits,
            fg_video_level_logits,
            bg_video_level_logits,
        )
    else:
        return (
            fg_framewise_features,
            bg_framewise_features,
            fg_frame_wise_logits,
            bg_frame_wise_logits,
            fg_video_level_logits,
            bg_video_level_logits,
        )

In [None]:
for fg_sample, bg_sample, fg_name, bg_name, label in zip(
    fg_fewshot.feat_map,
    bg_fewshot.feat_map,
    fg_fewshot.name,
    bg_fewshot.name,
    fs_val_df_neg.label,
):
    fg_framewise_features, bg_framewise_features = func(
        fg_sample, bg_sample, feats=True
    )

    (
        fg_frame_wise_logits,
        bg_frame_wise_logits,
        fg_video_level_logits,
        bg_video_level_logits,
    ) = func(fg_sample, bg_sample, logits=True)

    # Plot fg and bg frame-wise logits
    fig, ax = plt.subplots(
        1, 5, figsize=(20, 10), gridspec_kw={"width_ratios": [1, 1, 0.25, 0.25, 0.25]}
    )
    sns.heatmap(
        fg_frame_wise_logits.T,
        cmap="viridis",
        annot=True,
        yticklabels=behaviours,
        ax=ax[0],
    )
    sns.heatmap(
        bg_frame_wise_logits.T,
        cmap="viridis",
        annot=True,
        ax=ax[1],
    )

    sns.heatmap(
        fg_video_level_logits.T,
        cmap="viridis",
        annot=True,
        ax=ax[2],
        cbar=False,
    )

    sns.heatmap(
        bg_video_level_logits.T,
        cmap="viridis",
        annot=True,
        ax=ax[3],
        cbar=False,
    )

    sns.heatmap(
        np.expand_dims(np.squeeze(np.array(ast.literal_eval(label))), axis=0).T,
        cmap="viridis",
        annot=True,
        ax=ax[4],
        cbar=False,
    )

    ax[0].title.set_text(f"FG Video-level Logits: {fg_name}")
    ax[1].title.set_text(f"BG Video-level Logits: {bg_name}")

    plt.show()

In [30]:
import ast

In [1]:
# TODO: create temporal mask using the background sample

In [None]:
# np array from (B,) to (1, B)