# Import

In [140]:
import os
import random
import shutil
from functools import partial
from pathlib import Path

import pandas as pd
import tensorflow as tf
import torchvision.transforms.functional as F
from joblib import Parallel, delayed
from PIL import Image
from tqdm.notebook import tqdm

# Config

In [2]:
dir_data = Path('data')
dir_content = dir_data/'content'
dir_raw = dir_content/'raw'

# Utility

In [3]:
def parallel(f, it, n_jobs=6):
    Parallel(n_jobs=n_jobs)(delayed(f)(i) for i in tqdm(it));

# Clean

Delete images which are corrupt

In [5]:
def _validate_image(file_img):
    try:
        Image.open(file_img)
    except:
        print(f'{file_img} is corrupt; removing...')
        file_img.unlink()        

In [6]:
parallel(_validate_image, dir_raw.rglob('*'))

HBox(children=(FloatProgress(value=0.0, max=123403.0), HTML(value='')))




# Make sure all images are rgb on disk

In [9]:
def _process_img(f):
    if len(np.array(Image.open(f)).shape) == 2:
        print(f)
        img = Image.open(f)
        img_rgb = Image.new('RGB', img.size)
        img_rgb.paste(img)
        img_rgb.save(f)

In [10]:
parallel(_process_img, dir_raw.rglob('*'))

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




# Pre-resize and crop images

Resize while maintaining aspect ratio, crop to size, then save in separate folders

In [13]:
def _img_pipeline(file_img, sz, dir_imgs_new):
    # Get image
    img = Image.open(file_img)
    img = F.resize(img, sz)
    img = F.center_crop(img, sz)
    # Save image
    dir_imgs_new.mkdir(parents=True, exist_ok=True)
    file_img_new = dir_imgs_new/file_img.name
    img.save(file_img_new)

In [14]:
def _parallel_img_pipeline(dir_imgs, sz):
    parallel(partial(_img_pipeline, 
                     sz=sz,
                     dir_imgs_new=dir_content/f'formatted_{sz}'),
             dir_imgs.rglob('*'))

In [15]:
_parallel_img_pipeline(dir_raw, 96)

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




In [16]:
_parallel_img_pipeline(dir_raw, 256)

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




# Create trn/val split

In [29]:
files_img = [f.name for f in dir_raw.rglob('*')]
random.shuffle(files_img)
idx_split = int(0.975*(len(files_img)-1))
files_trn, files_val = files_img[:idx_split], files_img[idx_split:]
df_trn, df_val = pd.DataFrame({'name': files_trn}), pd.DataFrame({'name': files_val})
file_csv_trn, file_csv_val = dir_content/'trn.csv', dir_content/'val.csv'
df_trn.to_csv(file_csv_trn, index=False)
df_val.to_csv(file_csv_val, index=False)

# Put images in a format that dali's FileReader can use

This requires the image folder to be the label name, since no labels are used I'll just make a dummy folder

In [34]:
def _format_pipeline(dir_imgs, t):
    dir_imgs_new = (dir_imgs/t/'imgs')
    dir_imgs_new.mkdir(parents=True, exist_ok=True)
    df = pd.read_csv(dir_imgs.parent/f'{t}.csv')
    for n in df.name:
        shutil.move((dir_imgs/n).as_posix(), dir_imgs_new)

In [35]:
_format_pipeline(dir_content/'formatted_96', 'trn')
_format_pipeline(dir_content/'formatted_96', 'val')

In [36]:
_format_pipeline(dir_content/'formatted_256', 'trn')
_format_pipeline(dir_content/'formatted_256', 'val')

# Create tfrecords for dali as well

In [123]:
def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

In [143]:
def imgs2tfrecord(files_img, file_tfrecord):
    with tf.io.TFRecordWriter(file_tfrecord.as_posix()) as writer:
        for file_img in files_img:
            bytes_img = open(file_img, 'rb').read()
            example = tf.train.Example(features=tf.train.Features(feature={
                'encoded': _bytes_feature(bytes_img)
            }))
            writer.write(example.SerializeToString())
    os.system(f'tfrecord2idx {file_tfrecord} {file_tfrecord.parent/(file_tfrecord.stem + ".idx")}')

In [144]:
def _imgs2tfrecord_pipeline(dir_imgs, t):
    imgs2tfrecord((dir_imgs/t/'imgs').rglob('*'), dir_imgs/f'{t}.tfrecord')

In [145]:
_imgs2tfrecord_pipeline(dir_content/'formatted_96', 'trn')
_imgs2tfrecord_pipeline(dir_content/'formatted_96', 'val')

In [146]:
_imgs2tfrecord_pipeline(dir_content/'formatted_256', 'trn')
_imgs2tfrecord_pipeline(dir_content/'formatted_256', 'val')