In [None]:
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Once the repo is cloned then:
#best practice remove and reeinstal fspec
#!rm -rf /opt/conda/lib/python3.10/site-packages/fsspec*
#!pip install fsspec==2024.6.1 --force-reinstall --no-deps

#pip install e .
#pip install -e ".[demo]"


# #install also to vizualize figures
# sudo apt-get update
# sudo apt-get install -y libgl1-mesa-glx
# sudo apt-get install -y libglib2.0-0



In [None]:
!pip install shapely

In [None]:
!pip install s3fs

In [None]:
!pip install rasterio

## Running an example of general segmentation using SAM2

In this example, the model segments everything it finds in the image.


In [None]:
import os
# if using Apple MPS, fall back to CPU for unsupported ops
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import numpy as np
import torch
import matplotlib.pyplot as plt
from PIL import Image

In [None]:
# select the device for computation
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"using device: {device}")

if device.type == "cuda":
    # use bfloat16 for the entire notebook
    torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
    # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
    if torch.cuda.get_device_properties(0).major >= 8:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
elif device.type == "mps":
    print(
        "\nSupport for MPS devices is preliminary. SAM 2 is trained with CUDA and might "
        "give numerically different outputs and sometimes degraded performance on MPS. "
        "See e.g. https://github.com/pytorch/pytorch/issues/84936 for a discussion."
    )

In [None]:
np.random.seed(3)

def show_anns(anns, borders=True):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)

    img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
    img[:, :, 3] = 0
    for ann in sorted_anns:
        m = ann['segmentation']
        color_mask = np.concatenate([np.random.random(3), [0.5]])
        img[m] = color_mask 
        if borders:
            import cv2
            contours, _ = cv2.findContours(m.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) 
            # Try to smooth contours
            contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
            cv2.drawContours(img, contours, -1, (0, 0, 1, 0.4), thickness=1) 

    ax.imshow(img)

### Geting the image from Solafune competition uploaded to my s3 bucket

In [None]:
import rasterio
import numpy as np

def load_image_as_array(image_s3_uri):
    """Load an image from S3 and convert it to a NumPy array."""
    # Open the image using rasterio
    with rasterio.open(image_s3_uri) as src:
        # Read the image data
        image = src.read()  # This will read all the bands
        image = np.moveaxis(image, 0, -1)  # Move channels to the last dimension
        image = image[:, :, :3]  # Assuming you want to use only the first 3 bands (R, G, B)
    return image

# Example usage
image_s3_uri = 's3://solafune/train_images/images/train_25.tif'
image = load_image_as_array(image_s3_uri)

# Now 'image' is a NumPy array in RGB format
print(image.shape)  # This will print the shape of the image array


# Assuming 'image' is the array you loaded
image_normalized = (image - np.min(image)) / (np.max(image) - np.min(image)) * 255
image_normalized = image_normalized.astype(np.uint8)  # Convert to uint8

image = image_normalized

In [None]:
from sam2.build_sam import build_sam2
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator

sam2_checkpoint = "./segment-anything-2/checkpoints/sam2_hiera_large.pt" # try large sam2_hiera_large.pt
model_cfg = "sam2_hiera_l.yaml" #"sam2_hiera_l.yaml" for large

sam2 = build_sam2(model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False)

mask_generator = SAM2AutomaticMaskGenerator(sam2)

In [None]:
masks = mask_generator.generate(image)

In [None]:
print(len(masks))
print(masks[0].keys())

In [None]:
plt.figure(figsize=(20, 20))
plt.imshow(image)
show_anns(masks)
plt.axis('off')
plt.show() 

# Finetuning by retraining the model with solafune's data

In [None]:
import numpy as np
import rasterio
from shapely.geometry import Polygon
from shapely.ops import transform
import cv2
from PIL import Image, ImageDraw

def polygon_to_mask(polygon, width, height):
    mask = Image.new('L', (width, height), 0)
    ImageDraw.Draw(mask).polygon(polygon, outline=1, fill=1)
    return torch.tensor(np.array(mask), dtype=torch.float32)


In [None]:
def load_image(image_s3_uri):
    """Load an image from S3 and convert it to a NumPy array."""
    # Open the image using rasterio
    with rasterio.open(image_s3_uri) as src:
        # Read the image data
        image = src.read()  # This will read all the bands
        image = np.moveaxis(image, 0, -1)  # Move channels to the last dimension
        image = image[:, :, :3]  # Assuming you want to use only the first 3 bands (R, G, B)
        image_normalized = (image - np.min(image)) / (np.max(image) - np.min(image)) * 255
        image_normalized = image_normalized.astype(np.uint8)
    return image_normalized


