In [None]:
!pip install segment-anything
!pip install opencv-python matplotlib



In [None]:
import torch
import torchvision
import sys
import numpy as np
import matplotlib.pyplot as plt
import cv2
import json
from IPython.display import display, clear_output
import os
from google.colab import drive
from segment_anything import sam_model_registry, SamPredictor

In [None]:
# Mount Google Drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
print("PyTorch version:", torch.__version__)
print("Torchvision version:", torchvision.__version__)
print("CUDA is available:", torch.cuda.is_available())

PyTorch version: 2.4.0+cpu
Torchvision version: 0.19.0+cpu
CUDA is available: False


In [None]:
# Download SAM model if not already present
if not os.path.exists("sam_vit_h_4b8939.pth"):
    !wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

In [None]:
def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)

def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))

In [None]:
def load_and_display_images(folder_path):
    images = []
    for filename in os.listdir(folder_path):
        if filename.endswith((".jpg", ".jpeg", ".png")):
            img_path = os.path.join(folder_path, filename)
            img = cv2.imread(img_path)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            images.append((filename, img))
            plt.figure(figsize=(10,10))
            plt.imshow(img)
            plt.title(filename)
            plt.axis('on')
            plt.show()
    return images

def choose_json_file(json_path):
    print(f"Checking for JSON file at: {json_path}")
    if os.path.exists(json_path):
        print(f"JSON file found at {json_path}.")
        use_existing = input("Use existing file? (yes/no): ").lower() == 'yes'
        if use_existing:
            with open(json_path, 'r') as f:
                return json.load(f), True
        else:
            return {}, False
    else:
        print(f"No existing JSON file found at {json_path}. A new file will be created.")
        return {}, False

def create_saliency_map(mask, output_path):
    saliency = (mask * 255).astype(np.uint8)
    saliency = cv2.resize(saliency, (1920, 1080), interpolation=cv2.INTER_NEAREST)
    cv2.imwrite(output_path, saliency)


In [None]:
def process_images(images, json_path, saliency_folder, use_existing_json, existing_data):
    image_points = existing_data if use_existing_json else {}

    sam_checkpoint = "sam_vit_h_4b8939.pth"
    model_type = "vit_h"
    device = "cuda" if torch.cuda.is_available() else "cpu"
    sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
    sam.to(device=device)
    predictor = SamPredictor(sam)

    for filename, image in images:
        if filename in image_points and use_existing_json:
            print(f"Using existing points for {filename}")
            input_points = np.array(image_points[filename])

            # Display image with existing points
            plt.figure(figsize=(10,10))
            plt.imshow(image)
            show_points(input_points, np.ones(len(input_points)), plt.gca())
            plt.title(f"Existing points for {filename}")
            plt.axis('off')
            plt.show()

            # Generate mask using existing points
            predictor.set_image(image)
            masks, scores, logits = predictor.predict(
                point_coords=input_points,
                point_labels=np.ones(len(input_points)),
                multimask_output=True,
            )

            # Display masks
            for i, (mask, score) in enumerate(zip(masks, scores)):
                plt.figure(figsize=(10,10))
                plt.imshow(image)
                show_mask(mask, plt.gca())
                show_points(input_points, np.ones(len(input_points)), plt.gca())
                plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
                plt.axis('off')
                plt.show()

            # Ask user to choose the best mask or redo
            choice = input("Enter the number of the best mask (or 'redo' to input new points): ")
            if choice.lower() == 'redo':
                use_existing_json = False
            else:
                best_mask_index = int(choice) - 1

                # Create and save saliency map
                saliency_filename = f"{os.path.splitext(filename)[0]}_Saliency.jpeg"
                saliency_path = os.path.join(saliency_folder, saliency_filename)
                create_saliency_map(masks[best_mask_index], saliency_path)
                print(f"Saliency map saved to {saliency_path}")
                continue

        while True:
            points_input = input(f"Enter points for {filename} as x1,y1 x2,y2 ... (or 'skip' to skip this image): ")
            if points_input.lower() == 'skip':
                break

            input_points = np.array([list(map(float, p.split(','))) for p in points_input.split()])
            input_labels = np.ones(len(input_points))

            predictor.set_image(image)
            masks, scores, logits = predictor.predict(
                point_coords=input_points,
                point_labels=input_labels,
                multimask_output=True,
            )

            for i, (mask, score) in enumerate(zip(masks, scores)):
                plt.figure(figsize=(10,10))
                plt.imshow(image)
                show_mask(mask, plt.gca())
                show_points(input_points, input_labels, plt.gca())
                plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
                plt.axis('off')
                plt.show()

            choice = input("Enter the number of the best mask (or 'redo' to input new points): ")
            if choice.lower() == 'redo':
                continue

            best_mask_index = int(choice) - 1
            image_points[filename] = input_points.tolist()

            # Create and save saliency map
            saliency_filename = f"{os.path.splitext(filename)[0]}_Saliency.jpeg"
            saliency_path = os.path.join(saliency_folder, saliency_filename)
            create_saliency_map(masks[best_mask_index], saliency_path)
            print(f"Saliency map saved to {saliency_path}")

            break

        # Save to JSON after each image
        with open(json_path, 'w') as f:
            json.dump(image_points, f, indent=2)
        print(f"Points saved to {json_path}")

    return image_points

In [None]:
def process_all_batches(base_folder):
    batch_folders = ['Batch_Seven']


    for batch in batch_folders:
        folder_path = os.path.join(base_folder, batch)

        # Special case for Batch One
        if batch == 'Batch_One':
            json_filename = 'image_points.json'
        else:
            json_filename = f'image_points_{batch.lower()}.json'

        json_path = os.path.join(folder_path, json_filename)
        saliency_folder = os.path.join(base_folder, 'Saliency', batch)

        print(f"\nProcessing {batch}")
        print(f"Folder path: {folder_path}")
        print(f"JSON path: {json_path}")
        print(f"Saliency folder: {saliency_folder}")

        if not os.path.exists(folder_path):
            print(f"Warning: Folder {folder_path} does not exist. Skipping this batch.")
            continue

        os.makedirs(saliency_folder, exist_ok=True)

        images = load_and_display_images(folder_path)
        print(f"Total images loaded: {len(images)}")

        existing_data, use_existing_json = choose_json_file(json_path)

        final_image_points = process_images(images, json_path, saliency_folder, use_existing_json, existing_data)
        print(f"Final image points for {batch}:")
        print(json.dumps(final_image_points, indent=2))


In [None]:
# Main execution
base_folder = '/content/drive/MyDrive/Data/images/'
print(f"Base folder: {base_folder}")
process_all_batches(base_folder)