## Imports

In [1]:
import os
from pathlib import Path
import numpy as np
from typing import List, Optional, Tuple
from copy import deepcopy
import matplotlib.pyplot as plt
import pandas as pd

import cv2

import torch
from torch.nn import functional as F
from torchmetrics import JaccardIndex
from torchvision.transforms.functional import resize, to_pil_image
import pytorch_lightning as pl

from minerva.models.finetune_adapters import LoRA
from minerva.models.nets.image.sam import Sam
from minerva.data.datasets.supervised_dataset import SimpleDataset
from minerva.data.readers.png_reader import PNGReader
from minerva.data.readers.tiff_reader import TiffReader
from minerva.transforms.transform import _Transform
from minerva.data.readers.reader import _Reader
from minerva.pipelines.lightning_pipeline import SimpleLightningPipeline

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
print("PyTorch Version:", torch.__version__)
print("CUDA Version:", torch.version.cuda)
print("CUDA Available:", torch.cuda.is_available())
print("CUDA Device Count:", torch.cuda.device_count())
print("CUDA Device Name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "No CUDA Device")

PyTorch Version: 2.5.1+cu124
CUDA Version: 12.4
CUDA Available: True
CUDA Device Count: 1
CUDA Device Name: NVIDIA GeForce RTX 4090


## Variables

In [3]:
# f3
train_path = "/workspaces/Minerva-Discovery/shared_data/seismic/f3_segmentation/images"
annotation_path = "/workspaces/Minerva-Discovery/shared_data/seismic/f3_segmentation/annotations"

# parihaka
# train_path = "/workspaces/Minerva-Discovery/shared_data/seam_ai_datasets/seam_ai/images"
# annotation_path = "/workspaces/Minerva-Discovery/shared_data/seam_ai_datasets/seam_ai/annotations"

# checkpoints SAM
model_name_experiment = "SAM-ViT_B-original-weights_f3_"
checkpoint_path = "/workspaces/Minerva-Discovery/shared_data/weights_sam/checkpoints_sam/sam_vit_b_01ec64.pth" # vit_b
vit_model = 'vit-b'
# checkpoint_path = "/workspaces/Minerva-Discovery/shared_data/weights_sam/checkpoints_sam/sam_vit_h_4b8939.pth" # vit_h
# vit_model = 'vit-h'

USE_ORIGINAL_WEIGHTS = False
multimask_output=False # if True, return num_classes, else, return the class with mostly iou
num_classes = 3 # num of classes for original Sam
num_points = 10 # num of prompt points

# fine_tuning & adapter
if not USE_ORIGINAL_WEIGHTS:
    num_classes = 6 # num of classes for Sam fine tuned
    vit_model = 'vit-b'
    checkpoint = "/workspaces/Minerva-Discovery/my_experiments/sam_original/notebooks/checkpoints_f3/sam-fine_tuning_&_adapter_1.0-2024-12-17-epoch=98-val_loss=0.01.ckpt"
    # checkpoint = "/workspaces/Minerva-Discovery/my_experiments/sam_original/notebooks/checkpoints_parihaka/sam-fine_tuning_&_adapter_1.0-2024-12-18-epoch=10-val_loss=0.13.ckpt"
    model_name_experiment = f"SAM-ViT_B_fine_tuning_&_adapter_f3_"
    apply_freeze={"prompt_encoder": True, "image_encoder": False, "mask_decoder": True}
    apply_adapter={"mask_decoder": LoRA}

## Transforms

In [4]:
class ResizeLongestSide:
    """
    Resizes images to the longest side 'target_length', as well as provides
    methods for resizing coordinates and boxes. Provides methods for
    transforming both numpy array and batched torch tensors.
    """

    def __init__(self, target_length: int) -> None:
        self.target_length = target_length

    def apply_image(self, image: np.ndarray) -> np.ndarray:
        """
        Expects a numpy array with shape HxWxC in uint8 format.
        """
        target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
        return np.array(resize(to_pil_image(image), target_size))

    def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
        """
        Expects a numpy array of length 2 in the final dimension. Requires the
        original image size in (H, W) format.
        """
        old_h, old_w = original_size
        new_h, new_w = self.get_preprocess_shape(
            original_size[0], original_size[1], self.target_length
        )
        coords = deepcopy(coords).astype(float)
        coords[..., 0] = coords[..., 0] * (new_w / old_w)
        coords[..., 1] = coords[..., 1] * (new_h / old_h)
        return coords

    def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
        """
        Expects a numpy array shape Bx4. Requires the original image size
        in (H, W) format.
        """
        boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
        return boxes.reshape(-1, 4)

    def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
        """
        Expects batched images with shape BxCxHxW and float format. This
        transformation may not exactly match apply_image. apply_image is
        the transformation expected by the model.
        """
        # Expects an image in BCHW format. May not exactly match apply_image.
        target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length)
        return F.interpolate(
            image, target_size, mode="bilinear", align_corners=False, antialias=True
        )

    def apply_coords_torch(
        self, coords: torch.Tensor, original_size: Tuple[int, ...]
    ) -> torch.Tensor:
        """
        Expects a torch tensor with length 2 in the last dimension. Requires the
        original image size in (H, W) format.
        """
        old_h, old_w = original_size
        new_h, new_w = self.get_preprocess_shape(
            original_size[0], original_size[1], self.target_length
        )
        coords = deepcopy(coords).to(torch.float)
        coords[..., 0] = coords[..., 0] * (new_w / old_w)
        coords[..., 1] = coords[..., 1] * (new_h / old_h)
        return coords

    def apply_boxes_torch(
        self, boxes: torch.Tensor, original_size: Tuple[int, ...]
    ) -> torch.Tensor:
        """
        Expects a torch tensor with shape Bx4. Requires the original image
        size in (H, W) format.
        """
        boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
        return boxes.reshape(-1, 4)

    @staticmethod
    def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
        """
        Compute the output size given input size and target long side length.
        """
        scale = long_side_length * 1.0 / max(oldh, oldw)
        newh, neww = oldh * scale, oldw * scale
        neww = int(neww + 0.5)
        newh = int(newh + 0.5)
        return (newh, neww)

