In [2]:
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
import matplotlib.pyplot as plt
import pathlib, os, time

print("TensorFlow", tf.__version__)


TensorFlow 2.19.0


In [3]:
ds_train, ds_test = tfds.load(
    'stanford_dogs',  # community dataset
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=False)

# CLASS_NAMES = ['cardboard', 'glass', 'metal', 'paper', 'plastic', 'trash']
NUM_CLASSES = 120

def format_example(image, label):
    image = tf.image.resize(image, [224, 224])
    image = tf.cast(image, tf.float32) / 255.0
    return image, label

BATCH = 32
train_ds = ds_train.map(format_example, num_parallel_calls=tf.data.AUTOTUNE)\
                   .cache().shuffle(1000).batch(BATCH).prefetch(tf.data.AUTOTUNE)
test_ds  = ds_test.map(format_example, num_parallel_calls=tf.data.AUTOTUNE)\
                  .cache().batch(BATCH).prefetch(tf.data.AUTOTUNE)

In [4]:
# 2.  Build lightweight model
# ==========================================
base = tf.keras.applications.MobileNetV2(
    input_shape=(224, 224, 3),
    include_top=False,
    weights='imagenet')
base.trainable = False  # fine-tune later if desired

model = tf.keras.Sequential([
    base,
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dense(NUM_CLASSES, activation='softmax')
])

model.compile(optimizer=tf.keras.optimizers.Adam(),
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

In [None]:
# ==========================================
# 3.  Train quickly (5 epochs is enough for > 90 %)
# ==========================================
EPOCHS = 3
hist = model.fit(train_ds, epochs=EPOCHS, validation_data=test_ds)
test_loss, test_acc = model.evaluate(test_ds, verbose=0)
print(f"Float model accuracy: {test_acc:.3f}")

Epoch 1/3
[1m375/375[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1s/step - accuracy: 0.4592 - loss: 2.4132

In [None]:
 5.  Convert to TensorFlow-Lite
# ==========================================
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
open("waste_float32.tflite", "wb").write(tflite_model)
print("Float TFLite model size:", len(tflite_model)/1024, "kB")

# ==========================================
# 6.  Post-training quantisation (full-integer)
# ==========================================
def representative_dataset():
    for img_batch, _ in train_ds.take(100):  # 100 batches → ~3200 images
        yield [img_batch]

converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8

tflite_quant = converter.convert()
open("waste_int8.tflite", "wb").write(tflite_quant)
print("Quantised TFLite model size:", len(tflite_quant)/1024, "kB")

# ==========================================
# 7.  Accuracy check of quantised model
# ==========================================
interpreter = tf.lite.Interpreter(model_content=tflite_quant)
interpreter.allocate_tensors()
input_idx  = interpreter.get_input_details()[0]['index']
output_idx = interpreter.get_output_details()[0]['index']

correct = 0
total   = 0
for img, lab in test_ds.unbatch().batch(1):
    img_uint8 = tf.cast(img*255, tf.uint8)
    interpreter.set_tensor(input_idx, img_uint8)
    interpreter.invoke()
    pred = interpreter.get_tensor(output_idx).argmax()
    correct += (pred == lab.numpy()[0])
    total += 1
quant_acc = correct / total
print(f"INT8 TFLite accuracy: {quant_acc:.3f}")

# ==========================================
# 8.  Show classification example
# ==========================================
def infer(image):
    img_uint8 = tf.cast(tf.image.resize(image, [224,224])*255, tf.uint8)
    interpreter.set_tensor(input_idx, img_uint8[None])
    interpreter.invoke()
    return interpreter.get_tensor(output_idx)[0]

sample_img, true_lab = next(iter(test_ds.unbatch().batch(1)))
probs = infer(sample_img[0])
plt.imshow(sample_img[0])
plt.title(f"True: {CLASS_NAMES[true_lab[0]]}  Pred: {CLASS_NAMES[np.argmax(probs)]}")
plt.axis('off')
plt.show()
