# Import libraries

In [1]:
import wandb
from wandb.keras import WandbCallback
import numpy as np
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import pandas as pd
import tensorflow as tf
from sklearn.model_selection import train_test_split

# Load dataset

In [2]:
df = pd.read_csv('/kaggle/input/celeba-dataset/list_attr_celeba.csv')
df['datadir'] = '/kaggle/input/celeba-dataset/img_align_celeba/img_align_celeba/' + df['image_id'].astype(str)
df['gender'] = df['Male']
df['gender'] = df['gender'].replace(-1,0)
df = df[['datadir','gender']]
train_df, test_df = train_test_split(df, test_size=0.1)

# Model

In [5]:
defaults = {
    'epochs': 20,
    'batch_size': 128,
    'fc1_num_neurons': 512,
    'fc2_num_neurons': 512,
    'fc3_num_neurons': 512,
    'seed': 7,
    'learning_rate': 3e-4,
    'optimizer': 'adam',
    'hidden_activation': 'relu',
    'output_activation': 'sigmoid',
    'loss_function': 'binary_crossentropy',
    'metrics': ['accuracy'],
}

wandb.init(config=defaults, resume=True, name='Pre Train Model resNet50V2', project='CelebA resNet50V2 Runs', notes='resNet50V2 pretraining for tuned model, 0.1 test split')
config = wandb.config

# Load images into keras image generator 
datagen_train = ImageDataGenerator(
    preprocessing_function=tf.keras.applications.resnet_v2.preprocess_input,
)
datagen_test = ImageDataGenerator(
    preprocessing_function=tf.keras.applications.resnet_v2.preprocess_input,
)

train_generator = datagen_train.flow_from_dataframe(
    dataframe=train_df,
    x_col='datadir',
    y_col='gender',
    batch_size=config.batch_size,
    seed=config.seed,
    shuffle=True,
    class_mode='raw',
    target_size=(224,224),
)

test_generator = datagen_test.flow_from_dataframe(
    dataframe=test_df,
    x_col='datadir',
    y_col='gender',
    batch_size=config.batch_size,
    seed=config.seed,
    shuffle=True,
    class_mode='raw',
    target_size=(224,224),
)

# Define model
resNet50V2 = tf.keras.applications.resnet_v2.ResNet50V2(
    include_top=False,
    pooling='avg',
    weights='imagenet',
    input_shape=(224,224,3),
)
resNet50V2.trainable = True

fc1 = tf.keras.layers.Dense(config.fc1_num_neurons,activation=config.hidden_activation)
fc2 = tf.keras.layers.Dense(config.fc2_num_neurons,activation=config.hidden_activation)
fc3 = tf.keras.layers.Dense(config.fc3_num_neurons,activation=config.hidden_activation)
bn1 = tf.keras.layers.BatchNormalization()
bn2 = tf.keras.layers.BatchNormalization()
bn3 = tf.keras.layers.BatchNormalization()
bn4 = tf.keras.layers.BatchNormalization()

model = tf.keras.models.Sequential([
    resNet50V2,
    tf.keras.layers.Flatten(),
    bn1,
    fc1,
    bn2,
    fc2,
    bn3,
    fc3,
    bn4,
    tf.keras.layers.Dense(1, activation=config.output_activation),
])
model.summary()

# Compile model 
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=config.learning_rate),
    loss=config.loss_function,
    metrics=config.metrics,
)

model.fit(
    train_generator,
    validation_data=test_generator,
    shuffle=True,
    epochs=config.epochs,
    callbacks=[WandbCallback()],
)
model.save_weights('model_celeba_tune_resNet50V2.h5') 
# run.finish()

[34m[1mwandb[0m: Currently logged in as: [33mcz4042_assignment_2_hdk[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.6 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


Found 182339 validated image filenames.
Found 20260 validated image filenames.
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50v2_weights_tf_dim_ordering_tf_kernels_notop.h5
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
resnet50v2 (Functional)      (None, 2048)              23564800  
_________________________________________________________________
flatten (Flatten)            (None, 2048)              0         
_________________________________________________________________
batch_normalization (BatchNo (None, 2048)              8192      
_________________________________________________________________
dense (Dense)                (None, 512)               1049088   
_________________________________________________________________
batch_normalization_1 (Batch (None, 512)               2048      
______________________________