In [1]:
import os
import numpy as np
import fnmatch
import warnings
from multiprocessing import Pool
from tqdm import tqdm
from skimage.util.shape import view_as_blocks
from skimage import io
from params import dresden_images_root, images_db_path, patch_span, patch_num, patches_root, patches_db_path

In [2]:
warnings.filterwarnings("ignore")

In [3]:
images_db = np.load(images_db_path, allow_pickle=True).item()

In [4]:
model_list = np.unique(images_db['brand_model'])

In [5]:
img_list = os.listdir(dresden_images_root)
num_test = int(len(img_list) * 0.2)
num_train = int(len(img_list) - num_test)

In [6]:
shuffle_list = np.random.permutation(img_list)

train_list = shuffle_list[0:num_train].tolist()
test_list = shuffle_list[num_train:].tolist()

In [7]:
for model in model_list:
    tmp_list = fnmatch.filter(train_list, model + '*')
    print("{} in training set: {}.".format(model, len(tmp_list)))
    tmp_list = fnmatch.filter(test_list, model + '*')
    print("{} in test set: {}.\n".format(model, len(tmp_list)))

Agfa_DC-504 in training set: 134.
Agfa_DC-504 in test set: 35.

Agfa_DC-733s in training set: 227.
Agfa_DC-733s in test set: 54.

Agfa_DC-830i in training set: 290.
Agfa_DC-830i in test set: 73.

Agfa_Sensor505-x in training set: 139.
Agfa_Sensor505-x in test set: 33.

Agfa_Sensor530s in training set: 297.
Agfa_Sensor530s in test set: 75.

Canon_Ixus55 in training set: 177.
Canon_Ixus55 in test set: 47.

Canon_Ixus70 in training set: 458.
Canon_Ixus70 in test set: 109.

Canon_PowerShotA640 in training set: 147.
Canon_PowerShotA640 in test set: 41.



In [8]:
def patchify(img_name, patch_span, pacth_size=(256, 256)):
    img = io.imread(img_name)
    if img is None or not isinstance(img, np.ndarray):
        print('Unable to read the image: {:}'.format(args['img_path']))

    center = np.divide(img.shape[:2], 2).astype(int)
    start = np.subtract(center, patch_span/2).astype(int)
    end = np.add(center, patch_span/2).astype(int)
    sub_img = img[start[0]:end[0], start[1]:end[1]]
    patches = view_as_blocks(sub_img[:, :, 1], (256, 256))
    return patches

In [9]:
def extract_and_save(args):
    # 'Agfa/DC-504/Agfa_DC-504_0_1/Agfa_DC-504_0_1_00.png' for example,
    # last part is the patch idex.
    output_rel_paths = [os.path.join(args['img_brand'],args['img_model'],
                        os.path.splitext(os.path.split(args['img_path'])[-1])[0],
                        os.path.splitext(os.path.split(args['img_path'])[-1])[0]+'_'+'{:02}'.format(patch_idx) + '.png')\
                        for patch_idx in range(args['patch_num'])]
    read_img = False
    for out_path in output_rel_paths:
        out_fullpath = os.path.join(args['patch_root'], out_path)
        # if there is no this path, then we have to read images
        if not os.path.exists(out_fullpath):
            read_img = True
            break
    if read_img:
        img_name = os.path.join(args['img_root'], args['img_path'])
        patches = patchify(img_name, args['patch_span']).reshape((-1, 256, 256))
        
        for out_path, patch in zip(output_rel_paths, patches):
            out_fullpath = os.path.join(args['patch_root'],out_path)
            # the diretory of the patches images
            out_fulldir = os.path.split(out_fullpath)[0]
            if not os.path.exists(out_fulldir):
                os.makedirs(out_fulldir)
            if not os.path.exists(out_fullpath):
                io.imsave(out_fullpath, patch)

    return output_rel_paths

In [11]:
print('Collecting image data...')
imgs_list = []
for img_brand,img_model,img_path in \
        tqdm(zip(images_db['brand'], images_db['model'], images_db['path'])):
    imgs_list += [{'img_path':img_path,
                   'img_brand':img_brand,
                   'img_model':img_model,
                   'patch_span':patch_span,
                   'patch_num':patch_num,
                   'patch_root': patches_root,
                   'img_root': dresden_images_root
                   }]
    
print('Extracting patches...')

num_processes = 12
pool = Pool(processes=num_processes)
patches_paths = pool.map(extract_and_save, imgs_list)

# Create patches dataset
print('Creating patches dataset...')
patch_dataset = dict()
patch_dataset['path'] = []
patch_dataset['shot'] = []

for patch_rel_paths, img_shot in tqdm(zip(patches_paths, images_db['shot'])):
    for patch_rel_path in patch_rel_paths:
        patch_dataset['path'] += [patch_rel_path]
        patch_dataset['shot'] += [img_shot]

patch_dataset['path'] = np.asarray(patch_dataset['path']).flatten()
patch_dataset['shot'] = np.asarray(patch_dataset['shot']).flatten()

print('Saving patches dataset to: {:}'.format(patches_db_path))
np.save(patches_db_path, patch_dataset)

print('Completed.')

2336it [00:00, 473465.46it/s]

Collecting image data...
Extracting patches...



2336it [00:00, 169305.77it/s]

Creating patches dataset...





Saving patches dataset to: patches/patch.npy
Completed.
