In [0]:
import os
import re
import requests
import random
import json
import glob
import shutil
import multiprocessing as mp
from PIL import Image
from tqdm import tqdm

random.seed(2020)

output_size = 512
num_workers = 8

KEYS = ['PAIR', 'TOPP', 'BOTT', 'LEFT', 'BACK', 'RGHT', 'FRNT']
# only download some keys, because its too large for colab
KEYS_DOWNLOAD = KEYS[:1]

In [0]:
# create folders holding downloaded items
for key in KEYS_DOWNLOAD:
    os.makedirs(key.lower(), exist_ok=True)

# load dataset
if not os.path.exists('dataset.json'):
    shutil.copy('drive/My Drive/dataset.json', '.')

with open('dataset.json', 'r') as f:
    DATASET = json.load(f)

In [0]:
# request using proxies to prevent IP banned

PROXIES = None # list of proxies

def get_proxies():
    """Get all proxies from free-proxy-list.net"""
    res = requests.get('https://free-proxy-list.net/')
    pattern = r'<tr><td>([\d\.]+)<\/td><td>([\d]+)<\/td>'
    proxies = re.findall(pattern, res.text)
    proxies = ['http://{}:{}'.format(ip, port) for (ip, port) in proxies[:20]]
    return proxies

def _request(url):
    """request with proxies"""
    for proxy in random.sample(PROXIES, len(PROXIES)):
        try:
            res = requests.get(url, proxies={'http': proxy})
            return res
        except:
            pass
    # all proxies are dead
    return None

def request(url):
    """request and update new PROXIES"""
    global PROXIES
    if PROXIES is None:
        PROXIES = get_proxies()
    res = _request(url)
    if res:
        return res
    # refresh proxies
    PROXIES = get_proxies()
    res = _request(url)
    return res

In [0]:
def download(idx):
    """download using multiprocessing"""

    if idx >= len(DATASET):
        return
    img_id = list(DATASET.keys())[idx]
    img_data = DATASET[img_id]

    for key, link in img_data['images'].items():
        if key not in KEYS_DOWNLOAD:
            continue
        output = os.path.join(key.lower(), '{}-{}.png'.format(img_id, key.lower()))
        if not os.path.exists(output):
            # download
            res = request(link)
            with open(output, 'wb') as f:
                f.write(res.content)

            # remove broken images
            try:
                im = Image.open(output).convert('RGB')
            except:
                os.remove(output)
                continue

            # resize into output_size
            w, h = im.size
            if (w, h) == (output_size, output_size):
                continue
            bg = Image.new('RGB', (max(im.size), max(im.size)), color=(255, 255, 255))
            w1, h1, = bg.size
            bg.paste(im, ((w1 - w)//2, (h1 - h)//2))
            bg = bg.resize((output_size, output_size))
            bg.save(output)

In [0]:
pool = mp.Pool(num_workers)
for i in tqdm(range((len(list(DATASET.keys())) + num_workers)//num_workers)):
    args = [num_workers*i + j for j in range(num_workers)]
    pool.map(download, args)

  3%|▎         | 238/8995 [03:59<2:23:28,  1.02it/s]

In [0]:
# check if all images are squared (output_size x output_size)
images = glob.glob('*/*.png')
for img in tqdm(images):
    im = Image.open(img)
    if im.size != (output_size, output_size):
        print(img)
        break

In [0]:
# zip and save to drive
save_folder = 'drive/My Drive/UIT-SHOESGAN/'
for key in KEYS_DOWNLOAD:
    shutil.make_archive(key.lower(), format='zip', root_dir=key.lower())
    shutil.move('{}.zip'.format(key.lower()), save_folder)
    shutil.move(key.lower(), save_folder)