In [1]:
import numpy as np
import os
from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision.transforms import transforms
import torch.nn as nn

In [2]:
ref_dir_path = 'dataset/2afc/train/mix/ref/'
p0_dir_path = 'dataset/2afc/train/mix/p0/'
p1_dir_path = 'dataset/2afc/train/mix/p1/'
judge_dir_path = 'dataset/2afc/train/mix/judge/'

img_size = 64
batch_size = 32

In [3]:
class TwoAFCDataset(Dataset):
    def __init__(self, ref_dir, p0_dir, p1_dir, judge_dir, transform=None):
        
        self.ref_paths = sorted(get_paths(ref_dir, mode='img'))
        self.p0_paths = sorted(get_paths(p0_dir, mode='img'))
        self.p1_paths = sorted(get_paths(p1_dir, mode='img'))
        
        self.transform = transform
        
        self.judge_paths = sorted(get_paths(judge_dir, mode='np'))
        
    def __len__(self):
        return len(self.ref_paths)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        p0_img = Image.open(self.p0_paths[idx]).convert('RGB')
        p1_img = Image.open(self.p1_paths[idx]).convert('RGB')
        ref_img = Image.open(self.ref_paths[idx]).convert('RGB')
        
        if self.transform:
            p0_img = self.transform(p0_img)
            p1_img = self.transform(p1_img)
            ref_img = self.transform(ref_img)
        
        judge_img = np.load(self.judge_paths[idx])
        
        return {'p0': p0_img, 'p1': p1_img, 'ref': ref_img, 'judge': judge_img}

In [4]:
IMG_EXTENSIONS = [
    '.jpg', '.JPG', '.jpeg', '.JPEG',
    '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
]

NP_EXTENSIONS = ['.npy', ]

def get_paths(dir_path, mode='img'):
    paths = []
    for root, _, filenames in sorted(os.walk(dir_path)):
        for filename in filenames:
            if is_right_file(filename, mode=mode):
                path = os.path.join(root, filename)
                paths.append(path)
    return paths

def is_right_file(filename, mode='img'):
    if mode == 'img':
        return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
    else:
        return any(filename.endswith(extension) for extension in NP_EXTENSIONS)