## Predict

In [5]:
if not USE_ORIGINAL_WEIGHTS:
    print("usando Sam ajustado")
    # modelo ajustado
    model = Sam.load_from_checkpoint(
        checkpoint_path=checkpoint,
        vit_type=vit_model,
        num_multimask_outputs=num_classes,
        iou_head_depth=num_classes,
        apply_freeze=apply_freeze,
        apply_adapter=apply_adapter,
        train_metrics={"mIoU": JaccardIndex(task="multiclass", num_classes=num_classes)},
        val_metrics={"mIoU": JaccardIndex(task="multiclass", num_classes=num_classes)},
        test_metrics={"mIoU": JaccardIndex(task="multiclass", num_classes=num_classes)}
    )
else:
    print("usando Sam ORIGINAL")
    # modelo com pesos originais
    if torch.cuda.is_available():
        device = "cuda"
    elif torch.backends.mps.is_available():
        device = "mps"
    else:
        device = "cpu"
    print("Using device: ", device)
    model = Sam(
        vit_type=vit_model,
        checkpoint=checkpoint_path,
        num_multimask_outputs=num_classes,
        iou_head_depth=num_classes
    ).to(device)

model.eval()

train_img_reader = TiffReader(Path(train_path) / 'test') # Configura os dados para inferência
train_label_reader = PNGReader(Path(annotation_path) / 'test') # Configura os dados para inferência

miou_metric = JaccardIndex(task="multiclass", num_classes=num_classes) # Inicializando a métrica de mIoU

usando Sam ajustado
Prompt Encoder freeze!
Mask Decoder freeze!
LoRA applied in Mask Decoder!


