## Tutorial para usar SAMLoRA (modelo do SAM com LoRA)
- No Minerva-Dev, mude para a branch "141-feature-request-add-sam-segment-anything-model-to-minerva"
- Execute o código abaixo que é pra funcionar :v

### Modificações nesse notebook
- Os caminhos (dataset e weights) tirei o Path, pois uso a pathlib direto nas classes onde precisa
- Não uso transforms, então comentei
- Não uso ParihakaModule, então comentei
- Adicionei meu próprio Dataset e Module
- No evaluate model, mudei para pegar corretamente o retorno das masks que o SAM retorna e apliquei o softmax (como sempre fiz)
    - OBS: nessa célula comentei a parte que gera prints e plots por 2 motivos: primeiro que eu faço patches, então todas as imagens que aparecem plotadas são de patches e não de seções do volume. Segundo, como os patches são 255x255, há muitas amostras e o processo é demorado, e se eu botar pra plotar tudo trava o vscode.

In [1]:
import os
from pathlib import Path
from typing import Optional
import numpy as np
import lightning as L
import torch
import matplotlib.pyplot as plt

from torch.utils.data import DataLoader

from minerva.models.nets.image.segment_anything.sam_lora import SAMLoRA
from minerva.data.datasets.supervised_dataset import SimpleDataset
from minerva.data.readers.png_reader import PNGReader
from minerva.data.readers.tiff_reader import TiffReader
from minerva.transforms.transform import _Transform, TransformPipeline
from minerva.data.readers.reader import _Reader

from torchmetrics import JaccardIndex

import tqdm

from typing import List, Optional, Tuple

import gc



## Variables

In [2]:
model_name = "SAM-ViT_B"
dataset_name = "seam_ai"
image_dir = "/workspaces/Minerva-Discovery/shared_data/seam_ai_datasets/seam_ai/images"
annotations_dir = "/workspaces/Minerva-Discovery/shared_data/seam_ai_datasets/seam_ai/annotations"
image_height = 255
image_width = 255
image_channels = 1
batch_size = 1      # Plase set this to 1
num_classes = 6
predict_on_partition = "test"
ckpt_path = "/workspaces/Minerva-Discovery/my_experiments/sam_v1/checkpoints/final_train-raio-1.0-2024-11-27-epoch=11-val_loss=0.41.ckpt"
device = "cuda" if torch.cuda.is_available() else "cpu"
print_crosslines_every = 50
output_dir = Path(f"./logs/{dataset_name}/{model_name}/")
output_dir.mkdir(parents=True, exist_ok=True)
print(f"Output directory: {output_dir}")

Output directory: logs/seam_ai/SAM-ViT_B


## Helper functions

In [3]:
def plot_images(
    images,
    plot_title=None,
    subplot_titles=None,
    cmaps=None,
    filename=None,
    x_label=None,
    y_label=None,
    height=5,
    width=5,
    show=False
):
    num_images = len(images)

    # Create a figure with subplots (1 row, num_images columns), adjusting size based on height and width parameters
    fig, axs = plt.subplots(1, num_images, figsize=(width * num_images, height))

    # Set overall plot title if provided
    if plot_title is not None:
        fig.suptitle(plot_title, fontsize=16)

    # Ensure subplot_titles and cmaps are lists with correct lengths
    if subplot_titles is None:
        subplot_titles = [None] * num_images
    if cmaps is None:
        cmaps = ["gray"] * num_images

    # Plot each image in its respective subplot
    for i, (img, ax, title, cmap) in enumerate(
        zip(images, axs, subplot_titles, cmaps)
    ):
        im = ax.imshow(img, cmap=cmap)

        # Set title for each subplot if provided
        if title is not None:
            ax.set_title(title)

        # Add a colorbar for each subplot
        fig.colorbar(im, ax=ax)

        # Set x and y labels if provided
        if x_label:
            ax.set_xlabel(x_label)
        if y_label:
            ax.set_ylabel(y_label)

    # Adjust layout to fit titles, labels, and colorbars
    plt.tight_layout()

    # Save the figure if filename is provided
    if filename is not None:
        plt.savefig(filename, bbox_inches="tight")
        print(f"Figure saved as '{filename}'")

    # Show the plot
    if show:
        plt.show()
    else:
        plt.close()

