In [None]:
import tensorflow as tf
from tensorflow import keras
import pandas as pd
import numpy as np
import gc
import matplotlib.pyplot as plt

In [None]:
df = pd.read_csv("data/metadata.csv")

In [None]:
df

# Drop a couple of bad records

In [None]:
df = df[~df["card_file_name"].isin([
    "data/pictures/Mysterious Treasures/Honchkrow_(Mysterious_Treasures_10)",
    "data/pictures/Unified Minds/Umbreon_%26_Darkrai-GX_(Unified_Minds_125)", # missing on website
    "data/pictures/Base Set/Charizard_(Base_Set_4)", # special version of normal website
    "data/pictures/Stormfront/Charizard_(Stormfront_103)", # special version of normal website
])]

filter to Pokemon types

In [None]:
# df.type.value_counts()
df = df[df["type"].isin([
    "Water",
    "Grass",
    "Colorless",
    "Psychic",
    "Fighting",
    "Fire",
    "Lightning",
    "Darkness",
    "Metal",
    "Dragon",
    "Fairy",
])]

# Create dataset

In [None]:
# Function to load and preprocess each image
def parse_image_file(filename, label = None):
    img = tf.io.read_file(filename)
    img = tf.cast(tf.image.decode_jpeg(img, channels=3), tf.float32) / 255.0
    
#     if img.shape[0] < 50:
#         print(f"Error with image `{filename}`: shape found was small: {img.shape}")
    img = tf.image.resize_with_crop_or_pad(img, 260, 180)
    
    # Some images have an extra alpha channel. Remove that.
    img = img[:, :, :3]
    
    if label is None: return img
    else: return img, label

In [None]:
# if 'x' in locals():
#     del x
# x = tf.stack([
#     parse_image_file(f) for f in df["card_file_name"].values
# ], axis=0)



In [None]:
classnames, indices = np.unique(df["type"].values, return_inverse=True)

y = keras.utils.to_categorical(
        indices
)

In [None]:
n_classes = y.shape[-1]
n_records = df.shape[0]

In [None]:
# all_data = tf.data.Dataset.from_tensor_slices(
#   (x, y)
# ).shuffle(10000)

# del x, y

In [None]:
x = tf.constant(df["card_file_name"].values)

all_data = tf.data.Dataset.from_tensor_slices(
    (x, y)
).map(parse_image_file).shuffle(10000)

# Split dataset

In [None]:
train_frac = .7
batch_size = 50

train = all_data.take(int(train_frac * n_records)).batch(batch_size)
remaining_data = all_data.skip(int(train_frac * n_records))

validate = remaining_data.take(int(.5 * (1 - train_frac) * n_records)).batch(batch_size)
test = remaining_data.skip(int(.5 * (1 - train_frac) * n_records)).batch(batch_size)

In [None]:
for x, y in test:
    print(y)
    break

# Model

In [None]:
model = keras.Sequential([
    keras.layers.Conv2D(filters=10, kernel_size=5, strides=2),
    keras.layers.MaxPool2D(),
    keras.layers.Conv2D(filters=10, kernel_size=5, strides=2),
    keras.layers.MaxPool2D(),
    keras.layers.Flatten(),
    keras.layers.Dense(units=30, activation="relu"),
    keras.layers.Dense(units=30, activation="relu"),
    keras.layers.Dense(units=n_classes, activation="softmax"),
])

In [None]:
model.compile(
    optimizer=keras.optimizers.Nadam(learning_rate=.001),
    loss=keras.losses.CategoricalCrossentropy(),
    metrics=[keras.metrics.categorical_accuracy]
)

In [None]:
try:
    model.fit(train, validation_data=validate, epochs=5, shuffle=True)
finally:
    gc.collect()

In [None]:
gc.collect()

# Test a couple of records

In [None]:
def label_converter(label):
    return tf.math.argmax(label, axis=-1)

In [None]:
def plot_and_predict(img):
    assert img.shape[0] == 1, "Please only provide a single image at a time"
    
    prediction = model.predict(img)
    class_index = label_converter(prediction)
    class_name = classnames[class_index[0]]
    
    plt.imshow(img[0])
    plt.title(class_name)
    

In [None]:
for vis_x, _ in train.take(1):
    for i in range(vis_x.shape[0]):
        plot_and_predict(vis_x[i:i+1])
        plt.show()