In [6]:
def plot_all(image, label, pred, diff, score, point_coords, point_labels):
    """
    Plota as imagens lado a lado: imagem original, label, predição, diff.
    Pontos acumulados são exibidos sobre as imagens.
    """
    num_subplots = 4  # Número de subplots: imagem original, label, pred, diff
    plt.clf()
    fig, axes = plt.subplots(1, num_subplots, figsize=(5 * num_subplots, 5))

    # Plot 1: Imagem original
    axes[0].imshow(image, cmap='gray')
    axes[0].set_title("Original Image")
    axes[0].axis('off')

    # Plot 2: Label
    axes[1].imshow(label, cmap='gray')
    axes[1].set_title("Label")
    axes[1].axis('off')

    # Plot 3: Predição acumulada
    axes[2].imshow(pred, cmap='gray')
    axes[2].set_title(f"Pred - Score: {score}")
    axes[2].axis('off')

    # Plot 4: Diferença entre label e pred
    axes[3].imshow(diff, cmap='gray')
    axes[3].set_title("Difference (Label - Pred)")
    axes[3].axis('off')

    # Adiciona os pontos em todas as imagens
    for ax in axes:
        for (x, y), label in zip(point_coords, point_labels):
            color = 'green' if label == 1 else 'red'
            ax.scatter(x, y, color=color, s=50, edgecolors='white')

    plt.tight_layout()
    plt.show()

In [7]:
def set_points():
    # Inicialize os acumuladores como arrays vazios
    accumulated_coords = np.empty((0, 2), dtype=int)  # Nx2 array
    accumulated_labels = np.empty((0,), dtype=int)   # Array de comprimento N
    return accumulated_coords, accumulated_labels

def calculate_center_region(
        accumulated_coords, 
        accumulated_labels, 
        original_size, 
        region: np.array, 
        point_type: str, 
        min_distance: int = 10):
    """
    Calcula o centroide da maior região de pixels brancos de uma imagem binária,
    deslocando horizontalmente o ponto se ele estiver próximo demais dos acumulados.

    Args:
        region (np.array): Imagem binária com a região de interesse (pixels brancos).
        point_type (str): Tipo do ponto ('positive' ou 'negative').
        min_distance (int): Distância mínima permitida entre pontos.

    Returns:
        point_coords (np.ndarray): Array Nx2 de pontos acumulados.
        point_labels (np.ndarray): Array N de rótulos acumulados.
    """
    if not isinstance(region, np.ndarray):
        raise TypeError("region needs to be a NumPy array.")
    
    # Encontrar as componentes conectadas
    num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(region, connectivity=8)

    if num_labels < 2:  # Apenas fundo e nenhuma região branca
        raise ValueError("No connected white regions found in the binary image.")
    
    # Ignorar o rótulo 0 (fundo), pegar a maior componente conectada
    largest_label = 1 + np.argmax(stats[1:, cv2.CC_STAT_AREA])
    center_x, center_y = centroids[largest_label]
    center_x, center_y = int(center_x), int(center_y)
    new_coords = np.array([[center_x, center_y]])

    # Verificar se o ponto está muito próximo dos anteriores
    if accumulated_coords.shape[0] > 0:
        distances = np.sqrt(np.sum((accumulated_coords - new_coords) ** 2, axis=1))
        if np.any(distances < min_distance):
            # print("Ponto muito próximo ao anterior, deslocando horizontalmente...")
            
            # Tentar deslocar o ponto horizontalmente dentro da região branca
            region_height, region_width = region.shape
            # Tenta deslocar horizontalmente usando o min_distance
            for delta_x in range(min_distance, region_width, min_distance):  # Incrementa em min_distance
                candidate_x_right = center_x + delta_x
                candidate_x_left = center_x - delta_x

                # Verifica primeiro para direita, depois para esquerda
                if candidate_x_right < region_width and region[center_y, candidate_x_right] > 0:
                    center_x = candidate_x_right
                    break
                elif candidate_x_left >= 0 and region[center_y, candidate_x_left] > 0:
                    center_x = candidate_x_left
                    break

            new_coords = np.array([[center_x, center_y]])

    # Definir o rótulo (positivo ou negativo)
    if point_type == 'positive':
        new_labels = np.array([1])
    elif point_type == 'negative':
        new_labels = np.array([0])
    else:
        raise ValueError("Invalid point_type. Must be 'positive' or 'negative'.")

    # Acumular os resultados
    accumulated_coords = np.vstack([accumulated_coords, new_coords])
    accumulated_labels = np.hstack([accumulated_labels, new_labels])

    # convertendo para o formato do Sam()
    transform = ResizeLongestSide(model.model.image_encoder.img_size)
    point_coords = transform.apply_coords(accumulated_coords, original_size)
    coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=model.device)
    labels_torch = torch.as_tensor(accumulated_labels, dtype=torch.int, device=model.device)
    coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]

    return accumulated_coords, accumulated_labels, coords_torch, labels_torch

