In [None]:
from segment_anything import sam_model_registry, SamPredictor

sam_checkpoint = r"path\to\sam_vit_h_4b8939.pth"
model_type = "vit_h"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
device = "cuda"


def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)
    
def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   
    
def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))    


def process_image(image, image_dir, result_dir, sam, device):
    path_image = os.path.join(image_dir, image)
    bbox = np.array(yolov8_detection(path_image))
    img = cv2.imread(path_image)
    if len(bbox) == 0:
        print("rien à voir ici")
        show_image(img)
    else:
        for i, bounding in enumerate(bbox):
            sam.to(device=device)
            predictor = SamPredictor(sam)
            predictor.set_image(img)
            masks, _, _ = predictor.predict(
                point_coords=None,
                point_labels=None,
                box=bounding[None, :],
                multimask_output=False,
            )
            
            show_result_image(img, masks[0], bounding, result_dir, image, i)

def show_image(img):
    plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
    plt.axis('off')
    plt.show()

def show_result_image(img, mask, bounding, result_dir, image, i):
    mask_path = os.path.join(result_dir, f"{image}_mask_{i}.png")
    cv2.imwrite(mask_path, mask * 255)  # Assurez-vous que les valeurs sont dans [0, 255]
    
    mask_color = (0, 255, 0)  # Couleur du masque (vert ici)
    mask_alpha = 0.5  # Opacité du masque
    
    mask_image = (mask_alpha * np.array(mask_color) * mask[:, :, None]).astype(np.uint8)
    result_img = cv2.addWeighted(img, 1.0, mask_image, 1.0, 0)
    
    cv2.rectangle(result_img, (int(bounding[0]), int(bounding[1])), (int(bounding[2]), int(bounding[3])), mask_color, 2)
    
    result_path = os.path.join(result_dir, f"{image}_result_{i}.png")
    cv2.imwrite(result_path, result_img)
    plt.imshow(result_img)
    plt.axis('off')
    plt.show()

# Créer le dossier "result" s'il n'existe pas déjà
result_dir = "result"
if not os.path.exists(result_dir):
    os.makedirs(result_dir)

for image_jpg in image_files_jpg:
    print(image_jpg)
    process_image(image_jpg, images_dir, result_dir, sam, device)


image_path=r'path\to\Orthophotos\Perigueux.tif' #exemple
mask_path=r'path\to\Orthophotos\Orthophotos\result\mask_Perigueux.tif' #exemple

import rasterio
import matplotlib.pyplot as plt

import rasterio
import matplotlib.pyplot as plt

def show_geotiff_with_mask(image_path, mask_path):
    # Ouvrir l'image GeoTIFF
    with rasterio.open(image_path) as src:
        image = src.read()  # Lire toutes les bandes
        transform = src.transform  # Récupérer la transformation géospatiale

    # Ouvrir le masque TIFF
    with rasterio.open(mask_path) as src:
        mask = src.read(1)  # Lire la première bande (ou la seule bande)

    # Afficher l'image et le masque superposés avec les vraies couleurs
    plt.figure(figsize=(10, 10))
    plt.imshow(image.transpose(1, 2, 0))  # Afficher l'image RVB
    plt.imshow(mask, alpha=0.5, cmap='viridis')  # Superposer le masque en couleur

    # Ajouter une légende
    plt.title('Image GeoTIFF RVB avec masque')
    plt.axis('off')  # Masquer les axes

    plt.show()

# Exemple d'utilisation de la fonction

show_geotiff_with_mask(image_path, mask_path)

