In [1]:
import tensorflow as tf 

In [2]:
# import tensorflow as tf

IMG_SIZE = (224, 224)
BATCH_SIZE = 32

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    "cell_images",              # Path to root directory
    validation_split=0.2,
    subset="training",
    seed=123,
    image_size=IMG_SIZE,
    batch_size=BATCH_SIZE
)

val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    "cell_images",
    validation_split=0.2,
    subset="validation",
    seed=123,
    image_size=IMG_SIZE,
    batch_size=BATCH_SIZE
)


Found 27558 files belonging to 2 classes.
Using 22047 files for training.
Found 27558 files belonging to 2 classes.
Using 5511 files for validation.


In [None]:
from tensorflow.keras.applications import MobileNetV3Small, MobileNetV3Large
from tensorflow.keras import layers, models

base_model = MobileNetV3Small(input_shape=(224, 224, 3),
                              include_top=False,
                              weights='imagenet')
base_model.trainable = False  # Fine-tune later if needed

model = models.Sequential([
    base_model,
    layers.GlobalAveragePooling2D(),
    layers.Dense(128, activation='relu'),
    layers.Dense(1, activation='sigmoid')  # Binary classification
])

model.compile(optimizer='adam',
              loss='binary_crossentropy',
              metrics=['accuracy'])


Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v3/weights_mobilenet_v3_small_224_1.0_float_no_top_v2.h5
[1m4334752/4334752[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 2us/step


In [4]:
history = model.fit(train_ds, validation_data=val_ds, epochs=10)


Epoch 1/10
[1m689/689[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m348s[0m 495ms/step - accuracy: 0.9353 - loss: 0.1830 - val_accuracy: 0.9541 - val_loss: 0.1264
Epoch 2/10
[1m689/689[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m247s[0m 358ms/step - accuracy: 0.9578 - loss: 0.1313 - val_accuracy: 0.9545 - val_loss: 0.1304
Epoch 3/10
[1m689/689[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m255s[0m 370ms/step - accuracy: 0.9576 - loss: 0.1236 - val_accuracy: 0.9545 - val_loss: 0.1332
Epoch 4/10
[1m689/689[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m254s[0m 369ms/step - accuracy: 0.9588 - loss: 0.1169 - val_accuracy: 0.9559 - val_loss: 0.1218
Epoch 5/10
[1m689/689[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m256s[0m 371ms/step - accuracy: 0.9603 - loss: 0.1130 - val_accuracy: 0.9594 - val_loss: 0.1154
Epoch 6/10
[1m689/689[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m256s[0m 371ms/step - accuracy: 0.9620 - loss: 0.1083 - val_accuracy: 0.9590 - val_loss: 0.1134
Epoc

In [7]:
model.save('malaria_mobilenetv3.keras')

In [10]:
import tensorflow as tf
import numpy as np
from safetensors.numpy import save_file

# Load the Keras model
model = tf.keras.models.load_model("malaria_mobilenetv3.keras")

# Get ALL weights (trainable + non-trainable)
weight_dict = {}
for layer in model.layers:
    for i, weight in enumerate(layer.weights):
        # Preserve original naming format
        name = f"{layer.name}_{i}"
        weight_dict[name] = weight.numpy()

# Save with proper metadata
metadata = {
    "framework": "keras",
    "model_type": "MobileNetV3",
    "task": "malaria_detection"
}
save_file(weight_dict, "malaria_mobilenetV3.safetensors", metadata=metadata)

# ✅ Corrected print statement
total_params = sum(np.prod(v.shape) for v in weight_dict.values())
print(f"✅ Saved {len(weight_dict)} tensors ({total_params:,} total parameters)")


✅ Saved 210 tensors (1,013,105 total parameters)


In [11]:
import json 

config = {
    "model_type": "MobileNetV3Small",
    "input_shape": [1, 224, 224, 3],
    "num_classes": 1,
    "activation": "sigmoid",
    "pooling": "global_average",
    "hidden_units": [128],
    "framework": "keras",
    "pretrained_base": "imagenet",
    "trainable_base": False,
    "classifier_head": {
        "dense_1": {
            "units": 128,
            "activation": "relu"
        },
        "output": {
            "units": 1,
            "activation": "sigmoid"
        }
    }
}

with open("config.json", "w") as f:
    json.dump(config, f, indent=2)
