In [None]:
import os
import pandas as pd
import numpy as np

import tensorflow as tf
from tensorflow.keras.utils import image_dataset_from_directory

### Pre-processing

In [None]:
TRAIN_DIR = './Rice_Image_Dataset/Train/'
BATCH_SIZE = 256
IMAGE_WIDTH = 256
IMAGE_HEIGHT = 256

In [None]:
train_ds = image_dataset_from_directory(directory=TRAIN_DIR, batch_size=BATCH_SIZE, image_size=(IMAGE_WIDTH, IMAGE_HEIGHT))

Found 50000 files belonging to 5 classes.


In [None]:
NUM_CLASSES = len(train_ds.class_names)

### Training

In [None]:
# Import needed libraries

from tensorflow import keras
from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, InputLayer, Rescaling

In [None]:
# Define architecture

def build_cnn(image_width, image_height, num_classes):
    model = Sequential()

    model.add(InputLayer(input_shape=(image_width, image_height, 3)))
    
    model.add(Rescaling(1./255))

    model.add(Conv2D(16, 3, padding='same', activation='elu'))
    model.add(Conv2D(32, 3, padding='same', activation='elu'))
    model.add(MaxPooling2D(pool_size=(2,2)))
       
    # Ideally the model should output a label or something
    model.add(Flatten())
    model.add(Dense(num_classes))

    return model

In [None]:
cnn = build_cnn(image_width=IMAGE_WIDTH, image_height=IMAGE_HEIGHT, num_classes=NUM_CLASSES)

In [None]:
# Do the actual training
cnn.compile(
  optimizer='adam',
  loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),
  metrics=['accuracy'])

In [None]:
checkpoint_path = "training_1/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                 save_weights_only=True,
                                                 verbose=1)

cnn.fit(train_ds, epochs=1, callbacks=[cp_callback])


Epoch 1: saving model to training_1/cp.ckpt


<keras.callbacks.History at 0x7f92c200bcd0>

In [None]:
cnn.save('model.h5')

<a style='text-decoration:none;line-height:16px;display:flex;color:#5B5B62;padding:10px;justify-content:end;' href='https://deepnote.com?utm_source=created-in-deepnote-cell&projectId=9aa6fced-121f-49c6-aa5e-ed34b7b6710c' target="_blank">
 </img>
Created in <span style='font-weight:600;margin-left:4px;'>Deepnote</span></a>