In [1]:
!pip install wandb

Collecting wandb
  Downloading wandb-0.17.6-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (10 kB)
Collecting docker-pycreds>=0.4.0 (from wandb)
  Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl.metadata (1.8 kB)
Collecting gitpython!=3.1.29,>=1.0.0 (from wandb)
  Downloading GitPython-3.1.43-py3-none-any.whl.metadata (13 kB)
Collecting sentry-sdk>=1.0.0 (from wandb)
  Downloading sentry_sdk-2.12.0-py2.py3-none-any.whl.metadata (9.8 kB)
Collecting setproctitle (from wandb)
  Downloading setproctitle-1.3.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.9 kB)
Collecting gitdb<5,>=4.0.1 (from gitpython!=3.1.29,>=1.0.0->wandb)
  Downloading gitdb-4.0.11-py3-none-any.whl.metadata (1.2 kB)
Collecting smmap<6,>=3.0.1 (from gitdb<5,>=4.0.1->gitpython!=3.1.29,>=1.0.0->wandb)
  Downloading smmap-5.0.1-py3-none-any.whl.metadata (4.3 kB)
Downloading wandb-0.17.6-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_6

In [None]:
import torch
import torchvision
from torchvision.datasets import VOCDetection
from torchvision.transforms import functional as F
from torch.utils.data import DataLoader, random_split
from sklearn.metrics import precision_recall_fscore_support, confusion_matrix
import numpy as np
from torchvision.ops import box_iou
import matplotlib.pyplot as plt
import seaborn as sns
import wandb
import os
from ultralytics import YOLO
from PIL import Image

# Initialize WandB for experiment tracking
api_key = "9ce954fd827fd8d839648cb3708ff788ad51bafa"
wandb.login(key=api_key)
wandb.init(project="yolov8-project-pretrained", entity="enxo7899")

# Hyperparameters and configuration
hyperparams = {
    "batch_size": 8,            # Increased batch size for faster training if memory allows
    "num_workers": 4,           # Increased workers for faster data loading
    "learning_rate": 0.005,     # Learning rate for optimizer
    "momentum": 0.9,            # Momentum for SGD
    "num_epochs": 3,            # Reduced epochs for faster runs
    "iou_threshold": 0.5,       # IoU threshold for evaluation
}

# Check if GPU is available and use it if possible
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define a mapping from class names to integers
CLASS_NAMES = [
    "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair",
    "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant",
    "sheep", "sofa", "train", "tvmonitor"
]
CLASS_NAME_TO_IDX = {name: i for i, name in enumerate(CLASS_NAMES)}

# Download and prepare the VOC2012 dataset
class VOCDataset(torch.utils.data.Dataset):
    def __init__(self, root, year, image_set, transforms):
        self.dataset = VOCDetection(root, year=year, image_set=image_set, download=True)
        self.transforms = transforms

    def __getitem__(self, idx):
        image = self.dataset[idx][0]
        target = self.dataset[idx][1]
        target = self._convert_target(target)

        if self.transforms:
            image = self.transforms(image)

        return image, target

    def __len__(self):
        return len(self.dataset)

    def _convert_target(self, target):
        # Convert the target from the VOC format to a format compatible with YOLOv8
        boxes = []
        labels = []
        objects = target['annotation']['object']
        if isinstance(objects, dict):
            objects = [objects]

        for obj in objects:
            bbox = obj['bndbox']
            box = [
                int(bbox['xmin']),
                int(bbox['ymin']),
                int(bbox['xmax']),
                int(bbox['ymax'])
            ]
            boxes.append(box)
            labels.append(CLASS_NAME_TO_IDX[obj['name']])

        return {
            'boxes': torch.as_tensor(boxes, dtype=torch.float32),
            'labels': torch.as_tensor(labels, dtype=torch.int64)
        }

# Define transformations
def get_transform(train):
    transforms = []
    transforms.append(F.to_tensor)
    return torchvision.transforms.Compose(transforms)

# Initialize dataset and dataloaders
full_dataset = VOCDataset(root='./data', year='2012', image_set='train', transforms=get_transform(train=True))
test_dataset = VOCDataset(root='./data', year='2012', image_set='val', transforms=get_transform(train=False))

