In [65]:
%matplotlib inline
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
import os
import json
import datetime as dt
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = [20, 16]
plt.rcParams['font.size'] = 10
import seaborn as sns
import cv2
import pandas as pd
import numpy as np

from PIL import Image, ImageDraw 
import tensorflow as tf

from tensorflow import keras
from tensorflow.keras.layers import Conv2D, MaxPooling2D
from tensorflow.keras.layers import Dense, Dropout, Flatten, Activation
from tensorflow.keras.metrics import categorical_accuracy, top_k_categorical_accuracy, categorical_crossentropy
from tensorflow.keras.models import Sequential
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.applications import MobileNet
from tensorflow.keras.applications.mobilenet import preprocess_input

In [2]:
DP_DIR = '../input/googlequickdrawzips'

CATEGORIES = sorted(['airplane', 'alarm clock', 'ambulance', 'angel', 'animal migration', 'ant', 'anvil', 'apple', 'arm', 'asparagus', 'axe', 'backpack', 'banana', 'bandage', 'barn', 'baseball', 'baseball bat', 'basket', 'basketball', 'bat', 'bathtub', 'beach', 'bear', 'beard', 'bed', 'bee', 'belt', 'bench', 'bicycle', 'binoculars', 'bird', 'birthday cake', 'blackberry', 'blueberry', 'book', 'boomerang', 'bottlecap', 'bowtie', 'bracelet', 'brain', 'bread', 'bridge', 'broccoli', 'broom', 'bucket', 'bulldozer', 'bus', 'bush', 'butterfly', 'cactus', 'cake', 'calculator', 'calendar', 'camel', 'camera', 'camouflage', 'campfire', 'candle', 'cannon', 'canoe', 'car', 'carrot', 'castle', 'cat', 'ceiling fan', 'cell phone', 'cello', 'chair', 'chandelier', 'church', 'circle', 'clarinet', 'clock', 'cloud', 'coffee cup', 'compass', 'computer', 'cookie', 'cooler', 'couch', 'cow', 'crab', 'crayon', 'crocodile', 'crown', 'cruise ship', 'cup', 'diamond', 'dishwasher', 'diving board', 'dog', 'dolphin', 'donut', 'door', 'dragon', 'dresser', 'drill', 'drums', 'duck', 'dumbbell', 'ear', 'elbow', 'elephant', 'envelope', 'eraser', 'eye', 'eyeglasses', 'face', 'fan', 'feather', 'fence', 'finger', 'fire hydrant', 'fireplace', 'firetruck', 'fish', 'flamingo', 'flashlight', 'flip flops', 'floor lamp', 'flower', 'flying saucer', 'foot', 'fork', 'frog', 'frying pan', 'garden', 'garden hose', 'giraffe', 'goatee', 'golf club', 'grapes', 'grass', 'guitar', 'hamburger', 'hammer', 'hand', 'harp', 'hat', 'headphones', 'hedgehog', 'helicopter', 'helmet', 'hexagon', 'hockey puck', 'hockey stick', 'horse', 'hospital', 'hot air balloon', 'hot dog', 'hot tub', 'hourglass', 'house', 'house plant', 'hurricane', 'ice cream', 'jacket', 'jail', 'kangaroo', 'key', 'keyboard', 'knee', 'ladder', 'lantern', 'laptop', 'leaf', 'leg', 'light bulb', 'lighthouse', 'lightning', 'line', 'lion', 'lipstick', 'lobster', 'lollipop', 'mailbox', 'map', 'marker', 'matches', 'megaphone', 'mermaid', 'microphone', 'microwave', 'monkey', 'moon', 'mosquito', 'motorbike', 'mountain', 'mouse', 'moustache', 'mouth', 'mug', 'mushroom', 'nail', 'necklace', 'nose', 'ocean', 'octagon', 'octopus', 'onion', 'oven', 'owl', 'paint can', 'paintbrush', 'palm tree', 'panda', 'pants', 'paper clip', 'parachute', 'parrot', 'passport', 'peanut', 'pear', 'peas', 'pencil', 'penguin', 'piano', 'pickup truck', 'picture frame', 'pig', 'pillow', 'pineapple', 'pizza', 'pliers', 'police car', 'pond', 'pool', 'popsicle', 'postcard', 'potato', 'power outlet', 'purse', 'rabbit', 'raccoon', 'radio', 'rain', 'rainbow', 'rake', 'remote control', 'rhinoceros', 'river', 'roller coaster', 'rollerskates', 'sailboat', 'sandwich', 'saw', 'saxophone', 'school bus', 'scissors', 'scorpion', 'screwdriver', 'sea turtle', 'see saw', 'shark', 'sheep', 'shoe', 'shorts', 'shovel', 'sink', 'skateboard', 'skull', 'skyscraper', 'sleeping bag', 'smiley face', 'snail', 'snake', 'snorkel', 'snowflake', 'snowman', 'soccer ball', 'sock', 'speedboat', 'spider', 'spoon', 'spreadsheet', 'square', 'squiggle', 'squirrel', 'stairs', 'star', 'steak', 'stereo', 'stethoscope', 'stitches', 'stop sign', 'stove', 'strawberry', 'streetlight', 'string bean', 'submarine', 'suitcase', 'sun', 'swan', 'sweater', 'swing set', 'sword', 't-shirt', 'table', 'teapot', 'teddy-bear', 'telephone', 'television', 'tennis racquet', 'tent', 'The Eiffel Tower', 'The Great Wall of China', 'The Mona Lisa', 'tiger', 'toaster', 'toe', 'toilet', 'tooth', 'toothbrush', 'toothpaste', 'tornado', 'tractor', 'traffic light', 'train', 'tree', 'triangle', 'trombone', 'truck', 'trumpet', 'umbrella', 'underwear', 'van', 'vase', 'violin', 'washing machine', 'watermelon', 'waterslide', 'whale', 'wheel', 'windmill', 'wine bottle', 'wine glass', 'wristwatch', 'yoga', 'zebra', 'zigzag'])
BASE_SIZE = 128
NUM_CSVS = 100
NUM_CLASSES = 340
# np.random.seed(seed=1123)

