In [1]:
from groundingdino.util.inference import load_model, load_image, predict
import cv2
import torch
import csv
from ultralytics import SAM
from pathlib import Path
import time as t
from PIL import Image, ImageDraw
import numpy as np
from PyQt5 import QtWidgets, QtGui, QtCore
import sys
import os
import torch
from PIL import Image
from transformers import MllamaForConditionalGeneration, AutoProcessor
import re
import tempfile
import warnings
warnings.filterwarnings("ignore")

In [2]:
def clean_labels(boxes, max_area):
    clean_boxes = []
    box_list = boxes.tolist()
    for box in box_list:
        # if width * height < 0.9, add box to list.
        if (box[2] * box[3]) < max_area:
            clean_boxes.append(box)
    if len(clean_boxes) < 1:
        return boxes
    return torch.FloatTensor(clean_boxes)

def load_dino_model(model_size='swint'):
    #choose swinb or swint
    if model_size == 'swint':
        config_path = r"C:\Users\cmull\DataspellProjects\AutoAnnotate\GroundingDINO\groundingdino\config\GroundingDINO_SwinT_OGC.py"
        checkpoint_path = r"C:\Users\cmull\DataspellProjects\AutoAnnotate\GroundingDINO\weights\groundingdino_swint_ogc.pth"
    elif model_size == 'swinb':
        checkpoint_path = r"C:\Users\cmull\DataspellProjects\AutoAnnotate\GroundingDINO\weights\groundingdino_swinb_cogcoor.pth"
        config_path = r"C:\Users\cmull\DataspellProjects\AutoAnnotate\GroundingDINO\groundingdino\config\GroundingDINO_SwinB_cfg.py"

    model = load_model(config_path, checkpoint_path)
    return model

def run_dino_from_model(model, img_path, prompt, box_threshold, text_threshold, maxarea=0.7, save_dir="DINO-labels"):
    image_source, image = load_image(img_path)
    boxes, accuracy, obj_name = predict(model = model, image = image, caption = prompt, box_threshold = box_threshold, text_threshold = text_threshold)

    #Convert boxes from YOLOv8 format to xyxy
    img_height, img_width = cv2.imread(img_path).shape[:2]
    clean_boxes = clean_labels(boxes, maxarea)
    absolute_boxes = [[(box[0]-(box[2]/2))*img_width,
                       (box[1]-(box[3]/2))*img_height,
                       (box[0]+(box[2]/2))*img_width,
                       (box[1]+(box[3]/2))*img_height] for box in clean_boxes.tolist()]
    save_labels = True
    if save_labels:
        clean_boxes = clean_boxes.tolist()
        for x in clean_boxes:
            x.insert(0,0)
        with open(f'{save_dir}/{os.path.splitext(os.path.basename(img_path))[0]}.txt', 'w', newline='') as csvfile:
            writer = csv.writer(csvfile, delimiter=' ')
            writer.writerows(clean_boxes)
    return absolute_boxes

def save_masks(sam_results, output_dir):
    segments = sam_results[0].masks.xyn
    with open(f"{Path(output_dir) / Path(sam_results[0].path).stem}.txt", "w") as f:
        for i in range(len(segments)):
            s = segments[i]
            if len(s) == 0:
                continue
            segment = map(str, segments[i].reshape(-1).tolist())
            f.write(f"0 " + " ".join(segment) + "\n")

def run_image(DINO, img_dir, output_dir, prompt, conf, box_threshold, save_dir):
    sam_model = "sam2_t.pt"
    dino_model = "swint"
    start = t.time()
    fname = os.path.basename(img_dir)
    path = img_dir
    if not os.path.exists(save_dir):
        print(f"{save_dir} does not exist, creating")
        os.makedirs(save_dir, exist_ok=True)
    if not os.path.exists(output_dir):
        print(f"{output_dir} does not exist, creating")
        os.makedirs(output_dir, exist_ok=True)

    boxes = run_dino_from_model(DINO, img_dir, prompt, conf, 0.1, box_threshold, save_dir=save_dir)
    model = SAM(sam_model)
    sam_results = model(img_dir, model=sam_model, bboxes=boxes, verbose=False)
    save_masks(sam_results, output_dir)

    print(f"Completed in: {t.time() - start} seconds, masks saved in {output_dir}")
    return sam_results

def adjust_masks(sam_results):
    result = sam_results[0]

    masks = result.masks.data.cpu().numpy()     # masks, (N, H, W)
    masks = np.moveaxis(masks, 0, -1) # masks, (H, W, N)
    masks = np.moveaxis(masks, -1, 0) # masks, (N, H, W)

    return masks

def overlay_with_borders(image, mask, color, thickness=2):
    # Convert mask to uint8 type
    mask_uint8 = (mask * 255).astype(np.uint8)

    # Find contours in the mask
    contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    # Draw contours on the image
    cv2.drawContours(image, contours, -1, color, thickness)
    return image

In [3]:
def draw_boxes_on_image(image, boxes):
    """
    Draw bounding boxes on the image using absolute coordinates with clipping.

    Args:
        image (np.ndarray): The original image.
        boxes (list): List of bounding boxes in the format [x1, y1, x2, y2].

    Returns:
        np.ndarray: Image with bounding boxes drawn on it.
    """
    # Convert the OpenCV image (BGR) to PIL for drawing
    pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
    draw = ImageDraw.Draw(pil_image)
    img_width, img_height = pil_image.size  # Get image dimensions

    # Iterate over the list of boxes and clip coordinates before drawing
    for box in boxes:
        x1, y1, x2, y2 = box

        # Clip coordinates to ensure they are within the image boundaries
        x1 = max(0, min(x1, img_width - 1))
        y1 = max(0, min(y1, img_height - 1))
        x2 = max(0, min(x2, img_width - 1))
        y2 = max(0, min(y2, img_height - 1))

        # Ensure the box is valid
        if x2 > x1 and y2 > y1:
            draw.rectangle([x1, y1, x2, y2], outline=(255, 0, 255), width=4)

    # Convert back to OpenCV format for display
    return cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)

In [4]:
def optimize_prompts(prompts_file, gt_path, img_dir, save_file, threshold, DINO):
    inf_path = fr"GroundingDINO\DINO-labels"

    with open(prompts_file, 'r') as file:
        result_dict = {}
        for x in file:
            result_dict[x.strip()] = {}

    # result_dict = dict.fromkeys(prompts,{})
    for prompt in result_dict.keys():
        print(f'Trying prompt: "{prompt}"')
        for fname in os.listdir(img_dir):
            box_threshold = 0.3
            text_threshold = 0.1
            model_size = 'swint'
            run_dino_from_model(DINO, os.path.join(img_dir, fname), prompt, box_threshold, text_threshold, model_size)

        metrics = process_files(inf_path, gt_path, threshold=threshold)

        result_dict[prompt]['iou_scores'] = np.mean(metrics['iou_scores'])

    results = sorted(list(result_dict.items()), key=lambda a: a[1]['iou_scores'], reverse=True)
    print(results)

    with open(save_file, 'w') as output:
        for prompt_stats in results:
            output.write(str(prompt_stats) + '\n')

    return results

def prompt_optimizer(prompts_file, gt_path, img_path, save_file, threshold, DINO):
    print('entered prompt optimizer')
    # Ensure inference path exists
    inf_path = r"C:\Users\cmull\DataspellProjects\AutoAnnotate\GUI and Pipeline\DINO-labels"
    os.makedirs(inf_path, exist_ok=True)

    # Initialize result dictionary from prompt file
    with open(prompts_file, 'r') as file:
        result_dict = {x.strip(): {} for x in file}

    # Process each prompt
    for prompt in result_dict.keys():
        print(f'Trying prompt: "{prompt}"')

        # Run prediction and save labels
        run_dino_from_model(DINO, img_path, prompt, box_threshold=0.3, text_threshold=0.1, maxarea=threshold)

        # Process single predicted and ground truth file
        predicted_mask_file = os.path.join(inf_path, f"{os.path.splitext(os.path.basename(img_path))[0]}.txt")
        #print(predicted_mask_file)
        metrics = process_file(predicted_mask_file, gt_path, threshold)

        # Save the IoU score for the prompt
        result_dict[prompt]['iou_scores'] = np.mean(metrics['iou_scores'])

    # Sort and save results
    results = sorted(result_dict.items(), key=lambda a: a[1]['iou_scores'], reverse=True)
    print("Results:", results)

    with open(save_file, 'w') as output:
        for prompt_stats in results:
            output.write(str(prompt_stats) + '\n')
    return results

