<a href="https://colab.research.google.com/github/geoaigroup/Aerial-SAM/blob/main/AerialSAM_GEOAI_Demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## SAM Online Demo: Segment everything Mode

The Segment Anything Model (SAM) predicts object masks given prompts that indicate the desired object.

Please go tho this link:
https://segment-anything.com/demo

And use this image as input:
https://github.com/geoaigroup/Aerial-SAM/blob/main/483.png

## Environment Set-up

If running locally using jupyter, first install `segment_anything` in your environment using the [installation instructions](https://github.com/facebookresearch/segment-anything#installation) in the repository. If running from Google Colab, set `using_colab=True` below and run the cell. In Colab, be sure to select 'GPU' under 'Edit'->'Notebook Settings'->'Hardware accelerator'.

In [None]:
#Samples used in this demo are from the WHU Building Dataset: https://paperswithcode.com/dataset/whu-building-dataset
# !wget https://github.com/geoaigroup/Aerial-SAM/raw/main/resources/data.zip
# !wget https://github.com/geoaigroup/Aerial-SAM/raw/main/resources/pred_shapefile.zip
# !unzip data.zip
# !unzip pred_shapefile

In [None]:
# using_colab = True

In [None]:
# if using_colab:
#     import torch
#     import torchvision
#     print("PyTorch version:", torch.__version__)
#     print("Torchvision version:", torchvision.__version__)
#     print("CUDA is available:", torch.cuda.is_available())
#     import sys
#     !{sys.executable} -m pip install opencv-python matplotlib
#     !{sys.executable} -m pip install 'git+https://github.com/facebookresearch/segment-anything.git'

#     !wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
#     !pip install geopandas
#     !pip install rasterio
#     !git clone https://github.com/geoaigroup/buildingsSAM.git


In [None]:
#Necessary imports and helper functions for displaying points, boxes, and masks.
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
import geopandas as gpd
import os
import json
import glob
from tqdm import tqdm
import shapely.geometry as sg
from shapely import affinity
from shapely.geometry import Point, Polygon
import random
from PIL import Image, ImageDraw
import rasterio
from rasterio.features import geometry_mask
#from metrics import DiceScore,IoUScore
import pandas as pd
import gc
import shutil
import fiona
import json

import utils
from evaluate import cal_scores,matching_algorithm
from pred_SAM import SAM


In [None]:

def cal_score(gt_tile, pred_tile):
    matcher = matching_algorithm(gt_tile, pred_tile)
    iou_list, f1_scores, tp_pred_indices, tp_gt_indices, fp_indices, fn_indices, mscores, precision, recall = matcher.matching()
    tp_iou_list, avg_tp_iou = matcher.tp_iou(tp_pred_indices, tp_gt_indices)
    score = {}
    scores_b = []
    score['iou_list'] = iou_list
    score['f1_scores'] = f1_scores
    score['tp_iou_list'] = tp_iou_list
    score['fp_indices'] = fp_indices
    score['fn_indices'] = fn_indices
    score['Mean_iou'] = np.mean(iou_list, dtype=float)
    score['Mean_f1'] = np.mean(f1_scores, dtype=float)
    score['avg_tp_iou'] = float(avg_tp_iou) if avg_tp_iou != None else 0.0
    score['precision'] = precision
    score['recall'] = recall

    for s in mscores:
        scores_b.append(s)
    scores_b.append(score)

    gtmask=np.zeros((512,512))
    predmask=np.zeros((512,512))
    for g in gt_tile:
        gtmask=g+gtmask
    for p in pred_tile:
        predmask=p+predmask
    fig,ax = plt.subplots(1,2,figsize = (10,10))
    ax = ax.ravel()
    ax[0].imshow(gtmask)
    ax[0].set_title("GT")
    ax[1].imshow(predmask)
    ax[1].set_title("MultiClassUnet CNN")
    plt.show()

    return scores_b

def Calculate_CNN_Results():
    ff = gpd.read_file(pred)
    score_list = []

    ids = [f for f in os.listdir(orig_shp)]

    for name in tqdm(ids):
        print(name)
        if glob.glob(score_dir + "/" + name + "_score.json" ):
            print("Found")
            continue
        if name in os.listdir(orig_shp):
            try:
                gt = gpd.read_file(orig_shp + "/" + name)
                if len(gt["geometry"]) == 0:
                    continue
            except Exception as e:
                print(e)
                continue
        else:
            continue
        predic = ff.loc[ff["ImageId"] == name]
        n=name.split('.')[0]
        if len(predic["geometry"]) == 0:
            continue


        gc.collect()

        gt_tile = []
        pred_tile=[]

        gt_tile=utils.convert_polygon_to_mask_batch(gt['geometry'])
        pred_tile=utils.convert_polygon_to_mask_batch(predic["geometry"])

        scores_res=cal_score(gt_tile, pred_tile)
        os.makedirs(score_dir, exist_ok=True)

        with open(score_dir + f'/{name}_score.json', 'w') as f1:
            json.dump(scores_res, f1)

    scores=cal_scores(output_dir,score_dir)
    scores.macro_score()

In [None]:
def main(CNN="",prompt_type="",sam=None):
    score_list = []
    scores=cal_scores(output_dir,score_dir)
    # ff = gpd.read_file(pred)
    ids = [f for f in os.listdir(pred)]
    for name in tqdm(ids):
        print(name)
        print("Checking")
        flag=0
        if glob.glob(output_dir + "/" + name + "/" + name + ".shp" ) or glob.glob(output_dir + "/" + name + "/" + name + ".png" ):
            print("Found")
            continue

        tile_boxes = []
        image_data=None
        try:
            # image = cv2.imread(images + "/" + name+'.tif')
            # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            with rasterio.open(images + "/" + name+'.tif') as src:
                # Read the image data
                image_data = src.read()
                transform=src.transform
        except Exception as e:
            print(e)
            print(name)

        # if name in os.listdir(orig_shp):
        #         gt = gpd.read_file(orig_shp + "/" + name)
        #         if len(gt["geometry"]) == 0:
        #             continue
        # else:
        #     continue
        if CNN=="multiclassUnet":

            # predic = ff.loc[ff["ImageId"] == name]
            predic = gpd.read_file(pred+"/"+name)
        elif CNN=="DCNN":
            predic = gpd.read_file(pred+"/"+name)

        geo = predic["geometry"]

        if len(geo) == 0:
            continue

        input_point=None
        input_label=None
        input_boxes=None
         
        match prompt_type:
            case "single point":
                input_point,input_label=utils.create_list_points(geo,name)
            case "single + negative points":
                input_point,input_label=utils.create_list_points(geo,name,flag="negative")
                print(input_point)
                print(input_label)
            #Skeleton
            case "skeleton":
                input_point=[]
                input_label=[]
                with open(skeleton_points, 'r') as json_file:
                    data = json.load(json_file)
                matching_items = []
                for item in data:
                    if item['id'] == name:
                        matching_items.append(item)

                input_point=torch.Tensor(matching_items[0]['input_points']).cuda()
                input_label=torch.Tensor(matching_items[0]['input_labels']).cuda().long()

            case "multiple points":
                input_point,input_label=utils.generate_random_points_polygon(geo)
            
            case "multiple points + single point":
                input_point,input_label=utils.generate_random_points_polygon(geo,flag="rep")

            case "multiple points + negative points":
                input_point,input_label=utils.generate_random_points_polygon(geo,flag="negative")
           
            #creating boxes
            case "box":
                input_boxes=[]
                flag=1
                ##for georeferenced polygons
                mask=utils.convert_polygon_to_mask_batch_transform(geo,(1024,1024),transform=transform)
                tile_boxes=utils.create_boxes(mask,shapefile=False)
               
                # tile_boxes=utils.create_boxes(geo,shapefile=False)
                input_boxes=torch.tensor(tile_boxes).cuda()
            case "box + single point":
                input_boxes=[]   
                tile_boxes=utils.create_boxes(geo)
                input_boxes=torch.tensor(tile_boxes).cuda()
                input_point,input_label=utils.create_list_points(geo,name)

            case "box + multiple points":
                input_boxes=[]
                tile_boxes=utils.create_boxes(geo)
                input_boxes=torch.tensor(tile_boxes).cuda()
                input_point,input_label=utils.generate_random_points_polygon(geo)

            case _:
                print("no or wrong prompt entered")
                
        image_data_transpose=image_data.transpose(1,2,0)
        print(image_data_transpose.shape)
        print(image_data_transpose.shape[:2])
        # x = torch.from_numpy(image.transpose(2, 0, 1)).float().cuda()
        # pred_mask=sam.predictSAM(x=x,image=image,input_point=input_point,input_label=input_label,input_boxes=input_boxes,flag=flag)
        x = torch.from_numpy(image_data).float().cuda()
        pred_mask=sam.predictSAM(x=x,image=image_data_transpose,input_point=input_point,input_label=input_label,input_boxes=input_boxes,flag=flag)
        os.makedirs(score_dir, exist_ok=True)
        os.makedirs(output_dir + "/" + f"{name}", exist_ok=True)
        
        utils.save_shp(pred_mask,name,output_dir,image_data_transpose.shape[:2],image_data_transpose,tile_boxes)
        
    #     scores.micro_match_iou(pred_mask,name,gt,score_list,image,input_point,input_label,tile_boxes,geo=geo)
    # scores.macro_score()


In [None]:
# Paths
checkpoint="/home/jamada/jupyterlab/models/sam_vit_h_4b8939.pth"
images = "data/images_fragmented/n1"
# orig_shp="data/pred_shapefile"
# skeleton_points="data/points.json"
pred = "data/fragmented_shapefiles_n1_1024/n1"
output_dir = "data/output"

score_dir = "data/scores"


#get Multiclass Unet initial results
# Calculate_CNN_Results()

#loading SAM Model
sam=SAM(checkpoint)

#load Multiclass Unet CNN prediction file
# ff = gpd.read_file(pred)

#Run SAM prediction with box prompt
main(CNN="multiclassUnet",prompt_type="box",sam=sam)


In [None]:
1/0

In [None]:
image = cv2.imread('data/images/n1.tif')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
with rasterio.open('data/images/n1.tif') as src:
    image_data = src.read()

# print("image",image.shape)
# image=image.transpose(2, 0, 1)
# print("after image",image.shape)
print("beforeraster",image_data.shape)
image_data=image_data.transpose(1,2,0)
print("raster",image_data.shape)

In [None]:
#for D-linkNet model
pred = "data/DCNN_pred_shapefile"
sam=SAM(checkpoint)
main(CNN="DCNN",prompt_type="box",sam=sam)

In [None]:
import rasterio
from rasterio.plot import show

# Path to the georeferenced image file (replace with your actual file path)
image_path = 'data/images/n1_0.tif'

# Open the georeferenced image
with rasterio.open(image_path) as src:
    # Read the image data
    image_data = src.read()

    # Get image metadata
    metadata = src.meta

# Display the image using rasterio's plotting capabilities
show(image_data, transform=metadata['transform'])

print("Print metadata")
print(metadata)


In [None]:
!unzip data/n1.zip -d data/n1