In [36]:
# Necessary libraries are installed along with sam2 pip install -e ".[demo]"

In [37]:
import cv2
from shapely.geometry import Polygon

def mask_to_polygon(mask):
    """
    Convert a binary mask into a polygon.
    """
    contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    polygons = []
    for contour in contours:
        if contour.shape[0] > 2:  # Ensure it's a valid polygon with at least 3 points
            polygons.append(contour.squeeze(1).tolist())
    return polygons


In [38]:
import boto3
import rasterio
import matplotlib.pyplot as plt
import numpy as np
from io import BytesIO

# Function to convert Sentinel-2 TIFF image to RGB
def load_and_convert_tiff(image_path):
    with rasterio.open(image_path) as src:
        # Read the red, green, and blue bands (B4, B3, B2)
        red_band = src.read(4)
        green_band = src.read(3)
        blue_band = src.read(2)

        # Stack the bands to create an RGB image
        rgb = np.stack([red_band, green_band, blue_band], axis=-1)

        # Normalize to [0, 255]
        rgb_normalized = np.clip((rgb / np.max(rgb)) * 255, 0, 255).astype(np.uint8)

        return rgb_normalized

# Function to process all images in the S3 bucket
def process_all_images_in_bucket(bucket_name, prefix, mask_generator):
    s3 = boto3.client('s3')
    response = s3.list_objects_v2(Bucket=bucket_name, Prefix=prefix)
    
    # Iterate through all image objects in the S3 bucket
    for obj in response.get('Contents', []):
        img_path = obj['Key']
        
        # Load the image from S3
        s3_response = s3.get_object(Bucket=bucket_name, Key=img_path)
        image = load_and_convert_tiff(BytesIO(s3_response['Body'].read()))
        
        try:
            # Generate masks for the current image
            masks = mask_generator.generate(image)

            # Visualization
            plt.figure(figsize=(20, 20))
            plt.imshow(image)
            show_anns(masks)
            plt.axis('off')
            plt.show()

            # Save the masks or process them as needed

        except Exception as e:
            print(f"Mask generation failed for {img_path}: {str(e)}. Skipping...")


In [39]:
import json
import boto3
from io import BytesIO
from PIL import Image
import numpy as np

def process_images_and_convert_to_json(bucket_name, prefix, mask_generator):
    s3 = boto3.client('s3')
    response = s3.list_objects_v2(Bucket=bucket_name, Prefix=prefix)
    
    # Initialize the root structure
    data = {
        "images": []
    }

    for obj in response.get('Contents', []):
        img_path = obj['Key']
        
        # Load and convert image from S3
        s3_response = s3.get_object(Bucket=bucket_name, Key=img_path)
        image = load_and_convert_tiff(BytesIO(s3_response['Body'].read()))
        
        cleaned_img_name = img_path.replace(prefix, '')

        try:
            # Generate masks for the current image
            masks = mask_generator.generate(image)

            # Convert masks to polygons
            polygons = []
            for mask_dict in masks:
                mask = mask_dict.get("segmentation")  # Assuming the mask is under the 'segmentation' key
                if mask is not None:
                    polygons.extend(mask_to_polygon(mask))
            
            # Prepare JSON structure for this image
            image_entry = {
                "file_name": cleaned_img_name,
                "annotations": []
            }
            
            for polygon in polygons:
                annotation = {
                    "class": "field",
                    "segmentation": [coord for point in polygon for coord in point]
                }
                image_entry["annotations"].append(annotation)

            # Add the image entry to the root structure
            data["images"].append(image_entry)
            
        except Exception as e:
            print(f"Processing failed for {img_path}: {str(e)}. Saving as empty annotations.")
            data["images"].append({
                "file_name": cleaned_img_name,
                "annotations": []
            })
    
    # Save the final JSON structure
    json_output_path = 'final_predictions.json'
    with open(json_output_path, 'w') as f:
        json.dump(data, f, indent=2)
    print(f"Saved final predictions to {json_output_path}")

    return data


In [42]:
# Example usage
import torch
from sam2.build_sam import build_sam2
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
bucket_name = 'solafune'
prefix = 'test_images/images/'

# Load the model
sam2_checkpoint = "./segment-anything-2/checkpoints/sam2_hiera_small.pt" # try large sam2_hiera_large.pt
model_cfg = "sam2_hiera_s.yaml" #"sam2_hiera_l.yaml" for large
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda")


mask_generator = SAM2AutomaticMaskGenerator(sam2_model)
saved_state = torch.load("model.torch")
mask_generator.model.load_state_dict(saved_state)


# process_all_images_in_bucket(bucket_name, prefix, mask_generator)

# #If saved
# saved_state = torch.load("model_small.torch")
# predictor.model.load_state_dict(saved_state)

final_json_structure = process_images_and_convert_to_json(bucket_name, prefix, mask_generator)


AttributeError: 'SAM2AutomaticMaskGenerator' object has no attribute 'model'

In [19]:
!python validator.py --file_path final_predictions.json

Valid


In [30]:
import json

# Path to your JSON file
file_path = "final_predictions.json"

# Open and read the JSON file
with open(file_path, 'r') as file:
    data = file.read()

# Print the first 2000 characters
print(data[:500])

{
  "images": [
    {
      "file_name": "test_0.tif",
      "annotations": [
        {
          "class": "field",
          "segmentation": [
            705,
            29,
            704,
            30,
            703,
            30,
            702,
            31,
            701,
            31,
            699,
            33,
            698,
            33,
            697,
            34,
            696,
            34,
            692,
            38,
            691,
         


In [33]:
import json

# Path to your JSON file
file_path = "final_predictions.json"

# Load the JSON data
with open(file_path, 'r') as file:
    data = json.load(file)  # This correctly parses the JSON into a Python object

# Specify the file name you want to look for
target_file_name = "test_2.tif"

# Access the list of images
images = data.get("images", [])

# Find and print the data for the specified file_name
for entry in images:
    if entry.get('file_name') == target_file_name:
        print(json.dumps(entry, indent=2))
        break
else:
    print(f"file_name {target_file_name} not found.")


{
  "file_name": "test_2.tif",
  "annotations": [
    {
      "class": "field",
      "segmentation": [
        785,
        354,
        784,
        355,
        783,
        355,
        781,
        357,
        781,
        359,
        780,
        360,
        780,
        361,
        778,
        363,
        767,
        363,
        766,
        364,
        766,
        369,
        765,
        370,
        765,
        395,
        764,
        396,
        764,
        397,
        765,
        398,
        765,
        400,
        766,
        401,
        816,
        401,
        818,
        399,
        818,
        374,
        819,
        373,
        819,
        356,
        818,
        355,
        817,
        355,
        816,
        354
      ]
    },
    {
      "class": "field",
      "segmentation": [
        353,
        515,
        352,
        516,
        352,
        517,
        351,
        518,
        351,
        519,
        352,
        5