In [None]:
import tensorflow as tf
import wandb

from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient() 

personal_key_for_api = user_secrets.get_secret("WANDB_API_KEY")
!wandb login $personal_key_for_api
from wandb.keras import WandbCallback
wandb.init(project="leaf_disease_classification")

In [None]:
path = '../input/new-plant-diseases-dataset/New Plant Diseases Dataset(Augmented)/New Plant Diseases Dataset(Augmented)'

In [None]:
# loading data from image directory
train_data = tf.keras.preprocessing.image_dataset_from_directory(
    path+'/train',
    seed=123,
    image_size=(224, 224),
)
validation_data = tf.keras.preprocessing.image_dataset_from_directory(
    path+'/valid',
    seed=123,
    image_size=(224, 224),
)
# train_data = train_data.batch(128).repeat()
# # validation_data = validation_data.batch(128).repeat()

# AUTOTUNE = tf.data.AUTOTUNE
# train_data = train_data.cache().shuffle(1000).batch(512).prefetch(AUTOTUNE)
# validation_data = validation_data.cache().batch(512).prefetch(AUTOTUNE)

In [None]:
num_classe =38
class CNNClassifier(tf.keras.Model):
    def __init__(self):
        super(CNNClassifier, self).__init__()
        # effnet base
        self.base = tf.keras.applications.EfficientNetB0(weights="imagenet", include_top=False,input_shape=(224,224, 3))
        # add new layers
        self.global_average_pooling = tf.keras.layers.GlobalAveragePooling2D()
        self.fc1 = tf.keras.layers.Dense(256, activation="relu")
        self.fc2 = tf.keras.layers.Dense(num_classe, activation="softmax")

    def call(self, inputs):
        x = self.base(inputs)
        x = self.global_average_pooling(x)

        x = self.fc1(x)
        x = self.fc2(x)
        return x
    

model =CNNClassifier()
optimizer = tf.keras.optimizers.Adam(learning_rate=0.0005)
loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy")
model.compile(
       optimizer=optimizer,
       loss=loss_object,
       metrics=[accuracy])
        


reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
    monitor="test_accuracy",
    factor=0.5,
    patience=1,
    verbose=1,
)
early_stop = tf.keras.callbacks.EarlyStopping(
    monitor="test_accuracy",
    patience=2,
    verbose=1,
    restore_best_weights=True,
)
wandb_callback = WandbCallback(monitor='val_loss',
                               log_weights=True,
                               log_evaluation=True,
                               validation_steps=5)

callbacks = [reduce_lr, early_stop,wandb_callback]

EPOCHS = 10
history = model.fit(
    train_data,
    epochs=EPOCHS,
    validation_data=validation_data,
    callbacks=callbacks,
    batch_size =256,

)



        

In [None]:
# get single batch of data
for image, label in train_data.take(1):
    print(image.shape)
    print(label)
    break