def process_mask_arrays(predicted_mask_array, ground_truth_mask_array):
    # Resize predicted mask to match the ground truth mask's dimensions
    if predicted_mask_array.shape != ground_truth_mask_array.shape:
        predicted_mask_array = cv2.resize(predicted_mask_array, (ground_truth_mask_array.shape[1], ground_truth_mask_array.shape[0]), interpolation=cv2.INTER_NEAREST)

    # Initialize metrics dictionary
    metrics = {
        'iou_scores': [],
        #'pixel_accuracies': [],
        'precision_scores': [],
        'recall_scores': [],
        'f1_scores': [],
        'mcc_scores': [],
        'specificity_scores': []
    }

    # Convert masks to binary based on threshold
    _, predicted_mask_bin = cv2.threshold(predicted_mask_array, 127, 255, cv2.THRESH_BINARY)
    _, ground_truth_mask_bin = cv2.threshold(ground_truth_mask_array, 127, 255, cv2.THRESH_BINARY)

    # Normalize binary masks for calculation
    predicted_mask_bin = predicted_mask_bin / 255
    ground_truth_mask_bin = ground_truth_mask_bin / 255

    # Calculate true positives, true negatives, false positives, and false negatives
    tp = np.float64(np.sum(np.logical_and(predicted_mask_bin == 1, ground_truth_mask_bin == 1)))
    tn = np.float64(np.sum(np.logical_and(predicted_mask_bin == 0, ground_truth_mask_bin == 0)))
    fp = np.float64(np.sum(np.logical_and(predicted_mask_bin == 1, ground_truth_mask_bin == 0)))
    fn = np.float64(np.sum(np.logical_and(predicted_mask_bin == 0, ground_truth_mask_bin == 1)))

    # Calculate IoU and pixel accuracy
    intersection = np.logical_and(predicted_mask_bin, ground_truth_mask_bin)
    union = np.logical_or(predicted_mask_bin, ground_truth_mask_bin)
    metrics['iou_scores'].append(np.sum(intersection) / np.sum(union))
    #metrics['pixel_accuracies'].append(pixel_accuracy(predicted_mask_bin, ground_truth_mask_bin))

    # Calculate precision, recall, f1-score, MCC, and specificity
    precision, recall, f1, mcc, specificity = calculate_metrics(tp, fp, fn, tn)
    metrics['precision_scores'].append(precision)
    metrics['recall_scores'].append(recall)
    metrics['f1_scores'].append(f1)
    metrics['mcc_scores'].append(mcc)
    metrics['specificity_scores'].append(specificity)

    return metrics

def draw_boxes(boxes, image_dim=(1280, 720)):
    """
    Draw bounding boxes directly from a list of absolute boxes.

    Parameters:
    boxes (list): List of absolute box coordinates in xyxy format.
    image_dim (tuple): Dimensions of the output image (width, height).

    Returns:
    np.array: Binary image with boxes drawn.
    """
    # Create a blank image to draw the boxes
    image = Image.new('L', image_dim, 0)
    draw = ImageDraw.Draw(image)

    # Draw each box on the image
    for box in boxes:
        draw.rectangle(box, fill=255)

    return np.array(image, dtype=np.uint8)

def confidence_optimizer(prompt, DINO, gt_path, img_path, threshold):
    inf_path = r"C:\Users\cmull\DataspellProjects\AutoAnnotate\GUI and Pipeline\DINO-labels"
    os.makedirs(inf_path, exist_ok=True)

    best_iou = 0
    best_conf = 0

    image = cv2.imread(img_path)
    shape = image.shape

    # Step 1: Precision 1 sweep (coarse) from 0.0 to 0.9 in steps of 0.1
    for conf in np.arange(0.0, 0.91, 0.1):
        box_threshold = conf
        text_threshold = 0.1
        boxes = run_dino_from_model(DINO, img_path, prompt, box_threshold, text_threshold, maxarea=threshold)
        pred_masks = draw_boxes(boxes, (shape[1], shape[0]))
        gt_masks = read_and_draw_boxes_from_file(gt_path)

        metrics = process_mask_arrays(pred_masks, gt_masks)
        iou = np.mean(metrics['iou_scores'])
        print("P1")
        print(f"[Precision 1] Confidence: {conf:.1f}, IoU: {iou:.4f}")

        if iou > best_iou:
            best_iou = iou
            best_conf = conf

    print(f"Best from Precision 1: Confidence = {best_conf:.1f}, IoU = {best_iou:.4f}")

    # Step 2: Precision 2 sweep from (best_conf - 0.1) to (best_conf + 0.1) in steps of 0.01
    lower = best_conf - 0.1
    upper = best_conf + 0.1
    step = 0.01

    for conf in np.arange(lower, upper + step, step):
        box_threshold = conf
        text_threshold = 0.01
        boxes = run_dino_from_model(DINO, img_path, prompt, box_threshold, text_threshold, maxarea=threshold)
        pred_masks = draw_boxes(boxes, (shape[1], shape[0]))
        gt_masks = read_and_draw_boxes_from_file(gt_path)

        metrics = process_mask_arrays(pred_masks, gt_masks)
        iou = np.mean(metrics['iou_scores'])
        print('P2')
        print(f"[Precision 2] Confidence: {conf:.2f}, IoU: {iou:.4f}")

        if iou > best_iou:
            best_iou = iou
            best_conf = conf

    print(f"Final Best: Confidence = {best_conf:.2f}, IoU = {best_iou:.4f}")
    return best_iou, best_conf


def read_and_draw_boxes_from_file(file_path, image_dim=(1280, 720)):
    boxes = []
    with open(file_path, 'r') as file:
        for line in file:
            class_id, x, y, width, height = map(float, line.strip().split())
            x1 = (x-(width/2))*image_dim[0]
            x2 = (x+(width/2))*image_dim[0]
            y1 = (y-(height/2))*image_dim[1]
            y2 = (y+(height/2))*image_dim[1]
            boxes.append([x1, y1, x2, y2])
    image = Image.new('L', image_dim, 0)
    draw = ImageDraw.Draw(image)
    for box in boxes:
        draw.rectangle(box, fill=255)
        #draw.rectangle([1,1,20,20], fill=255)
    #image.save("test.jpg")
    return np.array(image, dtype=np.uint8)

def multi_optimizer(img_dir, gt_label_dir, DINO, prompts, threshold=0.9):
    print("entered multi_optimizer")
    t.sleep(2)
    start = t.time()
    best_iou = 0
    best_prompt = ""
    best_conf = 0
    for prompt in prompts:
        print(f"Trying prompt: '{prompt}'")
        iou, conf = confidence_optimizer(prompt, DINO, gt_label_dir, img_dir, threshold)
        if iou > best_iou:
            best_iou = iou
            best_conf = conf
            best_prompt = prompt
        print(f"So far: best prompt is '{best_prompt}', conf is {best_conf}, resulting in {best_iou} IOU)")
    print(f"\n\n\n\n\nFinal Result: best prompt is '{best_prompt}', conf is {best_conf}, resulting in {best_iou} IOU)")
    print(f"final time: {t.time() - start}")
    return best_prompt, best_conf

