# Script prepare data square occupation
Script to create training image data for the CNN that classify an empty or occupied square.

For each single square create:
- Image named as "<dataset image no.>_<square coord>.png"
- Text file "<dataset image no.>_<square coord>.txt" containing the true label

In [None]:
import os, glob
import cv2
from FEN import FEN
from chessboard_detection import *

Change path according to where you have the dataset folder, default input dir

Change path according to where you prefer having the output

Regex:
- path/1** : start with '1' (1000 to 1999)
- path/** : all

rewrite: Set 'True' to rewrite old already processed images found in the output dir 

In [None]:
input_img_path = './../input/**'
dst_dir_path = './../output/training_squares/'
rewrite = True
rewrite_errors = True
dataset_percentage = 1
save_output = True

In [None]:
import ast
with open("errors.txt", "r") as file:
  error_list = ast.literal_eval(file.read())

In [None]:
if rewrite_errors:
    error_list = []

In [None]:
os.makedirs(dst_dir_path, exist_ok=True)

to_be_processed = glob.glob(input_img_path)
if not rewrite:
    already_processed = glob.glob(f'{dst_dir_path}**')
    already_processed = list(set([os.path.splitext(filename)[0].split('\\')[-1].split('_')[0] for filename in already_processed]))

already_processed = []
already_processed.sort()

all_input_files = glob.glob(input_img_path)
to_be_processed = [filename for filename in all_input_files
                   if '.json' not in filename
                   and os.path.splitext(filename)[0].split('\\')[-1] not in already_processed
                   and os.path.splitext(filename)[0].split('\\')[-1] not in error_list
                   ]

In [None]:
cut_index = int(len(to_be_processed) * dataset_percentage - len(already_processed))
to_be_processed = to_be_processed[:cut_index]
to_be_processed.sort()

if cut_index < 0:
  exceeded = True

print(f"Already processed: {len(already_processed)}")
print(f"To be processed: {len(to_be_processed)}")

## Data Creation

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MaskRCNN_board() 
model.model.load_state_dict(torch.load('./../maskRCNN_epoch_4_.pth', map_location = device))
model.to(device)
model.eval()
print('')

In [None]:
num_of_to_be_processed = len(to_be_processed)
num_of_already_processed = 0
last_progress_percentage_shown = 0

for in_process in to_be_processed:
    imgnumber = os.path.splitext(in_process)[0].split('\\')[-1]

    # percentage update
    progress_percentage = num_of_already_processed / num_of_to_be_processed * 100
    if progress_percentage - last_progress_percentage_shown > 10:
        last_progress_percentage_shown = progress_percentage
        print(f"########################### progress: {progress_percentage}% ###########################\n")

    #skip if not a file image, if json does not exists, if already present in output
    if not os.path.isfile(in_process):
        continue
    if not in_process.lower().endswith(".png"):
        continue
    print(f"{in_process}...", end=' ')
    if not os.path.isfile(os.path.splitext(in_process)[0] + '.json'):
        print(f"Not found related json({in_process})")
        continue
    if imgnumber in already_processed:
        num_of_already_processed += 1
        print(f"Already processed, skipped({in_process})")
        continue

    # load FEN true label
    truth = FEN(os.path.splitext(in_process)[0])
    true_fen, true_pos, viewpoint = truth.fen, truth.pieces, truth.view
    
    try:
        # First pass preprocessing
        warpedBoardImg = board_detection(in_process, old_version=False, model=model)
        if warpedBoardImg is None:
            num_of_already_processed += 1
            print("Skipped (Error in warping)")
            raise Exception("Warping Error")

        # Second pass preprocessing
        grid_squares = grid_detection(warpedBoardImg, viewpoint)
        if grid_squares is None:
            print("Skipped (Error in grid)")
            num_of_already_processed += 1
            raise Exception("Grid Error")
        
        # Extend the information to include piece information in 3rd col (image remain last in 4th col)
        grid_squares = np.column_stack((grid_squares[:,:2], 
                                            [true_pos.get(coord, 'empty') for coord in grid_squares[:, 1]],
                                            grid_squares[:,-2:]
                                            ))
        
        for square_coord, piece, square_img in grid_squares[:,1:4]:
            output_filename = f'{dst_dir_path}{imgnumber}_{square_coord}'
            
            # .png
            cv2.imwrite(output_filename + '.png', square_img)

            # .txt
            with open(f'{output_filename}.txt', 'w') as f:
                f.write(true_pos[square_coord] if square_coord in true_pos else 'empty')

        print('Done')

    except:
        output_filename = f'{dst_dir_path}{imgnumber}_error'
        with open(f'{output_filename}.txt', 'w') as f:
                f.write('Error somewhere')
            
    num_of_already_processed += 1
        

Clean square files created by image that raised errors

In [None]:
errors = glob.glob(f'{dst_dir_path}**error**')

errors = list(set([os.path.splitext(filename)[0].split('\\')[-1].split('_')[0] for filename in errors]))
print(f"Found {len(errors)} to be deleted")

for error in errors:
    for filename in glob.glob(f"{dst_dir_path}*{error}*"):
        os.remove(filename)

errors.sort()
print(f"errors deleted: {errors}")

In [None]:
already_processed = glob.glob(f'{dst_dir_path}**')
already_processed = list(set([os.path.splitext(filename)[0].split('\\')[-1].split('_')[0] for filename in already_processed]))

total_errors = [f'{i:0>4}' for i in range(len(all_input_files)//2) if f'{i:0>4}' not in already_processed]
total_errors.sort()

with open("errors.txt", "w") as f:
  print(f"In total {len(total_errors)} errors, save in errors.txt")
  f.write(str(total_errors))

Update img numbers that of imgs that raised an error

In [None]:
with open("errors.txt", "w") as f:
  total_errors = list(set(error_list + errors))
  print(f"In total {len(total_errors)} errors, save in errors.txt")
  f.write(str(list(set(error_list + errors))))