In [20]:
import torch
import matplotlib.pyplot as plt
from ultralytics import YOLO
from datasets import YOLODataset
from torch.utils.data import DataLoader

In [21]:
# Root directory of the dataset
class_names = ['fish', 'jellyfish', 'penguin', 'puffin', 'shark', 'starfish', 'stingray']
num_classes = len(class_names)
root_dir = 'datasets/aquarium-data-cots/aquarium_pretrain'

# Create datasets
train_dataset = YOLODataset(root_dir, split='train', num_classes=num_classes)
valid_dataset = YOLODataset(root_dir, split='valid', num_classes=num_classes)
test_dataset = YOLODataset(root_dir, split='test', num_classes=num_classes)

# Create data loaders
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(valid_dataset)}")
print(f"Test samples: {len(test_dataset)}")

Training samples: 448
Validation samples: 127
Test samples: 63


## Visualization

In [None]:
def visualize_augmented_images(dataset, num_samples=5):
    fig, axes = plt.subplots(1, num_samples, figsize=(15, 15))
    
    for i in range(num_samples):
        image, labels, bboxes = dataset[i]
        
        image = image.permute(1, 2, 0).numpy()  # Change from CxHxW to HxWxC, tensor to numpy for plotting
        
        ax = axes[i]
        ax.imshow(image)
        
        # Draw bounding boxes
        for bbox in bboxes:
            xmin, ymin, xmax, ymax, class_id = bbox
            ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, 
                                      edgecolor='red', facecolor='none', linewidth=2))
            ax.text(xmin, ymin, f'Class {int(class_id)}', color='white', fontsize=10, 
                    bbox=dict(facecolor='red', alpha=0.5))
        
        ax.axis('off') 
    
    plt.show()

In [None]:
visualize_augmented_images(train_dataset, num_samples=5)

## Model

In [3]:
# Define the model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = YOLO('yolov5su.pt')
model = model.to(device)

## Train

In [None]:
num_epochs = 25

model.train(
    data='datasets/aquarium-data-cots/aquarium_pretrain/data.yaml', 
    epochs=num_epochs, 
    imgsz=640,  
    save_period=1,  # Save model every epoch
    # hyp='datasets/aquarium-data-cots/aquarium_pretrain/hyp.yaml', # Hyperparameter file
    save_dir='runs/train',  # Directory to save training results
)

## Validation Set

In [None]:
data_yaml_path = 'datasets/aquarium-data-cots/aquarium_pretrain/data.yaml'

# Run validation (inference) on the validation set
results = model.val(
    data=data_yaml_path,  # Path to the data.yaml file
    conf=0.5,              # Confidence threshold for predictions
    save_json=True,        # Optionally save results in COCO JSON format
    save_txt=True,         # Optionally save predictions as YOLO-format .txt files
)