def process_file(predicted_mask_file, ground_truth_mask_file, threshold):
    # Initialize metrics dictionary
    metrics = {
        'iou_scores': [],
        'precision_scores': [],
        'recall_scores': [],
        'f1_scores': [],
        'mcc_scores': [],
        'specificity_scores': []
    }

    # Preprocess predicted mask
    clean_labels_from_file(predicted_mask_file, threshold)
    predicted_mask = read_and_draw_boxes_from_file(predicted_mask_file)
    ground_truth_mask = read_and_draw_boxes_from_file(ground_truth_mask_file)

    # Convert masks to binary
    _, predicted_mask_bin = cv2.threshold(predicted_mask, 127, 255, cv2.THRESH_BINARY)
    _, ground_truth_mask_bin = cv2.threshold(ground_truth_mask, 127, 255, cv2.THRESH_BINARY)

    predicted_mask_bin = predicted_mask_bin / 255
    ground_truth_mask_bin = ground_truth_mask_bin / 255

    # Calculate true positives, true negatives, false positives, and false negatives
    tp = np.float64(np.sum(np.logical_and(predicted_mask_bin == 1, ground_truth_mask_bin == 1)))
    tn = np.float64(np.sum(np.logical_and(predicted_mask_bin == 0, ground_truth_mask_bin == 0)))
    fp = np.float64(np.sum(np.logical_and(predicted_mask_bin == 1, ground_truth_mask_bin == 0)))
    fn = np.float64(np.sum(np.logical_and(predicted_mask_bin == 0, ground_truth_mask_bin == 1)))

    # Calculate metrics
    intersection = np.logical_and(predicted_mask_bin, ground_truth_mask_bin)
    union = np.logical_or(predicted_mask_bin, ground_truth_mask_bin)
    metrics['iou_scores'].append(np.sum(intersection) / np.sum(union))
    # Calculate precision, recall, f1-score, MCC, and specificity
    precision, recall, f1, mcc, specificity = calculate_metrics(tp, fp, fn, tn)
    metrics['precision_scores'].append(precision)
    metrics['recall_scores'].append(recall)
    metrics['f1_scores'].append(f1)
    metrics['mcc_scores'].append(mcc)
    metrics['specificity_scores'].append(specificity)
    #print(metrics['iou_scores'])
    return metrics

def optimize_confidence(prompt, DINO, gt_path, img_dir, threshold):
    inf_path = r"C:\Users\cmull\DataspellProjects\AutoAnnotate\GroundingDINO\DINO-labels"
    best_iou = 0
    best_conf = 0
    # number of decimal points in confidence
    final_precision = 5
    ubound = 0.9
    lbound = 0.0
    for precision in [x + 1 for x in range(final_precision)]:
        esc = 0
        for conf in [x / (10 ** precision) for x in
                     range(int(lbound * (10 ** precision)), int(ubound * (10 ** precision)))]:
            for fname in os.listdir(img_dir):
                prompt = prompt
                box_threshold = conf
                text_threshold = 0.01
                run_dino_from_model(DINO, os.path.join(img_dir, fname), prompt, box_threshold, text_threshold)
            metrics = process_files(inf_path, gt_path, threshold)
            iou = np.mean(metrics['iou_scores'])
            if iou > best_iou:
                best_iou = iou
                best_conf = conf
            else:
                esc += 1
                if esc > 2 * precision:
                    break

            print(f"confidence: {conf}, IOU: {iou} (best: {best_iou})")
        print(f"Best IOU at p{precision} is {best_iou} with confidence = {best_conf}")
        lbound = max(0, best_conf - (1 / (10 ** precision)))
        ubound = min(0.9, best_conf + (1 / (10 ** precision)))

        if (best_conf > (0.2 * (10 ** precision))) and precision >= 2:
            print(f"Final Result: Best IOU is {best_iou} with confidence = {best_conf}")
            return best_iou, best_conf

    return best_iou, best_conf

def multi_optimizer(img_dir, gt_label_dir, DINO, prompts, threshold=0.9, callback=None):
    start = t.time()
    best_iou = 0
    best_prompt = ""
    best_conf = 0

    for i, prompt in enumerate(prompts):
        if callback:
            callback(prompt, i, len(prompts))
        iou, conf = confidence_optimizer(prompt, DINO, gt_label_dir, img_dir, threshold)
        if iou > best_iou:
            best_iou = iou
            best_conf = conf
            best_prompt = prompt

    print(f"\nFinal Best: prompt = '{best_prompt}', conf = {best_conf}, IOU = {best_iou}")
    print(f"final time: {t.time() - start}")
    return best_prompt, best_conf

def calculate_metrics(tp, fp, fn, tn):
    precision = tp / (tp + fp) if tp + fp > 0 else 0
    recall = tp / (tp + fn) if tp + fn > 0 else 0
    f1 = 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0
    mcc = ((tp * tn) - (fp * fn)) / np.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)) \
        if np.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)) > 0 else 0
    specificity = tn / (tn + fp) if tn + fp > 0 else 0
    return precision, recall, f1, mcc, specificity

def pixel_accuracy(predicted, ground_truth):
    correct = np.sum(predicted == ground_truth)
    total = predicted.shape[0] * predicted.shape[1]
    return correct / total

def read_and_draw_boxes(results, image_dim=(1280, 720)):
    boxes = results.boxes
    for box in boxes:
        class_id, x, y, width, height = map(float, box.strip().split())
        x1 = (x - (width / 2)) * image_dim[0]
        x2 = (x + (width / 2)) * image_dim[0]
        y1 = (y - (height / 2)) * image_dim[1]
        y2 = (y + (height / 2)) * image_dim[1]
        boxes.append([x1, y1, x2, y2])
    image = Image.new('L', image_dim, 0)
    draw = ImageDraw.Draw(image)
    for box in boxes:
        draw.rectangle(box, fill=255)
        # draw.rectangle([1,1,20,20], fill=255)
    image.save("test.jpg")
    return np.array(image, dtype=np.uint8)

def calculate_pixel_metrics(mask1, mask2):
    """
    Calculate IoU based on pixel values from two masks.
    """
    intersection = np.logical_and(mask1, mask2).sum()
    union = np.logical_or(mask1, mask2).sum()
    if union == 0:
        return 0
    return intersection / union

def clean_labels_from_file(file_path, cleaning_threshold=0.6):
    # Read the file and check if it has more than one line
    with open(file_path, 'r') as f:
        lines = f.readlines()

    if len(lines) > 1:
        accepted_lines = []

        # Process each line
        for line in lines:
            class_id, x, y, width, height = map(float, line.strip().split())
            # if width * height < 0.9:
            if (width * height) < cleaning_threshold:
                accepted_lines.append(line)

        # Overwrite the file with accepted lines
        with open(file_path, 'w') as f:
            if len(accepted_lines) > 0:
                for line in accepted_lines:
                    f.write(line)