# Split the dataset for training (70%) and validation (30%)
train_size = int(0.7 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

# Data loaders
data_loader = DataLoader(train_dataset, batch_size=hyperparams["batch_size"], shuffle=True, num_workers=hyperparams["num_workers"], collate_fn=lambda x: tuple(zip(*x)))
data_loader_val = DataLoader(val_dataset, batch_size=hyperparams["batch_size"], shuffle=False, num_workers=hyperparams["num_workers"], collate_fn=lambda x: tuple(zip(*x)))
data_loader_test = DataLoader(test_dataset, batch_size=hyperparams["batch_size"], shuffle=False, num_workers=hyperparams["num_workers"], collate_fn=lambda x: tuple(zip(*x)))

# Load a pretrained YOLOv8 model
model = YOLO('yolov8n.pt')
model.to(device)

# Function to save a model checkpoint
def save_checkpoint(state, filename="checkpoint.pth.tar"):
    torch.save(state, filename)
    print(f"Checkpoint saved to {filename}")

# Training loop
num_epochs = hyperparams["num_epochs"]
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for i, (images, targets) in enumerate(data_loader):
        images = [image.to(device) for image in images]
        results = model(images)  # Forward pass with YOLOv8



        # Save checkpoint every epoch
        save_checkpoint({'epoch': epoch + 1, 'state_dict': model.state_dict()}, f"checkpoint_epoch_{epoch+1}.pth.tar")

# Function to calculate evaluation metrics
def calculate_metrics(model, data_loader, device, iou_threshold=hyperparams["iou_threshold"]):
    model.eval()
    y_true = []
    y_pred = []

    with torch.no_grad():
        for images, targets in data_loader:
            images = [image.to(device) for image in images]
            results = model(images)  # Forward pass with YOLOv8

            for target, output in zip(targets, results):
                true_boxes = target['boxes'].to(device)
                true_labels = target['labels'].cpu().numpy()

                pred_boxes = output.boxes.xyxy.to(device)
                pred_labels = output.boxes.cls.cpu().numpy()
                pred_scores = output.boxes.conf.cpu().numpy()

                # Filter predictions by a score threshold (e.g., 0.5)
                keep = pred_scores > 0.5
                pred_boxes = pred_boxes[keep]
                pred_labels = pred_labels[keep]

                # Calculate IoU between predicted and true boxes
                if len(pred_boxes) > 0 and len(true_boxes) > 0:
                    ious = box_iou(pred_boxes, true_boxes)
                    ious_max, indices = ious.max(dim=1)
                    matched = ious_max > iou_threshold

                    # Filter matched indices for valid IoU
                    matched_true_indices = indices[matched].cpu().numpy()
                    matched_pred_indices = np.where(matched.cpu().numpy())[0]

                    # Collect the matched true and predicted labels
                    y_true.extend(true_labels[matched_true_indices].tolist())
                    y_pred.extend(pred_labels[matched_pred_indices].tolist())

    # Calculate precision, recall, and F1 score
    precision, recall, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='weighted', zero_division=0)
    conf_matrix = confusion_matrix(y_true, y_pred)

    # Log metrics to WandB
    wandb.log({"precision": precision, "recall": recall, "f1_score": f1})

    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1 Score: {f1:.4f}")
    print("Confusion Matrix:")
    print(conf_matrix)

    # Plot the confusion matrix
    plot_confusion_matrix(conf_matrix, CLASS_NAMES)

def plot_confusion_matrix(conf_matrix, class_names):
    plt.figure(figsize=(12, 10))
    sns.set(font_scale=1.2)
    sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.title('Confusion Matrix')
    plt.tight_layout()
    plt.show()

# Calculate and print metrics on the validation dataset
calculate_metrics(model, data_loader_val, device)

# Finish WandB logging
wandb.finish()

# Inference function for a single image
def inference(model, image_path, device):
    model.eval()
    image = Image.open(image_path)
    transform = get_transform(train=False)
    image = transform(image).to(device).unsqueeze(0)

    with torch.no_grad():
        results = model(image)

    # Extract predictions
    pred_boxes = results.boxes.xyxy.cpu().numpy()
    pred_labels = results.boxes.cls.cpu().numpy()
    pred_scores = results.boxes.conf.cpu().numpy()

    # Display results
    for box, label, score in zip(pred_boxes, pred_labels, pred_scores):
        print(f"Label: {CLASS_NAMES[int(label)]}, Score: {score}, Box: {box}")

    return results