In [4]:
import hashlib
from pathlib import Path

def hash_file(filepath):
    """Generate a hash for a file."""
    hasher = hashlib.sha256()
    with filepath.open('rb') as file:
        while chunk := file.read(8192):  # Read in 8 KB chunks
            hasher.update(chunk)
    return hasher.hexdigest()

def hash_folder(folder_path):
    """Generate a hash for a folder by hashing its files and structure."""
    hasher = hashlib.sha256()
    folder = Path(folder_path)
    
    files = list(sorted(folder.rglob('*')))  # Get all files and directories
    
    for path in tqdm.tqdm(files, desc="Hashing files..."):  # Recursively iterate over all files and directories
        if path.is_file():
            hasher.update(hash_file(path).encode('utf-8'))  # Hash file content
        hasher.update(str(path.relative_to(folder)).encode('utf-8'))  # Hash relative path for structure
    
    return hasher.hexdigest()

In [5]:
image_folder_hash = hash_folder(image_dir)
annotations_folder_hash = hash_folder(annotations_dir)
print(f"Image folder hash: {image_folder_hash[:8]} and Annotations folder hash: {annotations_folder_hash[:8]}")

Hashing files...: 100%|██████████| 1375/1375 [00:06<00:00, 223.38it/s]
Hashing files...: 100%|██████████| 1375/1375 [00:00<00:00, 1675.09it/s]

Image folder hash: 9799486b and Annotations folder hash: 2566b002





## Transforms (NÃO USADO)

In [6]:
# class PadCrop(_Transform):
#     """Transforms image and pads or crops it to the target size.
#     If the axis is larger than the target size, it will crop the image.
#     If the axis is smaller than the target size, it will pad the image.
#     """

#     def __init__(
#         self,
#         target_h_size: int,
#         target_w_size: int,
#         padding_mode: str = "reflect",
#         seed: int | None = None,
#         constant_values: int = 0,
#     ):
#         """
#         Initializes the transformation with target sizes, padding mode, and RNG seed.

#         Parameters:
#         - target_h_size (int): The target height size.
#         - target_w_size (int): The target width size.
#         - padding_mode (str): The padding mode to use (default is "reflect").
#         - seed (int): Seed for random number generator to make cropping reproducible.
#         """
#         self.target_h_size = target_h_size
#         self.target_w_size = target_w_size
#         self.padding_mode = padding_mode
#         self.rng = np.random.default_rng(
#             seed
#         )  # Random number generator with the provided seed
#         self.constant_values = constant_values

#     def __call__(self, x: np.ndarray) -> np.ndarray:
#         h, w = x.shape[:2]
#         # print(f"-> [{self.__class__.__name__}] x.shape={x.shape}")

#         # Handle height dimension independently: pad if target_h_size > h, else crop
#         if self.target_h_size > h:
#             pad_h = self.target_h_size - h
#             pad_top = pad_h // 2
#             pad_bottom = pad_h - pad_top
#             pad_args = {
#                 "array": x,
#                 "pad_width": (
#                     ((pad_top, pad_bottom), (0, 0), (0, 0))
#                     if len(x.shape) == 3
#                     else ((pad_top, pad_bottom), (0, 0))
#                 ),
#                 "mode": self.padding_mode,
#             }
#             if self.padding_mode == "constant":
#                 pad_args["constant_values"] = self.constant_values

#             x = np.pad(**pad_args)

#         elif self.target_h_size < h:
#             crop_h_start = self.rng.integers(0, h - self.target_h_size + 1)
#             x = x[crop_h_start : crop_h_start + self.target_h_size, ...]

#         # Handle width dimension independently: pad if target_w_size > w, else crop
#         if self.target_w_size > w:
#             pad_w = self.target_w_size - w
#             pad_left = pad_w // 2
#             pad_right = pad_w - pad_left

