In [1]:
import os
import numpy as np
import fnmatch
import warnings
from multiprocessing import Pool
from tqdm import tqdm
from skimage.util.shape import view_as_blocks
from skimage import io
from params import dresden_images_root, images_db_path, patch_span, \
        patch_num, patches_root, train_db_path, test_db_path

In [2]:
warnings.filterwarnings("ignore")

In [3]:
images_db = np.load(images_db_path, allow_pickle=True).item()

model_list = np.unique(images_db['brand_model'])

In [4]:
img_list = images_db['path']
num_test = int(len(img_list) * 0.2)
num_train = int(len(img_list) - num_test)
num_val = int(num_train * 0.2)
num_train = num_train - num_val

In [5]:
shuffle_list = np.random.permutation(img_list)

train_list = shuffle_list[0:num_train].tolist()
val_list = shuffle_list[num_train:num_train + num_val].tolist()
test_list = shuffle_list[num_train + num_val:].tolist()

In [6]:
for model in model_list:
    tmp_list = fnmatch.filter(train_list, model + '*')
    print("{} in training set: {}.".format(model, len(tmp_list)))
    tmp_list = fnmatch.filter(val_list, model + '*')
    print("{} in validation set: {}.".format(model, len(tmp_list)))
    tmp_list = fnmatch.filter(test_list, model + '*')
    print("{} in test set: {}.\n".format(model, len(tmp_list)))

Canon_Ixus55 in training set: 146.
Canon_Ixus55 in validation set: 38.
Canon_Ixus55 in test set: 40.

Canon_Ixus70 in training set: 359.
Canon_Ixus70 in validation set: 85.
Canon_Ixus70 in test set: 123.

Canon_PowerShotA640 in training set: 123.
Canon_PowerShotA640 in validation set: 33.
Canon_PowerShotA640 in test set: 32.



In [7]:
def patchify(img_name, patch_span, pacth_size=(256, 256)):
    img = io.imread(img_name)
    if img is None or not isinstance(img, np.ndarray):
        print('Unable to read the image: {:}'.format(args['img_path']))

    center = np.divide(img.shape[:2], 2).astype(int)
    start = np.subtract(center, patch_span/2).astype(int)
    end = np.add(center, patch_span/2).astype(int)
    sub_img = img[start[0]:end[0], start[1]:end[1]]
    patches = view_as_blocks(sub_img[:, :, 1], (256, 256))
    return patches

In [8]:
def extract_and_save(args):
    # 'Agfa_DC-504/Agfa_DC-504_0_1_00.png' for example,
    # last part is the patch idex.
    output_rel_paths = [os.path.join(args['data_set'], args['img_brand_model'],
                        os.path.splitext(os.path.split(args['img_path'])[-1])[0]+'_'+'{:02}'.format(patch_idx) + '.png')\
                        for patch_idx in range(args['patch_num'])]
    read_img = False
    for out_path in output_rel_paths:
        out_fullpath = os.path.join(args['patch_root'], out_path)
        # if there is no this path, then we have to read images
        if not os.path.exists(out_fullpath):
            read_img = True
            break
    if read_img:
        img_name = os.path.join(args['img_root'], args['img_path'])
        patches = patchify(img_name, args['patch_span']).reshape((-1, 256, 256))
        
        for out_path, patch in zip(output_rel_paths, patches):
            out_fullpath = os.path.join(args['patch_root'],out_path)
            # the diretory of the patches images
            out_fulldir = os.path.split(out_fullpath)[0]
            if not os.path.exists(out_fulldir):
                os.makedirs(out_fulldir)
            if not os.path.exists(out_fullpath):
                io.imsave(out_fullpath, patch)

    return output_rel_paths

In [9]:
files_labels = dict(zip(images_db['path'], images_db['brand_model']))

train_labels = []
val_labels = []
test_labels = []

for path in train_list:
    train_labels += [files_labels[path]]
    
for path in val_list:
    val_labels += [files_labels[path]]
    
for path in test_list:
    test_labels += [files_labels[path]]

In [11]:
print('Collecting image data...')

train_imgs_list = []
val_imgs_list = []
test_imgs_list = []

for img_brand_model,img_path in \
    tqdm(zip(images_db['brand_model'], images_db['path'])):
                   
    if img_path in train_list:
        train_imgs_list += [{'data_set':'train',
                           'img_path':img_path,
                           'img_brand_model':img_brand_model,
                           'patch_span':patch_span,
                           'patch_num':patch_num,
                           'patch_root': patches_root,
                           'img_root': dresden_images_root
                           }]
        
    elif img_path in val_list:
        val_imgs_list += [{'data_set':'val',
                           'img_path':img_path,
                           'img_brand_model':img_brand_model,
                           'patch_span':patch_span,
                           'patch_num':patch_num,
                           'patch_root': patches_root,
                           'img_root': dresden_images_root
                           }]
        
    else:
        test_imgs_list += [{'data_set':'test',
                           'img_path':img_path,
                           'img_brand_model':img_brand_model,
                           'patch_span':patch_span,
                           'patch_num':patch_num,
                           'patch_root': patches_root,
                           'img_root': dresden_images_root
                           }]
    
print('Extracting patches...')

num_processes = 12
pool = Pool(processes=num_processes)
train_paths = pool.map(extract_and_save, train_imgs_list)
val_paths = pool.map(extract_and_save, val_imgs_list)
test_paths = pool.map(extract_and_save, test_imgs_list)

# # Create patches dataset
# print('Creating patches dataset...')
# train_dataset = dict()
# train_dataset['path'] = []
# train_dataset['labels'] = []

# for patch_rel_paths, img_labels in tqdm(zip(train_paths, train_labels)):
#     for patch_rel_path in patch_rel_paths:
#         train_dataset['path'] += [patch_rel_path]
#         train_dataset['labels'] += [img_labels]

# train_dataset['path'] = np.asarray(train_dataset['path']).flatten()
# train_dataset['shot'] = np.asarray(train_dataset['labels']).flatten()
        
# test_dataset = dict()
# test_dataset['path'] = []
# test_dataset['labels'] = []

# for patch_rel_paths, img_labels in tqdm(zip(test_paths, test_labels)):
#     for patch_rel_path in patch_rel_paths:
#         test_dataset['path'] += [patch_rel_path]
#         test_dataset['labels'] += [img_labels]

# test_dataset['path'] = np.asarray(test_dataset['path']).flatten()
# test_dataset['labels'] = np.asarray(test_dataset['labels']).flatten()

# print('Saving training patches dataset to: {:}'.format(train_db_path))
# np.save(train_db_path, train_dataset)
# print('Saving testing patches dataset to: {:}'.format(test_db_path))
# np.save(test_db_path, test_dataset)

print('Completed.')

979it [00:00, 70851.93it/s]

Collecting image data...
Extracting patches...





FileExistsError: [Errno 17] File exists: 'patches/test/Canon_Ixus55'