<a href="https://colab.research.google.com/github/kapjaya60/SSD/blob/main/Train_PyTorch_SSD.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import pandas as pd
import os
from PIL import Image, ImageDraw
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import transforms as T
from torchvision.models.detection import fcos_resnet50_fpn
from sklearn.metrics import precision_recall_fscore_support
from torchvision.ops import box_iou

In [2]:
# Download the file from the specified URL
!wget "https://app.roboflow.com/ds/3H7cxrKYsz?key=HVaPUzZi2o" -O dataset.zip

# Unzip the downloaded file
!unzip dataset.zip > /dev/null

--2024-12-31 15:52:19--  https://app.roboflow.com/ds/3H7cxrKYsz?key=HVaPUzZi2o
Resolving app.roboflow.com (app.roboflow.com)... 151.101.1.195, 151.101.65.195, 2620:0:890::100
Connecting to app.roboflow.com (app.roboflow.com)|151.101.1.195|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://storage.googleapis.com/roboflow-platform-regional-exports/zLz9rQfSVg8cZGmLT3uV/BFD1UH2Tp5ibiZQSnXue/11/tensorflow.zip?X-Goog-Algorithm=GOOG4-RSA-SHA256&X-Goog-Credential=481589474394-compute%40developer.gserviceaccount.com%2F20241231%2Fauto%2Fstorage%2Fgoog4_request&X-Goog-Date=20241231T155219Z&X-Goog-Expires=900&X-Goog-SignedHeaders=host&X-Goog-Signature=3b2b7c6412012201bf91a7666638fadc4a2277a8040a62a3b6aae7dc2b0afa26713c25ecc87254ed48192bf1f3d293b422e91075fb89507c9cdc655c1ee70d343de6b2af5fad3d3cf081117aed6e4a27b44e03ffdc0b8a56f80e09b600d4c40b51d8dec05e1590c437307fcd8b71009925c1a46a6a3bd0c3464f5a33281e035fe8d4f97a802fa66ab6b872bb379756f1479c99ee28e366f10b3d41720184

In [3]:
# Load CSV and prepare datasets
train = pd.read_csv('/content/train/_annotations.csv')
valid = pd.read_csv('/content/valid/_annotations.csv')

In [4]:
train_unique_imgs = train.filename.unique()
valid_unique_imgs = valid.filename.unique()

In [5]:
class CustDat(Dataset):
    def __init__(self, df, unique_imgs, root_dir, transform=None):
        self.df = df
        self.unique_imgs = unique_imgs
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.unique_imgs)

    def __getitem__(self, idx):
        image_name = self.unique_imgs[idx]
        # Filter bounding boxes for the current image
        boxes_df = self.df[self.df.filename == image_name][['xmin', 'ymin', 'xmax', 'ymax']]
        boxes = boxes_df.values.astype("float")  # Convert bounding box coordinates to float
        img_path = os.path.join(self.root_dir, image_name)
        img = Image.open(img_path).convert('RGB')  # Open the image

        # Assuming a single class for all boxes (can be adjusted as needed)
        labels = torch.ones((boxes.shape[0]), dtype=torch.int64)

        # Prepare the target dictionary
        target = {
            'boxes': torch.tensor(boxes, dtype=torch.float32),
            'labels': labels
        }

        # Apply any transformations to the image
        if self.transform:
            img = self.transform(img)

        return img, target

In [6]:
# Create datasets
train_dataset = CustDat(df=train, unique_imgs=train_unique_imgs, root_dir='/content/train',
                        transform=T.ToTensor())
valid_dataset = CustDat(df=valid, unique_imgs=valid_unique_imgs, root_dir='/content/valid',
                        transform=T.ToTensor())

# Create dataloaders
train_dl = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))
valid_dl = DataLoader(valid_dataset, batch_size=1, shuffle=False, collate_fn=lambda x: tuple(zip(*x)))

In [7]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [8]:
model = fcos_resnet50_fpn(weights="DEFAULT")

Downloading: "https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth" to /root/.cache/torch/hub/checkpoints/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth
100%|██████████| 160M/160M [00:01<00:00, 153MB/s]


In [9]:
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.0025, momentum=0.9, weight_decay=0.0005)

# Learning rate scheduler
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)
num_epochs = 100

In [10]:
save_dir = "fcos_model"
os.makedirs(save_dir, exist_ok=True)  # Create the directory if it doesn't exist

In [11]:
model.to(device)

