Skip to content

Pytorch reimplementation of Decoder Denoising Pretraining for Semantic Segmentation

License

Notifications You must be signed in to change notification settings

bwconrad/decoder-denoising

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

16 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Decoder Denoising Pretraining for Semantic Segmentation

PyTorch reimplementation of "Decoder Denoising Pretraining for Semantic Segmentation".

Requirements

  • Python 3.8+
  • pip install -r requirements

Usage

To perform decoder denoising pretraining on a U-Net with a ResNet-50 encoder run:

python train.py --gpus 1 --max_epochs 100 --data.root path/to/data/ --model.arch unet --model.encoder resnet50 
  • --model.arch can be one of unet, unetplusplus, manet, linknet, fpn, pspnet, deeplabv3, deeplabv3plus, pan.
  • --model.encoder can be any from the list here.
  • configs/ contains example configuration files which can be run with python train.py --config path/to/config.
  • Run python train.py --help to get descriptions for all the options.

Using a Pretrained Model

Model weights can be extracted from a pretraining checkpoint file by running:

python scripts/extract_model_weights.py -c path/to/checkpoint/file

You can then initialize a segmentation model with these weights with the following (example for U-Net with ResNet-50 encoder):

import segmentation_models_pytorch as smp
import torch
import torch.nn as nn

weights = torch.load("weights.pt")

model = smp.create_model(
    "unet",
    encoder_name="resnet50",
    in_channels=3,
    classes=3, # Same number used during pretraining for now
    encoder_weights=None,
)

model.load_state_dict(weights, strict=True)

# Replace segmentation head for fine-tuning
in_channels = model.segmentation_head[0].in_channels
num_classes = 10
model.segmentation_head[0] = nn.Conv2d(in_channels, num_classes, kernel_size=3, padding=1)

Citation

@inproceedings{brempong2022denoising,
  title={Denoising Pretraining for Semantic Segmentation},
  author={Brempong, Emmanuel Asiedu and Kornblith, Simon and Chen, Ting and Parmar, Niki and Minderer, Matthias and Norouzi, Mohammad},
  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
  pages={4175--4186},
  year={2022}
}

About

Pytorch reimplementation of Decoder Denoising Pretraining for Semantic Segmentation

Topics

Resources

License

Stars

Watchers

Forks

Languages