In [None]:
# CELL 1 (do not change)
"""import statements and boiler plate code"""
import matplotlib.pyplot as plt
import numpy as np
import os
import pickle
import tensorflow as tf
from tensorflow import Tensor
from tensorflow.keras.layers import Input, ReLU, LeakyReLU, Dense, Conv2D, Flatten, MaxPool2D, Dropout, BatchNormalization, AveragePooling2D, Add
from tensorflow.keras.models import Model
from keras import layers
from PIL import Image
from keras.callbacks import ModelCheckpoint





# CELL 2 preprocessing
'''
to the pipeline:
    - add a random flip
    - random hue, max alpha = 0.3
    - random brightness, max alpha = 0.3
    - random_contrast, min = 0.9, max = 1.1
'''
##Setting an environmental variable so that CUDA doesn't run out of memory
#os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:100'

"""script to load data set and print info"""
with open('pokemon.pkl', 'rb') as f:
  df = pickle.load(f)                 # dataset dataframe
target = df.pop('Type1')

target = list(target)

#Counting how many images there are in the dataframe
number_images = len(df.axes[0])
#Initializing an empty numpy array for the images
img_array = np.empty(shape = (number_images, 120, 120, 3))
#Initializing a count variable
i = 0

for element in df['ImagePath']:
  element = element.replace("/content/workshop-f22/pokemon-dataset/", '')

  #Convering the image into an RGB image first through PIL Image
  image = Image.open(element)
  image = image.convert('RGB')
  image = np.array(image)
  image = image.reshape((1, 120, 120, 3))

  # print("Shape of the loaded image: " +str(image.shape))

  #Attributing values in the empty array into the values from image
  img_array[i] = image
  i += 1

print(img_array.shape)


#Implementing One Hot Encoding
categories = np.array(list(set(target)))
n_categories = len(categories)
ohe_labels = np.zeros((len(target), n_categories))
for ii in range(len(target)):
  jj = np.where(categories == target[ii])
  ohe_labels[ii][jj] = 1


#Defining the Relu & Batch Normalization Pipeline
def relu_bn(inputs: Tensor) -> Tensor:
    relu = ReLU()(inputs)
    bn = BatchNormalization()(relu)
    return bn

#Defining the Residual block (can be changed later on)
def residual_block(x: Tensor, downsample: bool, filters: int, kernel_size: int = 3) -> Tensor:
    y = Conv2D(kernel_size = kernel_size,
               strides = (1 if not downsample else 2),
               filters = filters,
               padding = "same")(x)
    y = relu_bn(y)
    y = Conv2D(kernel_size = kernel_size,
               strides = 1,
               filters = filters,
               padding = "same")(y)
    if downsample:
        x = Conv2D(kernel_size = 1, strides = 2, filters = filters, padding = "same")(x)
    out = Add()([x, y])
    out = relu_bn(out)
    return out

def create_res_net():

    inputs = Input(shape = (img_array.shape[1], img_array.shape[2], 3))
    #Subject to change
    num_filters = 64
    t = layers.RandomFlip("horizontal_and_vertical")(inputs)
    t = layers.RandomContrast(0.9, 1.1)(t)
    t = tf.keras.layers.RandomBrightness([-1, 0.3])(t)
    t = Conv2D(kernel_size = 3,
               strides = 1,
               filters = num_filters,
               padding = "same")(t)
    t = relu_bn(t)

    #Subject to change
    num_blocks_list  = [2, 5, 5, 2]
    for i in range(len(num_blocks_list)):
        num_blocks = num_blocks_list[i]
        for j in range(num_blocks):
            t = residual_block(t, downsample = (j==0 and i!=0), filters = num_filters)
        num_filters *= 2

    #Subject to change
    t = AveragePooling2D(2)(t)
    t = Flatten()(t)
    outputs = Dense(len(categories), activation = 'softmax')(t)

    model = Model(inputs, outputs)

    model.compile(
        optimizer = 'adam',
        loss = 'categorical_crossentropy',
        metrics = ['accuracy']
        )
    return model



#Creating the model
model = create_res_net()
model.summary()      



In [None]:
#Fitting the model
model.fit(
    x = img_array,
    y = ohe_labels,
    epochs = 38,
    verbose = 1,
    validation_split = 0.2,
    batch_size = 128,
    )