In [1]:
import os
import numpy as np
import torch
import torchvision.transforms as transforms

from random import randint
from PIL import Image

from src.config import get_parser

In [32]:
class MaskedDataset(torch.utils.data.Dataset):
    def __init__(self, config):
        self.patch_size = config.data_patch_size
        self.img_dir = config.data_image_dir
        self.mask_dir = config.data_mask_dir        
        images =  set([x.replace('.png', '') for x 
                       in os.listdir(self.img_dir) if x.endswith('.png')])        
        
        masks = set([x.replace('.pt', '') for x 
                     in os.listdir(self.mask_dir ) if x.endswith('.pt')])        

        self.entries = list(images & masks)        
        if len(self.entries) < len(images):
            print('Missing masks', images - masks)
            
        self.transform = {    
            "image_normed": transforms.Compose([
                transforms.ToTensor(),
                #transforms.Resize(config.data_image_size),
                transforms.Normalize(config.image_mean, config.image_std),
                transforms.Grayscale(),
            ]),
            "mask": transforms.Lambda(lambda x: 
                torch.nn.functional.interpolate(x.float(), size=config.data_image_size, 
                                                mode='nearest').squeeze(0)),
            "style_img": transforms.Compose([
                transforms.Resize(config.data_style_img),
            ]),
            "img_patch": transforms.Compose([
                transforms.Resize(config.data_image_resized),
                transforms.RandomCrop(config.data_patch_size),
            ]),            
        }        
        blueprint = np.load(os.path.join(config.data_dir, config.blueprint))
        self.points = torch.tensor(blueprint['points'])[0]
        self.normals = torch.tensor(blueprint['normals'])[0]        
        self.wmax = self.points.size(1)
        self.hmax = self.points.size(2)        
        
    def __len__(self):
        return len(self.entries)
    
    def __getitem__(self, idx):
        entry_path = self.entries[idx]
        img = Image.open(os.path.join(self.img_dir, entry_path + '.png'))
        mask =  torch.load(os.path.join(self.mask_dir, entry_path + '.pt'))

        img_normed  = self.transform['image_normed'](img)
        mask_resized = self.transform['mask'](mask)
        res_masked =  img_normed * mask_resized
        style_img = self.transform['style_img'](res_masked)
        img_patch = self.transform['img_patch'](res_masked)
                
        w = randint(0, self.wmax - self.patch_size)
        h = randint(0, self.hmax - self.patch_size)          
        points = self.points[:, w:w + self.patch_size, h:h + self.patch_size]
        normals = self.normals[:, w:w + self.patch_size, h:h + self.patch_size]
        
        return {
            'style_img': style_img,
            'img_patch': img_patch,
            'points': points,
            'normals': normals,
        }


config = get_parser().parse_args(args=[])    
config.data_patch_size = 128
ds = MaskedDataset(config)

In [33]:
ds.wmax, ds.hmax

(256, 256)

In [34]:
d1 = ds[0]
for key in d1:
    print(key, d1[key].shape)

107 82
style_img torch.Size([1, 192, 192])
img_patch torch.Size([1, 128, 128])
points torch.Size([3, 128, 128])
normals torch.Size([3, 128, 128])


In [5]:
blueprint = np.load(os.path.join(config.data_dir, config.blueprint))
points = torch.tensor(blueprint['points'])[0]
normals = torch.tensor(blueprint['normals'])[0]
points.shape, normals.shape

(torch.Size([3, 256, 256]), torch.Size([3, 256, 256]))

In [16]:
patch_size = 56 #config.data_patch_size
_, wmax, hmax = points.shape
w, h = randint(0, wmax - patch_size + 1), randint(0, hmax - patch_size + 1)
(points[:, w:w + patch_size, h:h + patch_size].shape,
 normals[:, w:w + patch_size, h:h + patch_size].shape)

(torch.Size([3, 56, 56]), torch.Size([3, 56, 56]))

In [20]:
config.data_patch_size

256

In [12]:
_, wmax, hmax = points.shape
wmax, hmax

(256, 256)