In [None]:
import numpy as np 
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
import string
from PIL import Image
from sklearn.utils import shuffle
from tensorflow.keras.layers import (
    Input, Conv2D, MaxPool2D, 
    BatchNormalization, Bidirectional, LSTM,
    Softmax
)
import csv
from PIL import Image
from time import process_time

In [None]:
def load_fn_contents(fn_obj, contents, loc):
    with open(loc, "r") as f:
      reader = csv.DictReader(f, delimiter=',')
      for line in reader:
        fn_obj.append(line['name'])
        contents.append(line['content'])
        
    return fn_obj, contents

In [None]:
def load_data(loc, filenames):
    data = []

    start_time = process_time()

    print("Process: loading images into memory ....")

    count = 0
    for fn in filenames:
      if count % 100 == 0:
        print("Loading {} images...".format(count))
      count += 1
      with Image.open(f'{loc}/{fn}', 'r') as img:
        img = img.resize((784, 32), Image.ANTIALIAS)
        img = np.asarray(img)
        img = img[:, :, :3]
        data.append(img)

    data = np.asarray(data)
    print(f"No of Images loaded :{data.shape[0]}")

    finish_time = process_time()
    print(f"Information: load images into memory took { round(finish_time-start_time, 2) } seconds")
    
    return data

In [None]:
train_filenames, train_contents = load_fn_contents([], [], '../input/ronelov2/data.csv')

In [None]:
X = load_data('../input/ronelov2/train', train_filenames)

In [None]:
def show(idx, data, label):
  plt.imshow(data[idx])
  plt.title(label[idx])

show(4000, X, train_contents)

In [None]:
symbols = f" {string.ascii_letters}{string.digits}.,*&!@~():`^[]';|-/$#?%"

MAX_CHAR = 49
SYMBOLS_COUNT = len(symbols)
IMG_COUNT = len(X)
CHANNELS = 3
 
print(f"Characters : {symbols}")
print(f"No of chars : {SYMBOLS_COUNT}")
print(f"No of chars : {IMG_COUNT}")

In [None]:
y_shape = (IMG_COUNT, MAX_CHAR, SYMBOLS_COUNT) 
Y = np.zeros(shape=y_shape) 

for example_no, words in enumerate(train_contents): #index, sentence
    for letter_no, letter in enumerate(words): #iterate through sentence
        try:
            Y[example_no][letter_no][symbols.index(letter)]=1
        except:
            print(letter, end=" ")

In [None]:
channel = 3
def reshape(data):
  return np.reshape(data, (data.shape[0], data.shape[1], data.shape[2], channel))

X = reshape(X)
X.shape

In [None]:
idx = np.where(Y[0][0] == 1)[0][0]
print(idx)
print(train_contents[0])
print(symbols[idx])

In [None]:
def OCRModel():
    img   = Input((X.shape[1], X.shape[2], CHANNELS))
#     img = Input((32, 384, 1))
    conv1 = Conv2D(16,(3,3), activation='relu', padding='same')(img)
    mp1   = MaxPool2D((2,2), padding='same')(conv1)
    conv2 = Conv2D(32,(3,3), activation='relu', padding='same')(mp1)
    mp2   = MaxPool2D((2,2), padding='same')(conv2)
    conv3 = Conv2D(64,(3,3), activation='relu', padding='same')(mp2)
    mp3   = MaxPool2D((2,2), padding='same')(conv3)
    conv4 = Conv2D(128,(3,3), activation='relu', padding='same')(mp3)
    mp4   = MaxPool2D((2,2), padding='same')(conv4)
    conv5 = Conv2D(256,(3,3), activation='relu', padding='same')(mp4)
    mp5   = MaxPool2D((2,1), padding='same')(conv5)
    conv6 = Conv2D(256,(3,3), activation='relu', padding='same')(mp5)
    mp7  = MaxPool2D((2,1), padding='same')(conv6)
    conv7 = Conv2D(512,(3,3), activation='relu', padding='same')(mp7)
    # mp6 = MaxPool2D((2,1), padding='same')(conv7_1)
    bn1   = BatchNormalization()(conv6)
    sq    = tf.squeeze(bn1, axis=1)

    rn1   = Bidirectional(LSTM(256,return_sequences=True))(sq)
    rn2   = Bidirectional(LSTM(256,return_sequences=True))(rn1)

    exd     = tf.expand_dims(rn2,axis=2)
    mapping = Conv2D(len(symbols),(2,2), activation='relu',padding='same')(exd)
    mapping = tf.squeeze(mapping,axis=2)
    mapping = Softmax()(mapping)

    # bn = keras.layers.BatchNormalization()(conv3)
    model   = keras.Model(img,mapping)
    model.compile(loss='categorical_crossentropy', 
                  optimizer=tf.keras.optimizers.Adam(), 
                  metrics=['accuracy'])
    
    return model


