In [1]:
import os
import re #regular expressions
from tqdm import tqdm #progress bar
from shutil import copyfile #letting me copy files
import pathlib

#FIXME: you need to change this path to wherever you downloaded the service account access credentials .json file to
#sets the environment variable as to where the google automl access credentials are:
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = str("path/to/model_access_credentials.json")

#you also need to install the google automl client libraries: 
#> pip install google-cloud-automl
from google.cloud import automl_v1beta1

In [1]:
#file paths on your local computer, relative paths from this notebook. 
#FIXME: change these as appropriate for your project
#input
img_folder = "./my_images_to_classify/"
#output
cat_img_folder = "./my_classified_images/"

In [3]:
#Opens a prediction client. 
#Important not to do this once per image as it breaks after a few tens of images. 
#This is why it is out of the loop.
prediction_client = automl_v1beta1.PredictionServiceClient()

#project and model ids, you will need to set these as yours
#FIXME: change to your project and model ids
def_project_id="my_project"
def_model_id="ABC01234567890123456789"

In [2]:
def get_prediction(content, project_id = def_project_id, model_id = def_model_id, prediction_client = prediction_client):
    '''Given 'content' is the contents of an image file, 
    returns a prediction object from Google AutoML
    '''
    
    name = 'projects/{}/locations/us-central1/models/{}'.format(project_id, model_id)
    payload = {'image': {'image_bytes': content }}
    params = {}
    request = prediction_client.predict(name, payload, params)
    return request  # waits till request is returned

In [3]:
def get_classification(prediction):
    '''From a prediction object, this function gets the display_name, 
    i.e. the predicted category, and returns it as a string.
    '''
    
    text = str(prediction.payload)

    try:
        found = re.search('display_name: "(.+?)"\n', text).group(1)
    except AttributeError:
        # 'display_name: "XXXXXXXXX"\n' not found in the original string
        found = ''
    return found

In [4]:
def get_class_from_filepath(image_path):
    '''Opens the file located at image_path, and predicts its classification.
    Returns it as a string. 
    Example:
    get_class_from_filepath('./subfolder/my_image.jpg')
    '''
    
    with open(image_path, 'rb') as ff:
        content = ff.read()
    
    pred=get_prediction(content=content)
    return get_classification(pred)

In [7]:
#Get all file names of .jpg(s) in the specified directory
contents_of_dir=os.listdir(img_folder)
jpgs_in_dir = [file for file in contents_of_dir if file.endswith(".jpg")]

In [None]:
#Make a dictionary of predictions in a for loop. This is where the time is spent.
img_classes = dict()
total_len = len(jpgs_in_dir)

for i in tqdm(range(0,total_len)):
    #print('\rClassifying ' + str(i) + ' out of ' + str(total_len))
    file = jpgs_in_dir[i]
    img_classes.update({file: [file, get_class_from_filepath(img_folder + file)]})

In [None]:
list_of_classes = [value[1] for key, value in img_classes.items()]

#Get only unique elements by converting to a set and then back to a list
list_of_classes = [*{*list_of_classes}]

#Make classified folders if they don't exist:
for sub_dir in list_of_classes:
    pathlib.Path(cat_img_folder+sub_dir).mkdir(parents=True, exist_ok=True) 

In [109]:
#Copy jpgs to classified folders
for key, value in img_classes.items():
    source = img_folder+key
    destin = cat_img_folder+value[1]+'/'+key
    copyfile(source, destin)