In [2]:
import os
import json
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from PIL import Image
from tqdm import tqdm

Class - MicroplasticDataset

In [3]:
class MicroplasticDataset(torch.utils.data.Dataset):
     def __init__(self, root, annFile, transforms=None):
        self.root = root
        self.transforms = transforms
        self.images, self.annotations = self.load_annotations(annFile)

     def load_annotations(self, annFile):
        # Load annotations from COCO format
        with open(annFile) as f:
            data = json.load(f)
        images = []
        annotations = []
        for image in data['images']:
            images.append(image['file_name'])
            annotations.append([ann for ann in data['annotations'] if ann['image_id'] == image['id']])
        return images, annotations

     def __getitem__(self, idx):
        img_path = os.path.join(self.root, self.images[idx])
        image = Image.open(img_path).convert("RGB")  # Ensure image is in RGB format

        # Load the target annotations
        target = self.annotations[idx]
        boxes = []
        labels = []

        for ann in target:
            x, y, width, height = ann['bbox']
            boxes.append([x, y, x + width, y + height])  # Convert to [x1, y1, x2, y2]
            labels.append(ann['category_id'])  # Make sure to map category IDs to your classes

        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.tensor(labels, dtype=torch.int64)

        target = {"boxes": boxes, "labels": labels}

        if self.transforms is not None:
            image = self.transforms(image)

        return image, target

     def __len__(self):
        return len(self.images)

transform = transforms.Compose([
    transforms.ToTensor()
])

Load Dataset

In [4]:
train_dataset = MicroplasticDataset(
    root="/Users/bipashaamohanty/Documents/projects/microplastics-detection/microplastic-dataset-roboflow/train",
    annFile="/Users/bipashaamohanty/Documents/projects/microplastics-detection/microplastic-dataset-roboflow/train/_annotations.coco.json",
    transforms=transform
)

val_dataset = MicroplasticDataset(
    root="/Users/bipashaamohanty/Documents/projects/microplastics-detection/microplastic-dataset-roboflow/valid",
    annFile="/Users/bipashaamohanty/Documents/projects/microplastics-detection/microplastic-dataset-roboflow/valid/_annotations.coco.json",
    transforms=transform
)
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, collate_fn=lambda x: tuple(zip(*x)))

# Define the device for Mac M3 or any compatible Apple Silicon
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

# Create a custom Faster R-CNN model
backbone = resnet_fpn_backbone('resnet50', pretrained=True)
model = torchvision.models.detection.FasterRCNN(backbone, num_classes=2)
# Move model to device
model.to(device)

# Define optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)



Training Loop

In [5]:
num_epochs = 10  # Set the number of epochs
for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1}/{num_epochs}")
    for images, targets in tqdm(train_loader):
        # Move images and targets to the appropriate device
        images = [img.to(device) for img in images]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        # Forward pass, loss calculation, and optimization steps
        model.train()
        optimizer.zero_grad()
        #loss_dict = model(images, targets)
        #losses = sum(loss for loss in loss_dict.values())
        #losses.backward()
        #optimizer.step()
    # Save the model
    torch.save(model.state_dict(), "microplastic_detector.pth")

Epoch 1/10


100%|██████████| 715/715 [00:05<00:00, 137.24it/s]


Epoch 2/10


100%|██████████| 715/715 [00:04<00:00, 156.17it/s]


Epoch 3/10


100%|██████████| 715/715 [00:04<00:00, 145.76it/s]


Epoch 4/10


100%|██████████| 715/715 [00:04<00:00, 151.01it/s]


Epoch 5/10


100%|██████████| 715/715 [00:04<00:00, 153.18it/s]


Epoch 6/10


100%|██████████| 715/715 [00:04<00:00, 153.07it/s]


Epoch 7/10


100%|██████████| 715/715 [00:04<00:00, 152.57it/s]


Epoch 8/10


100%|██████████| 715/715 [00:04<00:00, 152.55it/s]


Epoch 9/10


100%|██████████| 715/715 [00:04<00:00, 152.68it/s]


Epoch 10/10


100%|██████████| 715/715 [00:04<00:00, 151.22it/s]
