In [None]:
import cv2
import numpy as np
from PIL import Image
from datetime import datetime
import matplotlib.pyplot as plt
from reportlab.lib.pagesizes import letter
from reportlab.pdfgen import canvas
from reportlab.lib.utils import ImageReader
from reportlab.lib import colors
from reportlab.platypus import Table, TableStyle
import os

# HSV range definitions for different color thresholds
lower_teeth = np.array([10, 50, 50])  # Lower bound for non-red colors
upper_teeth = np.array([170, 255, 255])  # Upper bound for non-red colors
lower_dark_yellow = np.array([20, 100, 100])
upper_dark_yellow = np.array([30, 255, 200])
lower_skin = np.array([0, 20, 70])
upper_skin = np.array([20, 255, 255])
brightness_threshold = 235

# Function to calculate the teeth color score based on RGB brightness and yellowness/whiteness

def calculate_teeth_color_score(hex_colors):
    # Convert hex to RGB
    def hex_to_rgb(hex_color):
        hex_color = hex_color.lstrip('#')
        return tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4))
    
    scores = []
    for hex_color in hex_colors:
        rgb_color = hex_to_rgb(hex_color)
        # Omit black color from calculation
        if rgb_color == (0, 0, 0):
            continue

        # Calculate brightness using the weighted sum method (luminance formula)
        brightness = 0.299 * rgb_color[0] + 0.587 * rgb_color[1] + 0.114 * rgb_color[2]
        
        # Calculate yellowness as the difference between red and blue channels
        yellowness = max(0, rgb_color[0] - rgb_color[2])
        
        # Calculate whiteness as the average of all three channels
        whiteness = sum(rgb_color) / 3
        
        # Combine brightness, yellowness, and whiteness into a final score, penalizing high yellowness
        teeth_color_score = (0.5 * (brightness / 255) * 100) - (0.3 * (yellowness / 255) * 100) + (0.2 * (whiteness / 255) * 100)
        scores.append(teeth_color_score)
    
    # Return the average score if scores exist, otherwise return None
    return sum(scores) / len(scores) if scores else None

# Function to calculate average color and related metrics
def calculate_average_color(image_paths, lower_color=None, upper_color=None, lower_exclude=None, upper_exclude=None, lower_skin=None, upper_skin=None, reflection_threshold=240, brightness_threshold=200):
    images = [Image.open(image_path) for image_path in image_paths]

    average_colors, average_colors_hex, omitted_percentages, pixels_used, pixels_omitted = [], [], [], [], []

    for image in images:
        image_cv = np.array(image)
        hsv_image = cv2.cvtColor(image_cv, cv2.COLOR_RGB2HSV)

        if lower_color is not None and upper_color is not None:
            mask = cv2.inRange(hsv_image, lower_color, upper_color)

            if lower_exclude is not None and upper_exclude is not None:
                exclude_mask = cv2.inRange(hsv_image, lower_exclude, upper_exclude)
                mask = mask & ~exclude_mask

            if lower_skin is not None and upper_skin is not None:
                skin_mask = cv2.inRange(hsv_image, lower_skin, upper_skin)
                mask = mask & ~skin_mask

            reflection_mask = hsv_image[:, :, 2] < reflection_threshold
            mask = mask & reflection_mask

            brightness_mask = hsv_image[:, :, 2] > brightness_threshold
            mask = mask & brightness_mask

            masked_pixels = image_cv[mask > 0]
        else:
            masked_pixels = image_cv

        if len(masked_pixels) > 0:
            average_color = np.mean(masked_pixels, axis=0)
            average_color_int = average_color.astype(int)
            average_color_hex = '#%02x%02x%02x' % tuple(average_color_int[:3])

            average_colors.append(average_color)
            average_colors_hex.append(average_color_hex)
        else:
            average_colors.append(np.array([0, 0, 0]))
            average_colors_hex.append("#000000")

        total_pixels = image_cv.shape[0] * image_cv.shape[1]
        omitted_pixels = total_pixels - len(masked_pixels)
        omitted_percentage = (omitted_pixels / total_pixels) * 100
        omitted_percentages.append(omitted_percentage)
        pixels_used.append(len(masked_pixels))
        pixels_omitted.append(omitted_pixels)

    return average_colors, average_colors_hex, omitted_percentages, pixels_used, pixels_omitted

