# Multiple Instance Learning (MIL) for DeepFake Detections

This project was conceived in a conversation with Chris Farnan at PaigeAI. The idea was to use MIL for detecting DeepFakes using the technique used successfully by PaigeAI to find tumorous slides in thousands of samples available to them. This script was used to prepare data to use with the training script

In [None]:
#importing the required modules

import os
import cv2
import json
from PIL import Image
import matplotlib.pyplot as plt
from skimage.measure import compare_ssim
import imutils

In [None]:
import numpy as np

#setting the seed for reproducible train/test data splitting
np.random.seed(42)

In [None]:
#More than 475GB of data was available with the metadata in 10 json files. 
#We were only able to use a subset as there was not enough space available on the hard drive to store all the data
json_filenames=['../train_sample_videos/metadata.json','../train_sample_videos/metadata_prt2.json']
#json_filenames=['../train_sample_videos/metadata.json','../train_sample_videos/metadata_prt2.json','../train_sample_videos/metadata_prt3.json']
train_dir='../train_sample_videos/'
lstSlides=[]
lstTargets=[]

#extracting the names of all the Fake files with the counterpart Real (i.e. Unfake) file available
for json_filename in json_filenames:
    with open(json_filename) as json_file:
        data=json.load(json_file)
        for item in data:
            Filename=item     
            Label=data[item]['label']


            if Label=='FAKE':
                fake_filename_path=train_dir+Filename
                real_filename=data[item]['original']
                real_filename_path=train_dir+real_filename

                if os.path.exists(real_filename_path):
                    lstSlides.append(Filename)
                    lstTargets.append(1)
            else:
                if Filename not in lstSlides:
                    lstSlides.append(Filename)
                    lstTargets.append(0)

In [None]:
from sklearn.model_selection import train_test_split

#using scikit learn to split the test train data-set with 90% for training and 10% for testing
lstSlides_train, lstSlides_test, lstTargets_train, lstTargets_test = train_test_split(lstSlides, lstTargets, test_size=0.1, random_state=42, stratify=lstTargets)

In [None]:
#Number of files in the train data-set
lstTargets_train.count(1)

In [None]:
#Number of files in test data-set
lstTargets_test.count(1)

In [None]:
#function to find fake area inside two frames when provided with the real and fake image frame
def fn_find_fake_area(img_real, img_fake, nDisplayNum==0):
    if (img_real.shape!=img_fake.shape):
        print("ERROR: Img dimensions do not match...")
        return (None, None)
    mask=np.zeros(img_real.shape)
    (score, diff)=compare_ssim(img_real,img_fake,full=True)
    #print(np.min(diff))
    mask[diff<0.8]=255
    
    #Display some intermediate results
    if nDisplayNum>0:
        plt.figure()
        plt.imshow(img_real, cmap='gray')
        plt.figure()
        plt.imshow(img_fake, cmap='gray')
        plt.figure()
        plt.imshow(diff, cmap='gray')
        plt.figure()
        plt.imshow(mask, cmap='gray')
    
    #Find contours with fake area
    cnts = cv2.findContours(mask.astype("uint8"), cv2.RETR_EXTERNAL,cv2.CHAIN_APPROX_SIMPLE)
    cnts = imutils.grab_contours(cnts)
    if len(cnts)==0:
        return (None, None)
    c = max(cnts, key=cv2.contourArea)
    (x, y, w, h) = cv2.boundingRect(c)
    return (x,y)

In [None]:
#Function to write the PIL image
def write_image(filename,frame):
    #comp_filename=dirname+'/'+filename
    frame.save(filename)

In [None]:
import json
import numpy as np
from PIL import Image
import torch

#Idea is to create a Mosaic out of the entire video with a certain number of frames from real and fake videos
#The data is processed to approximate the requirements of the PaigeAI MIL code and simulate a pathology slide
#I have ensured as many samples of real and fake are included in each image as possible simulating a Tumor slide
Frames_per_row=5*2 #these many frames will be in a row
Max_Frame=50*2 #50 frames will be acquired from real and 50 from fake, the number is low because of the lack of space


