In [1]:
import os

import sys
'''PATH to lpips library for perceptual loss'''
new_path = '/home/cc/stylegan_K1/'
sys.path.append(new_path)

import glob
import argparse
import threading
import six.moves.queue as Queue # pylint: disable=import-error
import traceback
import numpy as np
import tensorflow as tf
import PIL.Image
import dnnlib.tflib as tflib

from training import dataset
import pandas as pd
import shutil as sh
import matplotlib.pyplot as plt

from skimage.transform import resize
from skimage import io

  data = yaml.load(f.read()) or {}


In [2]:
class TFRecordExporter:
    def __init__(self, tfrecord_dir, expected_images, print_progress=True, progress_interval=10):
        self.tfrecord_dir       = tfrecord_dir
        self.tfr_prefix         = os.path.join(self.tfrecord_dir, os.path.basename(self.tfrecord_dir))
        self.expected_images    = expected_images
        self.cur_images         = 0
        self.shape              = None
        self.resolution_log2    = None
        self.tfr_writers        = []
        self.print_progress     = print_progress
        self.progress_interval  = progress_interval

        if self.print_progress:
            print('Creating dataset "%s"' % tfrecord_dir)
        if not os.path.isdir(self.tfrecord_dir):
            os.makedirs(self.tfrecord_dir)
        assert os.path.isdir(self.tfrecord_dir)

    def close(self):
        if self.print_progress:
            print('%-40s\r' % 'Flushing data...', end='', flush=True)
        for tfr_writer in self.tfr_writers:
            tfr_writer.close()
        self.tfr_writers = []
        if self.print_progress:
            print('%-40s\r' % '', end='', flush=True)
            print('Added %d images.' % self.cur_images)

    def choose_shuffled_order(self): # Note: Images and labels must be added in shuffled order.
        order = np.arange(self.expected_images)
        np.random.RandomState(123).shuffle(order)
        return order

    def add_image(self, img):
        if self.print_progress and self.cur_images % self.progress_interval == 0:
            print('%d / %d\r' % (self.cur_images, self.expected_images), end='', flush=True)
        if self.shape is None:
            self.shape = img.shape
            self.resolution_log2 = int(np.log2(self.shape[1]))
            assert self.shape[0] in [1, 3]
            assert self.shape[1] == self.shape[2]
            assert self.shape[1] == 2**self.resolution_log2
            tfr_opt = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.NONE)
            for lod in range(self.resolution_log2 - 1):
                tfr_file = self.tfr_prefix + '-r%02d.tfrecords' % (self.resolution_log2 - lod)
                self.tfr_writers.append(tf.python_io.TFRecordWriter(tfr_file, tfr_opt))
        assert img.shape == self.shape
        for lod, tfr_writer in enumerate(self.tfr_writers):
            if lod:
                img = img.astype(np.float32)
                img = (img[:, 0::2, 0::2] + img[:, 0::2, 1::2] + img[:, 1::2, 0::2] + img[:, 1::2, 1::2]) * 0.25
            quant = np.rint(img).clip(0, 255).astype(np.uint8)
            ex = tf.train.Example(features=tf.train.Features(feature={
                'shape': tf.train.Feature(int64_list=tf.train.Int64List(value=quant.shape)),
                'data': tf.train.Feature(bytes_list=tf.train.BytesList(value=[quant.tostring()]))}))
            tfr_writer.write(ex.SerializeToString())
        self.cur_images += 1

    def add_labels(self, labels):
        if self.print_progress:
            print('%-40s\r' % 'Saving labels...', end='', flush=True)
        assert labels.shape[0] == self.cur_images
        with open(self.tfr_prefix + '-rxx.labels', 'wb') as f:
            np.save(f, labels.astype(np.float32))

    def __enter__(self):
        return self

    def __exit__(self, *args):
        self.close()


In [8]:
'''Putting all the live samples in one folder for TF record'''
data_pd = pd.read_pickle('/home/cc/Data/Zips/silk_index_df_fixed_crop.p')
data_pd['FilePath']= data_pd.FilePath.apply(lambda x: os.path.join('/home/cc/Data/Images/silk_static',x))
data_live = data_pd[np.logical_or(data_pd.FingerType == 'LIVE', data_pd.FingerType == 'Live')] # only live samples

In [11]:
data_live.head()

