In [1]:
import os
from tensorflow.keras.applications import VGG16
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras.models import Model
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import pandas as pd

# set the dimensions of the input images
img_width, img_height = 224, 224

# set the number of classes
num_classes = 26

# set the batch size and number of epochs
batch_size = 32
epochs = 10

# load the pre-trained VGG16 model
base_model = VGG16(weights='imagenet', include_top=False, input_shape=(img_width, img_height, 3))

# freeze the layers in the base model
for layer in base_model.layers:
    layer.trainable = False

# add a new output layer for the number of classes
x = base_model.output
x = Flatten()(x)
x = Dense(256, activation='relu')(x)
predictions = Dense(num_classes, activation='softmax')(x)

# create the new model
model = Model(inputs=base_model.input, outputs=predictions)

# compile the model
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

# set up the data generators for training and validation
train_data_dir = os.path.join("train-file")
test_data_dir = os.path.join("test-file")

# use data augmentation for the training set
train_datagen = ImageDataGenerator(rescale=1./255, shear_range=0.2, zoom_range=0.2, horizontal_flip=True)

# no data augmentation for the test set
test_datagen = ImageDataGenerator(rescale=1./255)

# load the data using the flow_from_directory method
train_generator = train_datagen.flow_from_directory(train_data_dir, target_size=(img_width, img_height), batch_size=batch_size, class_mode='categorical')
test_generator = test_datagen.flow_from_directory(test_data_dir, target_size=(img_width, img_height), batch_size=batch_size, class_mode='categorical')


# train the model
model.fit(train_generator, steps_per_epoch=train_generator.samples // batch_size, epochs=epochs, validation_data=test_generator, validation_steps=test_generator.samples // batch_size)

Found 7254 images belonging to 26 classes.
Found 1820 images belonging to 26 classes.
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.History at 0x2e311750940>

In [2]:
# save the model 
model.save('ocr_model.h5')