In [6]:
import torch
import matplotlib.pyplot as plt
from ultralytics import YOLO
from datasets import YOLODataset
from torch.utils.data import DataLoader
from collections import Counter

In [None]:
# Root directory of the dataset
class_names = ['fish', 'jellyfish', 'penguin', 'puffin', 'shark', 'starfish', 'stingray']
num_classes = len(class_names)
root_dir = 'datasets/aquarium-data-cots/aquarium_pretrain'

# Create datasets
train_dataset = YOLODataset(root_dir, split='train', num_classes=num_classes)
valid_dataset = YOLODataset(root_dir, split='valid', num_classes=num_classes)
test_dataset = YOLODataset(root_dir, split='test', num_classes=num_classes)

# Custom sampler for balanced class representation
train_sampler = train_dataset.get_sampler()  

# Create data loaders with the custom collate function to track class counts
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, num_workers=4)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(valid_dataset)}")
print(f"Test samples: {len(test_dataset)}")


## Visualization

In [8]:
def visualize_augmented_images(dataset, num_samples=5):
    fig, axes = plt.subplots(1, num_samples, figsize=(15, 15))
    
    for i in range(num_samples):
        image, labels, bboxes = dataset[i]
        
        image = image.permute(1, 2, 0).numpy()  # Change from CxHxW to HxWxC, tensor to numpy for plotting
        
        ax = axes[i]
        ax.imshow(image)
        
        # Draw bounding boxes
        for bbox in bboxes:
            xmin, ymin, xmax, ymax, class_id = bbox
            ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, 
                                      edgecolor='red', facecolor='none', linewidth=2))
            ax.text(xmin, ymin, f'Class {int(class_id)}', color='white', fontsize=10, 
                    bbox=dict(facecolor='red', alpha=0.5))
        
        ax.axis('off') 
    
    plt.show()

In [None]:
visualize_augmented_images(train_dataset, num_samples=5)

## Model

In [10]:
# Define the model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = YOLO('yolov5su.pt')
model = model.to(device)

## Train

In [None]:
num_epochs = 25

# YOLO will assess the model on the validation set in this training block

model.train(
    data='datasets/aquarium-data-cots/aquarium_pretrain/data.yaml', 
    epochs=num_epochs, 
    imgsz=640,  
    save_period=1,  # Save model every epoch
    save_dir='runs/train', 
    
)

## Inference on Test Set

In [14]:
data_yaml_path = 'datasets/aquarium-data-cots/aquarium_pretrain/data.yaml'

# Run inference on the test set

results = model.val(
    data='datasets/aquarium-data-cots/aquarium_pretrain/data.yaml', 
    conf=0.5,              # Confidence threshold for predictions
    save_json=True,        # Save results in COCO JSON format
    save_txt=True,         # Save predictions as YOLO-format .txt files
    split = 'test',
    project='runs',
    name='test'
)

Ultralytics 8.3.27 🚀 Python-3.10.15 torch-2.5.1+cu124 CUDA:0 (NVIDIA L4, 22478MiB)


[34m[1mval: [0mScanning /home/ubuntu/cs230_hannah/cs230_project/datasets/aquarium-data-cots/aquarium_pretrain/test/labels.cache... 63 images, 0 backgrounds, 0 corrupt: 100%|██████████| 63/63 [00:00<?, ?it/s]
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 4/4 [00:01<00:00,  2.13it/s]


                   all         63        584      0.913      0.648      0.794       0.56
                  fish         30        249      0.821      0.643      0.755      0.525
             jellyfish         11        154      0.887      0.818      0.871      0.627
               penguin          7         82      0.907      0.598      0.768      0.403
                puffin          6         35      0.812      0.371      0.614      0.361
                 shark         14         38      0.964      0.711       0.85      0.673
              starfish          5         11          1      0.727      0.864      0.649
              stingray         10         15          1      0.667      0.833      0.685
Speed: 0.4ms preprocess, 2.7ms inference, 0.0ms loss, 1.9ms postprocess per image
Saving runs/test/predictions.json...
Results saved to [1mruns/test[0m


## Inference on Video

In [15]:
model = YOLO('runs/detect/train5 (newest)/weights/best.pt')

In [None]:
results = model.predict(
    source='/home/ubuntu/cs230_hannah/cs230_project/datasets/aquarium-data-cots/aquarium_pretrain/aquarium.mp4',  
    conf=0.5,                    # Confidence threshold for detections
    save=True,                   # Save the output video with annotations
    save_txt=True,               # Save detections in YOLO-format text files
    save_conf=True,              # Save confidence scores in text files
    project='runs',              # Directory to save results
    name='video_detection',      # Subdirectory name
    exist_ok=True,               # Overwrite existing results if directory exists
)