Skip to content

Temporal WaSR-T model for maritime obstacle detection via semantic segmentation

License

Notifications You must be signed in to change notification settings

lojzezust/WaSR-T

Repository files navigation

WaSR-T: Temporal Context for Robust Maritime Obstacle Detection

arxiv weights data cite

This is the official PyTorch implementation of the WaSR-T network [1]. Contains scripts for training and running the network and weights pretrained on the MaSTr1325 [2] (and MaSTr1478) dataset.

Our work was presented at the IROS 2022 conference in Kyoto, Japan.

Comparison WaSR - WaSR-T Comparison between WaSR (single-frame) and WaSR-T (temporal context) on hard examples.

April 2023: A mobile adaptation of WaSR-T has been added.

About WaSR-T

WaSR-T is a temporal extension of the established WaSR model [3] for maritime obstacle detection. It harnesses the temporal context of recent image frames to reduce the ambiguity on reflections and improve the overall robustness of predictions.

WaSR-T architecture

The target and context (i.e. previous) frames are encoded using a shared encoder network. To extract the temporal context from the past frames, we apply a 3D convolution operation over the temporal dimension. The 3D convolution is able to extract discriminative information about local texture changes over the recent frames. The resulting temporal context is concatenated with the target frame features and passed to the decoder, which produces the final predictions.

Setup

Requirements: Python >= 3.6 (tested on Python 3.8), PyTorch 1.8.1, PyTorch Lightning 1.4.4 (for training)

The required Python libraries can be installed using the following pip command

pip install -r requirements.txt

Usage

The WaSR-T model used in our experiments (ResNet-101 backbone) can be initialized with the following code.

from wasr_t.wasr_t import wasr_temporal_resnet101

model = wasr_temporal_resnet101(num_classes=3)

WaSR-T model operates in two different modes:

  • sequential: An online mode, useful for inference. Only one frame is processed at a time. Features of the previous frames are stored in a circular buffer. The frames must be processed one after the other, thus the batch size must be 1. The context buffer is initialized from copies of the first frame in the sequence.
  • unrolled: An offline mode, used during training. Each sample consists of a target frame and the required previous frames. Supports batched processing.

You can switch between the two modes by calling sequential() or unrolled() on the model.

Example of sequential operation:

model = model.sequential()

model.clear_state() # Clear the temporal buffer of the model
for image in sequence:
    # image is a [1,3,H,W] tensor
    output = model({'image': image})

Note

If you run inference on multiple sequences you must call clear_state() on the model to clear the buffer before moving to a new sequence. Otherwise the context of the last frames of the previous sequence will be used, which may lead to faulty predictions.

Example of unrolled operation:

model = model.unrolled()

# images is a batch of images: [B,3,H,W] tensor, where B is the batch size
# hist_images is a batch of context images: [B,T,3,H,W], where T is the number of context frames used by the network (default 5)
output = model({'image': images, 'hist_images': hist_images})

Model inference

To run sequential WaSR-T inference on a sequence of image frames use the predict_sequential.py script.

# export CUDA_VISIBLE_DEVICES=-1 # CPU only
export CUDA_VISIBLE_DEVICES=0 # GPU to use
python predict_sequential.py \
--sequence-dir examples/sequence \
--weights path/to/model/weights.pth \
--output-dir output/predictions

The script will loop over the images in the --sequence-dir directory in alphabetical order. Predictions will be stored as color-coded masks to the specified output directory.

If you wish to run inference on a video file, first convert the file to a sequence of images. For example, using ffmpeg:

mkdir sequence_images
ffmpeg -i video.mp4 sequence_images/frame_%05d.jpg

Model weights

Currently available pretrained model weights. All models are evaluated on the MODS benchmark [4]. F1 scores overall and inside the danger zone are reported in the table.

model T training data Resolution F1 F1D weights
regular (RN101) 5 MaSTr1325 512 x 384 93.7 87.3 link
regular (RN101) 5 MaSTr1478 512 x 384 94.4 93.6 link