#             pad_args = {
#                 "array": x,
#                 "pad_width": (
#                     ((0, 0), (pad_left, pad_right), (0, 0))
#                     if len(x.shape) == 3
#                     else ((0, 0), (pad_left, pad_right))
#                 ),
#                 "mode": self.padding_mode,
#             }

#             if self.padding_mode == "constant":
#                 pad_args["constant_values"] = self.constant_values

#             x = np.pad(**pad_args)

#         elif self.target_w_size < w:
#             crop_w_start = self.rng.integers(0, w - self.target_w_size + 1)
#             x = x[:, crop_w_start : crop_w_start + self.target_w_size, ...]

#         # Ensure channel dimension consistency
#         if len(x.shape) == 2:  # For grayscale, add a channel dimension
#             x = np.expand_dims(x, axis=2)

#         # Convert to torch tensor with format C x H x W
#         # output = torch.from_numpy(x).float()
#         x = np.transpose(x, (2, 0, 1))  # Convert to C x H x W format
#         # print(f"[{self.__class__.__name__}] x.shape={x.shape}")
#         # print(f"<- [{self.__class__.__name__}] x.shape={x.shape}")
#         return x

#     def __str__(self) -> str:
#         return f"{self.__class__.__name__}(target_h_size={self.target_h_size}, target_w_size={self.target_w_size})"

#     def __repr__(self) -> str:
#         return str(self)


# class SelectChannel(_Transform):
#     """Perform a channel selection on the input image."""

#     def __init__(self, channel: int, expand_channels: int = None):
#         """
#         Initializes the transformation with the channel to select.

#         Parameters:
#         - channel (int): The channel to select.
#         """
#         self.channel = channel
#         self.expand_channels = expand_channels

#     def __call__(self, x: np.ndarray) -> np.ndarray:
#         x = x[self.channel, ...]
#         if self.expand_channels is not None:
#             x = np.expand_dims(x, axis=self.expand_channels)
#         # print(f"[{self.__class__.__name__}] x.shape={x.shape}")
#         return x

#     def __str__(self) -> str:
#         return f"{self.__class__.__name__}(channel={self.channel})"

#     def __repr__(self) -> str:
#         return str(self)


# class CastTo(_Transform):
#     def __init__(self, dtype: type):
#         """
#         Initializes the transformation with the target data type.

#         Parameters:
#         - dtype (type): The target data type.
#         """
#         self.dtype = dtype

#     def __call__(self, x: np.ndarray) -> np.ndarray:
#         # print(f"[{self.__class__.__name__}] x.shape={x.shape}")
#         return x.astype(self.dtype)

#     def __str__(self) -> str:
#         return f"{self.__class__.__name__}(dtype={self.dtype})"

#     def __repr__(self) -> str:
#         return str(self)


# class SwapAxes(_Transform):
#     def __init__(self, source_axis: int, target_axis: int):
#         """
#         Initializes the transformation with the source and target axes.

#         Parameters:
#         - source_axis (int): The source axis to swap.
#         - target_axis (int): The target axis to swap.
#         """
#         self.source_axis = source_axis
#         self.target_axis = target_axis

#     def __call__(self, x: np.ndarray) -> np.ndarray:
#         x = np.swapaxes(x, self.source_axis, self.target_axis)
#         # print(f"[{self.__class__.__name__}] x.shape={x.shape}")
#         return x

#     def __str__(self) -> str:
#         return f"{self.__class__.__name__}(source_axis={self.source_axis}, target_axis={self.target_axis})"

#     def __repr__(self) -> str:
#         return str(self)


# class RepeatChannel(_Transform):
#     def __init__(self, repeats: int, axis: int):
#         """
#         Initializes the transformation with the number of repeats.

#         Parameters:
#         - repeats (int): The number of repeats.
#         - axis (int): The axis to repeat.
#         """
#         self.repeats = repeats
#         self.axis = axis

#     def __call__(self, x: np.ndarray) -> np.ndarray:
#         x = np.repeat(x, self.repeats, axis=self.axis)
#         # print(f"[{self.__class__.__name__}] x.shape={x.shape}")
#         return x

