In [1]:
import cv2 as cv
import numpy as np
import tensorflow as tf

from helpers.utils import show_image, normalize, preprocess_image, detect_sudoku, extract_cells, sobel_gradients, get_binary_labels, get_digit_labels

In [17]:
def extract_information(img):
    img = detect_sudoku(img)  # varying size, colored
    img = cv.resize(img, (500, 500))  # fixed size
    img = cv.cvtColor(img, cv.COLOR_BGR2GRAY)  # fixed size, grayscale
    img = normalize(img)
    # show_image("sudoku", img)  # show cropped image (sudoku rectangle)

    # apply lines overlay:
    # for line in lines_vertical:
    #     cv.line(img, line[0], line[1], (0, 255, 0), 5)
    # for line in lines_horizontal:
    #     cv.line(img, line[0], line[1], (0, 0, 255), 5)
    # show_image("img", img)  # show sudoku with lines overlay
    
    binary_labels = get_binary_labels(img)
    
    model = tf.keras.models.load_model('saved_model/model.h5')
    
    digit_labels = get_digit_labels(img, model)

    return binary_labels, digit_labels



In [3]:
def get_results(input_dir, output_dir, number_of_samples):
    for i in range(1, number_of_samples + 1):
        if i < 10:
            img = cv.imread(
                f"{input_dir}/0{i}.jpg")
        else:
            img = cv.imread(
                f"{input_dir}/{i}.jpg")
        binary_labels, digit_labels = extract_information(img)
        binary_labels = np.array(binary_labels)
        binary_labels = np.reshape(binary_labels, (9, 9))
        digit_labels = np.array(digit_labels)
        digit_labels = np.reshape(digit_labels, (9, 9))

        file = open(f'{output_dir}/clasic/{i}_predicted.txt', 'w')
        for j in range(9):
            for k in range(9):
                if binary_labels[j][k] == 0:
                    char = 'o'
                else:
                    char = 'x'
                file.write(char)
            if j != 8:
                file.write('\n')
        file.close()

        file = open(f'{output_dir}/clasic/{i}_bonus_predicted.txt', 'w')
        for j in range(9):
            for k in range(9):
                char = str(digit_labels[j][k])
                if char == '0':
                    char = 'o'
                file.write(char)
            if j != 8:
                file.write('\n')
        file.close()

In [18]:
input_dir = 'datasets/antrenare/clasic'
output_dir = 'results'
number_of_samples = 20  # number of input images
get_results(input_dir, output_dir, number_of_samples)