# Training an object detector from FathomNet data using Detectron2
In this notebook, we'll:
1. Generate a dataset from FathomNet using the `fathomnet-generate` script
2. Split it into train/test sets
3. Train an object detection model (via Detectron2) on the train set
4. Evaluate the model on the test set
5. Visualize some predictions

## Setup
First, we need to install some relevant libraries and import them. We'll also define some constants that we'll use throughout the notebook.

In [None]:
%pip install fathomnet torch torchvision pillow 'git+https://github.com/facebookresearch/detectron2.git'

In [None]:
import gc
import random
from pathlib import Path

import numpy as np
from detectron2.model_zoo import get_config_file, get_checkpoint_url
from detectron2.data import DatasetCatalog, MetadataCatalog, build_detection_test_loader
from detectron2.data.datasets import register_coco_instances
from detectron2.config import get_cfg
from detectron2.engine import DefaultTrainer, DefaultPredictor
from detectron2.evaluation import COCOEvaluator, inference_on_dataset
from detectron2.utils.visualizer import Visualizer
from PIL import Image
from torch import cuda

In [None]:
CONCEPTS = [
    'Aegina citrea',
]

FULL_DATASET_NAME = 'demo_dataset'
TRAIN_DATASET_NAME = FULL_DATASET_NAME + '_train'
TEST_DATASET_NAME = FULL_DATASET_NAME + '_test'

DATASET_DIR = Path(FULL_DATASET_NAME)
IMAGE_DIR = DATASET_DIR / 'images'

# Stringify for use in shell commands
CONCEPTS_STR = "'" + ','.join(CONCEPTS) + "'"
DATASET_DIR_STR = "'" + str(DATASET_DIR) + "'"
IMAGE_DIR_STR = "'" + str(IMAGE_DIR) + "'"

## 1. Generate a dataset from FathomNet
This command will invoke the `fathomnet-generate` script to query for images and annotations from FathomNet. It will download the images and annotations (formatted as COCO JSON) into `DATASET_DIR` defined above.

In [None]:
!fathomnet-generate --format coco --concepts $CONCEPTS_STR --img-download $IMAGE_DIR_STR --output $DATASET_DIR_STR

## 2. Split the dataset into train/test sets
Detectron2 uses an internal registry of datasets. We'll load our original COCO dataset into the registry, and then split it into train/test sets.

In [None]:
# If dataset(s) are already registered, remove them so they can be re-registered
for name in (FULL_DATASET_NAME, TRAIN_DATASET_NAME, TEST_DATASET_NAME):
    if name in DatasetCatalog.list():
        DatasetCatalog.remove(name)
    if name in MetadataCatalog.list():
        MetadataCatalog.remove(name)

# Register the full dataset
register_coco_instances(FULL_DATASET_NAME, {}, str(DATASET_DIR / 'dataset.json'), str(IMAGE_DIR))

# Split the dataset into train and test sets
dataset_dicts = DatasetCatalog.get(FULL_DATASET_NAME)
split = int(len(dataset_dicts) * 0.8)
random.shuffle(dataset_dicts)  # Shuffle the dataset to ensure randomness
train_dicts = dataset_dicts[:split]
test_dicts = dataset_dicts[split:]

# Register the train and test sets
full_metadata = MetadataCatalog.get(FULL_DATASET_NAME)
DatasetCatalog.register(TRAIN_DATASET_NAME, lambda: train_dicts)
MetadataCatalog.get(TRAIN_DATASET_NAME).set(thing_classes=full_metadata.thing_classes)
DatasetCatalog.register(TEST_DATASET_NAME, lambda: test_dicts)
MetadataCatalog.get(TEST_DATASET_NAME).set(thing_classes=full_metadata.thing_classes)

train_metadata = MetadataCatalog.get(TRAIN_DATASET_NAME)
test_metadata = MetadataCatalog.get(TEST_DATASET_NAME)

print(f'Train dataset has {len(train_dicts)} images')
print(f'Test dataset has {len(test_dicts)} images')

## 3. Train an object detection model


### Configure
Now we'll set some configuration options for the training. We'll use a pre-trained Faster R-CNN model and train it for 200 iterations.

In [None]:
BACKBONE = 'COCO-Detection/retinanet_R_50_FPN_3x.yaml'                          # The base network to use

# Define the training configuration
cfg = get_cfg()
cfg.merge_from_file(get_config_file(BACKBONE))                                  # Retrieve the base network configuration
cfg.DATASETS.TRAIN = (TRAIN_DATASET_NAME,)                                      # Tell it what to use for training
cfg.DATASETS.TEST = (TEST_DATASET_NAME,)                                         # Define test data (we are leaving this blank for speed)
cfg.DATALOADER.NUM_WORKERS = 2                                                  # The number of threads to start for moving images to the GPU
cfg.MODEL.WEIGHTS = get_checkpoint_url(BACKBONE)                                # Retrieve the weights for the desired base network
cfg.SOLVER.IMS_PER_BATCH = 2                                                    # How many images to give the GPU at a time. Make this bigger if you have a more powerful card
cfg.SOLVER.BASE_LR = 0.0025                                                     # How much to move weights during backpropagation
cfg.SOLVER.MAX_ITER = 200                                                       # How many times to run the training images through the network
cfg.SOLVER.STEPS = []                                                           # When to change the learning rate. We are not making adjustements for this small training
cfg.MODEL.RETINANET.NUM_CLASSES = len(train_metadata.thing_classes)             # The number of output classes

# Create the output directory
Path(cfg.OUTPUT_DIR).mkdir(exist_ok=True, parents=True)

### Run training
Now we'll train the model. This can take a while, so we'll save the model to disk after training, so we can later load it and evaluate it.

In [None]:
# Clean up memory
gc.collect()
cuda.empty_cache()

# Spin up a trainer
trainer = DefaultTrainer(cfg)
trainer.resume_or_load(resume=False)

# Train! This may take a while, depending on the configuration
trainer.train()

## 4. Evaluate the model
Now that we have a trained object detector, let's evaluate it on the test set.

In [None]:
# Evaluate the model on the test set
evaluator = COCOEvaluator(TEST_DATASET_NAME, cfg, False, output_dir=cfg.OUTPUT_DIR)
val_loader = build_detection_test_loader(cfg, TEST_DATASET_NAME)
stats = inference_on_dataset(trainer.model, val_loader, evaluator)

## 5. Visualize some predictions
Finally, we'll visualize some predictions from the model. We'll load the model from disk and run it on a few images from the test set.

In [None]:
# Load the model weights into the configuration
cfg.MODEL.WEIGHTS = str(Path(cfg.OUTPUT_DIR) / 'model_final.pth')

# Set the confidence threshold (predictions with confidence < will be omitted)
cfg.MODEL.RETINANET.SCORE_THRESH_TEST = 0.5

# Spin up a predictor with the provided config
predictor = DefaultPredictor(cfg)

In [None]:
# Get paths to the test images
test_image_paths = [td['file_name'] for td in test_dicts]

# Pick random images from the test set and run them through the model for display
for image_path in random.sample(test_image_paths, 10):
  image_pil = Image.open(image_path)
  image = np.array(image_pil)

  # Run the image through the model
  outputs = predictor(image)

  # Render the predictions on the image
  visualizer = Visualizer(image, metadata=test_metadata, scale=0.5)
  output_image = visualizer.draw_instance_predictions(outputs['instances'].to('cpu'))
  output_image_pil = Image.fromarray(output_image.get_image())
  
  display(output_image_pil)