def load_annotations(annotation_s3_uri, image_filename):
    """Load annotations for a specific image from a JSON file on S3."""
    fs = s3fs.S3FileSystem()
    with fs.open(annotation_s3_uri, 'r') as f:
        data = json.load(f)
    
    for img in data['images']:
        if img['file_name'] == image_filename:
            return img['annotations']
    return None
    
def load_original_annotations(annotation_s3_uri):
    """Load annotations for a specific image from a JSON file on S3."""
    fs = s3fs.S3FileSystem()
    with fs.open(annotation_s3_uri, 'r') as f:
        data = json.load(f)
    return data

In [None]:
import torch
from torch.utils.data import Dataset

class SAM2Dataset(Dataset):
    def __init__(self, image_filenames, annotations, image_s3_prefix):
        self.image_filenames = image_filenames
        self.annotations = annotations
        self.image_s3_prefix = image_s3_prefix

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        image_filename = self.image_filenames[idx]
        image_s3_uri = os.path.join(self.image_s3_prefix, image_filename)

        # Load image using the new load_image function
        image = load_image(image_s3_uri)
        
        # Ensure the image is in the correct format for the model
        image_tensor = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1) / 255.0
        
        # Set requires_grad=True to enable gradient computation
        #image_tensor.requires_grad_(True)

        # Load annotations
        annotations = load_annotations(self.annotations, image_filename)
        masks = []
        height, width = image.shape[:2]  # Get the image dimensions
        for annotation in annotations:
            mask = polygon_to_mask(annotation['segmentation'], width, height)
            masks.append(mask)

        # Stack all masks into a single tensor
        masks_tensor = torch.stack(masks, dim=0)

        return image_tensor, masks_tensor


In [None]:
# Load Json
import s3fs
import json
train_annotation_s3_uri = 's3://solafune/train_annotation.json'

In [None]:
from torch.utils.data import DataLoader

# List of image filenames
image_filenames = [img['file_name'] for img in load_original_annotations(train_annotation_s3_uri)['images']]

#image_filenames = ['train_28.tif'] # set a quick training with specific images
image_s3_prefix = 's3://solafune/train_images/images'

# Instantiate Dataset
dataset = SAM2Dataset(image_filenames, train_annotation_s3_uri, image_s3_prefix)

# Create DataLoader
data_loader = DataLoader(dataset, batch_size=1, shuffle=True)

In [None]:
from sam2.build_sam import build_sam2
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator

model = build_sam2(model_cfg, sam2_checkpoint, device=device, apply_postprocessing=False)


In [None]:
def convert_to_binary_masks(predicted_masks):
    """
    Converts boolean masks from SAM2's output to binary masks.

    Args:
    - predicted_masks (list of dict): List of predicted masks with 'segmentation' key containing boolean arrays.

    Returns:
    - binary_masks (list of np.array): List of binary masks (1 and 0).
    """
    binary_masks = []
    for mask_data in predicted_masks:
        # Convert the boolean segmentation mask to an integer binary mask
        binary_mask = mask_data['segmentation'].astype(np.uint8)
        binary_masks.append(binary_mask)
    
    return binary_masks

In [None]:
import s3fs
import json

#Looking into the data - Checking if everything is in good order #  image and gt masks


for i, (images, masks) in enumerate(data_loader):
    print(f"Batch {i + 1}:")
    print(f" - Image shape before permute: {images.shape}")
    print(f" - Masks shape: {masks.shape}")
    
    # Assuming images are in the shape [B, C, H, W]
    # Permute to shape [H, W, C] for visualization
    image = images[0].permute(1, 2, 0)
    print(f" - Image shape after permute: {image.shape}")
    
    plt.figure(figsize=(12, 6))
    
    # Show the image
    plt.subplot(1, 2, 1)
    plt.imshow(image.detach().numpy())
    plt.title("Image")
    plt.axis('off')
    
    # Show the masks overlayed on the image
    combined_mask = masks[0].sum(axis=0)
    plt.subplot(1, 2, 2)
    plt.imshow(image.detach().numpy())
    plt.imshow(combined_mask, alpha=0.5, cmap='jet')
    plt.title("Masks Overlay")
    plt.axis('off')
    
    plt.show()
    
    # Break after the first batch to inspect
    if i == 0:
        break


