In [1]:
import torch
from PIL import Image
import torchvision
import torch.nn as nn
from torchvision import transforms
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os, os.path

In [2]:
# two layer MLP to get the consistency score for each pair of patches
class EVAL(nn.Module):
    def __init__(self, exif_model, exif_dim, middle_dim):
        super(EVAL, self).__init__()
        self.exif_model = exif_model
        self.layer1 = nn.Linear(exif_dim, middle_dim)
        self.layer2 = nn.Linear(middle_dim, 1)

    def forward(self, x1, x2):
        x = self.exif_model(x1, x2)        
        x = self.layer1(x)
        x = F.relu(x)
        x = self.layer2(x)
        
        return x

In [3]:
class EXIF(nn.Module):
    def __init__(self, encoder, n_features, projection_dim):
        super(EXIF, self).__init__()

        self.encoder = encoder
        self.n_features = n_features

        self.projector = nn.Sequential(
            nn.Linear(2 * self.n_features, self.n_features),
            nn.ReLU(),
            nn.Linear(self.n_features, 2 * projection_dim),
            nn.ReLU(),
            nn.Linear(2 * projection_dim, projection_dim),
        )

    def forward(self, x_i, x_j):
        h_i = self.encoder(x_i)
        h_j = self.encoder(x_j)

        c = torch.cat((h_i, h_j), 1)
        z = self.projector(c)
        return torch.sigmoid(z)

In [4]:
# use number of different bit in mask to represent score
def getConsistencyScore(p1, p2):
    return sum(sum(sum(abs(p1-p2))))

In [5]:
class EVALDATASET(Dataset):
    def __init__(self, root_dir, size=20):
        self.root_dir = root_dir
        self.size = size
        self.file_names = [name for name in os.listdir(self.root_dir+'/images')]

    def __len__(self):
        return len(self.file_names)

    def __getitem__(self, idx):
        
        file_name = self.file_names[idx].split('.')[0]

        image = Image.open(os.path.join(self.root_dir+'/images/'+file_name+".jpg"))
        mask = Image.open(os.path.join(self.root_dir+'/masks/'+file_name+".png"))
        
        
        image = torchvision.transforms.ToTensor()(image)
        mask = torchvision.transforms.ToTensor()(mask)
                
        size = self.size
        height = image.shape[1]
        width = image.shape[2]
        
        # random pick two left top corner
        x1 = random.randint(0, height-size)
        y1 = random.randint(0, width-size)
        p1 = image[:3, x1:x1+size, y1:y1+size]
        m1 = mask[:, x1:x1+size, y1:y1+size]
        
        x2 = random.randint(0, height-size)
        y2 = random.randint(0, width-size)
        p2 = image[:3, x2:x2+size, y2:y2+size]
        m2 = mask[:, x2:x2+size, y2:y2+size]
        
        sample = {'p1': p1, 'p2':p2, 'score': getConsistencyScore(m1, m2)}
#         return p1, p2, getConsistencyScore(m1, m2)
        return sample

In [6]:
# compare one patch with all other patches
def evaluatePatch(p1, image, size, eval_model):
    score = 0
    height = image.shape[1]
    width = image.shape[2]
    for i in range(height // size):
        for j in range(width // size):
            x = size * i
            y = size * j
            xs = min(x + size, height)
            ys = min(y + size, width)
            p2 = image[:, x:xs, y:ys]
            
            p1 = p1.view(1, 3, size, size)
            p2 = p2.view(1, 3, size, size)
#             print(p1.shape)
#             print(p2.shape)
            score += eval_model(p1, p2)
            
    return score

In [14]:
# calculate consistency map
def getConsistencyMap(image, size, eval_model):
    consistency_map = []
    height = image.shape[1]
    width = image.shape[2]
    print(height, width)
    for i in range(height // size):
        row = []
        for j in range(width // size):
            x = size * i
            y = size * j
            xs = min(x + size, height)
            ys = min(y + size, width)
            p = image[:, x:xs, y:ys]
            row.append(evaluatePatch(p, image, size, eval_model).item())
            print('add',i,j)
        consistency_map.append(row)
    return consistency_map

In [18]:
# check whether an image has been sliced
def ifSliced(consistency_map, threshold):
    return sum(sum(consistency_map)) > threshold

In [19]:
def collate_fn(batch):
    images_a = []
    images_b = []
    labels = []
    for i in batch:
        images_a.append(i['p1'])
        images_b.append(i['p2'])
        labels.append(i['score'])
    return (torch.stack(images_a), torch.stack(images_b)), torch.stack(labels)

In [22]:
# Training
import random

def train(eval_model, size, loss_function):
    total_loss = 0
    counter = 0
    for ((p1, p2), score)  in train_loader:
        optimizer.zero_grad()
        
        output = eval_model(p1, p2) # evaluation model generates a consistency score of a pair of patch
        score = score.view(len(p1[:, 0, 0]), 1)
#         print(output.shape)
#         print(score.shape)
        
        loss = loss_function(output, score)
        total_loss += loss.item()
        
        loss.backward(retain_graph=True)
        optimizer.step()
        counter += 1
        print('loss', loss.item())

train_dataset = EVALDATASET('dataset/label_in_wild', 20)
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    shuffle=True,
    batch_size=10,
    collate_fn=collate_fn
)
    
encoder = torchvision.models.resnet50(pretrained=False)
n_features = 1000
exif_model = EXIF(encoder, n_features, 8)
exif_dim = 8
middle_dim = 3
size = 20
optimizer = torch.optim.Adam(exif_model.parameters(), lr=1e-4)

eval_model = EVAL(exif_model, exif_dim, middle_dim)
loss_function = nn.MSELoss(reduction='mean')

train(eval_model, size, loss_function)
    

loss 16024.634765625
loss 21322.63671875


RuntimeError: stack expects each tensor to be equal size, but got [3, 20, 20] at entry 0 and [1, 20, 20] at entry 3

In [21]:
# Evaluating
size = 600
threshold = 100
eval_model.eval()
def evaluate(eval_model, image, size, threshold):
    consistency_map = getConsistencyMap(image, size, eval_model)
#     if_sliced = ifSliced(consistency_map, threshold)
    
    return consistency_map
    
image = Image.open(os.path.join('dataset/label_in_wild/images/im12_edit1.jpg'))
image = torchvision.transforms.ToTensor()(image)

print(evaluate(eval_model, image, size, threshold))

900 1600
add 0 0
add 0 1
[[-0.6183071136474609, -0.6184686422348022]]
