In [None]:
#import gradio as gr
import torch
import torch.nn.functional as F
from facenet_pytorch import MTCNN, InceptionResnetV1
import numpy as np
from PIL import Image
import cv2
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
import warnings
warnings.filterwarnings("ignore")

In [None]:
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'

mtcnn = MTCNN(
    select_largest=False, #select_largest: This parameter is set to False. When set to True, it would select the largest face detected when multiple faces are present in an image
    post_process=False, #post_process: This parameter is set to False. When set to True, it applies a post-processing step to refine the bounding boxes of the detected faces.
    device=DEVICE
).to(DEVICE).eval() #eval() is called on the MTCNN object. This sets the model to evaluation mode,
# which is important when using pre-trained models, as it disables certain operations like dropout and batch normalization during inference.

In [None]:
#pretrained: This parameter is set to "vggface2", which means the model will be initialized with weights pre-trained on the VGGFace2 dataset.
#VGGFace2 is a large-scale face recognition dataset.
model = InceptionResnetV1(
    pretrained="vggface2",
    classify=True,
    num_classes=1, # num_classes parameter to 1 in the provided code implies that the InceptionResnetV1 model is being used for a binary classification task.
    device=DEVICE
)
#Load Model Weights:
checkpoint = torch.load("resnetinceptionv1_epoch_32.pth", map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['model_state_dict'])
model.to(DEVICE)
model.eval()

In [None]:
# the function generates a visualization of the attention region of the model using
# GradCAM (Gradient-weighted Class Activation Mapping) to highlight the important regions in the image that contribute to the model's prediction.
#https://medium.com/@mohamedchetoui/grad-cam-gradient-weighted-class-activation-mapping-ffd72742243a

def predict(input_image:Image.Image):
    """Predict the label of the input_image"""
    face = mtcnn(input_image)
    if face is None:
        raise Exception('No face detected')
    face = face.unsqueeze(0) # add the batch dimension
    face = F.interpolate(face, size=(256, 256), mode='bilinear', align_corners=False) #The detected face is resized to a fixed size of (256, 256) using bilinear interpolation

    # convert the face into a numpy array to be able to plot it
    prev_face = face.squeeze(0).permute(1, 2, 0).cpu().detach().int().numpy()
    prev_face = prev_face.astype('uint8')

    face = face.to(DEVICE)
    face = face.to(torch.float32)
    face = face / 255.0
    face_image_to_plot = face.squeeze(0).permute(1, 2, 0).cpu().detach().int().numpy()

    target_layers=[model.block8.branch1[-1]]
    use_cuda = True if torch.cuda.is_available() else False
    cam = GradCAM(model=model, target_layers=target_layers, use_cuda=use_cuda) #gradient weighted class activation mapping
    targets = [ClassifierOutputTarget(0)]

    grayscale_cam = cam(input_tensor=face, targets=targets, eigen_smooth=True)
    grayscale_cam = grayscale_cam[0, :]
    visualization = show_cam_on_image(face_image_to_plot, grayscale_cam, use_rgb=True)
    face_with_mask = cv2.addWeighted(prev_face, 1, visualization, 0.5, 0)

    with torch.no_grad():
        output = torch.sigmoid(model(face).squeeze(0))
        prediction = "real" if output.item() < 0.5 else "fake"

        real_prediction = 1 - output.item()
        fake_prediction = output.item()

        confidences = {
            'real': real_prediction,
            'fake': fake_prediction
        }
    return confidences, face_with_mask


In [None]:
# interface = gr.Interface(
#     fn=predict,
#     inputs=[
#         gr.inputs.Image(label="Input Image", type="pil")
#     ],
#     outputs=[
#         gr.outputs.Label(label="Class"),
#         gr.outputs.Image(label="Face with Explainability", type="pil")
#     ],
# ).launch()

In [None]:
from PIL import Image

# Assuming you have initialized the model and imported the required modules
# model = initialize_your_model()
# mtcnn = initialize_mtcnn()
# GradCAM = initialize_GradCAM()
# ClassifierOutputTarget = initialize_ClassifierOutputTarget()

def predict(input_image:Image.Image):
    """Predict the label of the input_image"""
    face = mtcnn(input_image)
    if face is None:
        raise Exception('No face detected')
    face = face.unsqueeze(0) # add the batch dimension
    face = F.interpolate(face, size=(256, 256), mode='bilinear', align_corners=False)

    # convert the face into a numpy array to be able to plot it
    prev_face = face.squeeze(0).permute(1, 2, 0).cpu().detach().int().numpy()
    prev_face = prev_face.astype('uint8')

    face = face.to(DEVICE)
    face = face.to(torch.float32)
    face = face / 255.0
    face_image_to_plot = face.squeeze(0).permute(1, 2, 0).cpu().detach().int().numpy()

    target_layers=[model.block8.branch1[-1]]
    use_cuda = True if torch.cuda.is_available() else False
    cam = GradCAM(model=model, target_layers=target_layers, use_cuda=use_cuda)
    targets = [ClassifierOutputTarget(0)]

    grayscale_cam = cam(input_tensor=face, targets=targets, eigen_smooth=True)
    grayscale_cam = grayscale_cam[0, :]
    visualization = show_cam_on_image(face_image_to_plot, grayscale_cam, use_rgb=True)
    face_with_mask = cv2.addWeighted(prev_face, 1, visualization, 0.5, 0)

    with torch.no_grad():
        output = torch.sigmoid(model(face).squeeze(0))
        prediction = "real" if output.item() < 0.5 else "fake"

        real_prediction = 1 - output.item()
        fake_prediction = output.item()

        confidences = {
            'real': real_prediction,
            'fake': fake_prediction
        }
    return confidences, face_with_mask


def classify_image(image_path):
    input_image = Image.open(image_path)

    try:
        confidences, _ = predict(input_image)

        # Determine the final prediction based on the confidence scores
        if confidences['real'] > confidences['fake']:
            prediction = "real"
            confidence_percentage = confidences['real'] * 100
        else:
            prediction = "fake"
            confidence_percentage = confidences['fake'] * 100

        print("Prediction:", prediction)
        print("Confidence:", confidence_percentage)

    except Exception as e:
        print("Error:", str(e))
        return None

if __name__ == "__main__":
    image_path = r"C:\Users\kaush\Downloads\WhatsApp Image 2023-01-07 at 9.51.50 PM.jpeg"  # image path
    classify_image(image_path)
    #https://www.nytimes.com/2023/01/22/business/media/deepfake-regulation-difficulty.html