# check gpu resources

In [None]:
# gpu resources
! nvidia-smi

# testing ManTraNet on different datasets

In [None]:
import os
import numpy as np
import pandas as pd
import cv2
import requests
import sys
import image_slicer

from PIL import Image
from io import BytesIO
from matplotlib import pyplot

In [None]:
# ManTraNet paths
manTraNet_root = './'
manTraNet_srcDir = os.path.join( manTraNet_root, 'src' )
sys.path.insert( 0, manTraNet_srcDir )
manTraNet_modelDir = os.path.join( manTraNet_root, 'pretrained_weights' )

# setup dataset paths

In [None]:
# sample ManTraNet datasets included in repo
manTraNet_dataDir = os.path.join( manTraNet_root, 'data' )
sample_file = os.path.join( manTraNet_dataDir, 'samplePairs.csv' )
print(sample_file)
assert os.path.isfile( sample_file ), "ERROR: can NOT find sample data, check `manTraNet_root`"
with open( sample_file ) as IN :
    sample_pairs = [line.strip().split(',') for line in IN.readlines() ]
L = len(sample_pairs)
print("INFO: in total, load", L, "samples")
    
def get_a_random_pair() :
    idx = np.random.randint(0,L)
    return ( os.path.join( manTraNet_dataDir, this ) for this in sample_pairs[idx] ) 

In [None]:
# CG-1050 dataset
mfc_data = os.path.join(manTraNet_root, 'openmfc_data')
cg_1050 = os.path.join(mfc_data, 'CG_1050')
cg_1050_description = os.path.join(cg_1050, 'DESCRIPTION')

cg_1050_mask = os.path.join(cg_1050, 'MASK')
cg_1050_original = os.path.join(cg_1050, 'ORIGINAL')
cg_1050_tampered = os.path.join(cg_1050, 'TAMPERED')

# load dataset description

In [None]:
description_path = os.path.join(cg_1050_description, 'Dataset_description_v2.csv')

df = pd.read_csv(description_path)

In [None]:
df.head()

In [None]:
df.loc[df['PHOTO NAME (Original)'] == 'Im_12']

# build dataset list

In [None]:
df["tamper_path"] = df["FOLDER NAME"] + '/' + df["PHOTO NAME (Tampered)"]
df["mask_path"] = df["FOLDER NAME.1"] + '/' + df["PHOTO NAME (Mask)"]

In [None]:
class Sample:
    def __init__(self, name, tampered, masks, methods):
        self.name = name
        self.tampered = tampered
        self.masks = masks
        self.methods = methods

In [None]:
dataset = []
unique_images = list(df['PHOTO NAME (Original)'].unique())

for image in unique_images:
    tampered = list(df.loc[df['PHOTO NAME (Original)'] == image]['tamper_path'])
    masks = list(df.loc[df['PHOTO NAME (Original)'] == image]['mask_path'])
    methods = list(df.loc[df['PHOTO NAME (Original)'] == image]['TAMPERING TYPE'])
    sample = Sample(name=image,
                   tampered=tampered,
                   masks=masks,
                   methods=methods)
    dataset.append(sample)

# dataset should have 100 original images (there are more tampered and masks)
assert len(dataset) == 100

In [None]:
# for reasons unknown, Masks for images 1 to 15 are organized differently
# the file names do not match the dataset description file

dataset = list(filter(lambda x: int(x.name.split('_')[1]) > 15, dataset))

In [None]:
print(dataset[0].name)
print(dataset[0].tampered)
print(dataset[0].masks)
print(dataset[0].methods)

In [None]:
print(len(dataset[0].tampered))
print(len(dataset[0].masks))
print(len(dataset[0].methods))

In [None]:
def get_a_random_sample():
    idx = np.random.randint(0,len(dataset))
    sample = dataset[idx]
    original_path = os.path.join(cg_1050_original, sample.name + '.jpg')
    
    idx_mt = np.random.randint(0,len(sample.tampered)) # each tampered image has a mask, so lengths are the same
    mask_path = os.path.join(cg_1050_mask, sample.masks[idx_mt])
    tampered_path = os.path.join(cg_1050_tampered, sample.tampered[idx_mt])
    method = sample.methods[idx_mt]
    
    return (original_path, mask_path, tampered_path, method)

# Load A Pretrained ManTraNet Model

In [None]:
import modelCore # try running nvidia-smi at the top after the model loads to see model memory requirements
manTraNet = modelCore.load_pretrain_model_by_index( 4, manTraNet_modelDir )

In [None]:
# ManTraNet Architecture 
print(manTraNet.summary(line_length=120))

In [None]:
# Image Manipulation Classification Network
IMCFeatex = manTraNet.get_layer('Featex')
print(IMCFeatex.summary(line_length=120))

# test samples from cg-1050

In [None]:
from datetime import datetime 
def read_rgb_image( image_file ) :
    rgb = cv2.imread( image_file, 1 )[...,::-1]
    return rgb
    
def decode_an_image_array( rgb, manTraNet ) :
    x = np.expand_dims( rgb.astype('float32')/255.*2-1, axis=0 )
    t0 = datetime.now()
    y = manTraNet.predict(x)[0,...,0]
    t1 = datetime.now()
    return y, t1-t0

def decode_an_image_file( image_file, manTraNet ) :
    rgb = read_rgb_image( image_file )
    mask, ptime = decode_an_image_array( rgb, manTraNet )
    return rgb, mask, ptime.total_seconds()

In [None]:
def slice_and_decode(filename, manTraNet):
    tiles = image_slicer.slice(filename, number_tiles=8)
    mask_tiles = []
    total_ptime = 0
    for tile in tiles:
        rgb, mask, ptime = decode_an_image_file(tile.filename, manTraNet)
        mask_image = Image.fromarray(np.uint8(mask * 255))
        mask_tile = image_slicer.Tile(image=mask_image, number=tile.number, position=tile.position, coords=tile.coords)
        mask_tiles.append(mask_tile)
        total_ptime += ptime
    mask_tiles = tuple(mask_tiles)
    res = image_slicer.join(mask_tiles)
    return res, total_ptime

In [None]:
# modified for CG-1050 dataset

for k in range(5):
    # get a sample
    orig_file, mask_file, tampered_file, tamper_method = get_a_random_sample()
    
    # load the original image for reference
    ori = read_rgb_image( orig_file )
    
    # load tampered image for reference
    rgb = read_rgb_image( tampered_file )
    
    # manipulation detection using ManTraNet
    # rgb, mask, ptime = decode_an_image_file( tmpr_resized_filename, manTraNet )
    
    # manipulation detection using slice method
    mask, ptime = slice_and_decode( tampered_file, manTraNet )
    
    # show results
    pyplot.figure( figsize=(25,25) )
    
    pyplot.subplot(2, 2, 1)
    pyplot.imshow( ori )
    pyplot.title('Original Image')
    
    pyplot.subplot(2, 2, 2)
    pyplot.imshow( rgb )
    pyplot.title('Forged Image (ManTra-Net Input)')
    
    pyplot.subplot(2, 2, 3)
    pyplot.imshow( mask, cmap='gray' )
    pyplot.title('Predicted Mask (ManTra-Net Output)')

    pyplot.subplot(2, 2, 4)
    actual_mask = pyplot.imread(mask_file)
    pyplot.imshow(actual_mask, cmap='gray')
    pyplot.title('Actual Mask (CG-1050)')
    
    pyplot.suptitle('Decoded {} of size {} for {:.2f} seconds'.format( os.path.basename( tampered_file ), rgb.shape, ptime ) )
    
    pyplot.show()
    
    break # remove break to test on more images