# Main train Loop

##### Using a panoptic loss both to better perform on the chalange, and also to compare multiple masks.

In [None]:
import torch
import torch.nn as nn
from tqdm import tqdm

class PanopticLoss(nn.Module):
    def __init__(self, iou_threshold=0.5):
        super(PanopticLoss, self).__init__()
        self.iou_threshold = iou_threshold

    def forward(self, predicted_masks, target_masks):
        # Convert to boolean tensors for logical operations
        predicted_masks = predicted_masks > 0.5
        target_masks = target_masks > 0.5

        print(f"Predicted Masks Shape: {predicted_masks.shape}")
        print(f"Target Masks Shape: {target_masks.shape}")

        batch_size = predicted_masks.shape[0]
        total_loss = 0.0

        for i in range(batch_size):
            pred = predicted_masks[i]
            target = target_masks[i]

            print(f"Batch {i + 1}: Pred Shape: {pred.shape}, Target Shape: {target.shape}")

            matched_pred, matched_target = self.match_masks(pred, target)

            print(f"Matched Pred Shape: {matched_pred.shape}, Matched Target Shape: {matched_target.shape}")

            iou_scores = self.calculate_iou(matched_pred, matched_target)

            print(f"IoU Scores: {iou_scores}")

            if iou_scores.numel() > 0:  # Ensure there's something to calculate
                loss = 1 - iou_scores.mean()
                print(f"Loss for Batch {i + 1}: {loss.item()}")
            else:
                loss = torch.tensor(0.0, device=pred.device)
                print("No valid IoU scores, setting loss to 0.")

            total_loss += loss

        final_loss = total_loss / batch_size
        print(f"Final Loss: {final_loss.item()}")

        return final_loss

    def match_masks(self, pred, target):
        pred = pred.view(pred.size(0), -1)
        target = target.view(target.size(0), -1)

        print(f"Flattened Pred Shape: {pred.shape}")
        print(f"Flattened Target Shape: {target.shape}")

        matched_pred = []
        matched_target = []

        for i, t in enumerate(target):
            best_iou = 0
            best_pred = None
            for j, p in enumerate(pred):
                intersection = torch.logical_and(p, t).float().sum()
                union = torch.logical_or(p, t).float().sum()
                iou = (intersection + 1e-6) / (union + 1e-6)
                if iou > best_iou:
                    best_iou = iou
                    best_pred = p
            if best_iou > self.iou_threshold:
                matched_pred.append(best_pred)
                matched_target.append(t)

        if matched_pred and matched_target:
            print(f"Number of Matched Pairs: {len(matched_pred)}")
            return torch.stack(matched_pred), torch.stack(matched_target)
        else:
            print("No matches found. Returning zero-filled tensors.")
            return torch.zeros_like(pred), torch.zeros_like(target)

    def calculate_iou(self, pred, target):
        if pred.dim() == 2:
            pred = pred.unsqueeze(0)
        if target.dim() == 2:
            target = target.unsqueeze(0)

        intersection = (pred & target).float().sum((1, 2))
        union = (pred | target).float().sum((1, 2))
        iou = (intersection + 1e-6) / (union + 1e-6)
        return iou


In [None]:
# import torch

# class SAM2MaskGenerationFunction(torch.autograd.Function):
#     @staticmethod
#     def forward(ctx, img_tensor, mask_generator):
#         # Save the original image tensor for backward pass
#         ctx.save_for_backward(img_tensor)
        
#         # Convert the image tensor to a NumPy array for SAM2
#         img_np = img_tensor.permute(1, 2, 0).cpu().numpy()

#         # Generate masks using SAM2 (non-differentiable operation)
#         masks_output = mask_generator.generate(img_np)

#         # Convert masks to binary mask tensor
#         binary_masks_np = np.array(convert_to_binary_masks(masks_output))
#         binary_masks_tensor = torch.tensor(binary_masks_np, dtype=torch.float32).to(img_tensor.device)
        
#         # Reintroduce a differentiable operation
#         #binary_masks_tensor = binary_masks_tensor * img_tensor.sum() * 0 + binary_masks_tensor

#         return binary_masks_tensor

#     @staticmethod
#     def backward(ctx, grad_output):
#         # Retrieve the saved image tensor
#         img_tensor, = ctx.saved_tensors

