In [None]:
# default_exp image_downloader

# Image downloader

> Useful helper function to download images from Google search. Adapted from the fastai repository ([link](https://github.com/fastai/fastai/blob/eb6b2eab34cc5a65e338df1cec91fb7296981048/fastai/widgets/image_downloader.py)).

In [None]:
# export
try:
    from fastai.core import *
    from fastai.vision.data import *
except:
    print("fastai not installed. To use the ImageDownloader widget install fastai first.")
from urllib.parse import quote
from bs4 import BeautifulSoup
import time

In [None]:
# exports

class ImageDownloader:
    def __init__(self, data_path, dataset_name):
        """The ImageDownloader helps download images from the google image search page"""
        self._path = Path(data_path)
        self._dataset_path = self.path/dataset_name
        self._dataset_name = dataset_name
        
        os.makedirs(self._path, exist_ok=True)
        os.makedirs(self._dataset_path, exist_ok=True)
        
    def add_images_to_class(self, class_name, google_query, n_images=1000):
        """Add new images to the image class with a Google search query."""
        class_path = self._dataset_path/class_name
        url = _search_url(google_query)
        html = self.get_google_image_html(url)
        img_urls = self.get_img_urls_from_html(html)
        print(f'{len(img_urls)} image links found on Google image search for the query "{google_query}".')
        img_fnames = _download_images(class_path, img_urls)
        print(f'{len(img_fnames)} images now available in class {class_name}.')
        
        
    def get_google_image_html(self, url):
        """Get the html code of the Google Image Search."""
        options = webdriver.ChromeOptions()
        options.add_argument("--headless")
        options.add_argument('--no-sandbox')
        try: 
            driver = webdriver.Chrome(chrome_options=options)
        except: 
            print("""Error initializing chromedriver. 
                  Check if it's in your path by running `which chromedriver`""")
        driver.set_window_size(1440, 900)
        driver.get(url)
        old_height = 0
        for i in range(n_images // 100 + 1):
            driver.execute_script("window.scrollTo(0, document.body.scrollHeight);")
            time.sleep(1.0 + random.random())
            new_height = driver.execute_script("return document.body.scrollHeight")
            if new_height == old_height:
                try:
                    button = driver.find_elements_by_xpath("//input[@type='button' and @value='Show more results']")[0]
                    button.click()
                except:
                    pass    
            old_height = new_height
        return driver.page_source
    
    def get_img_urls_from_html(self, html):
        
        bs = BeautifulSoup(html, 'html.parser')
        img_tags = bs.find_all('img')
        urls = []
        
        for tag in img_tags:
            if tag.has_attr('data-src'):
                urls.append(tag['data-src'])
        return urls

def _download_images(label_path:PathOrStr, img_urls:list, max_workers:int=defaults.cpus, timeout:int=4) -> FilePathList:
    """
    Downloads images in `img_tuples` to `label_path`. 
    If the directory doesn't exist, it'll be created automatically.
    Uses `parallel` to speed things up in `max_workers` when the system has enough CPU cores.
    If something doesn't work, try setting up `max_workers=0` to debug.
    """
    os.makedirs(Path(label_path), exist_ok=True)
    parallel( partial(_download_single_image, label_path, timeout=timeout), img_urls, max_workers=max_workers)
    return get_image_files(label_path)

def _download_single_image(label_path:Path, img_url:tuple, i:int, timeout:int=4) -> None:
    """
    Downloads a single image from Google Search results to `label_path`
    given an `img_tuple` that contains `(fname, url)` of an image to download.
    `i` is just an iteration number `int`. 
    """
    fname = img_url.split('%')[1].split('&')[0]+'.png'
    download_url(img_url, label_path/fname, timeout=timeout)
    
def _search_url(search_term:str, size:str='>400*300', format:str='jpg') -> str:
    "Return a Google Images Search URL for a given search term."
    return ('https://www.google.com/search?q=' + quote(search_term) +
            '&espv=2&biw=1366&bih=667&site=webhp&source=lnms&tbm=isch' +
            _url_params(size, format) + '&sa=X&ei=XosDVaCXD8TasATItgE&ved=0CAcQ_AUoAg')