Unnamed: 0.1,Unnamed: 0,CaptureID,Attempt,CaptureID.1,FilePath,Finger,FingerType,PathF1,PathF10,PathF2,...,set_type_fold_2,set_type_fold_3,set_type_fold_4,set_type_fold_5,set_type_fold_6,set_type_fold_7,set_type_fold_8,set_type_fold_9,x_center,y_center
0,0,ds1_1,,ds1_1,/home/cc/Data/Images/silk_static/dev_set_1/Sta...,Left Little,LIVE,dev_set_1/Sequences/live/3932164900003/2000020...,,dev_set_1/Sequences/live/3932164900003/2000020...,...,tr,tr,tr,tr,tr,v,tr,tr,300.0,405.0
1,1,ds1_2,,ds1_2,/home/cc/Data/Images/silk_static/dev_set_1/Sta...,Left Little,LIVE,dev_set_1/Sequences/live/3932164900003/2000020...,,dev_set_1/Sequences/live/3932164900003/2000020...,...,tr,tr,tr,tr,tr,v,tr,tr,272.0,369.0
2,2,ds1_3,,ds1_3,/home/cc/Data/Images/silk_static/dev_set_1/Sta...,Right Little,LIVE,dev_set_1/Sequences/live/3932164900003/2000020...,,dev_set_1/Sequences/live/3932164900003/2000020...,...,tr,tr,tr,tr,tr,v,tr,tr,261.0,426.0
3,3,ds1_4,,ds1_4,/home/cc/Data/Images/silk_static/dev_set_1/Sta...,Right Little,LIVE,dev_set_1/Sequences/live/3932164900003/2000020...,,dev_set_1/Sequences/live/3932164900003/2000020...,...,tr,tr,tr,tr,tr,v,tr,tr,238.0,444.0
4,4,ds1_5,,ds1_5,/home/cc/Data/Images/silk_static/dev_set_1/Sta...,Left Ring,LIVE,dev_set_1/Sequences/live/3932164900003/2000020...,,dev_set_1/Sequences/live/3932164900003/2000020...,...,tr,tr,tr,tr,tr,v,tr,tr,245.0,437.0


In [144]:
data_live.shape

(45478, 36)

In [None]:
images_loc = '/home/cc/Data/Allinone'
if not os.path.exists(images_loc):
    os.mkdir(images_loc)
image_rgb = np.zeros((512,512,3), dtype=np.uint8)
for i in range(len(data_live)):  
    splits = data_live.iloc[i,4].split('/')
    '''Gray to RGB for style loss we have to do this'''
    img = io.imread(data_live.iloc[i,4])
    image_resized = resize(img, (400,300), preserve_range=True, clip=True).astype(np.uint8)
    image_paded = np.pad(image_resized, ([56,56], [106,106]), 'constant', constant_values=(255, 255))
    image_rgb[:,:,0] = image_paded
    image_rgb[:,:,1] = image_paded
    image_rgb[:,:,2] = image_paded
    #image_resized = resize(image_rgb, (512,512,3), preserve_range=True, clip=True).astype(np.uint8)
    io.imsave(os.path.join(images_loc,splits[-1]), image_rgb)
    #plt.imsave(os.path.join('/home/cc/Data/Allinone',splits[-1]), image_resized)

  warn("The default mode, 'constant', will be changed to 'reflect' in "
  warn("Anti-aliasing will be enabled by default in skimage 0.15 to "
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  w

  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s

  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s

  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s

  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s

  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s

  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s

  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)
  warn('%s is a low contrast image' % fname)


In [146]:
'''Cleaning the dataset - We have messed up images!!!!'''
image_filenames = sorted(glob.glob(os.path.join(images_loc, '*')))
Errors_channel = []
Errors_resulution = []
cleaned_filenames = []
for i in image_filenames:
    img = np.asarray(PIL.Image.open(i))
    if img.shape[2] !=3: 
        Errors_channel.append(i)
    elif img.shape[1] != 512:
        Errors_resulution.append(i)
    else:
        cleaned_filenames.append(i)

In [147]:
len(image_filenames)

45478

In [148]:
len(cleaned_filenames)

45478

In [149]:
Errors_channel

[]

In [150]:
Errors_resulution

[]

In [151]:
def create_from_images(tfrecord_dir, image_dir, shuffle):
    print('Loading images from "%s"' % image_dir)
    image_filenames = sorted(glob.glob(os.path.join(image_dir, '*')))
    if len(image_filenames) == 0:
        error('No input images found')

    img = np.asarray(PIL.Image.open(image_filenames[0]))
    resolution = img.shape[0]
    channels = img.shape[2] if img.ndim == 3 else 1
    if img.shape[1] != resolution:
        error('Input images must have the same width and height')
    if resolution != 2 ** int(np.floor(np.log2(resolution))):
        error('Input image resolution must be a power-of-two')
    if channels not in [1, 3]:
        error('Input images must be stored as RGB or grayscale')
    
    
    with TFRecordExporter(tfrecord_dir, len(image_filenames)) as tfr:
        order = tfr.choose_shuffled_order() if shuffle else np.arange(len(image_filenames))
        for idx in range(order.size):
            img = np.asarray(PIL.Image.open(image_filenames[order[idx]]))
            if channels == 1:
                img = img[np.newaxis, :, :] # HW => CHW
            else:
                img = img.transpose([2, 0, 1]) # HWC => CHW
            tfr.add_image(img)

In [152]:
if not os.path.exists('/home/cc/Data/Silk_512'):
    os.mkdir('/home/cc/Data/Silk_512')

create_from_images('/home/cc/Data/Silk_512', images_loc , False)

Loading images from "/home/cc/Data/Allinone"
Creating dataset "/home/cc/Data/Silk_512"
Added 45478 images.                     
