In [None]:
from pathlib import Path
import sys

# Add parent directory to path
parent_dir = Path().cwd().parent
print(parent_dir)
sys.path.append(str(parent_dir))

In [None]:
from finetune_sam2 import SAMTrainer
from matplotlib import pyplot as plt

In [None]:
# Initialize the trainer
trainer = SAMTrainer(
    data_dir="../assets/Sam2-Train-Data",
    model_cfg="../sam2/configs/sam2.1/sam2.1_hiera_s.yaml",
    checkpoint_path="../checkpoints/sam2.1_hiera_small.pt"
)

# Prepare the dataset
trainer.prepare_data(test_size=0.2)

# Initialize the model
trainer.initialize_model()

# Train the model
trainer.train(
    steps=2000,
    learning_rate=0.0001,
    checkpoint_interval=500
)

In [None]:
# Perform inference
import random

if not trainer.test_data:
    raise ValueError("Test data is empty. Ensure that the dataset is prepared correctly.")

selected_entry = random.choice(trainer.test_data)
print(selected_entry)

image_path = selected_entry['image']
mask_path = selected_entry['annotation']
image, mask, segmentation = trainer.inference(
    image_path=image_path,
    mask_path=mask_path,
    checkpoint_path="./models/fine_tuned_sam2_2000.torch"
)

plt.figure(figsize=(18, 6))
plt.subplot(1, 3, 1)
plt.title('Test Image')
plt.imshow(image)
plt.axis('off')

plt.subplot(1, 3, 2)
plt.title('Original Mask')
plt.imshow(mask, cmap='gray')
plt.axis('off')

plt.subplot(1, 3, 3)
plt.title('Predicted Segmentation')
plt.imshow(segmentation, cmap='jet')
plt.axis('off')

plt.tight_layout()
plt.show()