# SSD Lite model on the cows dataset from ICAERUS/France
Most of this is adapted from basic cookiecutter model training in pytorch.

Dataset consists of three areas: Jalogny, Derval and Mauron: these directly correspond to train/test/val sets.
I removed all images without an annotation in them (e.g. just a picture of a field, without a cow)
* train: jalogny: x img, x annotations
* test: derval: x imgs, x annotations
* val: mauron: x imgs, x annotations


In [None]:
!pip install tensorboard

In [3]:
import torch
import torchvision
from torchvision.models.detection import ssdlite320_mobilenet_v3_large
from torchvision.datasets import VOCDetection
from torchvision.transforms.v2 import functional as F
from torchvision.transforms import v2
from torchvision import ops
import torchvision.transforms.v2 as transforms
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchvision import datapoints
from torchvision.ops import generalized_box_iou_loss
from torch.utils import tensorboard
from torchvision.datasets import VisionDataset

In [4]:

class CustomDataset(VisionDataset):
    def __init__(self, images, labels, boxes, transform=None):
        self.images = images
        self.labels = labels
        self.boxes = boxes
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        labels = torch.tensor(self.labels[idx], dtype=torch.long)
        boxes = torch.tensor(self.boxes[idx], dtype = torch.float32)
        target = {'boxes': boxes,
                 'labels': labels}
        
        if self.transform:
            image = self.transform(image)

        return image, target

In [5]:
# in the image_slicing.ipynb, the images are loaded and then tiled (in  a 320x320 grid), placed into a CustomDataset
# with normalized values and all as tensors (), these are then pickled and stored.

def collate_fn(batch):
    return tuple(zip(*batch))

# Create a VOC dataset
train_dataset = torch.load("data/train_set.pkl")
test_dataset = torch.load("data/test_set.pkl")
val_dataset = torch.load("data/val_set.pkl")
# Create a DataLoader for the VOC dataset
batch_size = 8
shuffle = True

#and put them in the loaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, collate_fn = collate_fn)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn = collate_fn)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn = collate_fn)

In [24]:

from torchvision.ops import box_convert

# Training loop
num_epochs = 500
warmup_epochs = 50
evaluate_every = 10

# Create your SSD Lite model
model = ssdlite320_mobilenet_v3_large(weights= "SSDLite320_MobileNet_V3_Large_Weights.DEFAULT")

# Define your loss function and optimizer
criterion = generalized_box_iou_loss
optimizer = torch.optim.SGD(model.parameters(), lr=0.015, momentum=0.9)
# Set up the cosine annealing learning rate scheduler
scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs - warmup_epochs)

writer = tensorboard.SummaryWriter()

# Define the device (CPU or GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
for epoch in range(num_epochs):
    model.train()
    for i, data in enumerate(train_loader):
        images, targets = data        
        
        images = [img.to(device) for img in images]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        #targets = [{k: v.to(device).long() if k == "labels" else v.to(device) for k, v in t.items()} for t in targets]

        optimizer.zero_grad()

        # Forward pass
        loss_dict = model(images, targets)
        losses = loss_dict["classification"] + loss_dict["bbox_regression"]


        # Backward pass and optimization
        losses.backward()
        optimizer.step()
    print(f"Epoch {epoch + 1}, Batch Loss: {losses}")
        
            # Update the learning rate
    scheduler.step()
        # Log the training loss to Tensorboard
    writer.add_scalar('Loss/train', losses.item(), epoch)
   
    
      # Evaluate on the validation set every 'evaluate_every' epochs
    if epoch % evaluate_every == 0:
        with torch.no_grad():
            val_losses = 0.0
            for images, targets in val_loader:
                images = [img.to(device) for img in images]
                targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

                # Forward pass
                val_loss_dict = model(images, targets)
                
                val_losses += val_loss_dict["classification"] + loss_dict["bbox_regression"]

            avg_val_loss = val_losses / len(val_loader)

            # Log the validation loss to Tensorboard
            writer.add_scalar('Loss/val', avg_val_loss, epoch)
            print(f"Epoch {epoch + 1}, Validation Loss: {avg_val_loss.item()}")
            torch.save(model.state_dict(), f"models/ssdlite_cows_model_e{epoch}.pth")


    

# Save the trained model
torch.save(model.state_dict(), 'models/ssdlite_cows_model_500e.pth')


Epoch 1, Batch Loss: 2.7929348945617676
Epoch 1, Validation Loss: 2.7929370403289795
Epoch 2, Batch Loss: 2.1547679901123047
Epoch 3, Batch Loss: 2.3239827156066895
Epoch 4, Batch Loss: 1.9026097059249878
Epoch 5, Batch Loss: 2.445168972015381
Epoch 6, Batch Loss: 2.473205327987671
Epoch 7, Batch Loss: 2.8486196994781494
Epoch 8, Batch Loss: 2.6673994064331055
Epoch 9, Batch Loss: 1.0177251100540161
Epoch 10, Batch Loss: 3.4130921363830566
Epoch 11, Batch Loss: 2.5426130294799805
Epoch 11, Validation Loss: 2.542614221572876
Epoch 12, Batch Loss: 1.9889143705368042
Epoch 13, Batch Loss: 2.547889232635498
Epoch 14, Batch Loss: 1.6018751859664917
Epoch 15, Batch Loss: 1.93035089969635
Epoch 16, Batch Loss: 1.292243480682373
Epoch 17, Batch Loss: 1.1661664247512817
Epoch 18, Batch Loss: 1.805835247039795
Epoch 19, Batch Loss: 1.842522382736206
Epoch 20, Batch Loss: 1.6262757778167725
Epoch 21, Batch Loss: 0.8940614461898804
Epoch 21, Validation Loss: 0.8940613865852356
Epoch 22, Batch Loss