# Image Augmentation/Generation

In [None]:
import os
import tensorflow as tf
import numpy as np
from tqdm import tqdm
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten, Dropout
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator

Load directories

In [None]:
base_dir = "data"
train_dir = os.path.join(base_dir, "train")
validation_dir = os.path.join(base_dir, "validation")
test_dir = os.path.join(base_dir, "test")

### Helper functions

We will extract features from the images using a pre-trained ResNet152 model. This model is used to convert images into a feature vector that can be used for training a classifier.

In [None]:
base_model = tf.keras.applications.ResNet152(weights="imagenet", include_top=False, input_shape=(150, 150, 3))

Freeze the layers of the base model

In [None]:
for layer in base_model.layers:
    layer.trainable = False

Now we create a model on top of the base_model to improve the results

In [None]:
model = Sequential([
    base_model,
    Flatten(),
    Dense(256, activation='relu'),
    Dropout(0.5),
    Dense(8, activation='softmax')  # Assuming 5 classes
])

# Compile 
model.compile(optimizer=Adam(), loss='categorical_crossentropy', metrics=['accuracy'])


Instantiate the generators

In [None]:
train_datagen = ImageDataGenerator(rescale=1./255)
test_datagen = ImageDataGenerator(rescale=1./255)

train_generator = train_datagen.flow_from_directory(
    'data/train',  # Training data folder
    target_size=(150, 150),
    batch_size=32,
    class_mode='categorical'
)

validation_generator = test_datagen.flow_from_directory(
    'data/test',  # Validation data folder
    target_size=(150, 150),
    batch_size=32,
    class_mode='categorical'
)

When the generators are ready, we can now train the augmented network

In [None]:

# Fit the model and save the history
history = model.fit(
    train_generator,
    steps_per_epoch=train_generator.samples // train_generator.batch_size,
    validation_data=validation_generator,
    validation_steps=validation_generator.samples // validation_generator.batch_size,
    epochs=10
)

model.save('enhanced_pretrained.h5')

model.summary()

Visualise results

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()


plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()

plt.show()
