In [1]:
import os
import numpy as np
import pandas as pd
import warnings
import split_extract
from multiprocessing import Pool
from tqdm import tqdm
from params import dresden_images_root, train_csv_path, patch_span, \
        patch_num, patches_root, patches_db_path

In [2]:
warnings.filterwarnings("ignore")

if not os.path.exists(patches_root):
    os.makedirs(patches_root)

In [3]:
images_db = pd.read_csv(train_csv_path)

model_list = np.unique(images_db['brand_model'])
img_list = images_db['path']

if not os.path.exists(patches_db_path):
    train_list, val_list, test_list = split_extract.split(img_list, model_list, patches_db_path)
else:
    patches_db = np.load(patches_db_path, allow_pickle=True).item()
    train_list = patches_db['train']
    val_list = patches_db['val']
    test_list = patches_db['test']
    split_extract.split_info(train_list, val_list, test_list, model_list)

Canon_Ixus70 in training set: 352.
Canon_Ixus70 in validation set: 103.
Canon_Ixus70 in test set: 112.

Nikon_D200 in training set: 490.
Nikon_D200 in validation set: 113.
Nikon_D200 in test set: 149.

Olympus_mju-1050SW in training set: 669.
Olympus_mju-1050SW in validation set: 161.
Olympus_mju-1050SW in test set: 210.



In [4]:
print('Collecting image data...')
train_imgs_list = []
val_imgs_list = []
test_imgs_list = []

for img_brand_model,img_path in \
    tqdm(zip(images_db['brand_model'], images_db['path'])):
                   
    if img_path in train_list:
        train_imgs_list += [{'data_set':'train',
                           'img_path':img_path,
                           'img_brand_model':img_brand_model,
                           'patch_span':patch_span,
                           'patch_num':patch_num,
                           'patch_root': patches_root,
                           'img_root': dresden_images_root
                           }]
        
    elif img_path in val_list:
        val_imgs_list += [{'data_set':'val',
                           'img_path':img_path,
                           'img_brand_model':img_brand_model,
                           'patch_span':patch_span,
                           'patch_num':patch_num,
                           'patch_root': patches_root,
                           'img_root': dresden_images_root
                           }]
        
    else:
        test_imgs_list += [{'data_set':'test',
                           'img_path':img_path,
                           'img_brand_model':img_brand_model,
                           'patch_span':patch_span,
                           'patch_num':patch_num,
                           'patch_root': patches_root,
                           'img_root': dresden_images_root
                           }]

# num_processes = 12
num_processes = 4
pool = Pool(processes=num_processes)

print('Extracting training patches...')
train_paths = pool.map(split_extract.extract, train_imgs_list)
print('Extracting validation patches...')
val_paths = pool.map(split_extract.extract, val_imgs_list)
print('Extracting testing patches...')
test_paths = pool.map(split_extract.extract, test_imgs_list)
print('Completed.')

Collecting image data...


NameError: name 'images_db' is not defined