In [None]:
import ast
import mmcv
import math
import torch
import torch.nn.functional as F
import torchvision
import pickle as pkl
import pandas as pd
import numpy as np
import seaborn as sns
from einops import rearrange
import matplotlib.pyplot as plt

from torchmetrics.functional.classification import (
    multilabel_f1_score,
    multilabel_precision,
    multilabel_recall,
)

from data_utils import results2df
from sklearn.metrics.pairwise import cosine_similarity

# Slowfast imports
from slowfast.models import build_model
from slowfast.utils.parser import load_config, alt_parse_args
from slowfast.datasets.loader import construct_loader

In [None]:
path_to_config = "/home/dl18206/Desktop/phd/code/personal/facebook/slowfast/configs/SLOW_8x8_R50_Local.yaml"
path_to_ckpt = "/home/dl18206/Desktop/phd/code/personal/facebook/slowfast/checkpoint_epoch_00200.pyth"

args = alt_parse_args()[:-1]
cfg = load_config(
    args[0],
    path_to_config=path_to_config,
)
checkpoint = torch.load(path_to_ckpt, map_location="cpu")

model = build_model(cfg)
model.load_state_dict(checkpoint["model_state"])
model.eval()
model.cpu()

In [None]:
classifier = model.head.projection

In [None]:
loader = construct_loader(cfg, "test")  # dataset = build_dataset("tap", cfg, "train")
sample = next(iter(loader))

In [None]:
inputs, labels, index, time, meta = sample

In [None]:
meta, labels

In [None]:
feature_map = model.s5(
    model.s4(model.s3(model.s2(model.s1([inputs[0][1].unsqueeze(0)]))))
)[0]

In [None]:
def extract_frame_wise_features(feature_map, t):
    spatially_pooled = F.adaptive_avg_pool3d(feature_map, (t, 1, 1))
    frame_wise_features = torch.flatten(spatially_pooled, start_dim=2)
    return frame_wise_features


frame_wise_features = extract_frame_wise_features(feature_map, t=16)

In [None]:
# Compute frame-wise cosine similarity
frame_wise_features = frame_wise_features.squeeze(0)
frame_wise_features = frame_wise_features.T
frame_wise_features = F.normalize(frame_wise_features, p=2, dim=1)
frame_wise_cosine_similarity = torch.mm(frame_wise_features, frame_wise_features.T)

In [None]:
video_level_features = F.adaptive_avg_pool3d(feature_map, (1, 1, 1))
video_level_features = torch.flatten(video_level_features, start_dim=1)

In [None]:
# Plot cosine similarity matrix
# plt.figure(figsize=(10, 10))
# sns.heatmap(frame_wise_cosine_similarity.detach().numpy(), cmap="viridis", annot=True)
# plt.title("Frame-wise cosine similarity")
# plt.show()

In [None]:
frame_wise_logits = []
for feat in frame_wise_features:
    frame_wise_logits.append(torch.sigmoid(classifier(feat).detach()).numpy())

In [None]:
# Plot frame-wise logits as heatmap
frame_wise_logits = np.array(frame_wise_logits)
# plt.figure(figsize=(10, 15), dpi=300)
# sns.heatmap(frame_wise_logits.T, cmap="viridis", annot=True)
# plt.title("Frame-wise logits")
# plt.show()

In [None]:
video_level_logits = torch.sigmoid(classifier(video_level_features)).detach().numpy()

video_level_logits = np.array(video_level_logits)
# plt.figure(figsize=(1, 10))
# sns.heatmap(video_level_logits.T, cmap="viridis", annot=True)
# plt.title("Video-level logits")
# plt.show()

In [None]:
import cv2

In [None]:
# Get original video
video_path = "/home/dl18206/Desktop/phd/data/panaf/panaf_sequence/36070505.mp4"
cap = cv2.VideoCapture(video_path)
frames = []
while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        break
    frames.append(frame)

In [None]:
# Reshape frame[10] to 256x256
frame = frames[10]
frame = cv2.resize(frame, (256, 256))
frame.shape

In [None]:
# View frame
plt.imshow(frame)
plt.axis("off")

In [None]:
# Spatial


def returnCAM(feature_conv, weight_softmax, class_idx):
    # generate the class activation maps upsample to 256x256
    size_upsample = (256, 256)
    bz, nc, h, w = feature_conv.shape
    output_cam = []
    for idx in class_idx:
        cam = weight_softmax[idx].dot(feature_conv.reshape((nc, h * w)))
        cam = cam.reshape(h, w)
        cam = cam - np.min(cam)
        cam_img = cam / np.max(cam)
        cam_img = np.uint8(255 * cam_img)
        output_cam.append(cv2.resize(cam_img, size_upsample))
    return output_cam

In [None]:
spatial_map = feature_map[:, :, 0, :, :]

In [None]:
cams = returnCAM(
    spatial_map.detach(),
    classifier.weight.detach().numpy(),
    torch.linspace(0, 13, steps=14).int(),
)

In [None]:
# render the CAM and output
height, width, _ = frame.shape
heatmap = cv2.applyColorMap(cv2.resize(cams[10], (width, height)), cv2.COLORMAP_JET)
result = heatmap * 0.3 + frame * 0.5
cv2.imwrite("CAM.jpg", result)