In [1]:
import torch
import torch.nn as nn
import torchvision.transforms.functional as TF

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, feature_map):
        return self.conv(feature_map)

class UNET(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512],):
        super(UNET, self).__init__()
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Down part of UNET
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature

        # Up part of UNET
        for feature in reversed(features):
            self.ups.append(
                nn.ConvTranspose2d(
                    feature*2, feature, kernel_size=2, stride=2,
                )
            )
            self.ups.append(DoubleConv(feature*2, feature))

        self.bottleneck = DoubleConv(features[-1], features[-1]*2)
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)


    def forward(self, feature_map):
        skip_connections = []

        for down in self.downs:
            feature_map = down(feature_map)
            skip_connections.append(feature_map)
            feature_map = self.pool(feature_map)

        feature_map = self.bottleneck(feature_map)
        skip_connections = skip_connections[::-1]

        for idx in range(0, len(self.ups), 2):
            feature_map = self.ups[idx](feature_map)
            skip_connection = skip_connections[idx//2]

            if feature_map.shape != skip_connection.shape:
                feature_map = TF.resize(feature_map, size=skip_connection.shape[2:])

            concat_skip = torch.cat((skip_connection, feature_map), dim=1)
            feature_map = self.ups[idx+1](concat_skip)

        return self.final_conv(feature_map)

In [2]:
import torch
from PIL import Image
import numpy as np
import torchvision.transforms as transforms

def getDisk(image):
    IMAGE_HEIGHT = 160  # 1371 originally
    IMAGE_WIDTH = 240  # 1376 originally

    checkpoint = torch.load("disk_checkpoint.pth.tar", map_location=torch.device('cpu'))
    model = UNET(in_channels=3, out_channels=1)
    model.load_state_dict(checkpoint['state_dict'])
    model.eval()

    # Preprocess the input image
    transform = transforms.Compose([
        transforms.Resize((IMAGE_HEIGHT, IMAGE_WIDTH)),
        transforms.ToTensor(),
        
    ])
    input_image = image
    input_tensor = transform(input_image).unsqueeze(0)

    # Perform inference
    with torch.no_grad():
        output = model(input_tensor)

    # Convert the output to a numpy array
    output = torch.sigmoid(output).squeeze().cpu().numpy()  # Applying sigmoid to convert logits to probabilities
    predicted_mask = (output > 0.5).astype(np.uint8)  # Applying thresholding to obtain binary mask

    # Convert the predicted mask array to PIL Image
    predicted_mask_image = Image.fromarray(predicted_mask * 255)

    
    return predicted_mask_image



def getCup(image):
    IMAGE_HEIGHT = 160  # 1371 originally
    IMAGE_WIDTH = 240  # 1376 originally

    
    checkpoint = torch.load("cup_checkpoint.pth.tar", map_location=torch.device('cpu'))
    model = UNET(in_channels=3, out_channels=1)
    model.load_state_dict(checkpoint['state_dict'])
    model.eval()

    # Preprocess the input image
    transform = transforms.Compose([
        transforms.Resize((IMAGE_HEIGHT, IMAGE_WIDTH)),
        transforms.ToTensor(),
        
    ])
    input_image = image
    input_tensor = transform(input_image).unsqueeze(0)

    # Perform inference
    with torch.no_grad():
        output = model(input_tensor)

    # Convert the output to a numpy array
    output = torch.sigmoid(output).squeeze().cpu().numpy()  # Applying sigmoid to convert logits to probabilities
    predicted_mask = (output > 0.5).astype(np.uint8)  # Applying thresholding to obtain binary mask

    # Convert the predicted mask array to PIL Image
    predicted_mask_image = Image.fromarray(predicted_mask * 255)

    
    return predicted_mask_image


def findCDR(image):
    disk_image = getDisk(image)
    cup_image = getCup(image)
    
    #get area of disk
    
    disk_array = np.array(disk_image)
    
    disk_area = np.count_nonzero(disk_array == 255)
    
    #get area of cup
    
    cup_array = np.array(cup_image)
    
    cup_area = np.count_nonzero(cup_array == 255)
    
    #print(f"area is {cup_area / disk_area}")
    
    global cdr
    
    cdr = cup_area / disk_area
    
    # make final image
    
    
    # Identify overlapping white pixels
    overlap_pixels = (disk_array == 255) & (cup_array == 255)
    
    disk_array[disk_array == 255] = 128  # Set white pixels to grey (128)

    # Set the overlapping white pixels to white
    disk_array[overlap_pixels] = 255

    # Convert the updated NumPy array back to PIL Image
    final_image = Image.fromarray(disk_array)
    
    return final_image
    

In [4]:
import tkinter as tk
from tkinter import filedialog, messagebox, Toplevel
from PIL import Image, ImageTk

# Function to display the processed image in a new window
def display_processed_image(original_image, processed_image):
    # Create a new window
    processed_window = Toplevel(root)
    processed_window.title("Digital Image Processing")
    processed_window.geometry("1000x1000")
    processed_window.configure(bg="#007ACC")

    
    original_label = tk.Label(processed_window, text='Original Image', font=('Georgia', 24), bg='#007ACC', fg='black')
    original_label.pack(pady=5)

    # Resize the original image to fit the new window
    resized_original = original_image.resize((350, 250), Image.Resampling.LANCZOS)
    original_photo = ImageTk.PhotoImage(resized_original)

    
    original_image_label = tk.Label(processed_window, image=original_photo, bg="#87CEEB")
    original_image_label.image = original_photo  # Keep a reference to the image
    original_image_label.pack()

    
    processed_label = tk.Label(processed_window, text='Processed Image', font=('Georgia', 24), bg='#007ACC', fg='black')
    processed_label.pack(pady=5)

    # Resize the processed image to fit the new window
    resized_processed = processed_image.resize((350, 250), Image.Resampling.LANCZOS)
    processed_photo = ImageTk.PhotoImage(resized_processed)

    
    processed_image_label = tk.Label(processed_window, image=processed_photo, bg="#87CEEB")
    processed_image_label.image = processed_photo  # Keep a reference to the image
    processed_image_label.pack()
    
    # Label to display the value of CDR
    
    glaucomacheck = ""
    if(cdr < 0.5):
        glaucomacheck = " (No Indication of Glaucoma)"
    else:
        glaucomacheck = " (Enlarged Cup. You May have Glaucoma)"
    
    cdr_label = tk.Label(processed_window, text=f"CDR: { str(round(cdr,2)) + glaucomacheck}", font=('Georgia', 16), bg='#007ACC', fg='black')
    cdr_label.pack(pady=5)

# Function to load an image
def upload_image():
    global original_image
    file_path = filedialog.askopenfilename()
    if file_path:
        original_image = Image.open(file_path)

# Function to process the image
def process_image():
    global processed_image
    global original_image
    
    processed_image = findCDR(original_image)#getCup(original_image)#getDisk(original_image)
    display_processed_image(original_image, processed_image)

# Function to save the processed image
def save_image():
    if processed_image:
        file_path = filedialog.asksaveasfilename(defaultextension=".png")
        if file_path:
            processed_image.save(file_path)
    else:
        messagebox.showerror("Save Error", "There is no image to save.")

# Main window setup
root = tk.Tk()
root.title("Digital Image Processing")
root.geometry("1000x1000")

# Background image setup
bg_img = Image.open('backgroundimage.png')
bg_img = bg_img.resize((1000, 1000), Image.Resampling.LANCZOS)
bg_img_tk = ImageTk.PhotoImage(bg_img)
bg_image = tk.Label(root, image=bg_img_tk)
bg_image.place(relwidth=1, relheight=1)

# Logo setup
fast_img = Image.open('fastlogo.png')
fast_img = fast_img.resize((100, 100), Image.Resampling.LANCZOS)
fast_img_tk = ImageTk.PhotoImage(fast_img)
fast_logo_label = tk.Label(root, image=fast_img_tk, bg="#007ACC")
fast_logo_label.place(x=10, y=10)

# Title label
title_label = tk.Label(root, text='Digital Image Processing', font=('Georgia', 24), bg='#007ACC', fg='black')
title_label.place(relx=0.5, rely=0.1, anchor='center')

# Upload Image Button
upload_button = tk.Button(root, text="Upload Image", command=upload_image, font=('Georgia', 16), bg='#007ACC', fg='white')
upload_button.place(relx=0.5, rely=0.3, anchor='center')

# Process Image Button
process_button = tk.Button(root, text="Process Image", command=process_image, font=('Georgia', 16), bg='#FFA500', fg='white')
process_button.place(relx=0.5, rely=0.4, anchor='center')

# Save Image Button
save_button = tk.Button(root, text="Save Image", command=save_image, font=('Georgia', 16), bg='#00CC44', fg='white')
save_button.place(relx=0.5, rely=0.5, anchor='center')

# Run the main application loop
root.mainloop()
