In [1]:
import cv2
import os
import numpy as np
import xml.etree.ElementTree as ET

class Preprocessing:
    def __init__(self):
        # Create an instance of MyClass without calling __init__
        self.create_xml_file_class = CreateXMLFile.__new__(CreateXMLFile)
        self.read_txt_file_class = ReadTxtFile.__new__(ReadTxtFile)
        self.controller_class =  Controller.__new__(Controller)
        self.segment_type = ''   
        
    def page_setup(self, dataset_directory, final_dataset_directory, transcription_directory, image_name, image, 
                   line_threshold, number_threshold, name_threshold, margin_threshold):
        height, width, _ = image.shape
        self.image_to_crop = image.copy() #### to be removed later
        # Display the image in its original size
        cv2.namedWindow(image_name, cv2.WINDOW_NORMAL)
        cv2.resizeWindow(image_name, width, height)
        cv2.imshow(image_name, image)
        cv2.setWindowProperty(image_name, cv2.WND_PROP_FULLSCREEN, cv2.WINDOW_FULLSCREEN)
        
        # Search the transcription for <page> tag
        if self.read_txt_file_class.search_tag(os.path.join(transcription_directory,image_name[:-4]+'.txt'), '<page>'):
            page_type = 'Double page'
        else:
            page_type = 'Single page'

        # Get the transcription text from the txt file
        txt_file_content = self.read_txt_file_class.read_file(os.path.join(transcription_directory,image_name[:-4]+'.txt'))                    

        # Create XML file initial tags
        root = self.create_xml_file_class.create_element('Page')
        self.create_xml_file_class.create_attrib(root, 'Name', image_name)
        self.create_xml_file_class.create_attrib(root, 'Type', page_type)
        self.create_xml_file_class.create_attrib(root, 'ImageHeight', str(height))
        self.create_xml_file_class.create_attrib(root, 'ImageWidth', str(width))
        reading_order = self.create_xml_file_class.create_subelement(root, 'ReadingOrder')
                   
        self.new_page_flag = False
        while not self.new_page_flag and not keyboard.is_pressed("esc"):
            # Select ROI interactively
            segment_coordinates = cv2.selectROI(image_name, image, False, False)
            segment_x, segment_y, segment_w, segment_h = segment_coordinates
            # Crop image
            segment = self.image_to_crop[int(segment_y) : int(segment_y + segment_h),
                            int(segment_x) : int(segment_x + segment_w)]

            # Save the cropped segment in new folder path 'dataset/finaldataset/paragraphs/images folder name'
            self.create_interface() # determine segment type

            if self.new_page_flag:
                break

            # ِAdd information to the xml file
            segment_index_number = len(root.findall('.//RegionRefIndexed')) + 1

            region_ref_indexed = self.create_xml_file_class.create_subelement(reading_order, 'RegionRefIndexed')
            self.create_xml_file_class.create_attrib(region_ref_indexed, 'Index', str(segment_index_number))
            self.create_xml_file_class.create_attrib(region_ref_indexed, 'Type', self.segment_type) 

            text_region = self.create_xml_file_class.create_subelement(root, 'TextRegion')
            self.create_xml_file_class.create_attrib(text_region, 'Index', str(segment_index_number))
            self.create_xml_file_class.create_attrib(text_region, 'Type', self.segment_type) 
            text_region_coords = self.create_xml_file_class.create_subelement(text_region, 'Coords')
            self.create_xml_file_class.create_attrib(text_region_coords, 'Points', str(segment_coordinates))
            text_region_text = self.create_xml_file_class.create_subelement(text_region, 'TextEquiv')
            text_region_text_unicode = self.create_xml_file_class.create_subelement(text_region_text, 'Unicode')

            # save Image segment
            save_directory_path = os.path.join(final_dataset_directory, image_name.split("_")[0] + '/paragraphs/')
            self.controller_class.create_directory(save_directory_path)
            self.controller_class.save_image(os.path.join(save_directory_path, image_name[:-4] + f'_{self.segment_type}_' + str(segment_index_number) + '.jpg'), segment)

            if (self.segment_type == 'Number'):
                    txt_file_content = self.get_number_info(txt_file_content, text_region_text_unicode)
            elif (self.segment_type == 'Name'):
                    txt_file_content = self.get_name_info(txt_file_content, text_region_text_unicode)
            elif (self.segment_type == 'Paragraph'):
                    txt_file_content = self.get_text_info('<text>', txt_file_content, text_region_text_unicode,text_region, segment, segment_coordinates, 
                                            segment_index_number, save_directory_path, image_name, image)
            elif (self.segment_type == 'Margin'):
                    txt_file_content = self.get_text_info('<margin>', txt_file_content, text_region_text_unicode,text_region, segment, segment_coordinates, 
                                         segment_index_number, save_directory_path, image_name, image)
        cv2.waitKey(0)
        # Destroy the window after key press
        cv2.destroyAllWindows()
        save_xml_path = os.path.join(final_dataset_directory, image_name.split("_")[0] + '/xml_pages/')
        self.controller_class.create_directory(save_xml_path)
        self.create_xml_file_class.create_final_xml_file(save_xml_path, image_name[:-4] + '.xml', root)

    def get_text_info(self, tag, txt_file_content, text_region_text_unicode, text_region, segment, segment_coordinates, segment_index_number, save_directory_path, image_name, image):
        txt_text = None  # Initialize txt_text to None
        # Add paragraph/margin text
        start_index = txt_file_content.find(tag)
        end_index = txt_file_content.find('<\\' + tag[1:], start_index)
        if start_index != -1 and end_index != -1:
            start_index += len(tag)
            txt_text = txt_file_content[start_index:end_index].strip()
            self.create_xml_file_class.add_text(text_region_text_unicode, txt_text)
            # Get lines coordinates
            if tag == "<margin>":
                lines_coordinates = self.get_lines(image, segment, segment_coordinates, segment_index_number, save_directory_path, image_name, margin_threshold, padding=5)
            elif tag == "<text>":
                lines_coordinates = self.get_lines(image, segment, segment_coordinates, segment_index_number, save_directory_path, image_name, line_threshold, padding=10)

            # Split the paragraph to lines
            lines = txt_text.splitlines()
            for id, (line, line_coords) in enumerate(zip(lines, lines_coordinates), start=1):
                text_line = self.create_xml_file_class.create_subelement(text_region, 'TextLine')
                self.create_xml_file_class.create_attrib(text_line, 'ID', str(id))
                text_line_coords = self.create_xml_file_class.create_subelement(text_line, 'Coords')
                self.create_xml_file_class.create_attrib(text_line_coords, 'Points', str(line_coords))
                text_line_text = self.create_xml_file_class.create_subelement(text_line, 'TextEquiv')
                text_line_text_unicode = self.create_xml_file_class.create_subelement(text_line_text, 'Unicode')
                self.create_xml_file_class.add_text(text_line_text_unicode, line)
            print ('text lines ID are:  ', id)

            # Remove text from the txt file after appending it to the xml file
            if txt_text is not None:
                txt_file_content = txt_file_content.replace(tag, "", 1)
                txt_file_content = txt_file_content.replace(txt_text, "", 1)
                txt_file_content = txt_file_content.replace('<\\' + tag[1:], "", 1)
        return txt_file_content
    
    def remove_and_get_first_line(self, txt_file_content, text_region_text_unicode):
        txt_file_lines = txt_file_content.splitlines()
        while txt_file_lines:
            first_line = txt_file_lines.pop(0).strip()
            if not first_line.isspace():  # Check if the line is not empty
                self.create_xml_file_class.add_text(text_region_text_unicode, first_line)
                break
        return '\n'.join(txt_file_lines)

    def get_name_info(self, txt_file_content, text_region_text_unicode):
        txt_file_content = self.remove_and_get_first_line(txt_file_content, text_region_text_unicode)
        return txt_file_content
    
    def get_number_info(self, txt_file_content, text_region_text_unicode):
        begin_tag = '<begin>'
        if begin_tag in txt_file_content:
            txt_file_content = txt_file_content.split(begin_tag, 1)[1].lstrip('\n')
        txt_file_content = self.remove_and_get_first_line(txt_file_content, text_region_text_unicode)
        return txt_file_content

    def get_lines(self, image, segment, segment_coordinates, segment_index_number, save_directory_path, image_name, line_threshold, padding):
        line_index_number = 1
        self.controller_class.create_directory(os.path.join(save_directory_path, 'lines/'))
        save_line_directory_path = os.path.join(save_directory_path, 'lines/')
        
        
        gray = cv2.cvtColor(segment, cv2.COLOR_BGR2GRAY)
        blurImg = cv2.GaussianBlur(gray, (101, 51), 61) # (101, 51), 61)
        thresh = cv2.adaptiveThreshold(blurImg, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 71, 2)
        kernel = np.ones((1, 205), np.uint8) # (1, 205)
        img_dilation = cv2.dilate(thresh, kernel, iterations = 1)
        
        """# Create a window and display the image
        cv2.namedWindow('Dilated Image', cv2.WINDOW_NORMAL)
        cv2.imshow('Dilated Image', img_dilation)

        # Wait for a key press and then close the window
        cv2.waitKey(0)
        cv2.destroyAllWindows()"""
         
        
        # Find contours
        ctrs, hier = cv2.findContours(img_dilation.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        ctrs = list(filter(lambda ctr: cv2.boundingRect(ctr)[2] > line_threshold, ctrs))
        ctrs = list(filter(lambda ctr: cv2.boundingRect(ctr)[3] > 25, ctrs))

        # Overlapped text lines: Filter contours with height above 160 and split it into two lines
        ctrs = self.split_overlapped_text_lines(ctrs)    
        print ('text lines IMAGES are:  ', len(ctrs))
        
        x, y, w, h = segment_coordinates
        lines_coordinates = []
        # Generate random colors
        colors = np.random.randint(0, 150, size=(100, 3))
        for i, ctr in enumerate(ctrs):
            # Get bounding box
            x_, y_, w_, h_ = cv2.boundingRect(ctr)
            x_ = x_ + x
            y_ = y_ + y
            lines_coordinates.append((x_, y_, w_, h_))  # to be returned in order to link it to its equivalent text in the xml file

            roi = self.image_to_crop[y_ - padding: y_ + h_ + padding, x_: x_ + w_ + padding]

            # Perform text skew correction تصحيح نسبة ميلان الخط
            #roi = self.text_skew_correction(roi)
            
            # Show ROI
            self.controller_class.save_image(os.path.join(save_line_directory_path,
                                         image_name[:-4] +
                                         f'_{self.segment_type}_' +
                                         f'{segment_index_number}' +
                                         '_line_' +
                                         str(line_index_number) +
                                         '.jpg'), roi)

            line_index_number += 1
            color = colors[i].tolist()  # Convert color to a list
            cv2.rectangle(image, (x_, y_ - padding), (x_ + w_ + padding, y_ + h_ + padding), color, 5)
            
            
            # to be removed later ---------------
            # Calculate the position for the text
            text_position = (x_ + int(w_ / 2), y_ + int(h_ / 2))  # Adjust the y-coordinate as needed

            # Add a text label (number)
            font = cv2.FONT_HERSHEY_SIMPLEX
            font_scale = 2
            font_color = (255, 0, 0)  # White color
            line_thickness = 5
            cv2.putText(image, str(line_index_number-1), text_position, font, font_scale, font_color, line_thickness)
            # -----------------------------------
            
        
            
        return lines_coordinates
   
    
    
        
        
        
    def text_skew_correction(self, text_line):
        
        # Convert the image to grayscale
        gray = cv2.cvtColor(text_line, cv2.COLOR_BGR2GRAY)

        # Apply Gaussian blur to reduce noise
        blurImg = cv2.GaussianBlur(gray, (101, 51), 61)

        # Apply adaptive thresholding to create a binary image
        thresh = cv2.adaptiveThreshold(blurImg, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 71, 2)

        # Dilate the image to connect text components
        kernel = np.ones((1, 205), np.uint8)
        img_dilation = cv2.dilate(thresh, kernel, iterations=1)

        # Find contours of text components
        ctrs, hier = cv2.findContours(img_dilation.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

        # Find the primary contour with the largest area
        primary_contour = max(ctrs, key=cv2.contourArea)

        # Fit a rotated bounding box to the primary contour
        rect = cv2.minAreaRect(primary_contour)
        box = cv2.boxPoints(rect)
        box = np.int0(box)

        # Get the angle of the skew from the rotated bounding box
        angle = rect[-1]
        print ('angle is: ', angle)
        
        # Limit the angle to avoid wrong vertical rotation
        max_allowed_angle = 1.5
        if abs(angle) > max_allowed_angle:
            angle = np.sign(angle) * max_allowed_angle

        # Rotate the image to correct the skew
        rows, cols = text_line.shape[:2]
        M = cv2.getRotationMatrix2D((cols // 2, rows // 2), angle, 1)
        corrected_image = cv2.warpAffine(text_line, M, (cols, rows), flags=cv2.INTER_CUBIC, borderMode=cv2.BORDER_WRAP)
        return corrected_image
    
    # Overlapped text lines: Filter contours with height above 160 and split it into two lines
    def split_overlapped_text_lines(self, ctrs):
        ctrs_new = []  # List to store the modified contours
        for idx, contour in enumerate(ctrs):
            if cv2.boundingRect(contour)[3] > 160:
                x, y, w, h = cv2.boundingRect(contour)
                middle_y = y + h // 2
                # Create two new contours by dividing the original contour horizontally
                contour_top = contour.copy()
                contour_top[:, :, 1] = contour[:, :, 1].clip(None, middle_y)
                contour_bottom = contour.copy()
                contour_bottom[:, :, 1] = contour[:, :, 1].clip(middle_y, None)
                # Append the new contours to the modified contour list
                ctrs_new.append(contour_bottom)
                ctrs_new.append(contour_top)
            else:
                # If not split, keep the original contour
                ctrs_new.append(contour)
        # Reverse the order of modified contours
        ctrs_new.reverse()
        # Return the modified contour list
        return ctrs_new







    # Interface
    def new_entry(self, root):
        pass
        #self.segment_type = 'New Enter' 
        #print("New Entry button clicked")
        #root.destroy()
    
    # Interface
    def new_page(self, root):
        self.new_page_flag = True
        print("New page button clicked")
        root.destroy()
    
    # Interface
    def name(self, root):
        self.segment_type = 'Name' 
        print("Name button clicked")
        root.destroy()
        
    # Interface
    def number(self, root):
        self.segment_type = 'Number' 
        print("Number button clicked")
        root.destroy()
        
    # Interface
    def paragraph(self, root):
        self.segment_type = 'Paragraph' 
        print("Paragraph button clicked")
        root.destroy()
        
    # Interface
    def margin(self, root):
        self.segment_type = 'Margin'
        print("Margin button clicked")
        root.destroy()
    
    # Interface
    def create_interface(self):
        # Create the main window
        root = tk.Tk()
        root.title("Segment Classification")
        
        # Get the screen width and height
        screen_width = root.winfo_screenwidth()
        screen_height = root.winfo_screenheight()

        # Calculate the center position of the window
        window_width = 280
        window_height = 470
        x = (screen_width - window_width) // 2
        y = (screen_height - window_height) // 2

        # Set the window geometry to be centered on the screen
        root.geometry(f"{window_width}x{window_height}+{x}+{y}")
    
        # Create a custom font with larger size
        custom_font = font.Font(size=16)

        # Create buttons with larger size and custom font
        button_new_entry = tk.Button(root, text="New Entry", command=lambda: self.new_entry(root), font=custom_font, width=15, height=2)
        button_number = tk.Button(root, text="Number", command=lambda: self.number(root), font=custom_font, width=15, height=2, bg='blue')
        button_name = tk.Button(root, text="Name", command=lambda: self.name(root), font=custom_font, width=15, height=2, bg='green')
        button_paragraph = tk.Button(root, text="Paragraph", command=lambda: self.paragraph(root), font=custom_font, width=15, height=2, bg='yellow')
        button_margin = tk.Button(root, text="Margin", command=lambda: self.margin(root), font=custom_font, width=15, height=2, bg='orange')
        button_new_page = tk.Button(root, text="New Page", command=lambda: self.new_page(root), font=custom_font, width=20, height=2, bg='red')

        # Arrange buttons using the grid layout manager
        button_new_entry.grid(row=0, column=0, padx=10, pady=5)
        button_number.grid(row=1, column=0, padx=10, pady=5)
        button_name.grid(row=2, column=0, padx=10, pady=5)
        button_paragraph.grid(row=3, column=0, padx=10, pady=5)
        button_margin.grid(row=4, column=0, padx=10, pady=5)
        button_new_page.grid(row=5, column=0, padx=10, pady=5)

        # Start the main event loop
        root.mainloop()