In [1]:
import os
import numpy as np
import fnmatch
import warnings
from multiprocessing import Pool
from tqdm import tqdm
from skimage.util.shape import view_as_blocks
from skimage import io
from params import dresden_images_root, images_db_path, patch_span, \
        patch_num, patches_root, train_db_path, test_db_path

In [2]:
warnings.filterwarnings("ignore")

In [3]:
images_db = np.load(images_db_path, allow_pickle=True).item()

model_list = np.unique(images_db['brand_model'])

img_list = os.listdir(dresden_images_root)
num_test = int(len(img_list) * 0.2)
num_train = int(len(img_list) - num_test)

shuffle_list = np.random.permutation(img_list)

train_list = shuffle_list[0:num_train].tolist()
test_list = shuffle_list[num_train:].tolist()

for model in model_list:
    tmp_list = fnmatch.filter(train_list, model + '*')
    print("{} in training set: {}.".format(model, len(tmp_list)))
    tmp_list = fnmatch.filter(test_list, model + '*')
    print("{} in test set: {}.\n".format(model, len(tmp_list)))

Agfa_DC-504 in training set: 139.
Agfa_DC-504 in test set: 30.

Agfa_DC-733s in training set: 214.
Agfa_DC-733s in test set: 67.

Agfa_DC-830i in training set: 289.
Agfa_DC-830i in test set: 74.

Agfa_Sensor505-x in training set: 141.
Agfa_Sensor505-x in test set: 31.

Agfa_Sensor530s in training set: 295.
Agfa_Sensor530s in test set: 77.

Canon_Ixus55 in training set: 187.
Canon_Ixus55 in test set: 37.

Canon_Ixus70 in training set: 456.
Canon_Ixus70 in test set: 111.

Canon_PowerShotA640 in training set: 148.
Canon_PowerShotA640 in test set: 40.



In [4]:
def patchify(img_name, patch_span, pacth_size=(256, 256)):
    img = io.imread(img_name)
    if img is None or not isinstance(img, np.ndarray):
        print('Unable to read the image: {:}'.format(args['img_path']))

    center = np.divide(img.shape[:2], 2).astype(int)
    start = np.subtract(center, patch_span/2).astype(int)
    end = np.add(center, patch_span/2).astype(int)
    sub_img = img[start[0]:end[0], start[1]:end[1]]
    patches = view_as_blocks(sub_img[:, :, 1], (256, 256))
    return patches

In [5]:
def extract_and_save(args):
    # 'Agfa/DC-504/Agfa_DC-504_0_1/Agfa_DC-504_0_1_00.png' for example,
    # last part is the patch idex.
    output_rel_paths = [os.path.join(args['data_set'], args['img_brand_model'],
                        os.path.splitext(os.path.split(args['img_path'])[-1])[0]+'_'+'{:02}'.format(patch_idx) + '.png')\
                        for patch_idx in range(args['patch_num'])]
    read_img = False
    for out_path in output_rel_paths:
        out_fullpath = os.path.join(args['patch_root'], out_path)
        # if there is no this path, then we have to read images
        if not os.path.exists(out_fullpath):
            read_img = True
            break
    if read_img:
        img_name = os.path.join(args['img_root'], args['img_path'])
        patches = patchify(img_name, args['patch_span']).reshape((-1, 256, 256))
        
        for out_path, patch in zip(output_rel_paths, patches):
            out_fullpath = os.path.join(args['patch_root'],out_path)
            # the diretory of the patches images
            out_fulldir = os.path.split(out_fullpath)[0]
            if not os.path.exists(out_fulldir):
                os.makedirs(out_fulldir)
            if not os.path.exists(out_fullpath):
                io.imsave(out_fullpath, patch)

    return output_rel_paths

In [6]:
files_labels = dict(zip(images_db['path'], images_db['brand_model']))

train_labels = []
test_labels = []

for path in train_list:
    train_labels += [files_labels[path]]
    
for path in test_list:
    test_labels += [files_labels[path]]

In [8]:
print('Collecting image data...')

train_imgs_list = []
test_imgs_list = []

for img_brand_model,img_path in \
    tqdm(zip(images_db['brand_model'], images_db['path'])):
                   
    if img_path in train_list:
        train_imgs_list += [{'data_set':'train',
                           'img_path':img_path,
                           'img_brand_model':img_brand_model,
                           'patch_span':patch_span,
                           'patch_num':patch_num,
                           'patch_root': patches_root,
                           'img_root': dresden_images_root
                           }]
    else:
        test_imgs_list += [{'data_set':'test',
                           'img_path':img_path,
                           'img_brand_model':img_brand_model,
                           'patch_span':patch_span,
                           'patch_num':patch_num,
                           'patch_root': patches_root,
                           'img_root': dresden_images_root
                           }]
    
print('Extracting patches...')

num_processes = 12
pool = Pool(processes=num_processes)
train_paths = pool.map(extract_and_save, train_imgs_list)
test_paths = pool.map(extract_and_save, test_imgs_list)

# # Create patches dataset
# print('Creating patches dataset...')
# train_dataset = dict()
# train_dataset['path'] = []
# train_dataset['labels'] = []

# for patch_rel_paths, img_labels in tqdm(zip(train_paths, train_labels)):
#     for patch_rel_path in patch_rel_paths:
#         train_dataset['path'] += [patch_rel_path]
#         train_dataset['labels'] += [img_labels]

# train_dataset['path'] = np.asarray(train_dataset['path']).flatten()
# train_dataset['shot'] = np.asarray(train_dataset['labels']).flatten()
        
# test_dataset = dict()
# test_dataset['path'] = []
# test_dataset['labels'] = []

# for patch_rel_paths, img_labels in tqdm(zip(test_paths, test_labels)):
#     for patch_rel_path in patch_rel_paths:
#         test_dataset['path'] += [patch_rel_path]
#         test_dataset['labels'] += [img_labels]

# test_dataset['path'] = np.asarray(test_dataset['path']).flatten()
# test_dataset['labels'] = np.asarray(test_dataset['labels']).flatten()

# print('Saving training patches dataset to: {:}'.format(train_db_path))
# np.save(train_db_path, train_dataset)
# print('Saving testing patches dataset to: {:}'.format(test_db_path))
# np.save(test_db_path, test_dataset)

print('Completed.')

2336it [00:00, 47928.81it/s]

Collecting image data...
Extracting patches...





Completed.
