# Fine-tunning SAM on Parihaka Dataset

## Imports

In [1]:
import os
import numpy as np
import random
from typing import List, Optional, Tuple
from common import get_data_module, get_trainer_pipeline
from functools import partial
from pathlib import Path

import torch
from torch.utils.data import DataLoader
from torchmetrics import JaccardIndex
import lightning as L

from minerva.models.nets.image.sam import Sam
from minerva.data.datasets.supervised_dataset import SimpleDataset
from minerva.data.readers.png_reader import PNGReader
from minerva.data.readers.tiff_reader import TiffReader
from minerva.transforms.transform import _Transform
from minerva.data.readers.reader import _Reader

  from .autonotebook import tqdm as notebook_tqdm


## Variaveis

In [2]:
root_data_dir = "/workspaces/Minerva-Discovery/shared_data/seam_ai_datasets/seam_ai/images"
root_annotation_dir = "/workspaces/Minerva-Discovery/shared_data/seam_ai_datasets/seam_ai/annotations"
img_size = (1008, 784)          # Change this to the size of the images in the dataset
model_name = "sam"       # Model name (just identifier)
dataset_name = "seam_ai"        # Dataset name (just identifier)
single_channel = False          # If True, the model will be trained with single channel images (instead of 3 channels)

log_dir = "./logs"              # Directory to save logs
batch_size = 1                  # Batch size    
seed = 42                       # Seed for reproducibility
num_epochs = 100                # Number of epochs to train
is_debug = False                 # If True, only 3 batch will be processed for 3 epochs
accelerator = "gpu"             # CPU or GPU
devices = 1                     # Num GPUs

## Data Module

In [3]:
class Padding(_Transform):
    def __init__(self, target_h_size: int, target_w_size: int):
        self.target_h_size = target_h_size
        self.target_w_size = target_w_size

    def __call__(self, x: np.ndarray) -> np.ndarray:
        h, w = x.shape[:2]
        pad_h = max(0, self.target_h_size - h)
        pad_w = max(0, self.target_w_size - w)
        if len(x.shape) == 2:
            padded = np.pad(x, ((0, pad_h), (0, pad_w)), mode="reflect")
            padded = np.expand_dims(padded, axis=2)
            padded = torch.from_numpy(padded).float()
        else:
            padded = np.pad(x, ((0, pad_h), (0, pad_w), (0, 0)), mode="reflect")
            padded = torch.from_numpy(padded).float()

        padded = np.transpose(padded, (2, 0, 1))
        return padded

In [4]:
""" class for create dataset with SAM pattern """
class DatasetForSAM(SimpleDataset):
    def __init__(
            self, 
            readers: List[_Reader], 
            transforms: Optional[_Transform] = None,
            transform_coords_input:Optional[dict]=None,
            multimask_output:bool=True,
    ):
        """
        Custom Dataset to use properties that needed in images when send some image to SAM model.

        Parameters
        ----------
        readers: List[_Reader]
            List of data readers. It must contain exactly 2 readers.
            The first reader for the input data and the second reader for the
            target data.
        transforms: Optional[_Transform]
            Optional data transformation pipeline.
        transform_coords_input: Optional[dict] 
            List with transforms to apply.
                point_coords (np.ndarray or None): A Nx2 array of point prompts to the model. Each point is in (X,Y) in pixels.
                point_labels (np.ndarray or None): A length N array of labels for the point prompts. 1 indicates a foreground point and 0 indicates a background point.
    """
        super().__init__(readers, transforms)
        # self.transform_coords_input = transform_coords_input
        self.multimask_output = multimask_output

        assert (
            len(self.readers) == 2
        ), "DatasetForSAM requires exactly 2 readers (image your label)"

        # assert (
        #     len(self.readers) == len(self.transforms)
        #     and len(self.transforms) == len(self.transform_coords_input)
        #     and len(self.readers) == len(self.transform_coords_input)
        # ), "DatasetForSAM requires exactly iquals lens (readers, transforms and transform_coords_input)"
    
    def __getitem__(self, index: int) -> Tuple[np.ndarray, np.ndarray]:
        """
        Load data and return data with SAM format (dict), where dict has:
        'image' (required): The image as a torch tensor in 3xHxW format.
        'label' (required): The label of the image.
        'original_size' (required): The original size of the image before transformation.
        'point_coords' (optional): (torch.Tensor) Batched point prompts for this image, with shape BxNx2. Already transformed to the input frame of the model.
        'point_labels' (optional): (torch.Tensor) Batched labels for point prompts, with shape BxN. (0 is background, 1 is object and -1 is pad)
        'boxes' (optional): (torch.Tensor) Batched box inputs, with shape Bx4.  Already transformed to the input frame of the model.
        'mask_inputs' (optional): (torch.Tensor) Batched mask inputs to the model, in the form Bx1xHxW.
        """

        data_readers = []
        for reader, transform in zip(self.readers, self.transforms):
            sample = reader[index]
            if transform is not None:
                sample = transform(sample)
            data_readers.append(sample)
        
        data = {}
        # apply transform_coords_input to image (only in the image, not in label)
        # if self.transform_coords_input['point_coords'] is not None: # TODO adicionar essa parte quando implementar treino com prompts
        # image = self.readers[0][index]
        # TODO Implementar algum script que coloque pontos aleatoriamente nas fácies
        # point_coords = self.transform_coords_input['point_coords'].apply_coords(point_coords, self.original_size)
        # coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device)
        # labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
        # coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
        
        data['image'] = data_readers[0]
        data['label'] = data_readers[1]
        data['original_size'] = (int(data_readers[0].shape[1]), int(data_readers[0].shape[2])) # (tem que usar o shape depois do transform, se não dá erro) (int(image.shape[0]), int(image.shape[1]))
        data['multimask_output'] = self.multimask_output
        # TODO OBS: Só pode passar esses pontos se aplicar o transform_coords. Se tentar passar como None vai dar erro no Dataloader.
        # data['point_coords'] = None
        # data['point_labels'] = None
        # data['boxes'] = None
        # data['mask_inputs'] = None

        return data # (data, self.multimask_output)