# Function to create PDF report
def create_pdf_report(pdf_file_name, image_paths, average_colors_hex, omitted_percentages, pixels_used, pixels_omitted, average_color_all_hex, average_color_all_int):
    # Create a PDF canvas
    c = canvas.Canvas(pdf_file_name, pagesize=letter)
    width, height = letter

    # Title Page
    c.setFont("Times-Roman", 30)
    c.drawString(30, height - 50, "Teeth Color Analysis Report")
    c.setFont("Times-Roman", 16)
    c.drawString(30, height - 100, f"Generated on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    c.showPage()

    # Contents Page
    c.setFont("Times-Roman", 18)
    c.drawString(30, height - 50, "Contents")
    c.setFont("Times-Roman", 11)
    c.drawString(30, height - 100, "1. Data Results Page")
    c.drawString(30, height - 120, "2. Summary Page")
    c.drawString(30, height - 140, "3. Attachments")
    c.showPage()

    # Data Results Page
    c.setFont("Times-Roman", 18)
    c.drawString(30, height - 30, "Data Results")


    # Create table data
    table_data = [["Image", "Average Color", "Omitted Percentage", "Pixels Used", "Pixels Omitted"]]
    for i, (hex_color, image_path, omitted_percentage, used, omitted) in enumerate(zip(average_colors_hex, image_paths, omitted_percentages, pixels_used, pixels_omitted)):
        table_data.append([image_path.split('/')[-1], hex_color, f"{omitted_percentage:.2f}%", used, omitted])

    # Create the table with repeatRows=1 to repeat the header on every page
    table = Table(table_data, repeatRows=1)
    table.setStyle(TableStyle([
        ('BACKGROUND', (0, 0), (-1, 0), colors.grey),
        ('TEXTCOLOR', (0, 0), (-1, 0), colors.whitesmoke),
        ('ALIGN', (0, 0), (-1, -1), 'CENTER'),
        ('FONTNAME', (0, 0), (-1, 0), 'Times-Roman'),
        ('FONTSIZE', (0, 0), (-1, 0), 12),
        ('BOTTOMPADDING', (0, 0), (-1, 0), 12),
        ('BACKGROUND', (0, 1), (-1, -1), colors.beige),
        ('GRID', (0, 0), (-1, -1), 1, colors.black),
    ]))

    available_page_height = height - 100  # Leave space for headers
    table_width, table_height = table.wrapOn(c, width - 60, available_page_height)

    if table_height > available_page_height:
        c.showPage()
        available_page_height = height - 100

    table.drawOn(c, 30, available_page_height - table_height)

    c.showPage()

    # Calculate summary statistics
    total_images = len(image_paths)
    avg_omitted = np.mean(omitted_percentages)
    avg_pixels_used = np.mean(pixels_used)
    avg_pixels_omitted = np.mean(pixels_omitted)
    total_pixels_used = sum(pixels_used)
    total_pixels_omitted = sum(pixels_omitted)

    # Summary Page
    c.setFont("Times-Roman", 18)
    c.drawString(30, height - 30, "Summary")

    # Draw the average of all average hex values and RGB
    y_position = height - 60
    c.setFont("Times-Roman", 11)
    c.drawString(30, y_position, f"Average of All Average Colors: {average_color_all_hex}")
    c.drawString(30, y_position - 20, f"RGB: {average_color_all_int[0]}, {average_color_all_int[1]}, {average_color_all_int[2]}")
    y_position -= 40
    c.setFillColorRGB(average_color_all_int[0]/255, average_color_all_int[1]/255, average_color_all_int[2]/255)
    c.rect(30, y_position - 20, 50, 20, fill=1)
    c.setFillColorRGB(0, 0, 0)

    # Calculate the teeth color score
    teeth_color_score = calculate_teeth_color_score(average_colors_hex)

    # Check for black colors in average_colors_hex
    black_color_count = average_colors_hex.count("#000000")
    if black_color_count > 0:
        y_position -= 40
        c.setFillColor(colors.red)
        c.drawString(30, y_position, f"Warning: {black_color_count} images were invalid, results may be inaccurate or incorrect.")
        c.setFillColor(colors.black)

    # Display the Teeth Color Score
    y_position -= 40
    c.setFont("Times-Roman", 16)
    # bold font for the score
    c.setFont("Times-Bold", 16)
    c.drawString(30, y_position, f"Teeth Color Score: {teeth_color_score:.2f}")
    c.setFont("Times-Roman", 11)
    # explain the score
    c.drawString(30, y_position - 20, "The teeth color score is calculated based on the brightness, yellowness, and whiteness of the average colors. ")
    c.drawString(30, y_position - 40, "The score is typically between 50 and 55, and higher is better.")
    y_position -= 40


    # Add average and total statistics
    y_position -= 40
    c.drawString(30, y_position, f"Number of images processed: {total_images}")
    c.drawString(30, y_position - 20, f"Average omitted percentage: {avg_omitted:.2f}%")
    c.drawString(30, y_position - 40, f"Average pixels used: {avg_pixels_used:.2f}")
    c.drawString(30, y_position - 60, f"Average pixels omitted: {avg_pixels_omitted:.2f}")
    c.drawString(30, y_position - 80, f"Total pixels used: {total_pixels_used}")
    c.drawString(30, y_position - 100, f"Total pixels omitted: {total_pixels_omitted}")


    c.showPage()

    # Attachments Page
    c.setFont("Times-Roman", 18)
    c.drawString(30, height - 30, "Attachments")

    y_position = height - 60  # Reset y_position for the attachments
    image_width = (width - 60) / 3
    image_height = image_width * 0.75

    c.setFont("Times-Roman", 11)

    for image_path in image_paths:
        image_reader = ImageReader(image_path)
        c.drawImage(image_reader, 30, y_position - image_height, width=image_width, height=image_height, preserveAspectRatio=True, mask='auto')
        c.drawString(30, y_position - image_height - 15, os.path.basename(image_path))  # Add the file name below the image
        y_position -= (image_height + 35)  # Adjust y_position to account for the file name text
        if y_position < 100:
            c.showPage()
            y_position = height - 30
            c.drawString(30, height - 30, "Attachments (cont.)")

    c.save()
    print(f"Results have been written to {pdf_file_name}")

# Function to plot images and masks
def plot_images(image_paths, average_colors, omitted_percentages, pixels_used, pixels_omitted):
    fig, ax = plt.subplots(len(image_paths), 3, figsize=(15, 15))
    for i, (color, hex_color, image_path, omitted_percentage, used, omitted) in enumerate(zip(average_colors, average_colors_hex, image_paths, omitted_percentages, pixels_used, pixels_omitted)):
        image = Image.open(image_path)

        ax[i, 0].imshow(image)
        ax[i, 0].set_title(f"Original Image\n{image_path.split('/')[-1]}")
        ax[i, 0].axis('off')

        image_cv = np.array(image)
        hsv_image = cv2.cvtColor(image_cv, cv2.COLOR_RGB2HSV)
        mask = cv2.inRange(hsv_image, lower_teeth, upper_teeth)
        exclude_mask = cv2.inRange(hsv_image, lower_dark_yellow, upper_dark_yellow)
        mask = mask & ~exclude_mask
        skin_mask = cv2.inRange(hsv_image, lower_skin, upper_skin)
        mask = mask & ~skin_mask
        reflection_mask = hsv_image[:, :, 2] < 240
        mask = mask & reflection_mask
        brightness_mask = hsv_image[:, :, 2] > brightness_threshold
        mask = mask & brightness_mask

        ax[i, 1].imshow(mask, cmap='gray')
        ax[i, 1].set_title(f"Color Range Mask\n{image_path.split('/')[-1]}\n(White is in range)")
        ax[i, 1].axis('off')

        ax[i, 2].imshow(np.array([[color[:3]]], dtype=np.uint8))
        ax[i, 2].set_title(f"{hex_color}\n{image_path.split('/')[-1]}\nAvg color\nOmitted: {omitted_percentage:.2f}%\nUsed: {used}\nOmitted: {omitted}")
        ax[i, 2].axis('off')

    plt.tight_layout()
    plt.show()

# Main function to execute the process
def main():
    path = "./images/"
    
    # Dynamically get all PNG files in the /images directory
    image_paths = [os.path.join(path, file) for file in os.listdir(path) if file.endswith('.png') or file.endswith('.jpg')]

    # Check if at least two images are provided, else exit with error message
    if len(image_paths) < 2:
        print("Please provide at least two images to process for accuracy.")
        return

    # Calculate average colors
    average_colors, average_colors_hex, omitted_percentages, pixels_used, pixels_omitted = calculate_average_color(
        image_paths, lower_teeth, upper_teeth, lower_dark_yellow, upper_dark_yellow, lower_skin, upper_skin, brightness_threshold=brightness_threshold
    )

    # Filter out invalid or empty entries from average_colors before calculating the mean
    print("Average Colors:", average_colors)
    valid_average_colors = [color[:3] for color in average_colors if len(color) == 4]
    valid_rgb_average_colors = [color for color in average_colors if len(color) == 3]

    # Ensure there are valid colors to average
    if len(valid_average_colors) > 0:
        # Calculate the average of all valid average colors
        average_color_all = np.mean(valid_average_colors, axis=0)
        average_color_all_int = average_color_all.astype(int)
        average_color_all_hex = '#%02x%02x%02x' % tuple(average_color_all_int[:3])
    elif len(valid_rgb_average_colors) > 0:
        # Calculate the average of all valid average colors
        average_color_all = np.mean(valid_rgb_average_colors, axis=0)
        average_color_all_int = average_color_all.astype(int)
        average_color_all_hex = '#%02x%02x%02x' % tuple(average_color_all_int[:3])
    else:
        print("No valid average colors found to calculate overall average.")
        average_color_all_int = [0, 0, 0]
        average_color_all_hex = "#000000"

    

    # Generate the PDF file name with the current date and time
    pdf_file_name = f"TCAR_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pdf"

    # Create the PDF report
    create_pdf_report(pdf_file_name, image_paths, average_colors_hex, omitted_percentages, pixels_used, pixels_omitted, average_color_all_hex, average_color_all_int)

    # Plot the images and masks
    plot_images(image_paths, average_colors, omitted_percentages, pixels_used, pixels_omitted)

# Call the main function
if __name__ == "__main__":
    main()
