# InitialSetup

In [None]:
import os
import shutil

if not os.path.exists('gen_data_set/'):
    raise Exception("Generated dataset not found, pls run 'SetupDataset' first.")
    
if os.path.exists('OUTPUT/'):
    shutil.rmtree('OUTPUT/')
    
os.makedirs('OUTPUT/')

In [None]:
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense
from tensorflow.keras.optimizers import Adam

# Loading dataset

In [None]:
from keras.preprocessing.image import ImageDataGenerator

train_gen = ImageDataGenerator(rescale=1./255, brightness_range=[0.2,1.0], shear_range=0.2,
                                    fill_mode = 'nearest', width_shift_range=0.2, rotation_range=40,
                                   height_shift_range=0.2, horizontal_flip=True, zoom_range=0.2)


test_gen = ImageDataGenerator(rescale=1./255)

training_data = train_gen.flow_from_directory('gen_data_set/train', target_size=(256,256), 
                                                 batch_size=32, class_mode='binary')

test_data = test_gen.flow_from_directory('gen_data_set/test', target_size=(256, 256), 
                                            batch_size=32, class_mode='binary')

# Training

In [None]:
IMAGE_SHAPE = (256,256,3)
validation_steps = len(test_data)
steps_per_epoch = len(training_data)

In [None]:
def tune_model(pretrain_model):
    classifier = Sequential([
        pretrain_model,
    
        GlobalAveragePooling2D(),
    
        Dense(256, activation='relu'),
        Dense(128, activation='relu'),
        Dense(1, activation='sigmoid')
    ])
    
    classifier.compile(optimizer='adam', loss='binary_crossentropy', metrics = ['accuracy'])
    
    classifier.fit(training_data, epochs=45, steps_per_epoch=steps_per_epoch, workers=1,
                    validation_data=test_data, validation_steps=validation_steps)
    
    return classifier

## ResNet

In [None]:
from tensorflow.keras.applications import ResNet152V2

res_net = ResNet152V2(input_shape=IMAGE_SHAPE, include_top=False, weights='imagenet')

In [None]:
resnet_classifier = tune_model(res_net)

In [None]:
resnet_classifier.save("OUTPUT/ResnetModel.h5")

In [None]:
_, accuracy = resnet_classifier.evaluate(test_data, steps = validation_steps)

In [None]:
print('ResnetModel accuracy :',accuracy)