def process_files(predicted_mask_dir, ground_truth_mask_dir, threshold):
    predicted_files = os.listdir(ground_truth_mask_dir)
    metrics = {
        'iou_scores': [],
        'pixel_accuracies': [],
        'precision_scores': [],
        'recall_scores': [],
        'f1_scores': [],
        'mcc_scores': [],
        'specificity_scores': []
    }

    for fname in predicted_files:
        predicted_mask_path = os.path.join(predicted_mask_dir, fname)
        ground_truth_mask_path = os.path.join(ground_truth_mask_dir, os.path.splitext(fname)[0] + '.txt')

        if not os.path.exists(ground_truth_mask_path):
            metrics['iou_scores'].append(0)
            metrics['pixel_accuracies'].append(0)
            metrics['precision_scores'].append(0)
            metrics['recall_scores'].append(0)
            metrics['f1_scores'].append(0)
            metrics['mcc_scores'].append(0)
            metrics['specificity_scores'].append(0)
            continue

        clean_labels_from_file(predicted_mask_path, threshold)
        predicted_mask = read_and_draw_boxes(predicted_mask_path)
        ground_truth_mask = read_and_draw_boxes(ground_truth_mask_path)

        common_height, common_width = 1280, 720  # or any other desired size

        predicted_mask = cv2.resize(predicted_mask, (common_width, common_height))

        ground_truth_mask = cv2.resize(ground_truth_mask, (common_width, common_height))

        _, predicted_mask_bin = cv2.threshold(predicted_mask, 127, 255, cv2.THRESH_BINARY)
        _, ground_truth_mask_bin = cv2.threshold(ground_truth_mask, 127, 255, cv2.THRESH_BINARY)

        predicted_mask_bin = predicted_mask_bin / 255
        ground_truth_mask_bin = ground_truth_mask_bin / 255
        tp = np.float64(np.sum(np.logical_and(predicted_mask_bin == 1, ground_truth_mask_bin == 1)))
        tn = np.float64(np.sum(np.logical_and(predicted_mask_bin == 0, ground_truth_mask_bin == 0)))
        fp = np.float64(np.sum(np.logical_and(predicted_mask_bin == 1, ground_truth_mask_bin == 0)))
        fn = np.float64(np.sum(np.logical_and(predicted_mask_bin == 0, ground_truth_mask_bin == 1)))

        intersection = np.logical_and(predicted_mask_bin, ground_truth_mask_bin)
        union = np.logical_or(predicted_mask_bin, ground_truth_mask_bin)
        metrics['iou_scores'].append(np.sum(intersection) / np.sum(union))
        metrics['pixel_accuracies'].append(pixel_accuracy(predicted_mask_bin, ground_truth_mask_bin))
        precision, recall, f1, mcc, specificity = calculate_metrics(tp, fp, fn, tn)
        metrics['precision_scores'].append(precision)
        metrics['recall_scores'].append(recall)
        metrics['f1_scores'].append(f1)
        metrics['mcc_scores'].append(mcc)
        metrics['specificity_scores'].append(specificity)

    return metrics

In [5]:
class LLMWorker(QtCore.QObject):
    finished = QtCore.pyqtSignal(object, object)  # Signal to pass model and processor
    log = QtCore.pyqtSignal(str)  # Signal to send log text

    def run(self):
        try:
            self.log.emit("Loading LLaMA model...\n")
            model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"

            self.log.emit("Step 1: Loading model weights...\n")
            from transformers import MllamaForConditionalGeneration
            model = MllamaForConditionalGeneration.from_pretrained(
                model_id, torch_dtype=torch.bfloat16, device_map="auto"
            )

            self.log.emit("Step 2: Loading processor...\n")
            from transformers import AutoProcessor
            processor = AutoProcessor.from_pretrained(model_id)

            self.log.emit("Step 3: Tying weights...\n")
            model.tie_weights()

            self.log.emit("Model loaded successfully.\n")
            self.finished.emit(model, processor)
        except Exception as e:
            self.log.emit(f"Error loading model: {str(e)}\n")
            self.finished.emit(None, None)

class SplashScreen(QtWidgets.QWidget):
    def __init__(self):
        super().__init__()
        self.setWindowFlags(QtCore.Qt.FramelessWindowHint)
        self.setAttribute(QtCore.Qt.WA_TranslucentBackground)
        self.model = None
        self.processor = None
        self.init_ui()
        self.start_model_loading()

    def init_ui(self):
        layout = QtWidgets.QVBoxLayout()
        self.setStyleSheet("background-color: black;")  # Splash screen background

        # Logo with transparency
        label = QtWidgets.QLabel()
        pixmap = QtGui.QPixmap("AMS_Logo_Final_Removed.png")
        label.setPixmap(pixmap.scaledToWidth(400, QtCore.Qt.SmoothTransformation))
        label.setAlignment(QtCore.Qt.AlignCenter)
        label.setStyleSheet("background: transparent;")  # Ensures QLabel doesn't add its own background
        layout.addWidget(label)

        # Live log output
        self.log_box = QtWidgets.QPlainTextEdit()
        self.log_box.setReadOnly(True)
        self.log_box.setStyleSheet("color: white; background-color: #111; font-size: 18px;")
        layout.addWidget(self.log_box)

        self.setLayout(layout)
        self.resize(800, 600)
        self.center()


    def center(self):
        frameGm = self.frameGeometry()
        screen = QtWidgets.QApplication.primaryScreen()
        centerPoint = screen.geometry().center()
        frameGm.moveCenter(centerPoint)
        self.move(frameGm.topLeft())

    def start_model_loading(self):
        self.thread = QtCore.QThread()
        self.worker = LLMWorker()
        self.worker.moveToThread(self.thread)

        self.thread.started.connect(self.worker.run)
        self.worker.finished.connect(self.model_ready)
        self.worker.log.connect(self.append_log)
        self.worker.finished.connect(self.thread.quit)
        self.worker.finished.connect(self.worker.deleteLater)
        self.thread.finished.connect(self.thread.deleteLater)

        self.thread.start()

    def append_log(self, text):
        self.log_box.appendPlainText(text)

    def model_ready(self, model, processor):
        self.model = model
        self.processor = processor
        if model and processor:
            QtCore.QTimer.singleShot(1000, self.show_main_window)  # Wait 1s before launching main

    def show_main_window(self):
        self.main_window = MainWindow(self.model, self.processor)
        self.main_window.show()
        self.close()


class MainWindow(QtWidgets.QWidget):
    def __init__(self, model, processor):
        super().__init__()
        self.model = model
        self.processor = processor
        self.init_ui()

    def init_ui(self):
        self.setWindowTitle("Prompt and Confidence Tuning")
        self.showFullScreen()
        self.setStyleSheet("background-color: #454545;")

        layout = QtWidgets.QVBoxLayout()

        exit_btn = QtWidgets.QPushButton("Exit")
        exit_btn.setStyleSheet("background-color: #e93636; color: white; font-size: 24px;")
        exit_btn.setFixedSize(200, 100)
        exit_btn.clicked.connect(self.close)
        layout.addWidget(exit_btn, alignment=QtCore.Qt.AlignTop | QtCore.Qt.AlignLeft)

        button_layout = QtWidgets.QVBoxLayout()

        manual_btn = QtWidgets.QPushButton("Manual Prompt and Confidence Tuning")
        manual_btn.setStyleSheet("background-color: #4f82ff; color: white; font-size: 36px;")
        manual_btn.setFixedSize(800, 150)
        manual_btn.clicked.connect(self.select_manual)
        button_layout.addWidget(manual_btn, alignment=QtCore.Qt.AlignCenter)

        automated_btn = QtWidgets.QPushButton("Automated Prompt and Confidence Tuning")
        automated_btn.setStyleSheet("background-color: #e93636; color: white; font-size: 36px;")
        automated_btn.setFixedSize(800, 150)
        automated_btn.clicked.connect(self.select_automated)
        button_layout.addWidget(automated_btn, alignment=QtCore.Qt.AlignCenter)

        layout.addStretch()
        layout.addLayout(button_layout)
        layout.addStretch()
        layout.setAlignment(button_layout, QtCore.Qt.AlignCenter)
        self.setLayout(layout)

    def select_manual(self):
        self.manual_window = ManualWindow(self.model, self.processor)
        self.manual_window.show()
        self.hide()

    def select_automated(self):
        self.automated_window = AutomatedWindow(self.model, self.processor)
        self.automated_window.show()
        self.hide()


