In [9]:
import tensorflow as tf
import matplotlib.pyplot as plt
import ipywidgets as widgets
from ipywidgets import VBox, Layout, Box
import requests
import io
import json
import plotly.graph_objects as go
import ast
from PIL import Image
import os
import random

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # only display error messages

# styles of modern art used for classification
class_names = ('Abstract Art', 'Abstract Expressionism', 'Cubism', 'Expressionism', 'Naïve Art (Primitivism)', 'Op Art', 'Pop Art', 'Street art', 'Suprematism', 'Surrealism')

# images are resized to this square dimension
pixels = 224

# load previously trained model
model = tf.keras.models.load_model("./art-classifier-model/art-classification-multi-class-mobilenet-v3-large-sigmoid-0.1dropout-softmax")

with open('exampleArtists.json') as f:
    exampleArtists = list(json.load(f).items())

In [28]:
def get_predictions(img_array):
    [rank] = model.predict(img_array)
    preds = zip(list(class_names), list(rank))
    score = list(sorted(list(preds), key = lambda z: z[1], reverse = True)[:len(class_names)])

    return score

def format_predictions(score, artist):
    first = score[0]
    second = score[1]
    third = score[2]
    return "<b><p>{} {:.0f}% <p>{} {:.0f}% <p>{} {:.0f}%</b> <p>{}".format(    
        first[0], 100 * first[1], 
        second[0],100*second[1], 
        third[0],100*third[1],
        artist)

predictions = widgets.Output()
examples = widgets.Output()
uploader = widgets.FileUpload(accept='image/jpeg', multiple=False)

@predictions.capture(clear_output=True)
def display_image():
    first_image_filename = (list(uploader.value.keys())[0])
    image = uploader.value[first_image_filename]["content"]
    original = Image.open(io.BytesIO(image))

    # Model expects specific model dimensions 
    size = (pixels,pixels)
    image = original.resize(size)

    image = tf.keras.utils.img_to_array(image)
    image = tf.expand_dims(image, 0) # Create a batch
    image = image / 255 # normalize floats between 0 and 1

    [rank] = model.predict(image)
    preds = zip(list(class_names), list(rank))
    score = list(sorted(list(preds), key = lambda z: z[1], reverse = True)[:len(class_names)])
    
    formatted_prediction_title = format_predictions(score, first_image_filename)
    display(
        widgets.HTML(value="<h2>Art Influences</h2>"), 
        widgets.HTML(formatted_prediction_title)
    )

    plt.imshow(original)
    plt.axis('off')
    plt.show()
        
    # Print Prediction Visualization, Pie-chart-donut
    labels = [x[0] for x in score][:3] # x[0] is class name of tuple
    labels.append("Other")
    values = [x[1] for x in score][:3] # x[1] is prediction of class
    values.append(1 - sum(values)) # add the balance so the donut numbers add up
    fig = go.Figure(data=[go.Pie(labels=labels, values=values, textinfo='label+percent')])
    fig.update_traces(hole=.5)
    fig.update_layout(
        width=375,
        autosize=True,
        showlegend=False,
        annotations=[dict(text='Predictions', font_size=20, showarrow=False)])
    fig.show()

    get_examples(score)
    return


@examples.capture(clear_output=True)
def get_examples(score):
    display( widgets.HTML("<h2>Examples</h2>") )
    
    for i in [ 0, 1, 2]:
        style = score[i][0]
        allArtists = {}
        
        for k,v in exampleArtists:
            if k == style:
                allArtists = v
        artist = random.choice(allArtists)
        
        html = widgets.HTML(
            value="""
            <a href={} target='_blank'>
            <b>{}</b><br>
            <i>"{}"</i> by {} @ WikiArt.org</a>""".format(
                artist["link"],
                style,
                artist["title"],
                artist["artistName"]
            ))
        display(html)
        
        response = requests.get(artist["imageUrl"])
        image = Image.open(io.BytesIO(response.content))
        plt.figure(figsize = (10,10))
        plt.imshow(image,interpolation='nearest')
        plt.axis('off')
        plt.show()
    

def on_file_uploaded(change):
    display_image()
    uploader._counter=0

# upload image file
uploader.observe(on_file_uploaded, 'value')

# Modern Art Style Classifier

The influences of modern art are all around us. Use this machine learning model to discover the influences of modern art in your image.

_Images discarded after each use. Source available at [Github](https://github.com/todgru/art-classifier-ml-deploy). Deployed with [MyBinder.org](https://mybinder.org)_


In [29]:
box_layout = Layout(display='flex', flex_flow='wrap', width='100%')

VBox([uploader, Box(children=[predictions, examples],layout=box_layout) ])

VBox(children=(FileUpload(value={}, accept='image/jpeg', description='Upload'), Box(children=(Output(), Output…