### Notebook to get squares from sudoku dataset 

In [1]:
import os
import cv2
import numpy as np
import pandas as pd
from PIL import Image
from collections import defaultdict
from detect_sudoku import find_grid_in_img, center_image_single_contour, center_image_broken_contours

In [2]:
linewidth = 4 # safety margin 
square_size = 28
contour_area_threshold = 30
true_square_size = square_size + 2 * linewidth

In [3]:
def show_wait_destroy(img, window_name = 'image'):
    cv2.imshow(window_name, img)
    cv2.moveWindow(window_name, 500, 0)
    cv2.waitKey(0)
    cv2.destroyWindow(window_name)

In [4]:
def save_digits(grid, labels, data_dir):
    for i in range(9):
        row_labels = labels[i][0].strip().split(' ')
        for j in range(9):
            if row_labels[j] != '0':
                digit_count[row_labels[j]] += 1
                foldername = data_dir + row_labels[j] + '/'
                img_np = grid[i * true_square_size: (i + 1) * true_square_size, 
                              j * true_square_size: (j + 1) * true_square_size]
                
                img_np = img_np[linewidth: -linewidth, linewidth: -linewidth]
                centered = center_image_single_contour(img_np)
                if not isinstance(centered, np.ndarray):
                    centered = center_image_broken_contours(img_np)
                    
                try:
                    img = Image.fromarray(centered)
                except: continue
                
                if img.mode != 'RGB':
                    img = img.convert('RGB')
                
                img.save(foldername + str(digit_count[row_labels[j]]) + '.jpg')


In [5]:
# create folders for digits
curdir = os.getcwd()
for i in range(1, 10):
    if not os.path.exists(os.path.join(curdir, 'sudoku_dataset/train/' + str(i))):
        os.makedirs(os.path.join(curdir, 'sudoku_dataset/train/' + str(i)))
                    
for i in range(1, 10):
    if not os.path.exists(os.path.join(curdir, 'sudoku_dataset/test/' + str(i))):
        os.makedirs(os.path.join(curdir, 'sudoku_dataset/test/' + str(i)))

In [6]:
digit_count = defaultdict(int)

img_dir = curdir + '/sudoku_full_dataset/v1_training/'
data_dir = curdir + '/sudoku_dataset/train/'

length = len(os.listdir(img_dir))
for idx, entry in enumerate(os.scandir(img_dir)):
    print(f"\rtrain: {idx} / {length}", end = "")
    if (entry.path.endswith(".jpg") and entry.is_file()):
        try:
            img = cv2.imread(entry.path, 0)
            grid = find_grid_in_img(img)           
        except: 
            continue
            
        info_filename = entry.path[:-3] + 'dat'
        labels = pd.read_table(info_filename, sep = "\t", skiprows = 1).to_numpy()
        save_digits(grid, labels, data_dir)
        
print("")
digit_count = defaultdict(int)
img_dir = curdir + '/sudoku_full_dataset/v1_test/'
data_dir = curdir + '/sudoku_dataset/test/'

length = len(os.listdir(img_dir))
for idx, entry in enumerate(os.scandir(img_dir)):
    print(f"\rtest: {idx} / {length}", end = "")
    if (entry.path.endswith(".jpg") and entry.is_file()):
        try:
            img = cv2.imread(entry.path, 0)
            grid = find_grid_in_img(img)           
        except: continue
        info_filename = entry.path[:-3] + 'dat'
        labels = pd.read_table(info_filename, sep = "\t", skiprows = 1).to_numpy()
        save_digits(grid, labels, data_dir)

train: 239 / 240
test: 79 / 80