In [None]:
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras import Model

import numpy as np
import random

import glob

import PIL
from PIL import Image

import time

import os

In [None]:
# Ensure same GPU assigned by colab, for fair comparison

!nvidia-smi -L

# 1. Gather and prepare the dataset

## 1.1 Download and Extract the data

In [None]:
!gdown https://drive.google.com/u/0/uc?id=1STYsoP85lyKAtarMRuDyTjp89tAbIDM-

In [None]:
!unzip -o caltech256_subset_resized_cropped256x256.zip

In [None]:
folder_paths = sorted(glob.glob("caltech256_subset_resized_cropped256x256/data/*"))[:-1]

## 1.2 Split into training, validation and testing data

In [None]:
shuffled_paths = [] 
shuffled_labels = []

with open("caltech256_subset_resized_cropped256x256/shuffled_labels.txt") as label_file:
    label_file_lines = label_file.readlines()
    
for line in label_file_lines:
    image_path, image_label = line.strip().split(",")
    shuffled_paths.append(image_path)
    shuffled_labels.append(int(image_label))

In [None]:
train_split = 0.6
validation_split = 0.2

num_train_images = int(len(shuffled_paths) * train_split)
num_validation_images = int(len(shuffled_paths) * validation_split)

train_image_names = shuffled_paths[:num_train_images]
train_image_labels = np.array(shuffled_labels[:num_train_images])

validation_image_names = shuffled_paths[num_train_images:num_train_images + num_validation_images]
validation_image_labels = np.array(shuffled_labels[num_train_images:num_train_images + num_validation_images])

test_image_names = shuffled_paths[num_train_images + num_validation_images:]
test_image_labels = np.array(shuffled_labels[num_train_images + num_validation_images:])

## 1.3. Prepare data for the model

In [None]:
image_size = 256
crop_size = 224

In [None]:
# Load in the entire validation set: Assuming the entire validation set can fit into memory

val_images = []

for row in validation_image_names:
    filename = row.split(",")[0]
    
    img = PIL.Image.open("caltech256_subset_resized_cropped256x256/data/" + filename)
    
    val_images.append(np.array(img))

crop_offset = (image_size - crop_size) // 2

val_images = tf.image.crop_to_bounding_box(
    np.array(val_images), crop_offset, crop_offset, crop_size, crop_size
)

validation_image_labels = np.array(validation_image_labels)

## 2. Build and Train a model for 10 epochs

In [None]:
def build_model():
    inputs = layers.Input(shape=(crop_size, crop_size, 3))

    pretrained_resnet_model = tf.keras.applications.resnet50.ResNet50(include_top=False, input_tensor=inputs)

    pretrained_resnet_model.trainable = False

    x = layers.GlobalAveragePooling2D()(pretrained_resnet_model.output)

    x = layers.Dropout(0.2)(x)

    outputs = layers.Dense(256, activation="softmax")(x)

    model = Model(inputs, outputs, name="ResNet")

    opt = tf.keras.optimizers.Adam(learning_rate=0.01)

    model.compile(
        optimizer=opt, loss="sparse_categorical_crossentropy", metrics=["accuracy"]
    )

    return model

In [None]:
model = build_model()

batch_size = 128

initial_start = time.time()

for epoch in range(10):
    epoch_start = time.time()
    
    for start_index in range(0, len(train_image_names), batch_size):
        end_index = start_index + batch_size

        train_image_names_batch = train_image_names[start_index:end_index]
        label_batch = train_image_labels[start_index:end_index]

        images = []

        for filename in train_image_names_batch:
            img = PIL.Image.open("caltech256_subset_resized_cropped256x256/data/" + filename)
            img = tf.image.random_crop(np.array(img), size=[crop_size, crop_size, 3])
            images.append(np.array(img))

        images = np.array(images)

        model.fit(images, label_batch, epochs=1, verbose=0, batch_size=128)

    print(f"Evaluating Validation Accuracy...")
    
    model.evaluate(val_images, validation_image_labels)
    
    epoch_end = time.time()
    
    print(f"Epoch Time: {epoch_end - epoch_start}")

last_end = time.time()

print(last_end-initial_start)

print(f"Total Time: {last_end-initial_start}")