#     def __str__(self) -> str:
#         return f"{self.__class__.__name__}(repeats={self.repeats}, axis={self.axis})"

#     def __repr__(self) -> str:
#         return str(self)


# class ExpandDims(_Transform):
#     def __init__(self, axis: int):
#         """
#         Initializes the transformation with the axis to expand.

#         Parameters:
#         - axis (int): The axis to expand.
#         """
#         self.axis = axis

#     def __call__(self, x: np.ndarray) -> np.ndarray:
#         x = np.expand_dims(x, axis=self.axis)
#         # print(f"[{self.__class__.__name__}] x.shape={x.shape}")
#         return x

#     def __str__(self) -> str:
#         return f"{self.__class__.__name__}(axis={self.axis})"

#     def __repr__(self) -> str:
#         return str(self)

## Custom Dataset (Filipe)

In [7]:
class SupervisedDatasetPatches(SimpleDataset):
    def __init__(self, readers: List[_Reader], transforms: Optional[_Transform] = None, patch_size: int = 255, stride: int = 32):
        """Adds support for splitting images into patches.

        Parameters
        ----------
        readers: List[_Reader]
            List of data readers. It must contain exactly 2 readers.
            The first reader for the input data and the second reader for the
            target data.
        transforms: Optional[_Transform]
            Optional data transformation pipeline.
        patch_size: int
            Size of the patches into which the images will be divided.
        stride: int
            Stride used to extract patches from images.
        Raises
        -------
            AssertionError: If the number of readers is not exactly 2.
        """
        super().__init__(readers, transforms)
        self.patch_size = patch_size
        self.stride = stride
        self._patch_indices = []
        self._precompute_patch_indices()

        assert (
            len(self.readers) == 2
        ), "SupervisedReconstructionDataset requires exactly 2 readers"
    
    def _precompute_patch_indices(self):
        """Precomputes patch indices for all images."""
        for img_idx in range(len(self.readers[0])):
            # Obtem a dimensão da imagem para calcular o número de patches
            image = self.readers[0][img_idx]
            h, w = image.shape[:2]
            num_patches_h = (h - self.patch_size) // self.stride + 1
            num_patches_w = (w - self.patch_size) // self.stride + 1
            for patch_idx in range(num_patches_h * num_patches_w):
                self._patch_indices.append((img_idx, patch_idx))
    
    def _extract_single_patch(self, data, patch_idx, patch_size=255, stride=32, img_type='image'):
        if img_type == 'image': # caso seja imagens de entrada (h, w, c)
            h, w, _ = data.shape
        else: # caso seja labels de entrada (h, w)
            h, w = data.shape
        num_patches_w = (w - patch_size) // stride + 1
        row = patch_idx // num_patches_w # numero da linha do patch
        col = patch_idx % num_patches_w # numero da coluna do patch
        i, j = row * stride, col * stride # coordenada do patch no grid
        patch = data[i:i + patch_size, j:j + patch_size]
        if img_type == 'image':
            return patch.transpose(2, 0, 1).astype(np.float32) # (C H W)
        else:
            return patch.astype(np.int64)
    
    def __len__(self):
        """Returns the total number of patches."""
        return len(self._patch_indices)
    
    def __getitem__(self, index: int) -> Tuple[np.ndarray, np.ndarray]:
        """Load data and return a single patch."""
        img_idx, patch_idx = self._patch_indices[index]
        input_data = self.readers[0][img_idx]
        target_data = self.readers[1][img_idx]

        input_patch = self._extract_single_patch(input_data, patch_idx, img_type='image')
        target_patch = self._extract_single_patch(target_data, patch_idx, img_type='label')
        return input_patch, target_patch

## Parihaka DataModule (Filipe)

