In [None]:
import tensorflow as tf

In [None]:
# Function to parse the TFRecord file
def parse_example(example_proto):
    # Define the feature description dictionary
    feature_description = {
        'image': tf.io.FixedLenFeature([224 * 224 * 3], tf.float32),  # Assuming RGB images flattened
        'label': tf.io.FixedLenFeature([], tf.int64)
    }
    # Parse the input tf.train.Example proto using the dictionary
    parsed_example = tf.io.parse_single_example(example_proto, feature_description)
    # Reshape the image to its original dimensions
    image = tf.reshape(parsed_example['image'], [224, 224, 3])
    label = parsed_example['label']
    return image, label

# Load the TFRecord files
def load_dataset(filename):
    # Create a TFRecordDataset from the file
    raw_dataset = tf.data.TFRecordDataset(filename)
    # Parse the dataset
    parsed_dataset = raw_dataset.map(parse_example, num_parallel_calls=tf.data.AUTOTUNE)
    return parsed_dataset

# Load the training, validation, and test datasets
train_dataset = load_dataset('train_dataset.tfrecord')
val_dataset = load_dataset('val_dataset.tfrecord')
test_dataset = load_dataset('test_dataset.tfrecord')

# Batch the datasets for training and evaluation
batch_size = 32
train_dataset = train_dataset.batch(batch_size).prefetch(buffer_size=tf.data.AUTOTUNE)
val_dataset = val_dataset.batch(batch_size).prefetch(buffer_size=tf.data.AUTOTUNE)
test_dataset = test_dataset.batch(batch_size).prefetch(buffer_size=tf.data.AUTOTUNE)

# Example of building a MobileNet model
base_model = tf.keras.applications.MobileNetV2(input_shape=(224, 224, 3),
                                               include_top=False,
                                               weights='imagenet')
base_model.trainable = False

# Add custom layers for classification
global_avg_layer = tf.keras.layers.GlobalAveragePooling2D()
dense_layer = tf.keras.layers.Dense(1, activation='sigmoid')  # Assuming binary classification

# Assemble the model
model = tf.keras.Sequential([
    base_model,
    global_avg_layer,
    dense_layer
])

# Compile the model
model.compile(optimizer=tf.keras.optimizers.Adam(),
              loss='binary_crossentropy',
              metrics=['accuracy'])

# Train the model
model.fit(train_dataset,
          validation_data=val_dataset,
          epochs=10)

# Evaluate the model
eval_results = model.evaluate(test_dataset)
print(f"Test Loss: {eval_results[0]}, Test Accuracy: {eval_results[1]}")