mpg_inputdir='../train_sample_videos/'
#mpg_inputdir='../difficult_videos/'
#This function processed the videos to create a Mosaic
def process_video_mosaic_ssim(json_filename_lst,lstFiles,output_dir=None):
    nFileCnt=0
    lstSlides=[]
    lstGrid=[]
    lstTargets=[]
    lstMult=[]
    lstLevel=[]
    bDispFlg=False
    
    if output_dir==None:
        output_dir='MosaicFiles_Strat_RandSeed42_'+str(Max_Frame)+'/'
    
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)
        
    sizeC=int(Frames_per_row*224)
    sizeR=int((Max_Frame/Frames_per_row)*224)
    #print("Size Mosaic:",sizeR,sizeC)
    mosaic_frm=Image.new('RGB',(sizeC,sizeR))
    lstProcFilenames=[]
    
    #processing the files listed in the json files, can be skipped if the filenames are stored in the first instance
    for json_filename in json_filename_lst:
        with open(json_filename) as json_file:
            data=json.load(json_file)
            for item in data:
                Filename=item
                print('===================================================')
                print(Filename,'->',data[item])

                #Processing testing and training separately for only the files selected earlier
                if Filename not in lstFiles:
                    print("Skipping as not in list...")
                    continue
                dirname=Filename[:-4]
                print("Processing:",item,data[item])
                Label=data[item]['label']
                
                strOutFilename=output_dir+dirname+'.png'
                
                #if os.path.exists(strOutFilename):
                #    continue
                #lstSlides.append(Filename)
                
                if Label=='FAKE':
                    fake_filename_path=train_dir+Filename
                    orig_filename=data[item]['original']
                    real_filename_path=train_dir+orig_filename

                    if os.path.exists(real_filename_path):
                        real_mpg_File=cv2.VideoCapture(real_filename_path)
                        fake_mpg_File=cv2.VideoCapture(fake_filename_path)

                        lstIndices=[]
                        nFrameCnt=0
                        nRowCnt=0
                        bFirstFrame=False
                        
                        #Processing each frame in the video file
                        while(fake_mpg_File.isOpened() and real_mpg_File.isOpened()):
                            ret_real, frame_real=real_mpg_File.read()
                            ret_fake, frame_fake=fake_mpg_File.read()
                            frame_real=cv2.cvtColor(frame_real, cv2.COLOR_BGR2RGB)
                            frame_fake=cv2.cvtColor(frame_fake, cv2.COLOR_BGR2RGB)
                            (nR,nC,nD)=frame_real.shape
                            #print("Dimensions:",nR,nC,nD)

                            if ret_real==True and ret_fake==True:
                                frame_real_grey=cv2.cvtColor(frame_real, cv2.COLOR_BGR2GRAY)
                                frame_fake_grey=cv2.cvtColor(frame_fake, cv2.COLOR_BGR2GRAY)

                                #(score, diff)=compare_ssim(frame_real_grey,frame_fake_grey,full=True)
                                #print(score)
                                (ptx_fk,pty_fk)=fn_find_fake_area(frame_real_grey,frame_fake_grey)
                                #(ptx_rl,pty_rl)=fn_find_non_fake_area(frame_real_grey,frame_fake_grey)

                                print("Fake pts:",ptx_fk,pty_fk)
                                #print("Real pts:",ptx_rl,pty_rl)

                                if ptx_fk!=None and pty_fk!=None:
                                    #Placing rectangle for visualization of real and fake areas
                                    #cv2.rectangle(frame_real, (ptx_fk, pty_fk), (ptx_fk + 224, pty_fk + 224), (255, 255, 255), 2)
                                    #cv2.rectangle(frame_fake, (ptx_fk, pty_fk), (ptx_fk + 224, pty_fk + 224), (255, 255, 255), 2)

                                    #cv2.rectangle(frame_real, (ptx_rl, pty_rl), (ptx_rl + 224, pty_rl + 224), (0, 255, 0), 2)
                                    #cv2.rectangle(frame_fake, (ptx_rl, pty_rl), (ptx_rl + 224, pty_rl + 224), (0, 255, 0), 2)
                                    
                                    pty_strt=pty_fk-112
                                    pty_end=pty_fk+112
                                    ptx_strt=ptx_fk-112
                                    ptx_end=ptx_fk+112
                                    
                                    #Resolving the situation when the ROI occurs at the edge with not enough
                                    #pixels to acquire a map of 224X224
                                    if pty_strt<0:
                                        pty_strt=0
                                        pty_end=pty_strt+224
                                    elif pty_end>nR:
                                        pty_strt=nR-224
                                        pty_end=nR

                                    if ptx_strt<0:
                                        ptx_strt=0
                                        ptx_end=ptx_strt+224
                                    elif ptx_end>nC:
                                        ptx_strt=nC-224
                                        ptx_end=nC

                                    print("Fake pts:",ptx_strt,ptx_end,pty_strt,pty_end)
                                    #print("Real pts:",ptx_rl_strt,ptx_rl_end,pty_rl_strt,pty_rl_end)
                                    area_real=frame_real[pty_strt:pty_end,ptx_strt:ptx_end]
                                    area_fake=frame_fake[pty_strt:pty_end,ptx_strt:ptx_end]

                                    im_pil_rl = Image.fromarray(area_real)
                                    im_pil_fk = Image.fromarray(area_fake)
                                    
                                    #Acquiring indexes inside the mosaic to place extracted RoIs in
                                    nXIndex=(nFrameCnt%Frames_per_row)*224
                                    nYIndex=nRowCnt*224
                                    print(nXIndex,nYIndex)
                                    #Placing the real image in the mosaic
                                    mosaic_frm.paste(im_pil_rl,(nXIndex,nYIndex))
                                    lstIndices.append((nXIndex, nYIndex))
                                    nFrameCnt=nFrameCnt+1
                                    #Placing the fake image in the mosaic
                                    nXIndex=(nFrameCnt%Frames_per_row)*224
                                    nYIndex=nRowCnt*224
                                    mosaic_frm.paste(im_pil_fk,(nXIndex,nYIndex))
                                    lstIndices.append((nXIndex, nYIndex))
                                    nFrameCnt=nFrameCnt+1

                                    if nFrameCnt%Frames_per_row==0:
                                        nRowCnt=nRowCnt+1

                                    if bDispFlg:
                                        plt.figure()
                                        plt.imshow(im_pil_rl)

                                    lstTargets.append(0)
                                    lstTargets.append(1)
                                    if nFrameCnt<Max_Frame: #Process only upto a maximum of frames defined
                                        print("****Frame:",nFrameCnt)
                                    else:
                                        print("Writing file:",strOutFilename)
                                        write_image(strOutFilename,mosaic_frm)
                                        lstSlides.append(strOutFilename)
                                        lstGrid.append(lstIndices)
                                        break
                            else:
                                break


                lstMult.append(0)
                lstLevel.append(1)

                nFileCnt=nFileCnt+1
                
                resDict= {
                    "slides": lstSlides,
                    "grid": lstGrid,
                    "targets": lstTargets,
                    "mult": lstMult,
                    "level": lstLevel
                }
                torch.save(resDict,"MIL_data_dict_train_intermed")
                
                
    
    print(lstSlides)
    print(lstGrid)
    print(lstTargets)
    
    #lstSlides_train, lstSlides_test, y_train, y_test = train_test_split(lstSlides, lstGrid, lstMult, lstLevel, lstTargets, test_size=0.1, random_state=42, stratify=lstTargets)
    
    resDict= {
        "slides": lstSlides,
        "grid": lstGrid,
        "targets": lstTargets,
        "mult": lstMult,
        "level": lstLevel
    }  
    
    return resDict

In [None]:
#processing the training data to acquire MIL dictionary
dictMILTrain=process_video_mosaic_ssim(json_filenames, lstSlides_train)

In [None]:
torch.save(dictMILTrain,"MIL_data_dict_train")

In [None]:
#processing the test data to acquire MIL dictionary 
dictMILTest=process_video_mosaic_ssim(json_filenames, lstSlides_test)

In [None]:
torch.save(dictMILTest,"MIL_data_dict_test")