<a href="https://colab.research.google.com/github/krupaltisgaonkar/pytorch-ssd/blob/main/SSD_MobileNet_Training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# SSD MobileNet Training in PyTorch
This notebook walks through the steps to train an SSD MobileNet model using PyTorch.

## Make sure you are using a GPU

Under Runtime

## Install and Import Necessary Libraries

In [None]:
!pip install torch torchvision pycocotools
!pip install --upgrade protobuf

from google.colab import drive
from google.colab import files
import zipfile
import os
from PIL import Image

# Mount Google Drive
print("Mounting Google Drive...")
drive.mount('/content/drive')

## Upload Dataset

Make sure your dataset is setup like this:

```
dataset/
├── images/
│   ├── image1.png
├── labels/
│   ├── image1.txt
```

### Option 1: Use Google Drive

It is expected that your dataset is zipped in google drive.

In [None]:
drive_dataset_path = "/content/drive/MyDrive/dataset/YOLO/dataset.zip"  # Replace with your Google Drive dataset path
if os.path.exists(drive_dataset_path):
    with zipfile.ZipFile(drive_dataset_path, 'r') as zip_ref:
        zip_ref.extractall("data")
    print(f"Dataset extracted from Google Drive to: data/")

Dataset extracted from Google Drive to: data/


### Option 2: Upload Manually

In [None]:
print("Upload your zipped dataset...")
uploaded = files.upload()
# Extract the uploaded dataset
if uploaded:
    for filename in uploaded.keys():
        with zipfile.ZipFile(filename, 'r') as zip_ref:
            zip_ref.extractall("data")
        print(f"Dataset extracted to: data/")

## Resize Dataset

### Resize Images

You will have to resize your images to 640 by 640 to ensure accuracy and faster training

In [None]:
def resize_images(input_dir, output_dir, new_size=(640, 640)):
    os.makedirs(output_dir, exist_ok=True)
    for filename in os.listdir(input_dir):
        if filename.endswith(".jpg") or filename.endswith(".png"):
            img_path = os.path.join(input_dir, filename)
            img = Image.open(img_path)
            img_resized = img.resize(new_size)
            img_resized.save(os.path.join(output_dir, filename))
    print(f"Images resized and saved to {output_dir}.")

# Resize train and val images
resize_images("data/images", "custom_dataset/images")

### Resize Labels

You will have to resize your labels for your resized images.

In [None]:
!wget -O /content/resize_labels.py https://raw.githubusercontent.com/krupaltisgaonkar/pytorch/refs/heads/main/scripts/resize_labels.py

!python resize_labels.py --input_label_dir data/labels \
                 --input_image_dir data/images \
                 --output_label_dir custom_dataset/labels \
                 --new_size 640

## Set Dataset Paths

In [None]:
# Set dataset paths
dataset_root = 'custom_dataset/'
image_dir = f'{dataset_root}/images'
label_dir = f'{dataset_root}/labels'
classes_file = f'data/classes.txt'

## Read Classes

In [None]:
# Read class labels
with open(classes_file, 'r') as f:
    class_labels = [line.strip() for line in f.readlines()]
n_classes = len(class_labels)
print(f"Classes: {class_labels}, Total: {n_classes}")

Classes: ['fish'], Total: 1


## Define Custom Dataset Class

In [None]:
import os
import torch
from PIL import Image

class SSDDataset(torch.utils.data.Dataset):
    def __init__(self, image_dir, label_dir, transforms=None):
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.transforms = transforms
        self.image_files = sorted(os.listdir(image_dir))
        self.label_files = sorted(os.listdir(label_dir))

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        # Load image
        img_path = os.path.join(self.image_dir, self.image_files[idx])
        image = Image.open(img_path).convert("RGB")

        # Load label
        label_path = os.path.join(self.label_dir, self.label_files[idx])
        with open(label_path, 'r') as f:
            boxes = []
            labels = []
            for line in f:
                data = list(map(float, line.strip().split()))
                labels.append(int(data[0]))  # Class ID
                x_center, y_center, width, height = data[1:]
                x_min = x_center - width / 2
                y_min = y_center - height / 2
                x_max = x_center + width / 2
                y_max = y_center + height / 2
                boxes.append([x_min, y_min, x_max, y_max])

        # Convert to tensors
        boxes = torch.tensor(boxes, dtype=torch.float32)
        labels = torch.tensor(labels, dtype=torch.int64)
        target = {'boxes': boxes, 'labels': labels}

        if self.transforms:
            image = self.transforms(image)

        return image, target

## Define Transformations and Load Dataset

In [None]:
from torchvision import transforms

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((640, 640)),
])

dataset = SSDDataset(image_dir, label_dir, transforms=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))

## Load Pretrained SSD MobileNet

In [None]:
from torchvision.models.detection import ssdlite320_mobilenet_v3_large
import torch

# Load pre-trained model
model = ssdlite320_mobilenet_v3_large(weights="DEFAULT")

# Update the model to the desired number of classes (e.g., 2 for background and fish)
num_classes = n_classes + 1  # background + fish

# Access the classification head
classification_head = model.head.classification_head

# Modify the final convolution layer in the last block (module_list[-1])
# The last module in the `module_list` corresponds to the final classification layer
final_conv_layer = classification_head.module_list[-1][1]

# Replace the final conv layer with a new one that outputs `num_classes`
final_conv_layer.out_channels = num_classes

# Replace the last Conv2d layer with a new Conv2d that has the correct output channels
classification_head.module_list[-1][1] = torch.nn.Conv2d(
    in_channels=final_conv_layer.in_channels,
    out_channels=num_classes,  # Number of classes
    kernel_size=final_conv_layer.kernel_size,
    stride=final_conv_layer.stride,
    padding=final_conv_layer.padding
)

# Now, the model is updated with the correct number of output classes


## Training Loop

In [None]:
import torch
from torch.optim import SGD

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# Define optimizer
optimizer = SGD(model.parameters(), lr=0.005, momentum=0.9, weight_decay=0.0005)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for images, targets in dataloader:
        images = [img.to(device) for img in images]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())
        total_loss += losses.item()

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss}")

## Save the Model

In [None]:
# Save trained model
model_save_path = f'{dataset_root}/ssd_mobilenet.pth'
torch.save(model.state_dict(), model_save_path)
print(f"Model saved to {model_save_path}")

## Evaluate the Model

In [None]:
import matplotlib.pyplot as plt

def visualize_predictions(image, predictions):
    plt.figure(figsize=(10, 10))
    plt.imshow(image.permute(1, 2, 0).cpu().numpy())
    for box, label in zip(predictions['boxes'], predictions['labels']):
        x_min, y_min, x_max, y_max = box
        plt.gca().add_patch(
            plt.Rectangle((x_min, y_min), x_max - x_min, y_max - y_min,
                          fill=False, edgecolor='red', linewidth=2)
        )
        plt.text(x_min, y_min, class_labels[label], color='blue', fontsize=12)
    plt.show()

# Load an image for testing
image, _ = dataset[0]
model.eval()
with torch.no_grad():
    predictions = model([image.to(device)])[0]

visualize_predictions(image, predictions)

## Issues

If you have any issues or receive any errors, please go to the <a href = "https://github.com/krupaltisgaonkar/pytorch-ssd">Github Page</a> and file an issue.