In [1]:
import os
import numpy as np
import warnings
import split_extract
from multiprocessing import Pool
from tqdm import tqdm
from params import dresden_images_root, images_db_path, patch_span, \
        patch_num, patches_root, patches_db_path

In [2]:
warnings.filterwarnings("ignore")

if not os.path.exists(patches_root):
    os.makedirs(patches_root)

In [3]:
images_db = np.load(images_db_path, allow_pickle=True).item()

model_list = np.unique(images_db['brand_model'])
img_list = images_db['path']

if not os.path.exists(patches_db_path):
    train_list, val_list, test_list = split_extract.split(img_list, model_list, patches_db_path)
else:
    patches_db = np.load(patches_db_path, allow_pickle=True).item()
    train_list = patches_db['train']
    val_list = patches_db['val']
    test_list = patches_db['test']
    split_extract.split_info(train_list, val_list, test_list, model_list)

Canon_Ixus55 in training set: 156.
Canon_Ixus55 in validation set: 28.
Canon_Ixus55 in test set: 40.

Canon_Ixus70 in training set: 358.
Canon_Ixus70 in validation set: 96.
Canon_Ixus70 in test set: 113.

Canon_PowerShotA640 in training set: 114.
Canon_PowerShotA640 in validation set: 32.
Canon_PowerShotA640 in test set: 42.



In [4]:
files_labels = dict(zip(images_db['path'], images_db['brand_model']))

train_labels = []
val_labels = []
test_labels = []

for path in train_list:
    train_labels += [files_labels[path]]
    
for path in val_list:
    val_labels += [files_labels[path]]
    
for path in test_list:
    test_labels += [files_labels[path]]

In [5]:
print('Collecting image data...')
train_imgs_list = []
val_imgs_list = []
test_imgs_list = []

for img_brand_model,img_path in \
    tqdm(zip(images_db['brand_model'], images_db['path'])):
                   
    if img_path in train_list:
        train_imgs_list += [{'data_set':'train',
                           'img_path':img_path,
                           'img_brand_model':img_brand_model,
                           'patch_span':patch_span,
                           'patch_num':patch_num,
                           'patch_root': patches_root,
                           'img_root': dresden_images_root
                           }]
        
    elif img_path in val_list:
        val_imgs_list += [{'data_set':'val',
                           'img_path':img_path,
                           'img_brand_model':img_brand_model,
                           'patch_span':patch_span,
                           'patch_num':patch_num,
                           'patch_root': patches_root,
                           'img_root': dresden_images_root
                           }]
        
    else:
        test_imgs_list += [{'data_set':'test',
                           'img_path':img_path,
                           'img_brand_model':img_brand_model,
                           'patch_span':patch_span,
                           'patch_num':patch_num,
                           'patch_root': patches_root,
                           'img_root': dresden_images_root
                           }]

num_processes = 12
# num_processes = 4
pool = Pool(processes=num_processes)

print('Extracting training patches...')
train_paths = pool.map(split_extract.extract, train_imgs_list)
print('Extracting validation patches...')
val_paths = pool.map(split_extract.extract, val_imgs_list)
print('Extracting testing patches...')
test_paths = pool.map(split_extract.extract, test_imgs_list)
print('Completed.')

979it [00:00, 52436.16it/s]

Collecting image data...





Extracting training patches...
Extracting validation patches...
Extracting testing patches...
Completed.