#         # Here, sum/average the gradients to reduce the dimensions
#         # For simplicity, we assume reducing across the first dimension (e.g., channels or masks)
#         grad_input = grad_output.sum(dim=0, keepdim=True)  # Adjust this based on how you want to reduce

#         # Optionally, expand/reduce dimensions to match input tensor
#         grad_input = grad_input.expand_as(img_tensor)

#         return grad_input, None  # Return gradients for img_tensor and None for mask_generator

# # Example usage remains the same as before.


In [None]:
import torch
import torch.nn as nn

class SimpleBCELoss(nn.Module):
    def __init__(self):
        super(SimpleBCELoss, self).__init__()
        self.bce_loss = nn.BCEWithLogitsLoss()

    def forward(self, predicted_masks, target_masks):
        # Check initial grad_fn
        print("Initial predicted_masks grad_fn:", predicted_masks.grad_fn)
        print("Initial target_masks grad_fn:", target_masks.grad_fn)

        # Collapse masks by summing across the mask dimension (dim=1)
        predicted_mask_sum = predicted_masks.sum(dim=1)
        target_mask_sum = target_masks.sum(dim=1)

        # Flatten the masks to make them suitable for binary cross-entropy loss
        predicted_mask_flat = predicted_mask_sum.view(predicted_mask_sum.size(0), -1)
        target_mask_flat = (target_mask_sum > 0).float().view(target_mask_sum.size(0), -1)

        # Calculate binary cross-entropy loss with logits
        loss = self.bce_loss(predicted_mask_flat, target_mask_flat)

        # Check if loss has a valid grad_fn
        print("Loss grad_fn:", loss.grad_fn)

        return loss


In [None]:
def train_model_with_feedback(model, data_loader, mask_generator, num_epochs=5, print_interval=1):
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    #criterion = PanopticLoss(iou_threshold=0.5)  # Make sure this is fully differentiable
    criterion = SimpleBCELoss()  # Make sure this is fully differentiableSimplifiedPanopticLoss
    model.train()

    for epoch in range(num_epochs):
        epoch_loss = 0
        for i, (images, masks) in enumerate(tqdm(data_loader, total=len(data_loader))):
            images, masks = images.to(device), masks.to(device)

            optimizer.zero_grad()

            # Generate predicted masks
            masks_pred = []
            for img in images:
                img = img.permute(1, 2, 0).cpu().numpy()
                masks_output = mask_generator.generate(img)
                binary_masks_np = np.array(convert_to_binary_masks(masks_output))
                masks_pred.append(torch.tensor(binary_masks_np, dtype=torch.float32,requires_grad=True).to(device))

            masks_pred = torch.stack(masks_pred)
            
            # Check if masks_pred still has a valid grad_fn
            #print("Before loss calculation:", masks_pred.grad_fn)
            #loss = masks_pred.sum()  # Simplified loss for debugging
            # Compute the loss (simplify this part to isolate the problem)
            loss = criterion(masks_pred, masks)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

            # Print loss for every `print_interval` batches
            if (i + 1) % print_interval == 0:
                print(f"Batch [{i + 1}/{len(data_loader)}], Loss: {loss.item():.4f}")

        print(f"Epoch [{epoch + 1}/{num_epochs}], Average Loss: {epoch_loss / len(data_loader):.4f}")

    print("Training completed.")




In [None]:
# Replace with your actual data loader, model, and mask_generator
train_model_with_feedback(model, data_loader, mask_generator, num_epochs=5, print_interval=1)

In [None]:
# Suppose `model` is your model instance
torch.save(model.state_dict(), 'model_weights.pth')


## Test the results

In [None]:
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()  # Set the model to evaluation mode if you're done with training


In [None]:
#testing the new model

image_s3_uri = 's3://solafune/train_images/images/train_0.tif'
image = load_image_as_array(image_s3_uri)
image_normalized = (image - np.min(image)) / (np.max(image) - np.min(image)) * 255
image_normalized = image_normalized.astype(np.uint8)  # Convert to uint8
image = image_normalized

#building new generat


new_mask_generator = SAM2AutomaticMaskGenerator(model)

In [None]:
masks = mask_generator.generate(image)
new_masks = new_mask_generator.generate(image)

In [None]:
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_anns(masks)
plt.axis('off')
plt.show() 

In [None]:
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_anns(new_masks)
plt.axis('off')
plt.show() 