# Welcome to DeepFake-Buster

In [None]:
from ipywidgets import *
from fastai.vision.all import *
from face_detection.BlazeFaceDetector import BlazeFaceDetector
import albumentations as A
import cv2
import imageio
from custom import *
import PIL
from matplotlib import cm

In [None]:
class AlbumentationsTransform(DisplayedTransform):
    split_idx,order=0,2
    def __init__(self, train_aug): store_attr()
    
    def encodes(self, img: PILImage):
        aug_img = self.train_aug(image=np.array(img))['image']
        return PILImage.create(aug_img)

def get_item_tfms(size, blur_limit, var_limit, quality_lower, quality_upper, num_holes, hole_size):  
    alb = A.Compose([
            A.MotionBlur(blur_limit=blur_limit, p=0.3),
            A.GaussNoise(var_limit=var_limit, p=0.3),
            A.JpegCompression(quality_lower=quality_lower, quality_upper=quality_upper, p=0.3),
            A.Cutout(num_holes=num_holes, max_h_size=hole_size, max_w_size=hole_size, p=0.3)
])

    return [Resize(size), AlbumentationsTransform(alb)]

In [None]:
device = torch.device("cpu")

In [None]:
detect_learn = load_learner('dw_xrn18_sa_se_mixup_prog_prune_fp16_KD_1.pkl')

In [None]:
detector = BlazeFaceDetector(weights='face_detection/blazeface.pth', anchors='face_detection/anchors.npy', device=device)

In [None]:
def read_frames_sample(vid, n_frames=16):
    vid = imageio.get_reader(io.BytesIO(vid), 'ffmpeg')
    sample = np.linspace(0, int(vid.count_frames()) - 1 , min(n_frames, vid.count_frames())).astype(int)
    list_frames = list(vid.iter_data())
    return [list_frames[s] for s in sample]

In [None]:
def predict_all_frames(frames):
    
    detections = []
    for frame in frames:
        detection = detector.detect(frame)
        if len(detection)>0: detections.append(detection)

    return detections

In [None]:
def get_cropped_faces(frame, detections):
    
    faces = []
    
    for detection in detections:
        xmin = max(0, int(detection[0])) # Don't try to crop less than 0
        ymin = max(0, int(detection[1])) # Don't try to crop less than 0
        xmax = min(frame.shape[1], int(detection[2]))
        ymax = min(frame.shape[0], int(detection[3]))
        
        face = frame[ymin:ymax, xmin:xmax]
        
        faces.append(face)
    
    return faces

In [None]:
def get_preds(faces):
    return np.mean([detect_learn.predict(f)[2][0] for f in faces]) # returns the "realness" of the extracted face

In [None]:
btn_upload = widgets.FileUpload()
btn_run = widgets.Button(description='Get Prediction')
btn_att = widgets.Button(description='Show Attention')
out_pl = widgets.Output()
lbl_pred = widgets.Label()

label = {0:'TRUE', 1:'FAKE'}

In [None]:
def _update_children(change):
    for o in change['owner'].children:
        if not o.layout.flex: o.layout.flex = '0 0 auto'

In [None]:
def carousel(children=(), **layout):
    "A horizontally scrolling carousel"
    def_layout = dict(overflow='scroll hidden', flex_flow='row', display='flex')
    res = Box([], layout=merge(def_layout, layout))
    res.observe(_update_children, names='children')
    res.children = children
    return res

In [None]:
def widget(im, *args, **layout):
    "Convert anything that can be `display`ed by IPython into a widget"
    o = Output(layout=merge(*args, layout))
    with o: display(im)
    return o

In [None]:
def on_click(change):
    global faces
    frames = read_frames_sample(btn_upload.data[-1])
    frames_detections = predict_all_frames(frames)
    faces = [np.array(get_cropped_faces(frames[i], frames_detections[i])).squeeze() for i in range(len(frames_detections))]
    widg = carousel(width='100%')
    ims = [PILImage.create(f).to_thumb(256, 256).convert('RGBA') for f in faces]
    widg.children = [widget(im) for im in ims]
    out_pl.clear_output()
    with out_pl: display(widg)
    

btn_upload.observe(on_click, names=['data'])

In [None]:
class Hook():
    def __init__(self):
        self.stored = []
    def hook_func(self, m, i, o): self.stored.append(o.detach().clone())

def on_click(change):
    global hook_output
    hook_output = Hook()
    hook = detect_learn.model[4][1].convpath[4].conv.register_forward_hook(hook_output.hook_func)
    pred = get_preds(faces)
    pred_label = label[pred>0.5]
    prob = 1-pred if pred<0.5 else pred
    lbl_pred.value = f'I\'m {100*prob:.02f}% sure that this image is {pred_label} !'

btn_run.on_click(on_click)

In [None]:
def norm(hm):
    return (hm-hm.min())/(hm.max()-hm.min())

In [None]:
def on_click(change):

    blends = []
    widg = carousel(width='100%')
    hms = [(1-norm(output[0].mean(0)).view(64,64)) for output in hook_output.stored]
    ims = detect_learn.dls.test_dl(faces)

    imss = np.array(next(iter(ims))[0])


    for hm, im, in zip(hms, imss):
        fg = PIL.Image.fromarray(np.uint8(cm.magma(hm)*255))
        bg = PIL.Image.fromarray(np.uint8(im.transpose(1,2,0)*255))
        fg = fg.resize(bg.size, PIL.Image.BILINEAR)
        bg = bg.convert('RGBA')
        blends.append(PIL.Image.blend(bg, fg, alpha=0.5))


    out_pl.clear_output()
    widg.children = [widget(blend) for blend in blends]
    with out_pl: display(widg)
    

btn_att.on_click(on_click)

In [None]:
display(VBox([widgets.Label('Select your image!'), btn_upload, btn_run, btn_att,lbl_pred, out_pl]))

VBox(children=(Label(value='Select your image!'), FileUpload(value={}, description='Upload'), Button(descripti…

  fg = fg.resize(bg.size, PIL.Image.BILINEAR)
