In [1]:
import tensorflow as tf
import matplotlib.pyplot as plt
import ipywidgets as widgets
from ipywidgets import HBox, VBox
import requests
import io
import json
from PIL import Image
import os
import random
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # only display error messages

# styles of modern art used for classification
class_names = ('Abstract Art', 'Abstract Expressionism', 'Cubism', 'Expressionism', 'Naïve Art (Primitivism)', 'Op Art', 'Pop Art', 'Street art', 'Suprematism', 'Surrealism')

# images are resized to this square dimension
pixels = 224

# load previously trained model
model = tf.keras.models.load_model("./art-classifier-model/art-classification-multi-class-mobilenet-v3-large-sigmoid-0.1dropout-softmax")

with open('exampleArtists.json') as f:
    exampleArtists = list(json.load(f).items())
# print(exampleArtists)
print('')

Metal device set to: Apple M1



In [15]:
def get_predictions(img_array):
    [rank] = model.predict(img_array)
    preds = zip(list(class_names), list(rank))
    score = list(sorted(list(preds), key = lambda z: z[1], reverse = True)[:len(class_names)])

    return score

def format_predictions(score, artist):
    first = score[0]
    second = score[1]
    third = score[2]
    return """
{} {:.0f}%
{} {:.0f}% 
{} {:.0f}%
File Name: {}""".format(    
        first[0], 100 * first[1], 
        second[0],100*second[1], 
        third[0],100*third[1],
        artist)

out = widgets.Output()
examples = widgets.Output()

uploader = widgets.FileUpload(accept='image/jpeg', multiple=False)

@out.capture(clear_output=True)
def display_image():

    first_image_filename = (list(uploader.value.keys())[0])
    image_file_data = uploader.value[first_image_filename]["content"]
    image = Image.open(io.BytesIO(image_file_data))

    # Model expects specific model dimensions 
    size = (pixels,pixels)
    resized_image = image.resize(size)

    img_array = tf.keras.utils.img_to_array(resized_image)
    img_array = tf.expand_dims(img_array, 0) # Create a batch
    img_array = img_array / 255 # normalize floats between 0 and 1

    [rank] = model.predict(img_array)
    preds = zip(list(class_names), list(rank))
    score = list(sorted(list(preds), key = lambda z: z[1], reverse = True)[:len(class_names)])
    
    formatted_prediction_title = format_predictions(score,first_image_filename)
    print(formatted_prediction_title)


    plt.figure(figsize = (5,5))
    plt.imshow(image)
    plt.axis('off')
    plt.show()
    get_examples(score)
    return


@examples.capture(clear_output=True)
def get_examples(score):
    for i in [ 0, 1, 2]:

        style = score[i][0]
        allArtists = {}
        
        for k,v in exampleArtists:
            if k == style:
                allArtists = v
        artist = random.choice(allArtists)
        
        print("Example art for style: {}".format(style))
        print("Artist: {}".format(artist["artistName"]))
        print("Title: {}".format(artist["title"]))
        link = widgets.HTML(
            value="<a href={} target='_blank'>More info about {}</a>".format(
                artist["link"],artist["artistName"]
            ))
        display(link)
        
        response = requests.get(artist["imageUrl"])
        image = Image.open(io.BytesIO(response.content))
        plt.figure(figsize = (10,10))
        plt.imshow(image,interpolation='nearest')
        plt.axis('off')
        plt.show()
    

def on_file_uploaded(change):
    display_image()
    uploader._counter=0

# upload image file
uploader.observe(on_file_uploaded, 'value')

# Modern Art Style Classifier

The influences of modern art are all around us. Use this machine learning model to discover the influences of modern art in your image.

_Images discarded after each use. Source available at [Github](https://github.com/todgru/art-classifier-ml-deploy). Deployed with [MyBinder.org](https://mybinder.org)_


In [16]:
from ipywidgets import Layout, Button, Box

box_layout = Layout(display='flex',
                    flex_flow='column',
                    align_items='unset',#'stretch',
                    width='100%',
                    height='100%')

inner_left = Layout(display='flex',
                flex_flow='column',
                align_items='unset',#'center',#'stretch',
                width='30%')
inner_right = Layout(display='flex',
                flex_flow='column',
                align_items='unset',#'center',#'stretch',
                width='70%')


left_box = VBox([out])
right_box = VBox([examples])

top_row = HBox([uploader])
images_display_row = HBox(
    [
        Box(children=[left_box], layout=inner_left), 
        Box(children=[right_box], layout=inner_right)
    ])

items = [top_row, images_display_row]
box = Box(children=items, layout=box_layout)
box

Box(children=(HBox(children=(FileUpload(value={}, accept='image/jpeg', description='Upload'),)), HBox(children…