Model training

To train your own models, use the train.py script. For example, to reproduce the results of our experiments use the following steps:

  1. Download and prepare the MaSTr1325 dataset (images and GT masks). Also download the context frames for the MaSTr1325 images here.
  2. Edit the dataset configuration files (configs/mastr_1325_train.yaml, configs/mastr1325_val.yaml and configs/mastr153_all.yaml) so that they correctly point to the dataset directories.
  3. Use the train.py to train the network.
export CUDA_VISIBLE_DEVICES=0,1,2,3 # GPUs to use
python train.py \
--train-config configs/mastr1325_train.yaml \
--val-config configs/mastr1325_val.yaml \
--validation \
--model-name my_wasr \
--batch-size 2 \
--epochs 100

Note

Model training requires a large amount of GPU memory (>11 GB per GPU). If you use smaller GPUs, you can reduce the memory consumption by decreasing the number of backbone backpropagation steps (--backbone-grad-steps) or using a smaller context length (--hist-len).

Note

To reproduce training on MaSTr1478 use --additional-train-config configs/mastr153_all.yaml to specify the additional training examples.

Logging and model weights

A log dir with the specified model name will be created inside the output directory. Model checkpoints and training logs will be stored here. At the end of the training the model weights are also exported to a weights.pth file inside this directory.

Logged metrics (loss, validation accuracy, validation IoU) can be inspected using tensorboard.

tensorboard --logdir output/logs/model_name

Data

We extend the MaSTr1325 dataset by providing the context frames (5 preceding frames). We also extend the dataset with additional hard examples to form MaSTr1478.

  • MaSTr1325 context frames: link
  • MaSTr1478 extension data: link

Mobile WaSR-T

Contributed by @playertr

To enable the inference on devices with limited memory and compute resources, a light-weight, reduced-resolution version of WaSR-T has been trained. The mobile WaSR-T runs on the Jetson Nano embedded platform at around 13 FPS. Follow the installation instructions for a setup that has been tested on the 4GB original (pre-Orin) Jetson Nano developer kit.

To use or train the mobile version of WaSR-T use the --mobile and --size 256 192 arguments in the training and inference scripts. For example to run the inference using the provided mobile weights and the predict_sequential.py script use the following.

python predict_sequential.py \
--sequence-dir examples/sequence \
--weights path/to/weights.pth \
--output-dir output/predictions \
--mobile \
--size 256 192

We also provide an example script for inference using a gstreamer pipeline. By modifying the gstreamer pipeline the live segmentation results can be sent to other destinations for processing.

python predict_gstreamer.py --weights path/to/weights.pth --fp16 --mobile --size 256 192

Mobile weights

Pre-trained model weights for the mobile version of WaSR-T. Performance is reported on the MODS dataset.

model T training data Resolution F1 F1D weights
mobile 5 MaSTr1325 256 x 192 84.4 70.3 link
mobile 5 MaSTr1478 256 x 192 82.2 69.7 link

Citation

If you use this code, please cite our paper:

@InProceedings{Zust2022Temporal,
  title={Temporal Context for Robust Maritime Obstacle Detection},
  author={{\v{Z}}ust, Lojze and Kristan, Matej},
  booktitle={2022 IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS)},
  year={2022}
}

References

[1] Žust, L., & Kristan, M. (2022). Temporal Context for Robust Maritime Obstacle Detection. 2022 IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS)

[2] Bovcon, B., Muhovič, J., Perš, J., & Kristan, M. (2019). The MaSTr1325 dataset for training deep USV obstacle detection models. 2019 IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS)

[3] Bovcon, B., & Kristan, M. (2021). WaSR--A Water Segmentation and Refinement Maritime Obstacle Detection Network. IEEE Transactions on Cybernetics

[4] Bovcon, B., Muhovič, J., Vranac, D., Mozetič, D., Perš, J., & Kristan, M. (2021). MODS -- A USV-oriented object detection and obstacle segmentation benchmark.