In [1]:
import os
os.environ["WANDB_DISABLED"] = "true"
os.environ["WANDB_MODE"] = "disabled"
os.environ["WANDB_SILENT"] = "true"

In [2]:
import tensorflow as tf
import os
import glob
from tensorflow.keras.applications import EfficientNetB7
from tensorflow.keras import layers, models
from tensorflow.keras.callbacks import ModelCheckpoint

Define Dataset Loading Function

In [3]:
def load_dataset(tfrecord_pattern, batch_size=64, shuffle=True):
    # Get list of all TFRecord files in the directory
    files = glob.glob(tfrecord_pattern)

    # Define your feature description for parsing TFRecords
    feature_description = {
        'image': tf.io.FixedLenFeature([], tf.string),  # Serialized as bytes
        'label': tf.io.FixedLenFeature([4], tf.int64)   # Assuming labels are one-hot encoded
    }

    # Function to parse the TFRecord examples
    def _parse_function(proto):
        parsed_example = tf.io.parse_single_example(proto, feature_description)
        
        # Decode image bytes to proper image tensor
        image = tf.io.decode_jpeg(parsed_example['image'], channels=3)  # Adjust this to decode the image properly
        image = tf.image.resize(image, (224, 224))  # Resize to match input shape
        label = parsed_example['label']
        return image, label

    # Define the dataset
    raw_dataset = tf.data.TFRecordDataset(files)
    dataset = raw_dataset.map(_parse_function, num_parallel_calls=tf.data.AUTOTUNE)

    # Optionally shuffle the data and batch it
    if shuffle:
        dataset = dataset.shuffle(buffer_size=1000)

    dataset = dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)

    return dataset

# Load the datasets
train_dataset = load_dataset("/kaggle/input/oct-dataset-tfrecord/train/train_chunk_*.tfrecord", batch_size=64)
val_dataset = load_dataset("/kaggle/input/oct-dataset-tfrecord/val/val_chunk_*.tfrecord", batch_size=64, shuffle=False)
test_dataset = load_dataset("/kaggle/input/oct-dataset-tfrecord/test/test_chunk_*.tfrecord", batch_size=64, shuffle=False)

Set up Mirrored Strategy

In [4]:
import tensorflow as tf

# Initialize the strategy for multi-GPU training
strategy = tf.distribute.MirroredStrategy()

print(f"Number of devices: {strategy.num_replicas_in_sync}")

Number of devices: 2


Define the Model and Wrap in Strategy Scope

In [5]:
with strategy.scope():
    from tensorflow.keras.applications import EfficientNetB7
    from tensorflow.keras import layers, models

    # Define the model
    base_model = EfficientNetB7(
        include_top=False,
        weights='imagenet',  # Use pre-trained weights
        input_shape=(224, 224, 3)
    )

    # Freeze the pre-trained layers
    base_model.trainable = False

    # Add custom head for classification
    head_model = models.Sequential([
        layers.GlobalAveragePooling2D(),
        layers.Dense(4, activation='softmax')  # Output 4 classes
    ])

    # Combine the base and custom head
    final_model = models.Sequential([
        base_model,
        head_model
    ])

    # Compile the model
    final_model.compile(
        optimizer='adam',
        loss='categorical_crossentropy',
        metrics=['accuracy']
    )

Downloading data from https://storage.googleapis.com/keras-applications/efficientnetb7_notop.h5
[1m258076736/258076736[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 0us/step


Set up Callbacks for Checkpointing

In [6]:
from tensorflow.keras.callbacks import ModelCheckpoint

checkpoint_path = "/kaggle/working/ckpt_{epoch:02d}.weights.h5"
checkpoint_callback = ModelCheckpoint(
    filepath=checkpoint_path,
    save_weights_only=True,
    save_best_only=True,  # Save only the best model
    monitor='val_loss',
    mode='min',
    verbose=1
)

Train the Model

In [7]:
final_model.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=20,  # Start with 20 epochs
    callbacks=[checkpoint_callback]
)

Epoch 1/20
   1305/Unknown [1m750s[0m 508ms/step - accuracy: 0.4231 - loss: nan   

  self.gen.throw(typ, value, traceback)



Epoch 1: val_loss improved from inf to 1.58747, saving model to /kaggle/working/ckpt_01.weights.h5
[1m1305/1305[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m770s[0m 523ms/step - accuracy: 0.4231 - loss: nan - val_accuracy: 0.2500 - val_loss: 1.5875
Epoch 2/20
[1m1305/1305[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 514ms/step - accuracy: 0.4274 - loss: nan   
Epoch 2: val_loss improved from 1.58747 to 1.55020, saving model to /kaggle/working/ckpt_02.weights.h5
[1m1305/1305[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m674s[0m 516ms/step - accuracy: 0.4274 - loss: nan - val_accuracy: 0.2500 - val_loss: 1.5502
Epoch 3/20
[1m1305/1305[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 511ms/step - accuracy: 0.4270 - loss: nan   
Epoch 3: val_loss did not improve from 1.55020
[1m1305/1305[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m668s[0m 512ms/step - accuracy: 0.4270 - loss: nan - val_accuracy: 0.2500 - val_loss: 1.6244
Epoch 4/20
[1m1305/1305[0m [32m━━━━━━

<keras.src.callbacks.history.History at 0x7c93535a8f40>

Adding Test Dataset

In [8]:
test_loss, test_accuracy = final_model.evaluate(test_dataset)
print(f"Test Loss: {test_loss}")
print(f"Test Accuracy: {test_accuracy}")

[1m16/16[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 523ms/step - accuracy: 0.5369 - loss: nan
Test Loss: 1.6372750997543335
Test Accuracy: 0.26229506731033325
