In [1]:
# import os
# from argparse import ArgumentParser
# from collections import OrderedDict

# import numpy as np
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# import torchvision
# import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from src.data.masked_dataset import MaskedDataset

In [2]:
class MaskedDataModule(pl.LightningDataModule):

    def __init__(self, config):
        super().__init__()
        self.config = config
        self.batch_size = config.batch_size
        self.num_workers = config.num_workers
        
    def prepare_data(self):
        # download
        pass

    def setup(self, stage=None): # Stage {'fit'|'test'}
        self.ds = MaskedDataset(self.config) 
        
    def train_dataloader(self):
        return DataLoader(self.ds, batch_size=self.batch_size, num_workers=self.num_workers)


from src.config import get_parser

config = get_parser().parse_args(args=[])

mdm = MaskedDataModule(config)    

In [3]:
mdm.setup()
loader = mdm.train_dataloader()
loader

<torch.utils.data.dataloader.DataLoader at 0x7f35712e1e80>

In [4]:
batch = next(iter(loader))

In [5]:
batch

{'img_patch': tensor([[[[-0.0000, -0.0000, -0.0000,  ...,  0.3046,  0.3104,  0.4369],
           [-0.0000, -0.0000, -0.0000,  ...,  0.4479,  0.3449,  0.3053],
           [-0.0000, -0.0000, -0.0000,  ...,  0.4897,  0.4144,  0.4353],
           ...,
           [-0.0000, -0.0000, -0.0000,  ...,  0.7351,  0.6768,  0.7250],
           [-0.0000, -0.0000, -0.0000,  ...,  0.5012,  0.4815,  0.6377],
           [-0.0000, -0.0000, -0.0000,  ...,  0.8101,  0.1929,  0.4396]]],
 
 
         [[[ 1.1045,  1.1375,  1.1252,  ...,  0.0000,  0.0000,  0.0000],
           [ 1.2657,  1.1972,  1.1085,  ...,  0.0000,  0.0000,  0.0000],
           [ 0.9900,  1.0610,  1.2057,  ...,  0.0000,  0.0000,  0.0000],
           ...,
           [-0.7105, -0.8034, -0.8417,  ...,  0.0000,  0.0000,  0.0000],
           [-0.7429, -0.8028, -0.7839,  ...,  0.0000,  0.0000,  0.0000],
           [-0.7625, -0.6313, -0.5213,  ...,  0.0000,  0.0000,  0.0000]]],
 
 
         [[[-0.0000, -0.0000, -0.0000,  ...,  0.2183,  0.3641,  0.5

In [6]:
mdm.ds

<src.data.masked_dataset.MaskedDataset at 0x7f57d92bffd0>

In [7]:
mdm.ds.entries

[]

In [8]:
import os
import numpy as np
import torch
import torchvision.transforms as transforms

from random import randint
from PIL import Image

class MaskedDataset(torch.utils.data.Dataset):
    def __init__(self, config):
        self.patch_size = config.data_patch_size
        self.img_dir = config.data_image_dir
        self.mask_dir = config.data_mask_dir        
        images =  set([x.replace('.png', '') for x 
                       in os.listdir(self.img_dir) if x.endswith('.png')])        
        
        masks = set([x.replace('.pt', '') for x 
                     in os.listdir(self.mask_dir ) if x.endswith('.pt')])        

        self.entries = list(images & masks)        
        if len(self.entries) < len(images):
            print('Missing masks', images - masks)
            
        self.transform = {    
            "image_normed": transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
                transforms.Grayscale(),
            ]),
            "mask": transforms.Lambda(lambda x: 
                torch.nn.functional.interpolate(x.float(), size=config.data_image_size, 
                                                mode='nearest').squeeze(0)),    
            "img_patch": transforms.Compose([
                transforms.Resize(config.data_image_resized),
                transforms.RandomCrop(config.data_patch_size),
            ]),
            
        }
        
        
        blueprint =  np.load(config.blueprint)        
        self.points = torch.tensor(blueprint['points'])[0]
        self.normals = torch.tensor(blueprint['normals'])[0]
        
    def __len__(self):
        return len(self.entries)
    
    def __getitem__(self, idx):
        entry_path = self.entries[idx]
        img = Image.open(os.path.join(self.img_dir, entry_path + '.png'))
        mask =  torch.load(os.path.join(self.mask_dir, entry_path + '.pt'))

        img_normed  = self.transform['image_normed'](img)
        mask_resized = self.transform['mask'](mask)
        res_masked =  img_normed * mask_resized
        img_patch = self.transform['img_patch'](res_masked)
        patch_size = self.patch_size
        w, h = randint(0, patch_size), randint(0, patch_size)
        points = self.points[:, w:w + patch_size, h:h + patch_size]
        normals = self.normals[:, w:w + patch_size, h:h + patch_size]
        
        return {
            'img_patch': img_patch,
            'points': points,
            'normals': normals,
        }

ds = MaskedDataset(config)
ds

<__main__.MaskedDataset at 0x7fa9b1db4eb0>

In [13]:
#img_dir
config.data_image_dir, config.data_mask_dir  

('/home/bobi/Desktop/face-parsing.PyTorch/res/masks',
 '/home/bobi/Desktop/db/ffhq-dataset/images1024x1024')

In [9]:
ds[0]

{'img_patch': tensor([[[-0.0000, -0.0000, -0.0000,  ..., -0.6534, -0.6404, -0.6360],
          [-0.0000, -0.0000, -0.0000,  ..., -0.5448, -0.5795, -0.6404],
          [-0.0000, -0.0000, -0.0000,  ..., -0.3927, -0.4752, -0.6013],
          ...,
          [-0.0000, -0.0000, -0.0000,  ..., -0.0325, -0.4019, -0.4684],
          [-0.0000, -0.0000, -0.0000,  ..., -0.0749, -0.2201, -0.3470],
          [-0.0000, -0.0000, -0.0000,  ...,  0.0255, -0.3249, -0.4041]]]),
 'points': tensor([[[-0.3326, -0.3358, -0.3311,  ...,  0.2255,  0.2630,  0.2407],
          [-0.3317, -0.3332, -0.3324,  ...,  0.2808,  0.2667,  0.2872],
          [-0.3334, -0.3343, -0.3352,  ...,  0.2838,  0.2618,  0.2825],
          ...,
          [ 0.2019,  0.2080,  0.2103,  ...,  0.4487,  0.4480,  0.4467],
          [ 0.2097,  0.2160,  0.2089,  ...,  0.4506,  0.4473,  0.4461],
          [ 0.2097,  0.2145,  0.2132,  ...,  0.4491,  0.4481,  0.4455]],
 
         [[-0.3523, -0.3455, -0.3561,  ..., -0.6635, -0.6900, -0.6912],
     