In [1]:
import os
from collections import namedtuple

import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from torch.utils.data import Dataset
from torchvision import models
from torchvision.transforms import transforms

In [2]:
ref_dir_path = 'dataset/2afc/train/mix/ref/'
p0_dir_path = 'dataset/2afc/train/mix/p0/'
p1_dir_path = 'dataset/2afc/train/mix/p1/'
judge_dir_path = 'dataset/2afc/train/mix/judge/'

img_size = 64
batch_size = 32

In [3]:
class TwoAFCDataset(Dataset):
    def __init__(self, ref_dir, p0_dir, p1_dir, judge_dir, transform=None):

        self.ref_paths = sorted(get_paths(ref_dir, mode='img'))
        self.p0_paths = sorted(get_paths(p0_dir, mode='img'))
        self.p1_paths = sorted(get_paths(p1_dir, mode='img'))

        self.transform = transform

        self.judge_paths = sorted(get_paths(judge_dir, mode='np'))

    def __len__(self):
        return len(self.ref_paths)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        p0_img = Image.open(self.p0_paths[idx]).convert('RGB')
        p1_img = Image.open(self.p1_paths[idx]).convert('RGB')
        ref_img = Image.open(self.ref_paths[idx]).convert('RGB')

        if self.transform:
            p0_img = self.transform(p0_img)
            p1_img = self.transform(p1_img)
            ref_img = self.transform(ref_img)

        judge_img = np.load(self.judge_paths[idx])

        return {'p0': p0_img, 'p1': p1_img, 'ref': ref_img, 'judge': judge_img}

In [4]:
IMG_EXTENSIONS = [
    '.jpg', '.JPG', '.jpeg', '.JPEG',
    '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
]

NP_EXTENSIONS = ['.npy', ]


def get_paths(dir_path, mode='img'):
    paths = []
    for root, _, filenames in sorted(os.walk(dir_path)):
        for filename in filenames:
            if is_right_file(filename, mode=mode):
                path = os.path.join(root, filename)
                paths.append(path)
    return paths


def is_right_file(filename, mode='img'):
    if mode == 'img':
        return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
    else:
        return any(filename.endswith(extension) for extension in NP_EXTENSIONS)

In [5]:
data_transform = transforms.Compose([
    transforms.Resize(img_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

dataset = TwoAFCDataset(ref_dir_path,
                        p0_dir_path,
                        p1_dir_path,
                        judge_dir_path,
                        transform=data_transform)

dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=10
)

In [None]:
class VGG(nn.Module):
    def __init__(self, requires_grad=False, pretrained=True):
        super(VGG, self).__init__()
        vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
        self.slice1 = torch.nn.Sequential()
        self.slice2 = torch.nn.Sequential()
        self.slice3 = torch.nn.Sequential()
        self.slice4 = torch.nn.Sequential()
        self.slice5 = torch.nn.Sequential()
        self.N_slices = 5
        for x in range(4):
            self.slice1.add_module(str(x), vgg_pretrained_features[x])
        for x in range(4, 9):
            self.slice2.add_module(str(x), vgg_pretrained_features[x])
        for x in range(9, 16):
            self.slice3.add_module(str(x), vgg_pretrained_features[x])
        for x in range(16, 23):
            self.slice4.add_module(str(x), vgg_pretrained_features[x])
        for x in range(23, 30):
            self.slice5.add_module(str(x), vgg_pretrained_features[x])
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, x):
        h = self.slice1(x)
        h_relu1_2 = h
        h = self.slice2(h)
        h_relu2_2 = h
        h = self.slice3(h)
        h_relu3_3 = h
        h = self.slice4(h)
        h_relu4_3 = h
        h = self.slice5(h)
        h_relu5_3 = h
        # noinspection PyTypeChecker
        vgg_outputs = namedtuple('VggOutputs', ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
        out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)

        return out

In [None]:
class LinearLayer(nn.Module):
    """ A single linear layer which does a 1x1 conv """

    def __init__(self, chn_in, chn_out=1, use_dropout=False):
        super(LinearLayer, self).__init__()

        layers = [nn.Dropout(), ] if use_dropout else []
        layers += [nn.Conv2d(chn_in, chn_out, kernel_size=(1, 1), stride=(1, 1), padding=0, bias=False), ]
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

In [None]:
def normalize_tensor(in_feat, eps=1e-10):
    norm_factor = torch.sqrt(torch.sum(in_feat ** 2, dim=1, keepdim=True))
    return in_feat / (norm_factor + eps)

In [None]:
class LPIPS(nn.Module):
    def __init__(self, pretrained=True, net='alex', spatial=False, pnet_rand=False, pnet_tune=False, use_dropout=True,
                 model_path=None, eval_mode=True, verbose=True):
        # lpips - [True] means with linear calibration on top of base network
        # pretrained - [True] means load linear weights

        super(LPIPS, self).__init__()

        self.pnet_type = net
        self.pnet_tune = pnet_tune
        self.pnet_rand = pnet_rand
        self.spatial = spatial

        self.chns = [64, 128, 256, 512, 512]

        self.L = len(self.chns)

        self.net = VGG(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune)

        self.lin0 = LinearLayer(self.chns[0], use_dropout=use_dropout)
        self.lin1 = LinearLayer(self.chns[1], use_dropout=use_dropout)
        self.lin2 = LinearLayer(self.chns[2], use_dropout=use_dropout)
        self.lin3 = LinearLayer(self.chns[3], use_dropout=use_dropout)
        self.lin4 = LinearLayer(self.chns[4], use_dropout=use_dropout)
        self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]

        if pretrained:
            if model_path is None:
                import inspect
                import os
                model_path = os.path.abspath(
                    os.path.join(inspect.getfile(self.__init__), '..', 'weights/v%s/%s.pth' % (version, net)))

            if verbose:
                print('Loading model from: %s' % model_path)
            self.load_state_dict(torch.load(model_path, map_location='cpu'), strict=False)

        if eval_mode:
            self.eval()

    def forward(self, in0, in1, retPerLayer=False, normalize=False):
        # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1]
        if normalize:
            in0 = 2 * in0 - 1
            in1 = 2 * in1 - 1

        in0_input, in1_input = (in0, in1)
        outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
        feats0, feats1, diffs = {}, {}, {}

        for kk in range(self.L):
            feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
            diffs[kk] = (feats0[kk] - feats1[kk]) ** 2

        if self.spatial:
            res = [upsample(self.lins[kk](diffs[kk]), out_HW=in0.shape[2:]) for kk in range(self.L)]
        else:
            res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)]

        val = res[0]
        for l in range(1, self.L):
            val += res[l]

        if retPerLayer:
            return val, res
        else:
            return val