In [5]:
""" class for create data module """
class DataModule(L.LightningDataModule):
    def __init__(
        self,
        train_path: str,
        annotations_path: str,
        transforms: _Transform = None,
        transform_coords_input: _Transform = None,
        multimask_output:bool = True,
        batch_size: int = 1,
        data_ratio: float = 1.0,
        num_workers: int = None,
    ):
        super().__init__()
        self.train_path = Path(train_path)
        self.annotations_path = Path(annotations_path)
        self.transforms = transforms
        self.transform_coords_input = transform_coords_input
        self.multimask_output = multimask_output
        self.batch_size = batch_size
        self.data_ratio = data_ratio
        self.num_workers = (
            num_workers if num_workers is not None else os.cpu_count()
        )

        self.datasets = {}

    def setup(self, stage=None):
        if stage == "fit":
            train_img_reader = TiffReader(self.train_path / "train")
            train_label_reader = PNGReader(self.annotations_path / "train")

            # applying ratio
            num_train_samples = int(len(train_img_reader) * self.data_ratio)
            if num_train_samples < len(train_img_reader):
                indices = random.sample(range(len(train_img_reader)), num_train_samples)
                train_img_reader = [train_img_reader[i] for i in indices]
                train_label_reader = [train_label_reader[i] for i in indices]
            
            train_dataset = DatasetForSAM(
                readers=[train_img_reader, train_label_reader],
                transforms=self.transforms,
                transform_coords_input=self.transform_coords_input,
                multimask_output=self.multimask_output
            )

            val_img_reader = TiffReader(self.train_path / "val")
            val_label_reader = PNGReader(self.annotations_path / "val")
            val_dataset = DatasetForSAM(
                readers=[val_img_reader, val_label_reader],
                transforms=self.transforms,
                transform_coords_input=self.transform_coords_input,
                multimask_output=self.multimask_output
            )

            self.datasets["train"] = train_dataset
            self.datasets["val"] = val_dataset

        elif stage == "test" or stage == "predict":
            test_img_reader = TiffReader(self.train_path / "test")
            test_label_reader = PNGReader(self.annotations_path / "test")
            test_dataset = DatasetForSAM(
                readers=[test_img_reader, test_label_reader],
                transforms=self.transforms,
                transform_coords_input=self.transform_coords_input,
                multimask_output=self.multimask_output
            )
            self.datasets["test"] = test_dataset
            self.datasets["predict"] = test_dataset

        else:
            raise ValueError(f"Invalid stage: {stage}")
    
    def custom_collate_fn(self, batch):
        """
        Custom collate function for DataLoader to return a list of dictionaries.
        """
        return batch 

    def train_dataloader(self):
        return DataLoader(
            self.datasets["train"],
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=True,
            collate_fn=self.custom_collate_fn
        )

    def val_dataloader(self):
        return DataLoader(
            self.datasets["val"],
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False,
            collate_fn=self.custom_collate_fn
        )

    def test_dataloader(self):
        return DataLoader(
            self.datasets["test"],
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False,
            collate_fn=self.custom_collate_fn
        )

    def predict_dataloader(self):
        return DataLoader(
            self.datasets["predict"],
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            shuffle=False,
            collate_fn=self.custom_collate_fn
        )

In [6]:
# TODO PRECISA ENTENDER COMO COLOCAR O DATASET QUE FIZ PRO SAM AQUI DENTRO
# data_module = get_data_module(
#     root_data_dir=root_data_dir,
#     root_annotation_dir=root_annotation_dir,
#     img_size=img_size,
#     batch_size=batch_size,
#     seed=seed,
#     single_channel=single_channel
# )

data_module = DataModule(
    train_path=root_data_dir,
    annotations_path=root_annotation_dir,
    transforms=Padding(img_size[0], img_size[1]),
    # transform_coords_input={'point_coords': None, 'point_labels': None},
    multimask_output=True,
    batch_size=batch_size,
    data_ratio=1.0
)

data_module

<__main__.DataModule at 0x7fdb21f1ef50>