In [5]:
data_transform = transforms.Compose([
    transforms.Resize(img_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

dataset = TwoAFCDataset(ref_dir_path, 
                        p0_dir_path,
                        p1_dir_path,
                        judge_dir_path,
                        transform=data_transform)

dataloader = torch.utils.data.DataLoader(
    dataset,            
    batch_size=batch_size,
    shuffle=True,
    num_workers=10
)

In [None]:
class LPIPS(nn.Module):
    def __init__(self, pretrained=True, net='alex', version='0.1', lpips=True, spatial=False,
                 pnet_rand=False, pnet_tune=False, use_dropout=True, model_path=None, eval_mode=True, verbose=True):
        # lpips - [True] means with linear calibration on top of base network
        # pretrained - [True] means load linear weights

        super(LPIPS, self).__init__()
        if verbose:
            print('Setting up [%s] perceptual loss: trunk [%s], v[%s], spatial [%s]' %
                  ('LPIPS' if lpips else 'baseline', net, version, 'on' if spatial else 'off'))

        self.pnet_type = net
        self.pnet_tune = pnet_tune
        self.pnet_rand = pnet_rand
        self.spatial = spatial
        self.lpips = lpips  # false means baseline of just averaging all layers
        self.version = version
        self.scaling_layer = ScalingLayer()

        if self.pnet_type in ['vgg', 'vgg16']:
            net_type = pn.vgg16
            self.chns = [64, 128, 256, 512, 512]
        elif self.pnet_type == 'alex':
            net_type = pn.alexnet
            self.chns = [64, 192, 384, 256, 256]
        elif self.pnet_type == 'squeeze':
            net_type = pn.squeezenet
            self.chns = [64, 128, 256, 384, 384, 512, 512]
        self.L = len(self.chns)

        self.net = net_type(pretrained=not self.pnet_rand,
                            requires_grad=self.pnet_tune)

        if lpips:
            self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
            self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
            self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
            self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
            self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
            self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
            if self.pnet_type == 'squeeze':  # 7 layers for squeezenet
                self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout)
                self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout)
                self.lins += [self.lin5, self.lin6]
            self.lins = nn.ModuleList(self.lins)

            if pretrained:
                if model_path is None:
                    import inspect
                    import os
                    model_path = os.path.abspath(os.path.join(inspect.getfile(
                        self.__init__), '..', 'weights/v%s/%s.pth' % (version, net)))

                if verbose:
                    print('Loading model from: %s' % model_path)
                self.load_state_dict(torch.load(
                    model_path, map_location='cpu'), strict=False)

        if eval_mode:
            self.eval()

    def forward(self, in0, in1, retPerLayer=False, normalize=False):
        # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1]
        if normalize:
            in0 = 2 * in0 - 1
            in1 = 2 * in1 - 1

        # v0.0 - original release had a bug, where input was not scaled
        in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(
            in1)) if self.version == '0.1' else (in0, in1)
        outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
        feats0, feats1, diffs = {}, {}, {}

        for kk in range(self.L):
            feats0[kk], feats1[kk] = lpips.normalize_tensor(
                outs0[kk]), lpips.normalize_tensor(outs1[kk])
            diffs[kk] = (feats0[kk] - feats1[kk]) ** 2

        if self.lpips:
            if self.spatial:
                res = [upsample(self.lins[kk](diffs[kk]), out_HW=in0.shape[2:])
                       for kk in range(self.L)]
            else:
                res = [spatial_average(self.lins[kk](
                    diffs[kk]), keepdim=True) for kk in range(self.L)]
        else:
            if self.spatial:
                res = [upsample(diffs[kk].sum(dim=1, keepdim=True),
                                out_HW=in0.shape[2:]) for kk in range(self.L)]
            else:
                res = [spatial_average(diffs[kk].sum(
                    dim=1, keepdim=True), keepdim=True) for kk in range(self.L)]

        val = res[0]
        for l in range(1, self.L):
            val += res[l]

        # a = spatial_average(self.lins[kk](diffs[kk]), keepdim=True)
        # b = torch.max(self.lins[kk](feats0[kk]**2))
        # for kk in range(self.L):
        #     a += spatial_average(self.lins[kk](diffs[kk]), keepdim=True)
        #     b = torch.max(b,torch.max(self.lins[kk](feats0[kk]**2)))
        # a = a/self.L
        # from IPython import embed
        # embed()
        # return 10*torch.log10(b/a)

        if retPerLayer:
            return val, res
        else:
            return val

In [6]:
for data in dataloader:
    print(data)
    break

{'p0': tensor([[[[ 0.4353,  0.4431,  0.4510,  ...,  0.5451,  0.5529,  0.5529],
          [ 0.4667,  0.4353,  0.4431,  ...,  0.5059,  0.5059,  0.5137],
          [ 0.4431,  0.4353,  0.4353,  ...,  0.5216,  0.5059,  0.4980],
          ...,
          [-0.1373, -0.1294, -0.1373,  ...,  0.7176,  0.7176,  0.6784],
          [-0.2314, -0.2078, -0.1922,  ...,  0.6627,  0.6863,  0.6706],
          [-0.2941, -0.2471, -0.2392,  ...,  0.6392,  0.6471,  0.6392]],

         [[ 0.4902,  0.4902,  0.5059,  ...,  0.4980,  0.5137,  0.5137],
          [ 0.4824,  0.4902,  0.5137,  ...,  0.5137,  0.5137,  0.5059],
          [ 0.4902,  0.4667,  0.4667,  ...,  0.5216,  0.5216,  0.5059],
          ...,
          [-0.1451, -0.1294, -0.1922,  ...,  0.6627,  0.6627,  0.6314],
          [-0.2471, -0.2235, -0.2471,  ...,  0.6471,  0.6392,  0.6235],
          [-0.3176, -0.3020, -0.2863,  ...,  0.6157,  0.5922,  0.6078]],

         [[ 0.4118,  0.4275,  0.4275,  ...,  0.4196,  0.4431,  0.4510],
          [ 0.4039,  0.