def calculate_diff_label_pred(label:np.array, pred:np.array):
    """
    Calcula a diferença entre duas imagens binárias e determina se a área externa ou interna é maior.

    Args:
        label (np.array): Imagem binária de referência (label).
        pred (np.array): Imagem binária predita (pred).

    Returns:
        diff_colored (np.array): Imagem colorida representando as diferenças.
        point_type (str): 'negative' se a área externa for maior, 'positive' se a interna for maior.
    """
    if label.shape != pred.shape:
        raise ValueError("Label and Pred images have differents shapes. Check it before call calculate_dif_label_pred() function.")

    # Máscaras para regiões de diferença
    mask_outward = (label > pred)  # Diferença para fora -> Vermelho
    mask_inward = (label < pred)  # Diferença para dentro -> Azul

    area_outward = np.sum(mask_outward)
    area_inward = np.sum(mask_inward)

    diff_binary = teste1 = teste2 = np.zeros(label.shape, dtype=np.uint8) # [H,W]

    # fig, axes = plt.subplots(1, 2, figsize=(15, 5))
    # teste1[mask_outward] = 1
    # axes[0].imshow(teste1)
    # axes[0].set_title('Image 1')
    # axes[0].axis('off')
    # teste2[mask_inward] = 1
    # axes[1].imshow(teste2)
    # axes[1].set_title('Image 2')
    # axes[1].axis('off')
    # plt.tight_layout()
    # plt.show()

    # Comparar as áreas
    if area_outward > area_inward:
        diff_binary[mask_outward] = 1
        point_type = 'positive'
    else:
        diff_binary[mask_inward] = 1
        point_type = 'negative'
    
    return diff_binary, point_type

## Execute predict with prompts

In [8]:
results = pd.DataFrame(columns=['sample_id', 'facie_id', 'accumulated_point', 'iou', 'num_points'])

