# Mask R-CNN for Ultrasound Microrobot Detection

This notebook demonstrates how to use a custom dataset with Mask R-CNN to detect microrobots from ultrasound images. 

The dataset is assumed to have the following structure:

```
root_dir/
    images/
         train/   # contains .png images
         test/    # contains .png images
    labels/
         train/   # contains .txt files with bounding box labels
         test/    # contains .txt files with bounding box labels
```

Each label file is a text file with a single line like:

```
0 0.569076 0.381246 0.115152 0.130603
```

where the first value is a placeholder, and the next four are normalized `[x_center, y_center, width, height]` values. These are converted to absolute `[xmin, ymin, xmax, ymax]` coordinates during data loading.

In [None]:
import os
import numpy as np
import torch
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as T
import matplotlib.pyplot as plt

# Check device
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print('Using device:', device)

## Custom Dataset for Mask R-CNN

This dataset class loads an image and its corresponding label file, converts the normalized bounding box `[x_center, y_center, width, height]` to absolute coordinates `[xmin, ymin, xmax, ymax]` and returns a target dictionary as expected by Mask R-CNN.

In [None]:
class USMicrorobotDetectionDataset(Dataset):
    def __init__(self, root_dir, split='train', transforms=None):
        """
        Args:
            root_dir (str): Path to the dataset root directory.
            split (str): 'train' or 'test'.
            transforms: (callable, optional) Transforms to apply to the image.
        """
        self.root_dir = root_dir
        self.split = split
        self.transforms = transforms
        self.images_dir = os.path.join(root_dir, 'images', split)
        self.labels_dir = os.path.join(root_dir, 'labels', split)
        self.image_files = sorted([f for f in os.listdir(self.images_dir) if f.lower().endswith('.png')])

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        # Load image
        img_path = os.path.join(self.images_dir, self.image_files[idx])
        image = Image.open(img_path).convert('RGB')  # convert grayscale to RGB
        width, height = image.size

        # Load corresponding label file
        label_file = os.path.splitext(self.image_files[idx])[0] + '.txt'
        label_path = os.path.join(self.labels_dir, label_file)
        with open(label_path, 'r') as f:
            line = f.readline().strip()
            # Drop the first value, keeping only the 4 bounding box values
            values = [float(x) for x in line.split()[1:]]  # [x_center, y_center, w, h]

        # Convert normalized bbox to absolute coordinates
        x_center, y_center, w_norm, h_norm = values
        x_center_abs = x_center * width
        y_center_abs = y_center * height
        w_abs = w_norm * width
        h_abs = h_norm * height
        xmin = x_center_abs - w_abs / 2
        ymin = y_center_abs - h_abs / 2
        xmax = x_center_abs + w_abs / 2
        ymax = y_center_abs + h_abs / 2

        box = [xmin, ymin, xmax, ymax]

        # Create target dictionary
        target = {}
        target['boxes'] = torch.as_tensor([box], dtype=torch.float32)
        target['labels'] = torch.ones((1,), dtype=torch.int64)  # assuming one class: microrobot
        target['image_id'] = torch.tensor([idx])
        area = (xmax - xmin) * (ymax - ymin)
        target['area'] = torch.tensor([area], dtype=torch.float32)
        target['iscrowd'] = torch.zeros((1,), dtype=torch.int64)

        if self.transforms is not None:
            image = self.transforms(image)

        return image, target

# Example: Check the number of samples
root_dir = 'UsMicroMagSet-main/flagella'  # update with your dataset path
dataset = USMicrorobotDetectionDataset(root_dir, split='train', transforms=T.Compose([T.ToTensor()]))
print('Number of training samples:', len(dataset))

## Data Loaders

We use a custom collate function for detection tasks.

In [None]:
def collate_fn(batch):
    return tuple(zip(*batch))

train_dataset = USMicrorobotDetectionDataset(root_dir, split='train', transforms=T.Compose([T.ToTensor()]))
test_dataset = USMicrorobotDetectionDataset(root_dir, split='test', transforms=T.Compose([T.ToTensor()]))

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn)

print('Train samples:', len(train_dataset), '| Test samples:', len(test_dataset))

## Model Setup: Mask R-CNN

We load a pre-trained Mask R-CNN and modify it for our one-class (microrobot) detection.

For detection tasks, we usually have one extra class for background. Hence, set `num_classes = 2`.

In [None]:
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

# Load a pre-trained Mask R-CNN model
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)

# Number of classes: background and microrobot
num_classes = 2

# Replace the box predictor
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

model.to(device)
print('Model loaded on:', device)

## Training Loop

This is a basic training loop for Mask R-CNN. Note that training detection models can be resource intensive, so adjust the number of epochs and batch sizes as needed.

In [None]:
import torch.optim as optim

# Define the optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)

num_epochs = 10

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0
    for imgs, targets in train_loader:
        imgs = list(img.to(device) for img in imgs)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        
        loss_dict = model(imgs, targets)
        losses = sum(loss for loss in loss_dict.values())
        loss_value = losses.item()
        epoch_loss += loss_value
        
        optimizer.zero_grad()
        losses.backward()
        optimizer.step()
    
    print(f"Epoch {epoch} Loss: {epoch_loss/len(train_loader):.4f}")

    # Optionally, evaluate on the test set every epoch
    model.eval()
    total_loss = 0
    total_samples = 0
    with torch.no_grad():
        for imgs, targets in test_loader:
            imgs = list(img.to(device) for img in imgs)
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            loss_dict = model(imgs, targets)
            losses = sum(loss for loss in loss_dict.values())
            total_loss += losses.item() * len(imgs)
            total_samples += len(imgs)
    print(f"Epoch {epoch} | Test Loss: {total_loss/total_samples:.4f}")

## Save the Model

After training, you can save your model for future inference.

In [None]:
torch.save(model.state_dict(), 'mask_rcnn_microrobot.pth')
print('Model saved!')