In [236]:
from __future__ import print_function

import scipy as sp
import scipy.misc
import os
import ast
import skimage
import imageio
from itertools import islice
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
from math import log10
import easydict

from IPython.display import clear_output
from IPython.core.debugger import set_trace

from torchvision.transforms import Resize, ToTensor, Normalize
import re
import torch
from torch import nn 
from torch.nn import functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models
import torch.backends.cudnn as cudnn
import torchvision
import torch.nn.functional as F
from skimage import segmentation

from tqdm import tqdm_notebook

from shutil import rmtree
import json
from IPython.display import clear_output
from torch import autograd

from skimage import io
from skimage.feature import canny
from skimage.morphology import dilation, disk
from skimage.color import rgb2gray

from skimage.filters import threshold_otsu, gaussian
from skimage.segmentation import clear_border
from skimage.measure import label, regionprops
from utils import vis_batch, collate_fn, tensor2numpy
from encoder_decoder import MaskDecoder, MaskEncoder

In [6]:
class FashionEdgesDataset(Dataset):
    def __init__(self, 
                 images_fold, 
                 attr_file=None, 
                 check_corrupted=False, 
                 return_mask=True):
        
        self.corrupted_images = set()
        self.check_corrupted = check_corrupted
        self.images_fold = images_fold
        
        if attr_file is not None:
            images2attr_dict = {}
            with open(attr_file, 'r') as f:
                dicts = f.readlines()
            for d in tqdm_notebook(dicts):
                dct = ast.literal_eval(d)
                image_name = dct['image']
                
                if check_corrupted:
                    img = Image.open(os.path.join(self.images_fold, image_name))
                    if not self._is_appropriate(np.array(img)):
                        self.corrupted_images.add(image_name)
                        continue

                images2attr_dict[image_name] = dct['attributes']
                
            self.images2attr = images2attr_dict
            self.images_names = list(images2attr_dict.keys())
        else:
            self.images_names = os.listdir(self.images_fold)

        
    def _is_appropriate(self, img, thresh = 10):
    
        return np.all(img[:thresh,:thresh] == 255)
    
    def __getitem__(self, idx):
        
        image_name = self.images_names[idx]
        img = Image.open(os.path.join(self.images_fold, image_name))#.convert('RGB')        
        if not self.check_corrupted and not self._is_appropriate(np.array(img)):
            return None
        
        img = Resize((128, 128))(img)
        
        img = ToTensor()(img)
        if img.shape[0] > 3:
            img = img[:3]
        
        edges = self._image2edges(tensor2numpy(img))
        
        return torch.tensor(edges, dtype=torch.float32).unsqueeze(0), img

    def __len__(self):
        return len(self.images_names)

In [42]:
def image2edges(image, low_thresh=0.05, high_thresh=0.3, sigma=0.1, selem=True, d = 1.5):
        '''
        image - np.array
        '''
        image_gray_rescaled = rgb2gray(image)
        edges = canny(image_gray_rescaled, sigma = sigma, low_threshold=low_thresh, high_threshold=high_thresh)
        if selem:
            selem = disk(d)
            edges = dilation(edges, selem)
  
        return edges

In [41]:

mode = models.vgg19(pretrained=True)

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /home/ibulygin/.torch/models/vgg19-dcbb9e9d.pth
100%|██████████| 574673361/574673361 [00:12<00:00, 45817192.87it/s]


In [58]:
names = os.listdir('./cp-vton/')
# edges = []
# images = []
# for name in tqdm_notebook(names):
#     img =plt.imread('./cp-vton/'+name)
#     images.append(ToTensor()(img))
#     edge = image2edges(img)
#     edges.append(torch.tensor(edge.astype(float), dtype=torch.float32).unsqueeze(0))
    

HBox(children=(IntProgress(value=0, max=14221), HTML(value='')))




In [165]:
shapes = []
for image in tqdm_notebook(edges):
    image = image[0].numpy()
    h,w = image.shape
    shape = segmentation.flood(gaussian(image,0.2),
                                  seed_point=(0, 0))
    
    shapes.append(shape)
    
#     fig, axes = plt.subplots(ncols=2, nrows=1)
#     axes[0].imshow(image)
#     axes[0].set_title('Edges')
#     axes[1].imshow(shape)
#     axes[1].set_title('After Gaussian filter and Flood')
    
#     axes[0].set_xticks([])
#     axes[0].set_yticks([])
#     axes[1].set_xticks([])
#     axes[1].set_yticks([])
    
    
#     plt.show()

HBox(children=(IntProgress(value=0, max=14221), HTML(value='')))


