In [1]:
%load_ext autoreload
%autoreload 2

In [1]:
import torch
from torch.utils.data import DataLoader
import torchvision
import matplotlib.pyplot as plt
import dataset
from model import ResNet18YOLOv1
from loss import YOLOv1Loss
from tqdm import tqdm

# Load PASCAL VOC 2007 Dataset

In [2]:
# original dataset
pascal_voc_train = torchvision.datasets.VOCDetection(
    root="data",
    year="2007",
    image_set="train",
    download=False
)

In [3]:
# augment dataset for YOLOv1: resize and normalize image and convert bounding boxes from annotations to tensors
voc_train = dataset.PascalVOC(pascal_voc=pascal_voc_train)

TRANSFORMING PASCAL VOC


In [4]:
BATCH_SIZE = 64

In [5]:
train_dataloader = DataLoader(voc_train, batch_size=BATCH_SIZE, shuffle=True)

## Device

In [6]:
DEVICE = "cpu"
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
    DEVICE = torch.device("mps")

DEVICE

device(type='mps')

## Hyperparameters

In [7]:
S = 7
B = 2
C = 20
lambda_coord = 5.0
lambda_noobj = 0.5

In [8]:
yolo = ResNet18YOLOv1(S=S, B=B, C=C).to(DEVICE)

In [9]:
yolo_loss = YOLOv1Loss(S=S, B=B, C=C, lambda_coord=lambda_coord, lambda_noobj=lambda_noobj)
optimizer = torch.optim.SGD(yolo.parameters(), lr=1e-3, weight_decay=0.0005, momentum=0.9)

In [10]:
EPOCHS = 1

In [11]:
yolo.train()
losses = []

for epoch in range(EPOCHS):
    total_loss = 0
    
    for i, (X, Y) in enumerate(tqdm(train_dataloader, desc=f"Epoch {epoch + 1} of {EPOCHS}", leave=True)):
        X = X.to(DEVICE)
        Y = Y.to(DEVICE)
        
        pred = yolo(X)
        # print(pred.shape)
        
        loss = yolo_loss(pred, Y)
        
        # backpropagation
        # loss.backward()
        # optimizer.step()
        # optimizer.zero_grad()
        
        total_loss += loss.item()
    
    loss = total_loss / len(train_dataloader)
    losses.append(loss)
    
    print(f"Epoch [{epoch + 1}/{EPOCHS}], Train Loss: {loss}")

Epoch 1 of 1: 100%|█████████████████████████████| 40/40 [02:49<00:00,  4.23s/it]

Epoch [1/1], Train Loss: 31.714705514907838