In [None]:
from contextlib import redirect_stdout

OCR = OCRModel()

# with open('modelsummary.txt', 'w+') as f:
#   with redirect_stdout(f):
#     OCR.summary()
OCR.summary()

In [None]:
class AccCallback(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs={}):
    if(logs.get('accuracy') > 0.998):
      print("Accuracy has reached 99.8%")
      self.model.stop_training = True

callback = AccCallback()

In [None]:
# from tensorflow.keras.preprocessing.image import ImageDataGenerator

# train_datagen = ImageDataGenerator(
#     rotation_range = 10,
#     validation_split=0.2
# )

In [None]:
start_time = process_time()
print(f"Process: training models...")

OCR = OCRModel()
history = OCR.fit(X, Y, 
                  validation_split = 0.15,
                  epochs = 30,
                  shuffle=True,
                  callbacks=[callback])

finish_time = process_time()
print(f"Information: training OCR model took {finish_time-start_time} seconds...")

In [None]:
# OCR.save('ronelov2-2.0', save_format='tf')

In [None]:
test_filenames, test_contents = load_fn_contents([], [], '../input/ronelov2/test/data.csv')

In [None]:
test_filenames[0]

In [None]:
from shutil import copytree
copytree('../input/ronelov2/test', '/kaggle/working/test')

In [None]:
test_dir = '/kaggle/working/test/'
for file in os.listdir(test_dir):
    if file.endswith('.PNG'):
        os.rename(test_dir + file, test_dir + file[:-4] + '.png')

In [None]:
ls

In [None]:
# OCR.save('ronelov2-2.0.h5', save_format='h5')

In [None]:
!ls ronelov2-2.0/variables

In [None]:
TEST = load_data('/kaggle/working/test', test_filenames)

In [None]:
TEST.shape

In [None]:
test_one_hot_shape = (len(TEST), MAX_CHAR, SYMBOLS_COUNT) 
test_one_hot = np.zeros(shape=test_one_hot_shape) 

for example_no, words in enumerate(test_contents): #index, sentence
    for letter_no, letter in enumerate(words): #iterate through sentence
        try:
            test_one_hot[example_no][letter_no][symbols.index(letter)]=1
        except:
            print(letter, end=" ")

In [None]:
loss, acc = OCR.evaluate(TEST, test_one_hot, verbose=2)

In [None]:
# start_time = process_time()
# print(f"Process: predicting {len(TEST)} images...")
# pred = OCR.predict(TEST)

# finish_time = process_time()
# print(f"Information: predicting {len(TEST)} images took {finish_time-start_time} seconds...")

In [None]:
start_time = process_time()
print(f"Process: predicting {len(X)} images...")
pred = OCR.predict(X)

finish_time = process_time()
print(f"Information: predicting {len(X)} images took {finish_time-start_time} seconds...")

In [None]:
count = 0
# idx = 3
for idx in range(len(pred)):
  c = ""
  for i in range(MAX_CHAR):
      c += symbols[np.argmax(pred[idx][i])]
      # print(c)
  if c.strip() == train_contents[idx].strip():
    count += 1

print(count)
print(f"Correctly predicted: { round(count/len(pred), 2) * 100 }%" )

print("predicted:",c.strip())
print("\nOriginal:",train_contents[idx])
plt.imshow(X[idx][:,:,:3])

In [None]:
pred[100]

In [None]:
index = 5002
c=""
# print(len(pred[0]))
for i in range(len(pred[0])):
    c=c+(symbols[np.argmax(pred[index][i])])
#     print(c)
print("predicted:",c.strip())
print("\nOrignal:",train_contents[index])
plt.imshow(X[index][:,:,:3])

#### index = 20
c=""
print(len(pred[0]))
for i in range(len(pred[0])):
    c=c+(symbols[np.argmax(pred[index][i])])
    # print(c)
print("predicted:",c.strip())
print("\nOrignal:",contents[index])
plt.imshow(X[index][:,:,0])

In [None]:
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']

epochs = range(len(acc))

plt.plot(epochs, acc, 'r', label='Training accuracy')
plt.plot(epochs, val_acc, 'b', label='Validation accuracy')
plt.title('Training and validation accuracy')
plt.legend(loc=0)
plt.figure()


plt.show()

plt.plot(epochs, loss, 'r', label='Training loss')
plt.plot(epochs, val_loss, 'b', label='Validation loss')
plt.title('Training and validation loss')
plt.legend(loc=0)
plt.figure()

plt.show()