WORD_INDICES = {}
for i, cat in enumerate(CATEGORIES):
    WORD_INDICES[cat.lower()] = i

SIZE = 128 # give >70%

In [3]:
def get_numerical_label(word):
    return WORD_INDICES[word.lower()] # should Raise if doesn't exist

def top_3_accuracy(y_true, y_pred):
    return top_k_categorical_accuracy(y_true, y_pred, k=3)

In [4]:
model = MobileNet(
    input_shape=(SIZE, SIZE, 1), 
    alpha=0.7, # 0.8 gives > 70%
    weights=None, 
    classes=NUM_CLASSES,
    pooling='max',
)
model.compile(optimizer=Adam(learning_rate=0.001), loss='categorical_crossentropy',
              metrics=[categorical_crossentropy, categorical_accuracy])
#               metrics=[categorical_crossentropy, categorical_accuracy, top_3_accuracy])
print(model.summary())

In [5]:
def image_generator_xd(size, batchsize, ks):
    while True:
        for k in np.random.permutation(ks):
            filename = os.path.join(DP_DIR, 'train_k{}.csv'.format(k))
            for df in pd.read_csv(filename, chunksize=batchsize):
                images = []
                df['drawing'] = df['drawing'].apply(json.loads)
                for i, drawing in enumerate(df.drawing.values):
                    image = Image.new("P", (256,256), color=255)
                    image_draw = ImageDraw.Draw(image)
                    for stroke in drawing:
                        for i in range(len(stroke[0])-1):
                            image_draw.line( [stroke[0][i], stroke[1][i], 
                                              stroke[0][i+1], stroke[1][i+1]],
                                            fill=0, width=5 )
                    image = image.resize((SIZE, SIZE))
                    image = np.array(image)/255.0
                    images.append(image.reshape((SIZE, SIZE, 1)))
                images = np.array(images)
        
                x = images
                y_num = list(map(get_numerical_label, df.word.values))
                y = keras.utils.to_categorical(y_num, NUM_CLASSES)
                yield (x, y)

In [None]:
# df = pd.read_csv(os.path.join(DP_DIR, 'train_k{}.csv'.format(0)))
# list(map(get_numerical_label, df.word.values))

### Test the Generator

### Training

In [None]:
STEPS = 100
EPOCHS = 600
BATCH_SIZE = 256

In [None]:
train_datagen = image_generator_xd(size=SIZE, batchsize=BATCH_SIZE, ks=range(NUM_CSVS - 1))
val_datagen = image_generator_xd(size=SIZE, batchsize=BATCH_SIZE, ks=range(NUM_CSVS - 1))
# x_val, y_val = next(test_datagen)
# print(x_val.shape)

In [None]:
callbacks = [
    ReduceLROnPlateau(monitor='categorical_accuracy', factor=0.75, patience=3, min_delta=0.001,
                          mode='max', min_lr=1e-5, verbose=1),
    ModelCheckpoint('model_2.h5', monitor='categorical_accuracy', mode='max', save_best_only=True,
                    save_weights_only=True),
]