class ManualWindow(QtWidgets.QWidget):
    def __init__(self, model, processor):
        super().__init__()
        self.init_ui()
        self.model = model
        self.processor = processor
        self.current_image_index = 0
        self.images = []
        self.output_folder = ""  # Variable to store the output folder path
        self.DINO = load_dino_model()

    def init_ui(self):
        self.setWindowTitle("Manual Prompt and Confidence Tuning")
        self.showFullScreen()
        self.setStyleSheet("background-color: #454545;")

        # Main horizontal layout
        main_layout = QtWidgets.QHBoxLayout()

        # Left vertical layout for controls
        left_layout = QtWidgets.QVBoxLayout()

        # Right vertical layout for image
        right_layout = QtWidgets.QVBoxLayout()
        right_layout.setSpacing(0)
        right_layout.setContentsMargins(0, 0, 0, 0)
        # Back button
        back_btn = QtWidgets.QPushButton("Back")
        back_btn.setStyleSheet("background-color: #e93636; color: white; font-size: 24px;")
        back_btn.setFixedSize(200, 100)
        back_btn.clicked.connect(self.go_back)
        left_layout.addWidget(back_btn, alignment=QtCore.Qt.AlignTop | QtCore.Qt.AlignLeft)

        # Folder selection button
        folder_btn = QtWidgets.QPushButton("Select Image Folder")
        folder_btn.setStyleSheet("background-color: #4f82ff; color: white; font-size: 24px;")
        folder_btn.setFixedSize(400, 100)
        folder_btn.clicked.connect(self.select_folder)
        left_layout.addWidget(folder_btn, alignment=QtCore.Qt.AlignRight)  # Centered

        # Output folder selection button
        output_folder_btn = QtWidgets.QPushButton("Select Output Folder")
        output_folder_btn.setStyleSheet("background-color: #4f82ff; color: white; font-size: 24px;")
        output_folder_btn.setFixedSize(400, 100)
        output_folder_btn.clicked.connect(self.select_output_folder)
        left_layout.addWidget(output_folder_btn, alignment=QtCore.Qt.AlignRight)  # Centered

        # Prompt label
        prompt_label = QtWidgets.QLabel("Enter Prompt:")
        prompt_label.setStyleSheet("color: white; font-size: 24px;")
        left_layout.addWidget(prompt_label, alignment=QtCore.Qt.AlignLeft)

        # Prompt entry
        self.prompt_entry = QtWidgets.QLineEdit()
        self.prompt_entry.setStyleSheet("font-size: 24px; color: white; background-color: black;")
        self.prompt_entry.setFixedHeight(50)
        self.prompt_entry.setFixedWidth(800)

        left_layout.addWidget(self.prompt_entry, alignment=QtCore.Qt.AlignLeft)

        # Confidence label
        confidence_label = QtWidgets.QLabel("Confidence: 50")
        confidence_label.setStyleSheet("color: white; font-size: 24px;")
        left_layout.addWidget(confidence_label, alignment=QtCore.Qt.AlignLeft)
        self.confidence_label = confidence_label  # Save reference to update later

        # Confidence slider
        self.confidence_slider = QtWidgets.QSlider(QtCore.Qt.Horizontal)
        self.confidence_slider.setRange(0, 100)
        self.confidence_slider.setValue(50)
        self.confidence_slider.setStyleSheet("font-size: 24px;")
        self.confidence_slider.setFixedSize(800, 50)
        left_layout.addWidget(self.confidence_slider, alignment=QtCore.Qt.AlignLeft)
        self.confidence_slider.valueChanged.connect(self.update_confidence_value)

        # Box threshold label
        box_threshold_label = QtWidgets.QLabel("Box Threshold: 90")
        box_threshold_label.setStyleSheet("color: white; font-size: 24px;")
        left_layout.addWidget(box_threshold_label, alignment=QtCore.Qt.AlignLeft)
        self.box_threshold_label = box_threshold_label  # Save reference to update later

        # Box threshold slider
        self.box_threshold_slider = QtWidgets.QSlider(QtCore.Qt.Horizontal)
        self.box_threshold_slider.setRange(0, 100)
        self.box_threshold_slider.setValue(90)
        self.box_threshold_slider.setStyleSheet("font-size: 24px;")
        self.box_threshold_slider.setFixedSize(800, 50)
        left_layout.addWidget(self.box_threshold_slider, alignment=QtCore.Qt.AlignLeft)
        self.box_threshold_slider.valueChanged.connect(self.update_box_threshold_value)

        # Bottom buttons layout
        bottom_buttons_layout = QtWidgets.QHBoxLayout()

        # Next Image button
        next_btn = QtWidgets.QPushButton("Next Image")
        next_btn.setStyleSheet("background-color: green; color: white; font-size: 24px;")
        next_btn.setFixedSize(400, 100)
        next_btn.clicked.connect(self.next_image)
        bottom_buttons_layout.addWidget(next_btn, alignment=QtCore.Qt.AlignLeft)

        # Add stretch to push the next widget to the far right
        bottom_buttons_layout.addStretch()

        # Add bottom buttons to the left layout
        left_layout.addLayout(bottom_buttons_layout)

        # Add left layout to main layout
        main_layout.addLayout(left_layout)

        # Image label
        self.image_label = QtWidgets.QLabel()
        self.image_label.setFixedSize(800, 600)
        self.image_label.setStyleSheet("border: 1px #454545;")
        self.image_label.setAlignment(QtCore.Qt.AlignCenter)
        right_layout.addWidget(self.image_label, alignment=QtCore.Qt.AlignCenter)

        # Create checkbox layout under the image
        self.checkbox_layout = QtWidgets.QHBoxLayout()
        self.box_checkbox = QtWidgets.QCheckBox("Bounding Box")
        self.mask_checkbox = QtWidgets.QCheckBox("Segmentation")

        # Add checkboxes to the checkbox layout
        self.checkbox_layout.addWidget(self.box_checkbox)
        self.checkbox_layout.addWidget(self.mask_checkbox)
        self.box_checkbox.setStyleSheet("color: white; font-size: 24px;")
        self.mask_checkbox.setStyleSheet("color: white; font-size: 24px;")

        # Ensure only one checkbox is selected at a time
        self.box_checkbox.toggled.connect(lambda: self.mask_checkbox.setChecked(not self.box_checkbox.isChecked()))
        self.mask_checkbox.toggled.connect(lambda: self.box_checkbox.setChecked(not self.mask_checkbox.isChecked()))
        # Add the checkbox layout below the image display
        right_layout.addLayout(self.checkbox_layout)

        # Auto Annotate Remaining button
        auto_annotate_btn = QtWidgets.QPushButton("Auto Annotate Remaining")
        auto_annotate_btn.setStyleSheet("background-color: red; color: white; font-size: 24px;")
        auto_annotate_btn.setFixedSize(400, 100)
        auto_annotate_btn.clicked.connect(self.auto_annotate_remaining)
        right_layout.addWidget(auto_annotate_btn, alignment=QtCore.Qt.AlignRight)
        # Add right layout to main layout
        main_layout.addLayout(right_layout)

        self.setLayout(main_layout)

    def keyPressEvent(self, event):
        # Detect Enter key
        if event.key() in (QtCore.Qt.Key_Return, QtCore.Qt.Key_Enter):
            # Check if the prompt is entered before running the model
            if self.prompt_entry.text().strip():
                self.display_predictions()
            else:
                # Show a message if Enter is pressed without a prompt
                message_box = QtWidgets.QMessageBox()
                message_box.setStyleSheet("QLabel { color: black; font-size: 24px; } QMessageBox { background-color: white; }")
                message_box.setText("Please enter a prompt before running the model.")
                message_box.exec_()

    def display_predictions(self):
        # Check if the prompt entry is empty
        if not self.prompt_entry.text().strip():
            message_box = QtWidgets.QMessageBox()
            message_box.setStyleSheet("QLabel { color: black; font-size: 24px; } QMessageBox { background-color: white; }")
            message_box.setText("Please enter a prompt to run the model.")
            message_box.exec_()
            return  # Stop the function if prompt is missing

        # Proceed with the prediction as usual
        prompt = self.prompt_entry.text()
        confidence = self.confidence_slider.value() / 100
        max_area = self.box_threshold_slider.value() / 100
        # Call the appropriate display function based on the checkbox state
        if self.box_checkbox.isChecked():
            self.display_boxes_with_borders(self.DINO, self.images[self.current_image_index], prompt, confidence, max_area, output_path=self.output_folder)
        elif self.mask_checkbox.isChecked():
            self.display_masks_with_borders(self.DINO, self.images[self.current_image_index], prompt, confidence, max_area, output_path=self.output_folder)
        else:
            # Notify user to select a display mode if none is selected
            message_box = QtWidgets.QMessageBox()
            message_box.setStyleSheet("QLabel { color: black; font-size: 24px; } QMessageBox { background-color: white; }")
            message_box.setText("Please select a display mode: Bounding Box or Segmentation.")
            message_box.exec_()

    def display_boxes_with_borders(self, DINO, image_path, prompt, confidence, max_area, output_path):
        """
        Display the image with bounding boxes drawn over it using the absolute box coordinates from run_dino.
        """
        # Read the image
        img = cv2.imread(image_path)

        # Run DINO to get the bounding boxes
        absolute_boxes = run_dino_from_model(DINO, image_path, prompt, confidence, 0.1, max_area, save_dir=output_path)

        # Draw the bounding boxes on the image
        img_with_boxes = draw_boxes_on_image(img, absolute_boxes)

        # Display the resulting image with the bounding boxes
        self.show_result_image(img_with_boxes)

    def display_masks_with_borders(self, DINO, image_path, prompt, confidence, max_area, output_path):
        img = cv2.imread(image_path)
        sam_results = run_image(DINO, image_path, "", prompt, confidence, box_threshold=max_area, save_dir= output_path)
        masks = adjust_masks(sam_results)

        image_with_borders = np.copy(img)
        for mask_i in masks:
            image_with_borders = overlay_with_borders(image_with_borders, mask_i, color=(255, 0, 255), thickness=2)

        # Display the result in the GUI
        self.show_result_image(image_with_borders)

    def show_result_image(self, cv2_image):
        # Convert the BGR image to RGB
        rgb_image = cv2.cvtColor(cv2_image, cv2.COLOR_BGR2RGB)

        # Convert from OpenCV image (now in RGB format) to Qt format
        height, width, channel = rgb_image.shape
        bytes_per_line = 3 * width
        qt_image = QtGui.QImage(rgb_image.data, width, height, bytes_per_line, QtGui.QImage.Format_RGB888)
        pixmap = QtGui.QPixmap.fromImage(qt_image)

        # Set the pixmap to the QLabel
        self.image_label.setPixmap(pixmap.scaled(self.image_label.size(), QtCore.Qt.KeepAspectRatio))

    def go_back(self):
        self.main_window = MainWindow(self.model, self.processor)
        self.main_window.show()
        self.close()

    def select_folder(self):
        options = QtWidgets.QFileDialog.Options()
        dialog = QtWidgets.QFileDialog(self, "Select Image Folder", options=options)
        dialog.setFileMode(QtWidgets.QFileDialog.Directory)
        dialog.setOption(QtWidgets.QFileDialog.ShowDirsOnly, True)
        dialog.setStyleSheet("QWidget { background-color: white; color: black; }")
        dialog.setOption(QtWidgets.QFileDialog.ReadOnly, True)

        if dialog.exec_() == QtWidgets.QDialog.Accepted:
            folder = dialog.selectedFiles()[0]
            if folder:
                # Load images from the selected folder.
                self.images = [os.path.join(folder, f) for f in os.listdir(folder) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
                if self.images:
                    self.current_image_index = 0
                    # Display the first image.
                    self.display_image(self.images[self.current_image_index])

                    # Check if a prompt is already entered
                    if self.prompt_entry.text().strip():
                        self.display_predictions()
                else:
                    # Notify the user if the folder is empty.
                    message_box = QtWidgets.QMessageBox()
                    message_box.setStyleSheet("QLabel { color: black; font-size: 24px; } QMessageBox { background-color: white; }")
                    message_box.setText("The selected folder does not contain any images.")
                    message_box.exec_()

    def select_output_folder(self):
        options = QtWidgets.QFileDialog.Options()
        # Remove the ReadOnly option
        # options |= QtWidgets.QFileDialog.ReadOnly
        # Optionally, remove the DontUseNativeDialog option to use the native file dialog
        # options |= QtWidgets.QFileDialog.DontUseNativeDialog

        dialog = QtWidgets.QFileDialog(self, "Select Output Folder", options=options)
        dialog.setFileMode(QtWidgets.QFileDialog.Directory)
        dialog.setOption(QtWidgets.QFileDialog.ShowDirsOnly, True)
        dialog.setStyleSheet("QWidget { background-color: white; color: black; }")

        if dialog.exec_() == QtWidgets.QDialog.Accepted:
            self.output_folder = dialog.selectedFiles()[0]
            if self.output_folder:
                # You can add any additional logic needed when the output folder is selected
                message_box = QtWidgets.QMessageBox()
                message_box.setStyleSheet("QLabel { color: black; font-size: 24px; } QMessageBox { background-color: white; }")
                message_box.setText(f"Output folder selected: {self.output_folder}")
                message_box.exec_()

    def display_image(self, image_path):
        pixmap = QtGui.QPixmap(image_path)
        pixmap = pixmap.scaled(self.image_label.size(), QtCore.Qt.KeepAspectRatio, QtCore.Qt.SmoothTransformation)
        self.image_label.setPixmap(pixmap)

    def next_image(self):
        if self.images:
            # Update the image index and display the next image.
            self.current_image_index = (self.current_image_index + 1) % len(self.images)
            self.display_image(self.images[self.current_image_index])

            # Run predictions for the new image.
            self.display_predictions()
        else:
            message_box = QtWidgets.QMessageBox()
            message_box.setStyleSheet("QLabel { color: black; font-size: 24px; } QMessageBox { background-color: white; }")
            message_box.setText("No images loaded.")
            message_box.exec_()

    def auto_annotate_remaining(self):
        if not self.output_folder:
            # Prompt the user to select an output folder if not already selected
            self.select_output_folder()
            for image in self.images:
                print(image)
                run_dino_from_model(self.DINO, image, self.prompt_entry.text(), self.confidence_slider.value()/100, 0.1, self.box_threshold_slider.value()/100, save_dir=self.output_folder)
            if not self.output_folder:
                return  # Exit if no output folder is selected
        else:
            for image in self.images:
                print(image)
                run_dino_from_model(self.DINO, image, self.prompt_entry.text(), self.confidence_slider.value()/100, 0.1, self.box_threshold_slider.value()/100, save_dir=self.output_folder)

        message_box = QtWidgets.QMessageBox()
        message_box.setStyleSheet("QLabel { color: white; font-size: 24px; } QMessageBox { background-color: black; }")
        message_box.setText("Annotations saved to the output folder.")
        message_box.exec_()

    def update_confidence_value(self, value):
        self.confidence_label.setText(f"Confidence: {value}")

    def update_box_threshold_value(self, value):
        self.box_threshold_label.setText(f"Box Threshold: {value}")

class AutomatedWindow(QtWidgets.QWidget):
    def __init__(self, model, processor):
        super().__init__()
        self.generated_prompts = []  # Will store the 5 prompts from LLM
        self.loaded_prompt_file = ""
        self.DINO = load_dino_model(model_size="swint")
        self.model = model
        self.processor = processor
        self.init_ui()

    def show_message(self, title, text, level="info"):
        msg = QtWidgets.QMessageBox(self)
        msg.setWindowTitle(title)
        msg.setText(text)
        msg.setStyleSheet(
            "QLabel { color: white; font-size: 20px; }"
            "QMessageBox { background-color: #222; }"
            "QPushButton { background-color: #444; color: white; font-size: 18px; padding: 8px; }"
        )
        if level == "info":
            msg.setIcon(QtWidgets.QMessageBox.Information)
        elif level == "warn":
            msg.setIcon(QtWidgets.QMessageBox.Warning)
        elif level == "error":
            msg.setIcon(QtWidgets.QMessageBox.Critical)
        msg.exec_()


    def init_ui(self):
        self.setWindowTitle("Automated Prompt and Confidence Tuning")
        self.showFullScreen()
        self.setStyleSheet("background-color: black;")
        
        layout = QtWidgets.QVBoxLayout()
        
        back_btn = QtWidgets.QPushButton("Back")
        back_btn.setStyleSheet("background-color: grey; color: white; font-size: 24px;")
        back_btn.setFixedSize(200, 100)
        back_btn.clicked.connect(self.go_back)
        layout.addWidget(back_btn, alignment=QtCore.Qt.AlignTop | QtCore.Qt.AlignLeft)
        
        # Left layout for folder selection
        self.left_layout = QtWidgets.QVBoxLayout()
        
        label_btn = QtWidgets.QPushButton("Select Label Folder")
        label_btn.setStyleSheet("background-color: #4f82ff; color: white; font-size: 24px;")
        label_btn.setFixedSize(400, 100)
        label_btn.clicked.connect(self.select_label_folder)
        self.left_layout.addWidget(label_btn, alignment=QtCore.Qt.AlignTop)
        
        self.labelled_folder_label = QtWidgets.QLabel("")
        self.labelled_folder_label.setStyleSheet("color: white; font-size: 24px;")
        self.left_layout.addWidget(self.labelled_folder_label, alignment=QtCore.Qt.AlignTop)
        
        unannotated_btn = QtWidgets.QPushButton("Select Image Folder")
        unannotated_btn.setStyleSheet("background-color: #4f82ff; color: white; font-size: 24px;")
        unannotated_btn.setFixedSize(400, 100)
        unannotated_btn.clicked.connect(self.select_image_folder)
        self.left_layout.addWidget(unannotated_btn, alignment=QtCore.Qt.AlignTop)
        
        self.image_folder_label = QtWidgets.QLabel("")
        self.image_folder_label.setStyleSheet("color: white; font-size: 24px;")
        self.left_layout.addWidget(self.image_folder_label, alignment=QtCore.Qt.AlignTop)
        
        output_btn = QtWidgets.QPushButton("Select Output Folder")
        output_btn.setStyleSheet("background-color: #4f82ff; color: white; font-size: 24px;")
        output_btn.setFixedSize(400, 100)
        output_btn.clicked.connect(self.select_output_folder)
        self.left_layout.addWidget(output_btn, alignment=QtCore.Qt.AlignTop)
        
        self.output_folder_label = QtWidgets.QLabel("")
        self.output_folder_label.setStyleSheet("color: white; font-size: 24px;")
        self.left_layout.addWidget(self.output_folder_label, alignment=QtCore.Qt.AlignTop)
        
        # Right layout for prompt selection
        self.right_layout = QtWidgets.QVBoxLayout()
        
        prompt_select_btn = QtWidgets.QPushButton("Prompt Selection")
        prompt_select_btn.setStyleSheet("background-color: #4f82ff; color: white; font-size: 24px;")
        prompt_select_btn.setFixedSize(400, 100)
        prompt_select_btn.clicked.connect(self.prompt_selection)
        self.right_layout.addWidget(prompt_select_btn, alignment=QtCore.Qt.AlignTop)
        
        self.bottom_layout = QtWidgets.QVBoxLayout()
        self.status_label = QtWidgets.QLabel("Status: Ready")
        self.status_label.setStyleSheet("font-size: 24px; color: white;")
        self.status_label.setAlignment(QtCore.Qt.AlignTop)
        self.status_label.setWordWrap(True)
        self.bottom_layout.addWidget(self.status_label, alignment=QtCore.Qt.AlignCenter)
        start_btn = QtWidgets.QPushButton("Start Annotation")
        start_btn.setStyleSheet("background-color: green; color: white; font-size: 24px;")
        start_btn.setFixedSize(400, 100)
        start_btn.clicked.connect(self.perform_automatic_annotation)
        self.bottom_layout.addWidget(start_btn, alignment=QtCore.Qt.AlignCenter)
        
        main_layout = QtWidgets.QHBoxLayout()
        main_layout.addLayout(self.left_layout)
        main_layout.addLayout(self.right_layout)
        
        layout.addLayout(main_layout)
        layout.addLayout(self.bottom_layout)
        self.setLayout(layout)


    def go_back(self):
        self.main_window = MainWindow(self.model, self.processor)
        self.main_window.show()
        self.close()

    def select_image_folder(self):
        options = QtWidgets.QFileDialog.Options()
        dialog = QtWidgets.QFileDialog(self, "Select Image Folder", options=options)
        dialog.setFileMode(QtWidgets.QFileDialog.Directory)
        dialog.setOption(QtWidgets.QFileDialog.ShowDirsOnly, True)
        dialog.setStyleSheet("QWidget { background-color: white; color: black; }")
        dialog.setOption(QtWidgets.QFileDialog.ReadOnly, True)

        if dialog.exec_() == QtWidgets.QDialog.Accepted:
            folder = dialog.selectedFiles()[0]
            if folder:
                # Load images from the selected folder.
                self.images = [os.path.join(folder, f) for f in os.listdir(folder) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
                self.image_folder = folder
                self.image_folder_label.setText(f"Output Folder: {self.image_folder}")


    def select_output_folder(self):
        options = QtWidgets.QFileDialog.Options()
        # Remove the ReadOnly option
        # options |= QtWidgets.QFileDialog.ReadOnly
        # Optionally, remove the DontUseNativeDialog option to use the native file dialog
        # options |= QtWidgets.QFileDialog.DontUseNativeDialog

        dialog = QtWidgets.QFileDialog(self, "Select Output Folder", options=options)
        dialog.setFileMode(QtWidgets.QFileDialog.Directory)
        dialog.setOption(QtWidgets.QFileDialog.ShowDirsOnly, True)
        dialog.setStyleSheet("QWidget { background-color: white; color: black; }")

        if dialog.exec_() == QtWidgets.QDialog.Accepted:
            self.output_folder = dialog.selectedFiles()[0]
            if self.output_folder:
                # You can add any additional logic needed when the output folder is selected
                self.output_folder_label.setText(f"Output Folder: {self.output_folder}")

    def select_label_folder(self):
        options = QtWidgets.QFileDialog.Options()
        dialog = QtWidgets.QFileDialog(self, "Select Label Folder", options=options)
        dialog.setFileMode(QtWidgets.QFileDialog.Directory)
        dialog.setOption(QtWidgets.QFileDialog.ShowDirsOnly, True)
        dialog.setOption(QtWidgets.QFileDialog.ReadOnly, True)
        dialog.setStyleSheet("QWidget { background-color: white; color: black; }")

        if dialog.exec_() == QtWidgets.QDialog.Accepted:
            folder = dialog.selectedFiles()[0]
            if folder:
                # Load label files from the selected folder
                self.label_files = [os.path.join(folder, f) for f in os.listdir(folder) if f.lower().endswith('.txt')]
                if self.label_files:
                    self.label_folder = folder
                    self.labelled_folder_label.setText(f"Label Folder: {self.label_folder}")
                    # Optionally sort files or trigger further processing here
                else:
                    # Notify the user if the folder is empty or has no label files
                    message_box = QtWidgets.QMessageBox()
                    message_box.setStyleSheet("QLabel { color: black; font-size: 24px; } QMessageBox { background-color: white; }")
                    message_box.setText("The selected folder does not contain any .txt label files.")
                    message_box.exec_()

    def update_status(self, text):
        self.status_label.setText(text)
        QtWidgets.QApplication.processEvents()

    def perform_automatic_annotation(self):
        self.status_label.setText("Sorting label files...")
        QtWidgets.QApplication.processEvents()

        sorted_txt_files = sort_largest_file(self.label_folder)
        reference_txt = os.path.join(self.label_folder, sorted_txt_files[0])
        reference_image = os.path.join(self.image_folder, sorted_txt_files[0].split(".txt")[0] + ".jpg")

        if self.generated_prompts:
            with tempfile.NamedTemporaryFile(mode='w+', delete=False, suffix=".txt") as tmp:
                for prompt in self.generated_prompts:
                    tmp.write(prompt + "\n")
                tmp_path = tmp.name
        elif self.loaded_prompt_file:
            tmp_path = self.loaded_prompt_file
        else:
            self.show_message("Error", "No prompts loaded or generated. Please select or generate prompts.", level="error")
            return

        self.status_label.setText("Optimizing prompts for best IoU...")
        QtWidgets.QApplication.processEvents()
        prompt_result = prompt_optimizer(
            prompts_file=tmp_path,
            gt_path=reference_txt,
            img_path=reference_image,
            save_file="best.txt",
            threshold=0.8,
            DINO=self.DINO
        )

        top2 = [result[0] for result in prompt_result][:2]

        self.status_label.setText("Refining confidence scores...")
        QtWidgets.QApplication.processEvents()
        best_prompt, best_conf = multi_optimizer(img_dir=reference_image, gt_label_dir=reference_txt, DINO=self.DINO, prompts=top2, threshold=0.8, callback=lambda prompt, i, total: self.update_status(f"📌 Confidence tuning: '{prompt}' ({i+1}/{total})")
        )

        self.status_label.setText("Starting labelling of images...")
        QtWidgets.QApplication.processEvents()

        image_files = [f for f in os.listdir(self.image_folder) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
        total_images = len(image_files)

        for i, image_file in enumerate(image_files):
            image_path = os.path.join(self.image_folder, image_file)
            #run_dino_from_model(self.DINO, image_path, best_prompt, best_conf, 0.1, maxarea=0.8, save_dir=self.output_folder+"/boxes")
            run_image(self.DINO, img_dir=image_path, output_dir=self.output_folder+"/segments", prompt=best_prompt, conf=best_conf, box_threshold=0.8, save_dir=self.output_folder+"/boxes")
            self.status_label.setText(f"📸 Labelling image {i + 1} of {total_images}")
            QtWidgets.QApplication.processEvents()

        self.status_label.setText(f"✅ LABELLING COMPLETE\nOutput saved to: {self.output_folder}")


    def prompt_selection(self):
        if not hasattr(self, "prompt_buttons_added"):
            list_prompts_btn = QtWidgets.QPushButton("List of Prompts")
            list_prompts_btn.setStyleSheet("background-color: #4f82ff; color: white; font-size: 24px;")
            list_prompts_btn.setFixedSize(400, 100)
            list_prompts_btn.clicked.connect(self.handle_list_of_prompts)
            self.right_layout.addWidget(list_prompts_btn, alignment=QtCore.Qt.AlignTop)

            generate_prompts_btn = QtWidgets.QPushButton("Generate Prompts")
            generate_prompts_btn.setStyleSheet("background-color: #4f82ff; color: white; font-size: 24px;")
            generate_prompts_btn.setFixedSize(400, 100)
            generate_prompts_btn.clicked.connect(self.handle_generate_prompts)
            self.right_layout.addWidget(generate_prompts_btn, alignment=QtCore.Qt.AlignTop)

            self.prompt_buttons_added = True

    def handle_list_of_prompts(self):
        options = QtWidgets.QFileDialog.Options()
        options |= QtWidgets.QFileDialog.DontUseNativeDialog
        dialog = QtWidgets.QFileDialog(self, "Select Prompts List", "", "Text Files (*.txt);;CSV Files (*.csv)", options=options)
        dialog.setStyleSheet("QWidget { background-color: white; color: black; }")
        if dialog.exec_() == QtWidgets.QDialog.Accepted:
            self.loaded_prompt_file = dialog.selectedFiles()[0]
            self.generated_prompts = []
            try:
                with open(self.loaded_prompt_file, "r", encoding="utf-8") as f:
                    loaded = [line.strip() for line in f if line.strip()]
                self.show_message("Loaded Prompts", "\n".join(loaded[:10]))  # Preview top 10
            except Exception as e:
                self.show_message("Error Reading File", str(e), level="error")

    def handle_generate_prompts(self):
        options = QtWidgets.QFileDialog.Options()
        dialog = QtWidgets.QFileDialog(self, "Select Sample Image", "", "Image Files (*.png *.jpg *.jpeg)", options=options)
        dialog.setStyleSheet("QWidget { background-color: white; color: black; }")
        if dialog.exec_() == QtWidgets.QDialog.Accepted:
            sample_image_path = dialog.selectedFiles()[0]
            if sample_image_path:
                manual_entry, ok = QtWidgets.QInputDialog.getText(self, "Object to Describe", "Enter the object in the image:")
                if ok and manual_entry:
                    prompts = generate_prompts(sample_image_path, manual_entry, self.model, self.processor)
                    if prompts:
                        self.generated_prompts = prompts
                        self.loaded_prompt_file = ""
                        self.show_message("Generated Prompts", "\n".join(prompts))
                    else:
                        self.show_message("Prompt Generation Failed", "No prompts were returned.", level="warn")

def sort_largest_file(folder_path):
    # Dictionary to store file names and their line counts
    file_line_counts = {}

    # Iterate through files in the folder
    for file_name in os.listdir(folder_path):
        # Check if the file is a .txt file
        if file_name.endswith('.txt'):
            file_path = os.path.join(folder_path, file_name)
            # Open the file and count lines
            with open(file_path, 'r') as file:
                line_count = sum(1 for line in file)
            # Add the file and line count to the dictionary
            file_line_counts[file_name] = line_count
        else:
            print("File encountered not in .txt format.")
    # Sort files by line count in descending order and return as list of file names
    sorted_files = sorted(file_line_counts, key=file_line_counts.get, reverse=True)
    return sorted_files

def extract_descriptions(response):
    lines = response.split("\n")
    unwanted_keywords = ["user", "assistant", "describe", "text & image output"]
    descriptions = []
    for line in lines:
        clean_line = line.strip()
        if not clean_line:
            continue
        if any(keyword in clean_line.lower() for keyword in unwanted_keywords):
            continue
        clean_line = re.sub(r"^\s*\d+[\.\)\-]\s*", "", clean_line)
        if clean_line:
            descriptions.append(clean_line)
    return descriptions

def generate_prompts(image_path, manual_entry, model, processor):
    try:
        raw_image = Image.open(image_path).convert("RGB")
        conversation = [
            {
                "role": "user",
                "content": [
                    {"type": "image"},
                    {
                        "type": "text",
                        "text": (
                            f"Describe the {manual_entry} of the image in 3 words maximum for prompt use in a zero-shot detection model, "
                            "and give 5 separate entries, each separated by a new line, and its own separate descriptor of the target. "
                            "Number each prompt. Then simply new line. Strictly the prompts, no other response is required. "
                            "Use visual description of the target in the image only. Do not duplicate responses."
                        ),
                    },
                ],
            }
        ]
        prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
        inputs = processor(text=prompt, images=raw_image, return_tensors="pt").to(model.device)
        output = model.generate(**inputs, temperature=0.7, top_p=0.9, max_new_tokens=512)
        response = processor.decode(output[0], skip_special_tokens=True)
        return extract_descriptions(response)
    except Exception as e:
        print(f"Error generating prompts: {e}")
        return []

def main():
    app = QtWidgets.QApplication(sys.argv)
    splash = SplashScreen()
    splash.show()
    sys.exit(app.exec_())

In [6]:
main()

The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.


Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]

final text_encoder_type: bert-base-uncased
entered prompt optimizer
Trying prompt: "Blueberries."
Trying prompt: "Blue fruit."
Trying prompt: "Small round berries."
Trying prompt: "Blueberries on bush."
Trying prompt: "Berries in field."
Results: [('Small round berries.', {'iou_scores': 0.45482138687835627}), ('Berries in field.', {'iou_scores': 0.4423861357131589}), ('Blueberries.', {'iou_scores': 0.43472537776335246}), ('Blue fruit.', {'iou_scores': 0.41393969165659006}), ('Blueberries on bush.', {'iou_scores': 0.03911040479154131})]
P1
[Precision 1] Confidence: 0.0, IoU: 0.0295
P1
[Precision 1] Confidence: 0.1, IoU: 0.0580
P1
[Precision 1] Confidence: 0.2, IoU: 0.4576
P1
[Precision 1] Confidence: 0.3, IoU: 0.4562
P1
[Precision 1] Confidence: 0.4, IoU: 0.4887
P1
[Precision 1] Confidence: 0.5, IoU: 0.3360
P1
[Precision 1] Confidence: 0.6, IoU: 0.1313
P1
[Precision 1] Confidence: 0.7, IoU: 0.0000
P1
[Precision 1] Confidence: 0.8, IoU: 0.0000
P1
[Precision 1] Confidence: 0.9, IoU: 0.000

SystemExit: 0