In [8]:
class PatchingModule(L.LightningDataModule):
    def __init__(
        self,
        train_path: str,
        annotations_path: str,
        patch_size: int = 255,
        stride: int = 32,
        batch_size: int = 8,
        transforms: _Transform = None,
        num_workers: int = None,
    ):
        super().__init__()
        self.train_path = Path(train_path)
        self.annotations_path = Path(annotations_path)
        self.transforms = transforms
        self.batch_size = batch_size
        self.patch_size = patch_size
        self.stride = stride
        self.num_workers = num_workers if num_workers else os.cpu_count()

        self.datasets = {}

    # função útil
    def normalize_data(self, data, target_min=-1, target_max=1):
        """Function responsible for normalizing images in the range (-1,1)

        Parameters
        ----------
        data : np.ndarray
            Sample (image), with 3 channels
        target_min : int
            Min value of target to normalize data.
        target_max : int
            Max value of target to normalize data.

        Returns
        -------
        np.ndarray
            Sample (image) normalized.
        """
        data_min, data_max = data.min(), data.max()
        return target_min + (data - data_min) * (target_max - target_min) / (data_max - data_min)
    
    def setup(self, stage=None):
        if stage == "fit":
            train_img_reader = [self.normalize_data(image) for image in TiffReader(self.train_path / "train")] # lendo imagens e normalizando
            train_label_reader = PNGReader(self.annotations_path / "train")
            
            # Criar dataset para treinamento
            self.datasets['train'] = SupervisedDatasetPatches(
                readers=[train_img_reader, train_label_reader],
                transforms=self.transforms,
                patch_size=self.patch_size,
                stride=self.stride
            )
            del train_img_reader, train_label_reader
            gc.collect()

            val_img_reader = [self.normalize_data(image) for image in TiffReader(self.train_path / "val")]
            val_label_reader = PNGReader(self.annotations_path / "val")

            self.datasets["val"] = SupervisedDatasetPatches(
                readers=[val_img_reader, val_label_reader],
                transforms=self.transforms,
                patch_size=self.patch_size,
                stride=self.stride
            )
            del val_img_reader, val_label_reader
            gc.collect()
        
        elif stage == "test" or stage == "predict":
            test_img_reader = [self.normalize_data(image) for image in TiffReader(self.train_path / "test")]
            test_label_reader = PNGReader(self.annotations_path / "test")

            test_dataset = SupervisedDatasetPatches(
                readers=[test_img_reader, test_label_reader],
                transforms=self.transforms,
                patch_size=self.patch_size,
                stride=self.stride
            )
            del test_img_reader, test_label_reader
            gc.collect()

            self.datasets["test"] = test_dataset
            self.datasets["predict"] = test_dataset

        else:
            raise ValueError(f"Invalid stage: {stage}")

    def train_dataloader(self):
        return DataLoader(
            self.datasets["train"],
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=True,
            pin_memory=True, 
            drop_last=True
        )

    def val_dataloader(self):
        return DataLoader(
            self.datasets["val"],
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False,
            pin_memory=True, 
            drop_last=False
        )

    def test_dataloader(self):
        return DataLoader(
            self.datasets["test"],
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False,
            pin_memory=True, 
            drop_last=False
        )

    def predict_dataloader(self):
        return DataLoader(
            self.datasets["predict"],
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False,
            pin_memory=True, 
            drop_last=False
        )
    
    def __str__(self) -> str:
        return f"""DataModule
        Data: {self.train_path}
        Annotations: {self.annotations_path}
        Batch size: {self.batch_size}"""
    
    def __repr__(self) -> str:
        return str(self)

## Parihaka DataModule Definition (NÃO USADO)

In [9]:
# class GenericParihakaDataModule(L.LightningDataModule):
#     class Identity(_Transform):
#         def __call__(self, x: np.ndarray) -> np.ndarray:
#             return x
    
#     def __init__(
#         self,
#         root_data_dir: str,
#         root_annotation_dir: str,
#         image_transforms: TransformPipeline,
#         label_transforms: TransformPipeline,
#         batch_size: int = 1,
#         num_workers: Optional[int] = None,
#         predict_on: str = "test",
#     ):
#         assert predict_on in ["test", "val", "train"]
        
