<a href="https://colab.research.google.com/github/geoaigroup/geotils/blob/main/ToBeChecked/Demo_SAMGhandour.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## SAM Online Demo: Segment everything Mode

## Environment Set-up

If running locally using jupyter, first install `segment_anything` in your environment using the [installation instructions](https://github.com/facebookresearch/segment-anything#installation) in the repository. If running from Google Colab, set `using_colab=True` below and run the cell. In Colab, be sure to select 'GPU' under 'Edit'->'Notebook Settings'->'Hardware accelerator'.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install rasterio
!git clone https://github.com/facebookresearch/segment-anything
%cd /content/segment-anything
!pip install -e .
%cd ..


In [None]:
!wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

In [None]:
#Necessary imports and helper functions for displaying points, boxes, and masks.
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
import geopandas as gpd
import os
import json
import glob
from tqdm import tqdm
import shapely.geometry as sg
from shapely import affinity
from shapely.geometry import Point, Polygon
import random
from PIL import Image, ImageDraw
import rasterio
from rasterio.features import geometry_mask
from shapely.geometry import shape

#from metrics import DiceScore,IoUScore
import pandas as pd
import gc
import shutil
import fiona
import json

import sys
sys.path.append("/content/")
from evaluate import cal_scores, matching_algorithm
import SAMutils as utils
from pred_SAM import SAM


In [None]:
def cal_score(gt_tile, pred_tile):
    matcher = matching_algorithm(gt_tile, pred_tile)
    iou_list, f1_scores, tp_pred_indices, tp_gt_indices, fp_indices, fn_indices, mscores, precision, recall = matcher.matching()
    tp_iou_list, avg_tp_iou = matcher.tp_iou(tp_pred_indices, tp_gt_indices)
    score = {}
    scores_b = []
    score['iou_list'] = iou_list
    score['f1_scores'] = f1_scores
    score['tp_iou_list'] = tp_iou_list
    score['fp_indices'] = fp_indices
    score['fn_indices'] = fn_indices
    score['Mean_iou'] = np.mean(iou_list, dtype=float)
    score['Mean_f1'] = np.mean(f1_scores, dtype=float)
    score['avg_tp_iou'] = float(avg_tp_iou) if avg_tp_iou != None else 0.0
    score['precision'] = precision
    score['recall'] = recall

    for s in mscores:
        scores_b.append(s)
    scores_b.append(score)

    gtmask=np.zeros((512,512))
    predmask=np.zeros((512,512))
    for g in gt_tile:
        gtmask=g+gtmask
    for p in pred_tile:
        predmask=p+predmask
    fig,ax = plt.subplots(1,2,figsize = (10,10))
    ax = ax.ravel()
    ax[0].imshow(gtmask)
    ax[0].set_title("GT")
    ax[1].imshow(predmask)
    ax[1].set_title("MultiClassUnet CNN")
    plt.show()

    return scores_b

def Calculate_CNN_Results():
    ff = gpd.read_file(pred)
    score_list = []

    ids = [f for f in os.listdir(orig_shp)]

    for name in tqdm(ids):
        print(name)
        if glob.glob(score_dir + "/" + name + "_score.json" ):
            print("Found")
            continue
        if name in os.listdir(orig_shp):
            try:
                gt = gpd.read_file(orig_shp + "/" + name)
                if len(gt["geometry"]) == 0:
                    continue
            except Exception as e:
                print(e)
                continue
        else:
            continue
        predic = ff.loc[ff["ImageId"] == name]
        n=name.split('.')[0]
        if len(predic["geometry"]) == 0:
            continue


        gc.collect()

        gt_tile = []
        pred_tile=[]

        gt_tile=utils.convert_polygon_to_mask_batch(gt['geometry'])
        pred_tile=utils.convert_polygon_to_mask_batch(predic["geometry"])

        scores_res=cal_score(gt_tile, pred_tile)
        os.makedirs(score_dir, exist_ok=True)

        with open(score_dir + f'/{name}_score.json', 'w') as f1:
            json.dump(scores_res, f1)

    scores=cal_scores(output_dir,score_dir)
    scores.macro_score()

In [None]:
def main(image, pred_poly):
    # score_list = []
    # scores=cal_scores(output_dir,score_dir)
    # # ff = gpd.read_file(pred)
    # ids = [f for f in os.listdir(orig_shp)]

    # print(name)
    # print("Checking")
    flag=0
    #predic = gpd.read_file("data/pred_shapefile/n1_regularized/n1_regularized.shp")
    #geo = predic["geometry"]
    geo = pred_poly

    tile_boxes = []
    input_point=None
    input_label=None
    input_boxes=None

    input_boxes=[]
    flag=1
    tile_boxes=utils.create_boxes(geo)
    input_boxes=torch.tensor(tile_boxes).cuda()

    x = torch.from_numpy(image.transpose(2, 0, 1)).float().cuda()
    pred_mask=sam.predictSAM(x=x,image=image,input_point=input_point,input_label=input_label,input_boxes=input_boxes,flag=flag)
    #os.makedirs(score_dir, exist_ok=True)
    #os.makedirs(output_dir + "/" + f"{name}", exist_ok=True)

    #scores.micro_match_iou(pred_mask,name,gt,score_list,image,input_point,input_label,tile_boxes,geo=geo)
    return pred_mask
    #scores.macro_score()


In [None]:
import rasterio as rio
from old_utils import *    #not using the latets version with onnx
import sys
sys.path.append("/content/segment-anything")

in_dir = "/content/drive/MyDrive/"
out_dir = r'/content/results'
checkpoint="/content/sam_vit_h_4b8939.pth"

sam=SAM(checkpoint)

DOWN_SAMPLING=3

directory = os.fsencode(in_dir)

for file in os.listdir(directory):
    filename = os.fsdecode(file)
    if filename.endswith(".tif"):
        #tiff_path = '../results/MAPSGEO/Meriata/Meriata_2022_06_28.tif'

        raster_file = rio.open(f'{in_dir}/{filename}')
        full_img = raster_file.read([1,2,3]).transpose(1,2,0)

        #ff = gpd.read_file(pred)

        mask_geotif = rio.open(f'/content/results_n0.tif')
        predic_mask = mask_geotif.read()[0]#.transpose(1,2,0)

        HEIGHT_orig, WIDTH_orig = full_img.shape[:2]
        full_img = cv2.resize(full_img, (WIDTH_orig//DOWN_SAMPLING, HEIGHT_orig//DOWN_SAMPLING))

        #Use below for gray images only
        #full_img = raster_file.read().transpose(1,2,0)[:,:,0]
        #full_img = cv2.cvtColor(full_img,cv2.COLOR_GRAY2RGB)

        full_img, rrp_info = ratio_resize_pad(full_img, ratio = None, div=1024)
        full_predic_mask, mask_rrp_info = ratio_resize_pad(predic_mask, ratio = None, div=1024)



        #patching and running model
        PATCH_SIZE = 1024
        STRIDE_SIZE = 512
        #CROP_SIZE = 768

        HEIGHT, WIDTH = full_img.shape[:2]


        full_mask = np.zeros((HEIGHT, WIDTH), dtype=np.float32)
        full_mask[...] = np.nan

        a = 0
        M = 0
        for hs in tqdm(range(a,HEIGHT,STRIDE_SIZE)):

            for ws in range(a,WIDTH,STRIDE_SIZE):

                he = hs+PATCH_SIZE
                we = ws+PATCH_SIZE
                patch = full_img[hs:he,ws:we,:]
                patch_mask = full_predic_mask[hs:he,ws:we]

                shapes = rasterio.features.shapes(patch_mask)
                # read the shapes as separate lists
                geometry = []
                for shapedict, value in shapes:
                    if value == 0:
                        continue
                    geometry.append(shape(shapedict))

                # build the gdf object over the two lists
                patch_gdf = gpd.GeoDataFrame({'geometry': geometry})
                if len(patch_gdf) == 0:
                    full_mask[hs:he,ws:we] = 0
                else:
                    with torch.no_grad():
                        y_pred = main(patch,patch_gdf)


                    y_pred = y_pred.detach().cpu().long().numpy()[:,0,:,:].astype(np.int16)

                    n_patch,_,_ = y_pred.shape
                    b_ids = np.arange(n_patch) + 1
                    b_ids = b_ids[:,np.newaxis,np.newaxis]

                    y_pred_mask = (y_pred.copy().sum(axis=0) > 0).astype(np.int16)
                    y_pred *= b_ids
                    y_pred = y_pred.max(axis=0) + M*y_pred_mask
                    M = y_pred.max()
                    full_mask[hs:he,ws:we] = y_pred

                    torch.cuda.empty_cache()
                    gc.collect()
                    torch.cuda.empty_cache()
                    gc.collect()


        # results = (
        #         {
        #     'properties': {'id': v}, 'geometry': s}
        #         for i, (s, v)
        #         in enumerate(
        #             shapes(full_mask.astype(np.uint16), mask=None, transform=raster_file.transform)) if v!=0)
        # #gdf = gpd.GeoDataFrame.from_features(list(results))
        # #gdf = gdf.dissolve(by='id')
        # gdf = gpd.GeoDataFrame.from_features(list(results),crs=raster_file.crs)
        # gdf = gdf.dissolve(by='id')
        # gdf.to_file(f'{out_dir}/{filename}')
"""
        ##post_process
        thresh = 0.5
        h,w = full_mask.shape[:2]
        t,b,l,r = rrp_info['pads']
        orig_size = rrp_info['orig_size']

        full_mask = full_mask[t:h-b,l:w-r,:]
        full_mask = cv2.resize(full_mask,(WIDTH_orig, HEIGHT_orig))    #used if downsampling is needed
        print(full_mask.shape)
        instances = post_process(full_mask,thresh = thresh,thresh_b = 0.5,mina=100,mina_b=50)

        #to shapefile
        results = (
                {
            'properties': {'id': v}, 'geometry': s}
                for i, (s, v)
                in enumerate(
                    shapes(instances.astype(np.uint16), mask=None, transform=raster_file.transform)) if v!=0)
        #gdf = gpd.GeoDataFrame.from_features(list(results))
        #gdf = gdf.dissolve(by='id')
        gdf = gpd.GeoDataFrame.from_features(list(results),crs=raster_file.crs)
        gdf = gdf.dissolve(by='id')
        gdf.to_file(f'{out_dir}/{filename}')
"""
                #plt.imshow(y_pred)
                #break
                #patch_mask = y_pred[0,...].cpu().numpy().transpose(1,2,0)
                #full_mask_patch = full_mask[hs:he,ws:we].copy()
                # print(patch_mask.shape)
                # print(full_mask.shape)
                # print(full_mask_patch.shape)

                #full_mask_patch = np.stack([full_mask_patch,patch_mask])
                #full_mask[hs:he,ws:we] = np.nanmean(full_mask_patch,axis=0)
            #break
        #plt.show()

In [None]:
h,w = full_mask.shape[:2]
t,b,l,r = rrp_info['pads']
orig_size = rrp_info['orig_size']

full_mask = full_mask[t:h-b,l:w-r]

In [None]:
shapes = rasterio.features.shapes(full_mask, mask=None, transform=raster_file.transform)
# read the shapes as separate lists
geometry = []
for shapedict, value in shapes:
    if value == 0:
        continue
    geometry.append(shape(shapedict))

# build the gdf object over the two lists
full_gdf = gpd.GeoDataFrame({'geometry': geometry}, crs=raster_file.crs)
#gdf = gpd.GeoDataFrame.from_features(list(results),

#gdf = gdf.dissolve(by='id')
full_gdf.to_file(f'{out_dir}/{filename}')

In [None]:
5/0