In [1]:
import os
import tensorflow as tf
import mlflow
import mlflow.keras
from sklearn.metrics import precision_score, recall_score, accuracy_score
import numpy as np
from keras.preprocessing.image import ImageDataGenerator

In [2]:
#image augmentation
train_datagen = ImageDataGenerator(
          rescale=1./255,
          shear_range=0.2,
          zoom_range=0.2,
          horizontal_flip=True)
training_set = train_datagen.flow_from_directory(
          'E:/Deep Learning/TENSORFLOW/rice_image_detection/artifacts/data/train',
          target_size=(64, 64),
          batch_size=32)

Found 60000 images belonging to 5 classes.


In [3]:
test_datagen = ImageDataGenerator(rescale=1./255)
test_set = test_datagen.flow_from_directory(
          'E:/Deep Learning/TENSORFLOW/rice_image_detection/artifacts/data/test/test',
          target_size=(64, 64),
          batch_size=32)

Found 15000 images belonging to 5 classes.


In [4]:
MLFLOW_TRACKING_URI = "https://dagshub.com/karmakaragradwip02/rice_image_detection_cnn.mlflow"
os.environ['MLFLOW_TRACKING_URI'] = MLFLOW_TRACKING_URI
os.environ['MLFLOW_TRACKING_USERNAME'] = 'karmakaragradwip02'
os.environ['MLFLOW_TRACKING_PASSWORD'] = '9ccb0f28354fcca6469017b32544fa0704b9c343'

mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)
mlflow.set_experiment("CNN Classifier")

<Experiment: artifact_location='mlflow-artifacts:/3d9fa9ea225b44c187c7461de0ea5637', creation_time=1718721506518, experiment_id='0', last_update_time=1718721506518, lifecycle_stage='active', name='CNN Classifier', tags={}>

In [5]:
weight_decay = 1e-4  # Weight decay factor
learning_rate = 1e-3  # Custom learning rate

In [6]:
cnn = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(filters=32, kernel_size=3, activation='relu', input_shape=[64, 64, 3],
                           kernel_regularizer=tf.keras.regularizers.l2(weight_decay)),
    tf.keras.layers.MaxPool2D(pool_size=2, strides=2),
    tf.keras.layers.Conv2D(filters=32, kernel_size=3, activation='relu',
                           kernel_regularizer=tf.keras.regularizers.l2(weight_decay)),
    tf.keras.layers.MaxPool2D(pool_size=2, strides=2),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(units=128, activation='relu',
                          kernel_regularizer=tf.keras.regularizers.l2(weight_decay)),
    tf.keras.layers.Dense(5, activation='sigmoid')
])

optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

cnn.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])

In [7]:
with mlflow.start_run() as run:
    try:
        mlflow.log_param('weight_decay', weight_decay)
        mlflow.log_param('learning_rate', learning_rate)
        mlflow.log_param('epochs', 20)
        # Fit the model
        history = cnn.fit(x=training_set, validation_data=test_set, epochs=20)

        # Get predictions
        y_pred = np.argmax(cnn.predict(test_set), axis=1)
        y_true = np.argmax(test_set.labels, axis=1)

        # Calculate precision and recall
        precision = precision_score(y_true, y_pred, average='macro')
        recall = recall_score(y_true, y_pred, average='macro')
        accuracy = accuracy_score(y_true, y_pred)
        
        # Log metrics
        mlflow.log_metric('accuracy', accuracy)
        mlflow.log_metric('precision', precision)
        mlflow.log_metric('recall', recall)

        # Log the entire training history
        # Log the entire training history
        for epoch in range(len(history.history['accuracy'])):
            mlflow.log_metric('train_accuracy', history.history['accuracy'][epoch], step=epoch)
            mlflow.log_metric('val_accuracy', history.history['val_accuracy'][epoch], step=epoch)
            mlflow.log_metric('train_loss', history.history['loss'][epoch], step=epoch)
            mlflow.log_metric('val_loss', history.history['val_loss'][epoch], step=epoch)
        
        mlflow.pytorch.log_model(cnn, "model")
    except Exception as e:
        print(f"Exception during training: {e}")
    finally:
        mlflow.end_run()

Epoch 1/20
  93/1875 [>.............................] - ETA: 25:39 - loss: 0.6773 - accuracy: 0.7312

KeyboardInterrupt: 