In [7]:
# Just to check if the data module is working
data_module.setup("fit")
train_batch = next(iter(data_module.train_dataloader()))
train_batch_x, train_batch_y = train_batch[0]['image'], train_batch[0]['label']
# len is the batch size (because in forward of Sam his apply explicit batch. before, he need a list of dict), 
# where dict need have image, label of image and prompts), train_batch_x is the image and train_batch_y is label.
len(train_batch), train_batch_x.shape, train_batch_y.shape

(1, torch.Size([3, 1008, 784]), torch.Size([1, 1008, 784]))

## **** Create and Load model HERE ****

In [8]:
model = Sam(
    train_metrics={"mIoU": JaccardIndex(task="multiclass", num_classes=6)},
    val_metrics={"mIoU": JaccardIndex(task="multiclass", num_classes=6)},
    test_metrics={"mIoU": JaccardIndex(task="multiclass", num_classes=6)},
    vit_type='vit-b',
    checkpoint='/workspaces/Minerva-Discovery/shared_data/weights_sam/checkpoints_sam/sam_vit_b_01ec64.pth',
    num_multimask_outputs=6, # default: 3
    iou_head_depth=6, # default: 3
    # apply_freeze=apply_freeze,
    # apply_adapter=apply_adapter
)

model

  state_dict = torch.load(f)


Error when load original weights. Applying now remaping.
Prompt Encoder freeze!


Sam(
  (loss_fn): CrossEntropyLoss()
  (model): _SAM(
    (image_encoder): ImageEncoderViT(
      (patch_embed): PatchEmbed(
        (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (blocks): ModuleList(
        (0-11): 12 x Block(
          (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (attn): Attention(
            (qkv): Linear(in_features=768, out_features=2304, bias=True)
            (proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (mlp): MLPBlock(
            (lin1): Linear(in_features=768, out_features=3072, bias=True)
            (lin2): Linear(in_features=3072, out_features=768, bias=True)
            (act): GELU(approximate='none')
          )
        )
      )
      (neck): Sequential(
        (0): Conv2d(768, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): LayerNorm2d()
        (2): Conv2d(256, 256,

## Pipeline

In [9]:
pipeline = get_trainer_pipeline(
    model=model,
    model_name=model_name,
    dataset_name=dataset_name,
    log_dir=log_dir,
    num_epochs=num_epochs,
    accelerator=accelerator,
    devices=devices,
    is_debug=is_debug,
    seed=seed,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Seed set to 42


Log directory set to: /workspaces/Minerva-Discovery/Minerva-Dev/docs/notebooks/examples/seismic/facies_classification/parihaka/logs/sam/seam_ai


In [10]:
pipeline.run(data_module, task="fit")

/usr/local/lib/python3.10/dist-packages/lightning/fabric/loggers/csv_logs.py:268: Experiment logs directory ./logs/sam/seam_ai exists and is not empty. Previous log files in this directory will be deleted when the new ones are saved!
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Pipeline info saved at: /workspaces/Minerva-Discovery/Minerva-Dev/docs/notebooks/examples/seismic/facies_classification/parihaka/logs/sam/seam_ai/run_2024-12-17-18-02-4075ce386f01cd43d590e08c09257d2dc7.yaml



  | Name    | Type             | Params | Mode 
-----------------------------------------------------
0 | loss_fn | CrossEntropyLoss | 0      | train
1 | model   | _SAM             | 94.4 M | train
-----------------------------------------------------
94.3 M    Trainable params
6.2 K     Non-trainable params
94.4 M    Total params
377.415   Total estimated model params size (MB)
257       Modules in train mode
0         Modules in eval mode


                                                                           

/usr/local/lib/python3.10/dist-packages/lightning/pytorch/utilities/data.py:78: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 3. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.


Epoch 99: 100%|██████████| 1121/1121 [03:03<00:00,  6.11it/s, v_num=m_ai, train_loss_step=0.00339, train_mIoU_step=0.993, val_loss_step=0.693, val_mIoU_step=0.818, val_loss_epoch=0.438, val_mIoU_epoch=0.862, train_loss_epoch=0.00291, train_mIoU_epoch=0.993]

`Trainer.fit` stopped: `max_epochs=100` reached.


Epoch 99: 100%|██████████| 1121/1121 [03:09<00:00,  5.92it/s, v_num=m_ai, train_loss_step=0.00339, train_mIoU_step=0.993, val_loss_step=0.693, val_mIoU_step=0.818, val_loss_epoch=0.438, val_mIoU_epoch=0.862, train_loss_epoch=0.00291, train_mIoU_epoch=0.993]
Pipeline info saved at: /workspaces/Minerva-Discovery/Minerva-Dev/docs/notebooks/examples/seismic/facies_classification/parihaka/logs/sam/seam_ai/run_2024-12-17-18-02-4075ce386f01cd43d590e08c09257d2dc7.yaml


In [11]:
print(f"Checkpoint saved at {pipeline.trainer.checkpoint_callback.last_model_path}")

Checkpoint saved at ./logs/sam/seam_ai/checkpoints/last.ckpt