#         super().__init__()
#         self.root_data_dir = Path(root_data_dir)
#         self.root_annotation_dir = Path(root_annotation_dir)
#         self.image_transforms = image_transforms or self.Identity()
#         self.label_transforms = label_transforms or self.Identity()
#         self.batch_size = batch_size
#         self.num_workers = (
#             num_workers if num_workers is not None else os.cpu_count()
#         )
#         self.predict_on = predict_on
#         self.datasets = {}

#     def _create_dataset(self, partition: str):
#         img_reader = TiffReader(str(self.root_data_dir / partition))
#         label_reader = PNGReader(str(self.root_annotation_dir / partition))
#         return SimpleDataset(
#             readers=[img_reader, label_reader],
#             transforms=[self.image_transforms, self.label_transforms],
#         )
        
#     def _get_dataloader(self, partition: str, shuffle: bool):
#         return DataLoader(
#             self.datasets[partition],
#             batch_size=self.batch_size,
#             num_workers=self.num_workers, # type: ignore
#             shuffle=shuffle,
#         )

#     def setup(self, stage=None):        
#         if stage == "fit":
#             self.datasets["train"] = self._create_dataset("train")
#             self.datasets["val"] = self._create_dataset("val")
#         elif stage == "test":
#             self.datasets["test"] = self._create_dataset("test")
#         elif stage == "predict":
#             self.datasets["predict"] = self._create_dataset(self.predict_on)
#         else:
#             raise ValueError(f"Invalid stage: {stage}")

#     def train_dataloader(self):
#         return self._get_dataloader("train", shuffle=True)

#     def val_dataloader(self):
#         return self._get_dataloader("val", shuffle=False)

#     def test_dataloader(self):
#         return self._get_dataloader("test", shuffle=False)

#     def predict_dataloader(self):
#         return self._get_dataloader("predict", shuffle=False)

#     def __str__(self) -> str:
#         return f"""DataModule
#     Data: {self.root_data_dir}
#     Annotations: {self.root_annotation_dir}
#     Batch size: {self.batch_size}"""

#     def __repr__(self) -> str:
#         return str(self)
    

# Helper functions (if needed)
def get_train_dataloader(data_module):
    data_module.setup("fit")
    return data_module.train_dataloader()

def get_val_dataloader(data_module):
    data_module.setup("fit")
    return data_module.val_dataloader()

def get_test_dataloader(data_module):
    data_module.setup("test")
    return data_module.test_dataloader()

def get_predict_dataloader(data_module):
    data_module.setup("predict")
    return data_module.predict_dataloader()

## Transforms instantiation (NÃO USADO)

In [10]:
# image_transforms = []
# image_transforms.append(SwapAxes(0, -1))
# image_transforms.append(SelectChannel(0))
# image_transforms.append(SwapAxes(0, 1))
# image_transforms.append(PadCrop(image_height, image_width, padding_mode="reflect", seed=42))
# if image_channels > 1:
#     image_transforms.append(RepeatChannel(image_channels, axis=0))
# image_transforms.append(CastTo(np.float32))


# label_transforms = []
# label_transforms.append(PadCrop(image_height, image_width, padding_mode="reflect", seed=42))
# label_transforms.append(CastTo(np.float32))

# print(f"Image transforms: {image_transforms}")
# print(f"Label transforms: {label_transforms}")

## Defining Data Module

In [11]:
# data_module = GenericParihakaDataModule(
#     root_data_dir=str(image_dir),
#     root_annotation_dir=str(annotations_dir),
#     image_transforms=TransformPipeline(image_transforms),
#     label_transforms=TransformPipeline(label_transforms),
#     batch_size=batch_size,
# )

data_module = PatchingModule(
    train_path=image_dir,
    annotations_path=annotations_dir,
    patch_size=image_height,
    stride=32,
    batch_size=batch_size
)

data_module

DataModule
        Data: /workspaces/Minerva-Discovery/shared_data/seam_ai_datasets/seam_ai/images
        Annotations: /workspaces/Minerva-Discovery/shared_data/seam_ai_datasets/seam_ai/annotations
        Batch size: 1

In [12]:
train_batch_x, train_batch_y = next(iter(get_train_dataloader(data_module)))
print(f"Train batch X shape: {train_batch_x.shape}")
print(f"Train batch Y shape: {train_batch_y.shape}")

