# check gpu resources

In [None]:
# gpu resources
! nvidia-smi

# running ManTraNet on OpenMFC dataset

In [None]:
import os
import numpy as np
import pandas as pd
import cv2
import requests
import sys

from PIL import Image
from io import BytesIO
from matplotlib import pyplot

In [None]:
# ManTraNet paths
manTraNet_root = './'
manTraNet_srcDir = os.path.join( manTraNet_root, 'src' )
sys.path.insert( 0, manTraNet_srcDir )
manTraNet_modelDir = os.path.join( manTraNet_root, 'pretrained_weights' )

# setup dataset paths

In [None]:
# OpenMFC dataset
mfc_data = os.path.join(manTraNet_root, 'openmfc_data')
openmfc_2020 = os.path.join(mfc_data, 'OpenMFC2020')

openmfc_2020_p1 = os.path.join(openmfc_2020, 'OpenMFC20_Image_Ver1-part001of27')
openmfc_2020_p1_indexes = os.path.join(openmfc_2020_p1, 'indexes')
openmfc_2020_p1_probe = os.path.join(openmfc_2020_p1, 'probe')

# load dataset description

In [None]:
imdl_index_path = os.path.join(openmfc_2020_p1_indexes, 'OpenMFC20_Image-IMDL-index.csv')

In [None]:
df = pd.read_csv(imdl_index_path, sep='|')

In [None]:
df.head()

In [None]:
df['TaskID'].unique()

# build dataset list

In [None]:
dataset = os.listdir(openmfc_2020_p1_probe)

# Load A Pretrained ManTraNet Model

In [None]:
import modelCore
# 4 is default model
manTraNet = modelCore.load_pretrain_model_by_index( 4, manTraNet_modelDir )

In [None]:
# ManTraNet Architecture 
print(manTraNet.summary(line_length=120))

In [None]:
# Image Manipulation Classification Network
IMCFeatex = manTraNet.get_layer('Featex')
print(IMCFeatex.summary(line_length=120))

# test samples from OpenMFC

In [None]:
from datetime import datetime 
def read_rgb_image( image_file ) :
    rgb = cv2.imread( image_file, 1 )[...,::-1]
    return rgb
    
def decode_an_image_array( rgb, manTraNet ) :
    x = np.expand_dims( rgb.astype('float32')/255.*2-1, axis=0 )
    t0 = datetime.now()
    # y = manTraNet.predict(x)[0,...,0]
    t1 = datetime.now()
    return y, t1-t0

def decode_an_image_file( image_file, manTraNet ) :
    rgb = read_rgb_image( image_file )
    mask, ptime = decode_an_image_array( rgb, manTraNet )
    return rgb, mask, ptime.total_seconds()

In [None]:
def slice_and_decode(filename, manTraNet):
    tiles = image_slicer.slice(filename, number_tiles=8)
    mask_tiles = []
    total_ptime = 0
    for tile in tiles:
        rgb, mask, ptime = decode_an_image_file(tile.filename, manTraNet)
        mask_image = Image.fromarray(np.uint8(mask * 255))
        mask_tile = image_slicer.Tile(image=mask_image, number=tile.number, position=tile.position, coords=tile.coords)
        mask_tiles.append(mask_tile)
        total_ptime += ptime
    mask_tiles = tuple(mask_tiles)
    res = image_slicer.join(mask_tiles)
    return res, total_ptime

In [None]:
# modified for OpenMFC
count = 0
for sample in dataset:
    
    count += 1
    
    sample_probe = openmfc_2020_p1_probe + '/' + sample
    
    # load the original image for reference
    # ori = read_rgb_image( orig_resized_filename )
    
    # manipulation detection using ManTraNet
    # rgb, mask, ptime = decode_an_image_file( sample_path, manTraNet )
    
    # manipulation detection using slicing
    mask, ptime = slice_and_decode( sample_probe, manTraNet )
    
    # show results
    pyplot.figure( figsize=(25,25) )
    
    pyplot.subplot(2, 2, 1)
    sample_image = pyplot.imread(sample_probe)
    pyplot.imshow( sample_image )
    pyplot.title('Forged Image (ManTra-Net Input)')
    
    pyplot.subplot(2, 2, 2)
    pyplot.imshow( mask, cmap='gray' )
    pyplot.title('Predicted Mask (ManTra-Net Output w/ Slicing)')

#     pyplot.subplot(2, 2, 3)
#     actual_mask = pyplot.imread(mask_file)
#     pyplot.imshow(actual_mask, cmap='gray')
#     pyplot.title('Actual Mask (CG-1050)')
    
#     pyplot.subplot(2, 2, 4)
#     actual_mask = pyplot.imread(mask_file)
#     pyplot.imshow(actual_mask, cmap='gray')
#     pyplot.title('Actual Mask (CG-1050)')
    
    pyplot.suptitle('Decoded {} of size {} for {:.2f} seconds'.format( os.path.basename( sample_probe ), 'rgb.shape', ptime ) )
    
    pyplot.show()
    
    if count >= 10:
        break # remove break to test on more images