# for each (image, label)...
for idx, (image, label) in enumerate(zip(train_img_reader, train_label_reader)):
    num_facies = np.unique(label) # num of facies
    # print("facies: ", num_facies)

    # DEBUG
    # fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    # axes[0].imshow(image)
    # axes[1].imshow(label)
    # plt.show()
    
    # for each facie...
    for i, facie in enumerate(num_facies):
        region = np.zeros_like(label, dtype=np.uint8) # [H,W]
        region[label == facie] = 1
        real_label = region

        point_type = 'positive' # first point is positive
        accumulated_coords, accumulated_labels = set_points()

        # for each point...
        for point in range(num_points):
            # calculate center region
            accumulated_coords, accumulated_labels, _, _ = calculate_center_region(
                accumulated_coords, 
                accumulated_labels, 
                image.shape[:2],
                region=region, 
                point_type=point_type)
            
            point_coords = torch.tensor(accumulated_coords).unsqueeze(0).to(model.device)
            point_labels = torch.tensor(accumulated_labels).unsqueeze(0).to(model.device)
            # print(point_coords, point_labels)

            # convertendo para PNG (TODO SEM ISSO O SAM ORIGINAL NAO DÁ BONS RESULTADOS NO F3. NO PARIHAKA NAO TEM DIFERENCA)
            if image.dtype != np.uint8:
                tiff_image = ((image - image.min()) / (image.max() - image.min()) * 255).astype(np.uint8)
            else:
                tiff_image = image
            _, png_img = cv2.imencode('.png', tiff_image)
            decoded_image = cv2.imdecode(np.frombuffer(png_img, np.uint8), cv2.IMREAD_UNCHANGED)

            batch = [{
                'image': torch.from_numpy(decoded_image).permute(2, 0, 1).float().to(model.device),
                'label': torch.from_numpy(label).to(model.device),
                'original_size': decoded_image.shape[:2],
                'point_coords': point_coords,
                'point_labels': point_labels
            }]

            # Inferência
            outputs = model(batch, multimask_output=multimask_output)

            # if multimask_output:
            #     # stack logits 'masks_logits' and 'labels' for loss and metrics function
            #     masks_logits = torch.stack([output['masks_logits'].squeeze(0) for output in outputs])  # [batch_size, num_classes, H, W]
            #     labels = torch.stack([input['label'].squeeze(0) for input in batch])  # [batch_size, H, W]

            #     for i in range(len(batch)):
            #         pred = masks_logits[i].unsqueeze(0)
            #         label = labels[i].unsqueeze(0)
            #         pred = torch.argmax(pred, dim=1, keepdim=True).squeeze(1)
            #         # (DEBUG) visualizar as imagens
            #         pred = pred.squeeze(0).cpu().detach().numpy()  # Converte para numpy para plotar
            #         label = label.squeeze(0).cpu().detach().numpy()  # Converte para numpy para plotar
            #         fig, axs = plt.subplots(1, 2, figsize=(12, 6))  # Cria 2 colunas
            #         axs[0].imshow(pred, cmap='gray')          # Exibe a máscara
            #         axs[0].set_title("Mask Logits")           # Título para a primeira imagem
            #         axs[0].axis('off')                       # Remove os eixos da imagem
            #         axs[1].imshow(label, cmap='gray')        # Exibe os labels
            #         axs[1].set_title("Labels")               # Título para a segunda imagem
            #         axs[1].axis('off')                       # Remove os eixos da imagem
            #         for ax in axs:
            #             for (x, y), label in zip(accumulated_coords, accumulated_labels):
            #                 color = 'green' if label == 1 else 'red'
            #                 ax.scatter(x, y, color=color, s=50, edgecolors='white')

            #         plt.tight_layout()
            #         plt.show()
            
            masks_logits = torch.tensor(outputs[0]['masks'].squeeze()).to(model.device)  # Remover a dimensão extra e mover para GPU
            labels = torch.tensor(real_label).to(model.device)  # Converta para tensor 2D e mova para GPU
            
            diff, new_point_type = calculate_diff_label_pred(label=real_label, pred=masks_logits.squeeze(0).cpu().numpy())

            region = diff # [H,W], atualiza para a proxima regiao
            point_type = new_point_type # 'positive' ou 'negative', atualiza para a proxima regiao

            # Calcular o mIoU para a imagem i
            iou_score = miou_metric.to(model.device)(masks_logits, labels)
            # print("IoU: ", iou_score.item())

            # salvando progresso
            new_row = pd.DataFrame([{
                'sample_id': idx,
                'facie_id': facie,
                'accumulated_point': point + 1,
                'iou': iou_score.item(),
                'num_points': num_points
            }])

            results = pd.concat([results, new_row], ignore_index=True)
            
            # plot_all(
            #     image=image,
            #     label=real_label,
            #     pred=masks_logits.squeeze(0).cpu().numpy(),  # empilha as máscaras geradas
            #     diff=diff,
            #     score=iou_score,
            #     point_coords=accumulated_coords,
            #     point_labels=accumulated_labels
            # )
            
            # break
        # break
    # break
results.to_csv(f'{model_name_experiment}iou_results.csv', index=False)

  masks_logits = torch.tensor(outputs[0]['masks'].squeeze()).to(model.device)  # Remover a dimensão extra e mover para GPU
  results = pd.concat([results, new_row], ignore_index=True)
