<a href="https://colab.research.google.com/github/dandrnic/git_test/blob/main/pkmn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import os
from PIL import Image
from PIL import ImageFont
from PIL import ImageDraw
import pathlib
import sys

import tensorflow as tf
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' #silence some of the tf warnings

from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential


In [None]:

#Download and explore the dataset

def get_local_data(file_path, val_split=0.2, batch_size = 32, height = 180, width = 180):
	#returns training and validation datasets from a local dataset

	data_dir = pathlib.Path(file_path)
	print('Number of JPEG in dataset:',len(list(data_dir.glob('*/*.jpg')))) #print the total amount of images
	train_ds = tf.keras.preprocessing.image_dataset_from_directory(
		data_dir,
		validation_split = val_split, #100-val_split% used for training, val_split% for validation
		subset="training",
		seed=710,
		image_size=(height, width),
		batch_size=batch_size)

	val_ds = tf.keras.preprocessing.image_dataset_from_directory(
		data_dir,
		validation_split = val_split,
		subset="validation",
		seed=710,
		image_size=(height, width),
		batch_size=batch_size)

	class_names = train_ds.class_names
	print('Pokemon in set:', class_names[:3], ' ... ', class_names[-3:])

	return train_ds, val_ds

In [None]:
def configure_ds(train_ds, val_ds):
	AUTOTUNE = tf.data.experimental.AUTOTUNE
	train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
	val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
	return train_ds, val_ds


In [None]:

def create_model(num_classes = 5, height = 180, width = 180):

	#add data augmentation for more accurate results
	data_augmentation = keras.Sequential(
		[layers.experimental.preprocessing.RandomFlip("horizontal",
		    input_shape=(height,width,3)),
			layers.experimental.preprocessing.RandomRotation(0.1),
			layers.experimental.preprocessing.RandomZoom(0.1)]
		)

	model = Sequential([

		data_augmentation,
		#normalize model with rescaling
		layers.experimental.preprocessing.Rescaling(1./255, input_shape=(height, width, 3)),

		layers.Conv2D(16, 3, padding='same', activation='relu'),
		layers.MaxPooling2D(),
		layers.Conv2D(32, 3, padding='same', activation='relu'),
		layers.MaxPooling2D(),
		#layers.Conv2D(64, 3, padding='same', activation='relu'), #uncomment for entire data set
		#layers.MaxPooling2D(), #uncomment for entire data set
		layers.Dropout(0.15),
		layers.Flatten(),
		layers.Dense(128, activation='relu'), #uncomment for entire data set
		layers.Dense(num_classes)
	])

	#compile the model
	model.compile(optimizer='adam',
		loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
		metrics=['accuracy'])

	return model

In [None]:
def train_model(model, train_ds, val_ds, epochs= 12):
	history = model.fit(
		train_ds,
		validation_data=val_ds,
		epochs=epochs
		)
	return history, epochs


In [None]:
def make_prediction(model, picture_path, height = 180, width = 180):

	img = keras.preprocessing.image.load_img(
    	picture_path, target_size=(height, width)
	)
	img_array = keras.preprocessing.image.img_to_array(img)
	img_array = tf.expand_dims(img_array, 0) # Create a batch

	predictions = model.predict(img_array)
	score = tf.nn.softmax(predictions[0])

	result_string = "Likely {} with {:.2f}% confidence.".format(class_names[np.argmax(score)], 100 * np.max(score))
	#written by Shaun Miller
	img_to_print = Image.open(picture_path)
	draw = ImageDraw.Draw(img_to_print)
	font = ImageFont.truetype("arial.ttf", 100)
	draw.text((0, 0),result_string,(255,255,255),font=font)
	img_to_print.show()


In [None]:
ls

arial.ttf                     joshua-dunlop-mewtwo.jpg    preview_mewtwo.png
[0m[01;34mdataset[0m/                      joshua-dunlop-pikachu.jpg   preview_pikachu.png
[01;34mdataset_popular[0m/              joshua-dunlop-squirtle.jpg  preview.png
joshua-dunlop-bulbasaur.jpg   main.py                     README.md
joshua-dunlop-charmander.jpg  neildluffy_Pikachu.jpg      requirements.txt


In [None]:
if __name__ == "__main__":

	#Retreive dataset
	train_ds, val_ds = get_local_data('dataset_popular')
	class_names = train_ds.class_names

	#Train the Model
	train_ds, val_ds = configure_ds(train_ds, val_ds)
	model = create_model()
	print(model.summary())
	history, epochs = train_model(model, train_ds, val_ds)

	#Test the Model
	for picture_path in sys.argv[1:]: make_prediction(model, picture_path = picture_path)

Number of JPEG in dataset: 27
Found 65 files belonging to 5 classes.
Using 52 files for training.
Found 65 files belonging to 5 classes.
Using 13 files for validation.
Pokemon in set: ['Bulbasaur', 'Charmander', 'Mewtwo']  ...  ['Mewtwo', 'Pikachu', 'Squirtle']
Model: "sequential_3"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 sequential_2 (Sequential)   (None, 180, 180, 3)       0         
                                                                 
 rescaling_1 (Rescaling)     (None, 180, 180, 3)       0         
                                                                 
 conv2d_2 (Conv2D)           (None, 180, 180, 16)      448       
                                                                 
 max_pooling2d_2 (MaxPooling  (None, 90, 90, 16)       0         
 2D)                                                             
                                                        

FileNotFoundError: ignored

In [None]:
import os

cwd = os.getcwd()  # Get the current working directory (cwd)
files = os.listdir(cwd)  # Get all the files in that directory
print("Files in %r: %s" % (cwd, files))

Files in '/content/drive/MyDrive/WhosThatPokemon-main': ['neildluffy_Pikachu.jpg', 'preview.png', 'arial.ttf', 'requirements.txt', 'joshua-dunlop-squirtle.jpg', 'joshua-dunlop-charmander.jpg', 'README.md', 'joshua-dunlop-pikachu.jpg', 'main.py', 'joshua-dunlop-bulbasaur.jpg', 'joshua-dunlop-mewtwo.jpg', '.DS_Store', '.ipynb_checkpoints', 'dataset_popular', 'dataset', 'preview_pikachu.png', 'preview_mewtwo.png']
