In [1]:
import torch
from ultralytics import YOLO
from datasets import YOLODataset
from torch.utils.data import DataLoader

In [2]:
# Root directory of the dataset
class_names = ['fish', 'jellyfish', 'penguin', 'puffin', 'shark', 'starfish', 'stingray']
num_classes = len(class_names)
root_dir = 'datasets/aquarium-data-cots/aquarium_pretrain'

# Create datasets
train_dataset = YOLODataset(root_dir, split='train', num_classes=num_classes)
valid_dataset = YOLODataset(root_dir, split='valid', num_classes=num_classes)
test_dataset = YOLODataset(root_dir, split='test', num_classes=num_classes)

# Create data loaders
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(valid_dataset)}")
print(f"Test samples: {len(test_dataset)}")

Training samples: 448
Validation samples: 127
Test samples: 63


## Model

In [3]:
# Define the model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = YOLO('yolov5su.pt')
model = model.to(device)

## Train

In [None]:
num_epochs = 10

model.train(
    data='datasets/aquarium-data-cots/aquarium_pretrain/data.yaml', 
    epochs=num_epochs, 
    imgsz=640,  
    save_period=1,  # Save model every epoch
    save_dir='runs/train',  # Directory to save training results
)
 

## Validation Set

In [8]:
data_yaml_path = 'datasets/aquarium-data-cots/aquarium_pretrain/data.yaml'

# Run validation (inference) on the validation set
results = model.val(
    data=data_yaml_path,  # Path to the data.yaml file
    conf=0.5,              # Confidence threshold for predictions
    save_json=True,        # Optionally save results in COCO JSON format
    save_txt=True,         # Optionally save predictions as YOLO-format .txt files
)


[34m[1mval: [0mScanning /home/ubuntu/cs230_proj/datasets/aquarium-data-cots/aquarium_pretrain/valid/labels.cache... 127 images, 0 backgrounds, 0 corrupt: 100%|██████████| 127/127 [00:00<?, ?it/s]


                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██████████| 8/8 [00:01<00:00,  6.12it/s]


                   all        127        909      0.143       0.01     0.0764     0.0614
                person         63        459          0          0          0          0
               bicycle          9        155          0          0          0          0
                   car         17        104          0          0          0          0
            motorcycle         15         74          0          0          0          0
              airplane         28         57          1     0.0702      0.535       0.43
                   bus         17         27          0          0          0          0
                 train         23         33          0          0          0          0
Speed: 1.3ms preprocess, 2.6ms inference, 0.0ms loss, 0.4ms postprocess per image
Saving runs/detect/val2/predictions.json...
Results saved to [1mruns/detect/val2[0m


AttributeError: 'DetMetrics' object has no attribute 'pred'. See valid attributes below.

    Utility class for computing detection metrics such as precision, recall, and mean average precision (mAP) of an
    object detection model.

    Args:
        save_dir (Path): A path to the directory where the output plots will be saved. Defaults to current directory.
        plot (bool): A flag that indicates whether to plot precision-recall curves for each class. Defaults to False.
        on_plot (func): An optional callback to pass plots path and data when they are rendered. Defaults to None.
        names (dict of str): A dict of strings that represents the names of the classes. Defaults to an empty tuple.

    Attributes:
        save_dir (Path): A path to the directory where the output plots will be saved.
        plot (bool): A flag that indicates whether to plot the precision-recall curves for each class.
        on_plot (func): An optional callback to pass plots path and data when they are rendered.
        names (dict of str): A dict of strings that represents the names of the classes.
        box (Metric): An instance of the Metric class for storing the results of the detection metrics.
        speed (dict): A dictionary for storing the execution time of different parts of the detection process.

    Methods:
        process(tp, conf, pred_cls, target_cls): Updates the metric results with the latest batch of predictions.
        keys: Returns a list of keys for accessing the computed detection metrics.
        mean_results: Returns a list of mean values for the computed detection metrics.
        class_result(i): Returns a list of values for the computed detection metrics for a specific class.
        maps: Returns a dictionary of mean average precision (mAP) values for different IoU thresholds.
        fitness: Computes the fitness score based on the computed detection metrics.
        ap_class_index: Returns a list of class indices sorted by their average precision (AP) values.
        results_dict: Returns a dictionary that maps detection metric keys to their computed values.
        curves: TODO
        curves_results: TODO
    

# Test Set