In [1]:
import torch
from PIL import Image
import torchvision
import torch.nn as nn
from torchvision import transforms
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np

In [2]:
# two layer MLP to get the consistency score for each pair of patches
class EVAL(nn.Module):
    def __init__(self, exif_model, exif_dim, middle_dim):
        super(EVAL, self).__init__()
        self.exif_model = exif_model
        self.layer1 = nn.Linear(exif_dim, middle_dim)
        self.layer2 = nn.Linear(middle_dim, 1)

    def forward(self, x1, x2):
        x = self.exif_model(x1, x2)        
        x = self.layer1(x)
        x = F.ReLu(x)
        x = self.layer2(x)
        
        return x

In [3]:
# use number of different bit in mask to represent score
def getConsistencyScore(p1, p2):
    return sum(abs(p1-p2))

In [9]:
# compare one patch with all other patches
def evaluatePatch(p1, image, size, eval_model):
    score = 0
    height = image.shape[0]
    width = image.shape[1]
    for i in range(height // size):
        for j in range(width // size):
            x = size * i
            y = size * j
            xs = min(x + size, height)
            ys = min(y + size, width)
            p2 = image[x:xs, y:ys]
            score += eval_model(p1, p2)
            
    return score

In [7]:
# calculate consistency map
def getConsistencyMap(image, size=20, eval_model):
    consistency_map = []
    height = image.shape[0]
    width = image.shape[1]
    for i in range(height // size):
        row = []
        for j in range(width // size):
            x = size * i
            y = size * j
            xs = min(x + size, height)
            ys = min(y + size, width)
            p = image[x:xs, y:ys]
            row.append(evaluatePatch(p, image, size, eval_model))
        consistency_map.append(row)
    return consistency_map

In [6]:
# check whether an image has been sliced
def ifSliced(consistency_map, threshold):
    return sum(consistency_map) > threshold

In [12]:
# Training
import random

def train(eval_model, size, loss_function):
    loss = 0
    counter = 0
    for image, mask in train_dataloader:
        optimizer.zero_grad()
        height = image.shape[0]
        width = image.shape[1]
        # random pick two left top corner
        x1 = random.randint(0, height-size)
        y1 = random.randint(0, width-size)
        p1 = image[x1:x1+size, y1:y1+size]
        m1 = mask[x1:x1+size, y1:y1+size]
        
        x2 = random.randint(0, height-size)
        y2 = random.randint(0, width-size)
        p2 = image[x2:x2+size, y2:y2+size]
        m2 = mask[x2:x2+size, y2:y2+size]
        
        output = eval_model(p1, p2) # evaluation model generates a consistency score of a pair of patch
        
        loss += loss_function(output, getConsistencyScore(m1, m2))
        
        loss.backward()
        optimizer.step()
        counter += 1
    
        
exif_model = EXIF()
exif_dim = 10
middle_dim = 3
size = 20

eval_model = EVAL(exif_model, exif_dim, middle_dim)
loss_function = nn.MSELoss(reduction='sum')

train(eval_model, size, loss_function)
    

In [13]:
# Evaluating
size = 20
threshold = 100
def evaluate(eval_model, image, size, threshold):
    consistency_map = getConsistencyMap(image, size, eval_model)
    if_sliced = ifSliced(consistency_map, threshold)
    

evaluate(eval_model, image, size, threshold)