In [None]:
import warnings
warnings.simplefilter(action='ignore', category=UserWarning)

In [None]:
from duckduckgo_search import DDGS # pip install duckduckgo_search
from fastdownload import download_url # conda install fastdownload
from fastcore.all import * # conda install -c fastai fastai or conda install pytorch torchvision -c pytorch

from fastai.vision.all import *

from urllib.error import HTTPError

In [None]:
# A function that search images.
def search_images(term, max_images=2):
    print(f"Searching for '{term}'")
    return DDGS().images(keywords=term, max_results=max_images)

In [None]:
# Search, download and show a picture of a bird.
bird_dest = 'data/bird.jpg'
bird_urls = search_images('bird photos', max_images=1) # Relies on ddg, if error just try again.

print(bird_urls)
download_url(bird_urls[0]['image'], bird_dest, show_progress=True)

im = Image.open(bird_dest)
im.to_thumb(256,256)

In [None]:
# Search, download and show a picture of a forest.
forest_dest = 'data/forest.jpg'
forest_urls = search_images('forest photos', max_images=1)
download_url(forest_urls[0]['image'], forest_dest, show_progress=True)

im = Image.open(forest_dest)
im.to_thumb(256,256)

In [None]:
searches = 'forest','bird'
path = Path('data/bird_or_not')
from time import sleep

for search in searches:
    dest = (path/search)
    dest.mkdir(exist_ok=True, parents=True)

    # This block search and download a version of the searchterm.
    images = search_images(f'{search} photo', max_images=30)
    url_list = []
    for url in images:
        url_list.append(url['image'])   
    download_images(dest, urls=url_list)
    sleep(10)  # Pause between searches to avoid over-loading server

    # This block search and download a version of the searchterm.
    images = search_images(f'{search} sun photo', max_images=30)
    url_list = []
    for url in images:
        url_list.append(url['image'])  
    download_images(dest, urls=url_list)
    sleep(10)  # Pause between searches to avoid over-loading server

    # This block search and download a version of the searchterm.
    images = search_images(f'{search} shade photo', max_images=30)
    url_list = []
    for url in images:
        url_list.append(url['image'])  
    download_images(dest, urls=url_list)

    # Resize all the images in the folder.
    resize_images(path/search, max_size=400, dest=path/search)

In [None]:
# Removed images that did not download correctly.
path = Path('data/bird_or_not')
failed = verify_images(get_image_files(path))
failed.map(Path.unlink)
len(failed)

In [None]:
path = Path('data/bird_or_not')

# Prepare the dataset. Both the training set and the validation set.
dataloaders = DataBlock(
    blocks=(ImageBlock, CategoryBlock), # Input is images, Output is categories (bird / forest).
    get_items=get_image_files, # Get image files in path recursively, only in folders, if specified.
    splitter=RandomSplitter(valid_pct=0.2, seed=42), # Create function that splits items between train/val with valid_pct randomly.
    get_y=parent_label, # Label item with the parent folder name.
    item_tfms=[Resize(192, method='squish')] # Resize image by squishing (not cropping) before training.
).dataloaders(path, bs=32, verbose=True) # https://docs.fast.ai/data.block.html#datablock.dataloaders

In [None]:
dataloaders.show_batch(max_n=6)

In [None]:
# Train and tune our model.
learn = vision_learner(dataloaders, resnet18, metrics=error_rate) # Resnet18 is a widely used, fast, cv model.
learn.fine_tune(3) # FastAI use best practices for fine tuning a pre-trained model.

In [None]:
# Use our model by passing it the first picture that we downloaded.
category,_,probs = learn.predict(PILImage.create('data/bird.jpg'))

print(f"This is a: {category}.")
print(f"Probability it's a bird: {probs[0]:.4f}")

In [None]:
print(probs)