for epoch in range(num_epochs):
    # TRAINING
    model.train()
    epoch_loss = 0
    for imgs, targets in train_dl:
        imgs = [img.to(device) for img in imgs]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        # Forward pass
        loss_dict = model(imgs, targets)
        loss = sum(v for v in loss_dict.values())

        # Accumulate epoch loss
        epoch_loss += loss.cpu().detach().numpy()

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch [{epoch + 1}/{num_epochs}] Loss: {epoch_loss:.4f}")

    # Save model every 10 epochs
    if (epoch + 1) % 10 == 0:
        save_path = os.path.join(save_dir, f"fcos_epoch_{epoch + 1}.pth")
        torch.save(model.state_dict(), save_path)
        print(f"Model saved at {save_path}")


    model.eval()  # Set the model to evaluation mode
    y_true = []
    y_pred = []

    iou_threshold = 0.5

    with torch.no_grad():
        for imgs, targets in valid_dl:
            imgs = [img.to(device) for img in imgs]
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

            outputs = model(imgs)  # Predictions from the model

            # Validate predictions
            for i, output in enumerate(outputs):
                pred_boxes = output["boxes"].cpu()
                pred_labels = output["labels"].cpu()
                gt_boxes = targets[i]["boxes"].cpu()
                gt_labels = targets[i]["labels"].cpu()

                # IoU matching between ground truths and predictions
                if len(gt_boxes) > 0 and len(pred_boxes) > 0:
                    iou_matrix = box_iou(gt_boxes, pred_boxes)

                    for gt_idx, gt_label in enumerate(gt_labels):
                        max_iou, pred_idx = iou_matrix[gt_idx].max(0)  # Best match for each GT box
                        if max_iou > iou_threshold:
                            y_true.append(gt_label.item())
                            y_pred.append(pred_labels[pred_idx].item())
                        else:
                            # False negative for unmatched ground truths
                            y_true.append(gt_label.item())
                            y_pred.append(0)  # Assume unmatched predictions as class 0

                    # False positives for unmatched predictions
                    matched_pred_indices = iou_matrix.argmax(0)
                    unmatched_preds = set(range(len(pred_boxes))) - set(matched_pred_indices.tolist())
                    for pred_idx in unmatched_preds:
                        y_true.append(0)  # No ground truth
                        y_pred.append(pred_labels[pred_idx].item())
                else:
                    # If no predictions or no ground truths, handle accordingly
                    y_true.extend(gt_labels.tolist())
                    y_pred.extend([0] * len(gt_labels))  # All are unmatched (false negatives)

                    y_true.extend([0] * len(pred_boxes))  # All predictions are unmatched (false positives)
                    y_pred.extend(pred_labels.tolist())

    # Calculate validation metrics
    precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average="weighted", zero_division=0)

    # Print validation metrics
    print(f"Validation Metrics:")
    print(f"Precision: {precision:.4f}, Recall: {recall:.4f}, F1-Score: {f1:.4f}")

from google.colab import files
files.download('/content/fcos_model/fcos_epoch_100.pth')

Epoch [1/100] Loss: 48.5726
Validation - Epoch [1/100]:
Precision: 1.0000, Recall: 1.0000, F1-Score: 1.0000
Epoch [2/100] Loss: 24.1583
Validation - Epoch [2/100]:
Precision: 1.0000, Recall: 1.0000, F1-Score: 1.0000
Epoch [3/100] Loss: 20.3566
Validation - Epoch [3/100]:
Precision: 1.0000, Recall: 1.0000, F1-Score: 1.0000
Epoch [4/100] Loss: 17.8091
Validation - Epoch [4/100]:
Precision: 1.0000, Recall: 1.0000, F1-Score: 1.0000
Epoch [5/100] Loss: 16.2551
Validation - Epoch [5/100]:
Precision: 1.0000, Recall: 1.0000, F1-Score: 1.0000
Epoch [6/100] Loss: 15.1979
Validation - Epoch [6/100]:
Precision: 1.0000, Recall: 1.0000, F1-Score: 1.0000
Epoch [7/100] Loss: 14.6852
Validation - Epoch [7/100]:
Precision: 1.0000, Recall: 1.0000, F1-Score: 1.0000
Epoch [8/100] Loss: 13.6862
Validation - Epoch [8/100]:
Precision: 1.0000, Recall: 1.0000, F1-Score: 1.0000
Epoch [9/100] Loss: 13.0521
Validation - Epoch [9/100]:
Precision: 1.0000, Recall: 1.0000, F1-Score: 1.0000
Epoch [10/100] Loss: 12.4944