# callbacks = [
#     ReduceLROnPlateau(monitor='val_top_3_accuracy', factor=0.75, patience=3, min_delta=0.001,
#                           mode='max', min_lr=1e-5, verbose=1),
#     ModelCheckpoint('model.h5', monitor='val_top_3_accuracy', mode='max', save_best_only=True,
#                     save_weights_only=True),
# ]
hists = []
hist = model.fit(
    train_datagen, steps_per_epoch=STEPS, 
    epochs=EPOCHS, verbose=1,
    validation_data=val_datagen, validation_steps=STEPS,
    callbacks = callbacks,
)
hists.append(hist)

In [None]:
# print(hist.df)

In [None]:
hist_df = pd.concat([pd.DataFrame(hist.history) for hist in hists], sort=True)
hist_df.index = np.arange(1, len(hist_df)+1)
fig, axs = plt.subplots(nrows=2, sharex=True, figsize=(16, 10))
axs[0].plot(hist_df.val_categorical_accuracy, lw=5, label='Validation Accuracy')
axs[0].plot(hist_df.categorical_accuracy, lw=5, label='Training Accuracy')
axs[0].set_ylabel('Accuracy')
axs[0].set_xlabel('Epoch')
axs[0].grid()
axs[0].legend(loc=0)
axs[1].plot(hist_df.val_categorical_crossentropy, lw=5, label='Validation Loss')
axs[1].plot(hist_df.categorical_crossentropy, lw=5, label='Training Loss')
axs[1].set_ylabel('Loss')
axs[1].set_xlabel('Epoch')
axs[1].grid()
axs[1].legend(loc=0)
fig.savefig('hist.png', dpi=300)
plt.show();

## Loading

In [6]:
model.load_weights('../input/google-quick-draw-challenge-compressed-ds/model_2.h5')

### Testing

In [7]:
test_datagen = image_generator_xd(size=64, batchsize=256, ks=range(NUM_CSVS - 1))

In [8]:
(x_test, Y_test) = next(test_datagen)
Y_pred = model.predict(x_test)

In [9]:
# print(np.argmax(Y_test, axis=1))
# print(np.argmax(Y_pred, axis=1))

In [10]:
Y_test_arg = np.argmax(Y_test, axis=1)
Y_pred_arg = np.argmax(Y_pred, axis=1)
correct = sum([1 if a == b else 0 for a, b in zip(Y_test_arg, Y_pred_arg)])
total=len(Y_test_arg)

In [11]:
print('Correct:', correct)
print('Total:', total)
print('-- Accuracy:', correct/total*100)

In [None]:
for idx in np.random.permutation(50):   
    plt.title('Ans={}, Pred={}'.format(CATEGORIES[Y_test_arg[idx]], CATEGORIES[Y_pred_arg[idx]]));
    plt.imshow(x_test[idx], cmap='gray');
    plt.show();

## My Data

In [46]:
def load_data_from_dir(dir_path, show_img=True):
    images = []
    labels = []
    for img_name in os.listdir(dir_path):
        full_img_path = os.path.join(dir_path, img_name)
        image = cv2.imread(full_img_path, cv2.IMREAD_GRAYSCALE)
        image = cv2.resize(image, (SIZE, SIZE))
        image = np.array(image)/255.0
        
        if show_img:
            plt.imshow(image, cmap='gray')
            plt.show()
        
        images.append(image.reshape((SIZE, SIZE, 1)))
        labels.append(img_name[:img_name.index('_')])
    images = np.array(images)
    labels = np.array(labels)

    x = images
    y_num = list(map(get_numerical_label, labels))
    y = keras.utils.to_categorical(y_num, NUM_CLASSES)
    return (x, y)

In [None]:
x_my, y_my = load_data_from_dir('../input/googlequickdrawhandrawn', False)

In [48]:
print(x.shape, y.shape)

In [49]:
y_my_pred = model.predict(x_my)

In [68]:
def calc(x, y, y_pred):
    y_arg = np.argmax(y, axis=1)
    y_pred_arg = np.argmax(y_pred, axis=1)
    correct = sum([1 if a == b else 0 for a, b in zip(y_arg, y_pred_arg)])
    total=len(y_pred_arg)

    print('Correct:', correct)
    print('Total:', total)
    print('-- Accuracy:', correct/total*100)
    
    plt_cols=4
    plt_rows=(total+plt_cols-1)//plt_cols

    for idx in range(x.shape[0]):   
        plt.subplot(plt_rows, plt_cols, idx+1)
        plt.title('Ans={}, Pred={}'.format(CATEGORIES[y_arg[idx]], CATEGORIES[y_pred_arg[idx]]));
        plt.axis('off')
        plt.imshow(x[idx], cmap='gray');
        
    plt.show();

In [69]:
calc(x_my, y_my, y_my_pred)