Train batch X shape: torch.Size([1, 3, 255, 255])
Train batch Y shape: torch.Size([1, 255, 255])


In [13]:
print(f"O Batch (de tamanho {train_batch_x.shape[0]}) possui: {train_batch_x.shape[1]} canais, {train_batch_x.shape[2]} altura e {train_batch_x.shape[3]} largura.")

O Batch (de tamanho 1) possui: 3 canais, 255 altura e 255 largura.


# Define and load model here

In [14]:
# class DummyModel(torch.nn.Module):
#     def __init__(self, num_classes: int):
#         super().__init__()
#         self.num_classes = num_classes

#     def forward(self, x):
#         y_hat = torch.randint_like(x, 0, self.num_classes)
#         y_hat = torch.randn(x.shape[0], self.num_classes, x.shape[2], x.shape[3], device=x.device)
#         return y_hat

def get_model(ckpt_path) -> torch.nn.Module:
    """Create and load a model from a checkpoint."""
    # ckpt = torch.load(ckpt_path)
    # return DummyModel(num_classes)
    return SAMLoRA.load_from_checkpoint(
        checkpoint_path=ckpt_path,
        image_size=image_height,
        num_classes=num_classes-1, # considera 6 pois internamente o sam faz +1 pro background
        alpha=1,
        rank=4,
    )

model = get_model(ckpt_path)

Focal loss alpha=0.25, will shrink the impact in background


## Evaluate model

In [15]:
data_module.setup("test")

miou_metric = JaccardIndex(task="multiclass", num_classes=num_classes).to(
    "cuda"
)
metric_values = []

model.eval()
model = model.to("cuda")

curent_index = 0

for batch_idx, (batch_x, batch_y) in tqdm.tqdm(
    enumerate(data_module.test_dataloader()),
    desc="Testing",
    leave=True,
):
    batch_x, batch_y = batch_x.to("cuda"), batch_y.to("cuda")
    batch_y_hat = model.forward(batch_x, multimask_output=True, image_size=model.image_size)
    probs = torch.softmax(batch_y_hat['masks'], dim=1)
    batch_y_hat = torch.argmax(probs, dim=1).squeeze(1)
    # print(batch_x.shape, batch_y.shape, batch_y_hat.shape)

    for i, (x, y, y_hat) in enumerate(zip(batch_x, batch_y, batch_y_hat)):
        # print(i, x.shape, y.shape, y_hat.shape)
        curent_index += 1
        res = miou_metric(
            y_hat.unsqueeze(0), y.unsqueeze(0)
        ).item()  # re-add batch dimension (unsqueeze(0))
        metric_values.append(res)

        # if curent_index % print_crosslines_every == 0:
        #     x = x.permute(1, 2, 0).squeeze(0).cpu().numpy()
        #     y = y.squeeze(0).cpu().numpy()
        #     y_hat = y_hat.squeeze(0).cpu().numpy()
        #     diff = (y != y_hat).astype(np.int32)
            # print(f"Crossline {curent_index} MIOU: {res}")
            # print(
            #     f"x.shape={x.shape}, y.shape={y.shape}, y_hat.shape={y_hat.shape}, diff.shape={diff.shape}"
            # )

            # plot_images(
            #     images=[x, y, y_hat, diff],
            #     subplot_titles=[
            #         "Input",
            #         "Ground Truth",
            #         "Prediction",
            #         "Difference",
            #     ],
            #     cmaps=["seismic", "Accent", "Accent", "gray"],
            #     plot_title=f"{model_name} Parihaka Segmentation (crossline {curent_index}). MIOU={res:.3f}",
            #     filename=f"{output_dir}/segmentation_{curent_index}.png",
            #     show=True,
            # )

Testing: 52800it [14:44, 59.72it/s]


In [16]:
mean_iou = np.mean(metric_values)
print(f"Mean IoU: {mean_iou:.4f}")
with open(f"{output_dir}/mean_iou.txt", "w") as f:
    f.write(f"{mean_iou:.4f}")

Mean IoU: 0.5579
