In [None]:
import os
import torch
import torch.nn as nn
import numpy as np
from torch.nn import Parameter as P
import pretorched
from matplotlib import pyplot as plt
from IPython.display import Video
from pretorched.visualizers import grad_cam, visualize_samples

import models
from data import VideoFolder

# class Normalize(nn.Module):
#     def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225],
#                  shape=(1, -1, 1, 1, 1), rescale=True):
#         super().__init__()
#         self.shape = shape
#         self.mean = P(torch.tensor(mean).view(shape),
#                       requires_grad=False)
#         self.std = P(torch.tensor(std).view(shape),
#                      requires_grad=False)
#         self.rescale = rescale

#     def forward(self, x, rescale=None):
#         rescale = self.rescale if rescale is None else rescale
#         x.div_(255.) if rescale else None
#         return (x - self.mean) / self.std

device = 'cuda' if torch.cuda.is_available() else 'cpu'
WEIGHT_DIR = 'weights'
video_dir = 'DeepfakeDetection/test_videos'
cam_dir = 'cam_videos'
os.makedirs(cam_dir, exist_ok=True)
checkpoint_file = 'resnet18_dfdc_seg_count-24_init-imagenet-ortho_optim-Ranger_lr-0.001_sched-CosineAnnealingLR_bs-64_best.pth.tar'
fakenet = pretorched.resnet18(num_classes=2, pretrained=None)
fakenet.load_state_dict({k.replace('module.model.', ''): v
                         for k, v in torch.load(os.path.join(WEIGHT_DIR, checkpoint_file))['state_dict'].items()})
fakenet.eval()
facenet = models.FaceModel(size=fakenet.input_size[-1],
                               device=device,
                               margin=100,
                               min_face_size=50,
                               keep_all=True,
                               post_process=False,
                               select_largest=False,
                               chunk_size=150)
fakenet = fakenet.to(device)
gcam = grad_cam.GradCAM(model=fakenet)
norm = models.Normalize().to(device)

In [None]:
dataset = VideoFolder(video_dir, step=3)


In [None]:
for i in range(len(dataset)):
    filename, video, label = dataset[i]
    video = video.unsqueeze(0)
    video = video.to(device)
    faces = facenet(video)
    norm_faces = norm(faces)
#     input_faces = norm_faces[0].transpose(0, 1)[0:1]
    input_faces = norm_faces[0].transpose(0, 1)
    print(input_faces.shape)
    gcam.model.zero_grad()
    probs, idx = gcam.forward(input_faces)
    print(filename, label)
#     print(gcam.preds.tolist(), torch.argmax(gcam.preds).item())
    print(probs, idx)
    
    gcam.backward(idx=idx[0].unsqueeze(0).repeat(input_faces.size(0), 1))
    output = gcam.generate(target_layer='layer4')
    print(output.shape)
    print(faces.shape)
    raw_image = faces[0].permute(1, 2, 3, 0).detach().cpu().numpy() * 255
    print(raw_image.shape)
    output_video = [grad_cam.apply_heatmap(o, r).astype(np.uint8) for o, r in zip(output, raw_image)]
#     output_image = grad_cam.apply_heatmap(output, raw_image).astype(np.uint8)
    output_camfile = os.path.join(cam_dir, 'cam_' + filename)
    pretorched.data.utils.array_to_video(output_video, output_camfile)
#     for j in range(1, 10):
#         plt.imshow(output_video[j])
#         plt.show()
#         plt.imshow(output[j])
#         plt.show()
#     plt.imshow(output_video[0])
#     plt.show()
#     plt.imshow(output_video[1])
#     plt.show()
#     plt.imshow(output_video[2])
#     pl

In [None]:
from ipywidgets import interact
import ipywidgets as widgets

num_frames = len(output_video)
print(num_frames)

In [None]:
@interact(frame=(0, num_frames))
def show_frame(frame=1):
    plt.imshow(output_video[frame])
    plt.minorticks_off()
    plt.tight_layout()

In [None]:
pretorched.data.utils.array_to_video(output_video, 'test_video.mp4')

In [None]:
from IPython.display import Video

In [None]:
Video('test_video.mp4')

# filename, video, label = dataset[1]
video = video.unsqueeze(0)

In [None]:
video = video.to(device)
faces = facenet(video)
norm_faces = norm(faces)
print(video.shape)
print(faces.shape)
print(norm_faces.shape)
input_faces = norm_faces[0].transpose(0, 1)[0:1]
print(input_faces.shape)

In [None]:
probs, idx = gcam.forward(input_faces)

print(f'probs: {probs.shape}')
print(f'ids: {idx.shape}')

gcam.backward(idx=idx[0])
output = gcam.generate(target_layer='layer4')

In [None]:
import cv2
raw_image = faces[0, :, 0].permute(1, 2, 0).detach().cpu().numpy() * 255
h, w, _ = raw_image.shape
output = cv2.resize(output, (w, h))
output = cv2.applyColorMap(np.uint8(output * 255.0), cv2.COLORMAP_JET)
output = 0.2 * output.astype(np.float) + 0.8 * raw_image.astype(np.float)
output_image = output / output.max() * 255.0
# output_image = grad_cam.apply_heatmap(output, raw_image)

In [None]:
print(output_image.shape)


In [None]:
plt.imshow((raw_image).astype(np.uint8))
plt.imshow((output_image).astype(np.uint8))