# Example usage of inference
# inference_results = inference(model, './data/VOCdevkit/VOC2012/JPEGImages/2007_000032.jpg', device)




VBox(children=(Label(value='2.238 MB of 2.238 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
lr/pg0,▁
lr/pg1,▁
lr/pg2,▁
metrics/mAP50(B),▁
metrics/mAP50-95(B),▁
metrics/precision(B),▁
metrics/recall(B),▁
model/GFLOPs,▁
model/parameters,▁
model/speed_PyTorch(ms),▁

0,1
lr/pg0,0.00333
lr/pg1,0.00333
lr/pg2,0.00333
metrics/mAP50(B),0.44227
metrics/mAP50-95(B),0.30404
metrics/precision(B),0.55582
metrics/recall(B),0.4211
model/GFLOPs,8.858
model/parameters,3157200.0
model/speed_PyTorch(ms),1.507


Using downloaded and verified file: ./data/VOCtrainval_11-May-2012.tar
Extracting ./data/VOCtrainval_11-May-2012.tar to ./data
Using downloaded and verified file: ./data/VOCtrainval_11-May-2012.tar
Extracting ./data/VOCtrainval_11-May-2012.tar to ./data
[34m[1mengine/trainer: [0mtask=detect, mode=train, model=yolov8n.pt, data=coco.yaml, epochs=100, time=None, patience=100, batch=16, imgsz=640, save=True, save_period=-1, cache=False, device=cuda:0, workers=8, project=None, name=train2, exist_ok=False, pretrained=True, optimizer=auto, verbose=True, seed=0, deterministic=True, single_cls=False, rect=False, cos_lr=False, close_mosaic=10, resume=False, amp=True, fraction=1.0, profile=False, freeze=None, multi_scale=False, overlap_mask=True, mask_ratio=4, dropout=0.0, val=True, split=val, save_json=False, save_hybrid=False, conf=None, iou=0.7, max_det=300, half=False, dnn=False, plots=True, source=None, vid_stride=1, stream_buffer=False, visualize=False, augment=False, agnostic_nms=False,

[34m[1mtrain: [0mScanning /content/datasets/coco/labels/train2017.cache... 117266 images, 1021 backgrounds, 0 corrupt: 100%|██████████| 118287/118287 [00:00<?, ?it/s]


[34m[1malbumentations: [0mBlur(p=0.01, blur_limit=(3, 7)), MedianBlur(p=0.01, blur_limit=(3, 7)), ToGray(p=0.01), CLAHE(p=0.01, clip_limit=(1, 4.0), tile_grid_size=(8, 8))


os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
[34m[1mval: [0mScanning /content/datasets/coco/labels/val2017.cache... 4952 images, 48 backgrounds, 0 corrupt: 100%|██████████| 5000/5000 [00:00<?, ?it/s]


Plotting labels to runs/detect/train2/labels.jpg... 
[34m[1moptimizer:[0m 'optimizer=auto' found, ignoring 'lr0=0.01' and 'momentum=0.937' and determining best 'optimizer', 'lr0' and 'momentum' automatically... 
[34m[1moptimizer:[0m SGD(lr=0.01, momentum=0.9) with parameter groups 57 weight(decay=0.0), 64 weight(decay=0.0005), 63 bias(decay=0.0)
[34m[1mTensorBoard: [0mmodel graph visualization added ✅
Image sizes 640 train, 640 val
Using 8 dataloader workers
Logging results to [1mruns/detect/train2[0m
Starting training for 100 epochs...

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      1/100       4.7G      1.133      1.413      1.181        204        640: 100%|██████████| 7393/7393 [36:15<00:00,  3.40it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 157/157 [00:27<00:00,  5.76it/s]


                   all       5000      36335      0.556      0.421      0.442      0.304

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      2/100      5.71G      1.229      1.647      1.243        184        640:  62%|██████▏   | 4584/7393 [21:15<11:20,  4.13it/s]