In [None]:
from src.YoloSAM.scripts.train_sam import TrainSAM
from src.YoloSAM.utils.dataset import SAMDataset
from src.YoloSAM.utils.config import SAMFinetuneConfig, SAMDatasetConfig

finetune_config = SAMFinetuneConfig(
    device='cuda',
    wandb_project='SAM_finetune',
    wandb_name='test_run',
    model_type='vit_b',
    sam_path='../checkpoints/sam_vit_b_01ec64.pth',
    num_epochs=1,
    batch_size=2,
    learning_rate=1e-5,
    weight_decay=1e-4,
    lambda_bce=0.2,
    lambda_kl=0.2,
    sigma=1,
    wandb_mode='disabled',
    num_workers=0
)

train_dataset_config = SAMDatasetConfig(
    dataset_path='../../../datasets/DRIVE/train',
    remove_nonscar=True,
    sample_size=2,
    point_prompt=False, # -> If True, Random generation of points base on the mask
    box_prompt=False, # -> If True, box prompt is generated based on the mask
    enable_direction_aug=False, # -> If True, direction augmentation is enabled
    enable_size_aug=False, # -> If True, size augmentation is enabled
    yolo_prompt=True, # -> If True, yolo prompt is generated based on the mask
    yolo_model_path='../checkpoints/yolo11n.pt', # -> Path to the yolo model
    yolo_conf_threshold=0.25, # -> Confidence threshold for yolo
    yolo_iou_threshold=0.45, # -> IoU threshold for yolo
    yolo_imgsz=640, # -> Image size for yolo
    image_size=1024,
    train=True
)

val_dataset_config = SAMDatasetConfig(
    dataset_path='../../../datasets/DRIVE/val',
    remove_nonscar=True,
    sample_size=2,
    point_prompt=False,
    box_prompt=False,
    yolo_prompt=True,
    yolo_model_path='../checkpoints/yolo11n.pt',
    yolo_conf_threshold=0.25,
    yolo_iou_threshold=0.45,
    yolo_imgsz=640,
    image_size=1024,
    train=False
)

train_dataset = SAMDataset(train_dataset_config)
val_dataset = SAMDataset(val_dataset_config)

trainer = TrainSAM(finetune_config, train_dataset, val_dataset)
trainer.train(finetune_config.num_epochs)