Skip to content

Commit

Permalink
add segformer
Browse files Browse the repository at this point in the history
  • Loading branch information
jimexist committed Jan 24, 2024
1 parent fd7c856 commit 0418879
Show file tree
Hide file tree
Showing 4 changed files with 671 additions and 0 deletions.
9 changes: 9 additions & 0 deletions candle-examples/examples/segformer/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# candle-segformer

- [HuggingFace Segformer Model Card][segformer]
- [`mit-b0` - An encoder only pretrained model][encoder]
- [`segformer-b0-finetuned-ade-512-512` - A fine tuned model for segmentation][ade512]

[segformer]: https://huggingface.co/docs/transformers/model_doc/segformer
[encoder]: https://huggingface.co/nvidia/mit-b0
[ade512]: https://huggingface.co/nvidia/segformer-b0-finetuned-ade-512-512
44 changes: 44 additions & 0 deletions candle-examples/examples/segformer/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
use std::path::PathBuf;

use candle::Module;
use candle_nn::VarBuilder;
use candle_transformers::models::segformer;
use clap::Parser;

#[derive(Parser)]
#[clap(about, version, long_about = None)]
struct Args {
#[arg(
long,
help = "name of the huggingface hub model",
default_value = "nvidia/segformer-b0-finetuned-ade-512-512"
)]
model_name: String,
#[arg(long, help = "path to image as input")]
image: PathBuf,
#[arg(long, help = "use cpu")]
cpu: bool,
}

pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?;
println!("loaded image {image:?}");
let api = hf_hub::api::sync::Api::new()?;
let api = api.model(args.model_name);
let model_file = api.get("model.safetensors")?;
println!("model downloaded and loaded");
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], candle::DType::F32, &device)? };
let config = Default::default();
let num_labels = 150;
let model = segformer::SemanticSegmentationModel::new(&config, num_labels, vb)?;
let input = image.unsqueeze(0)?;
let segmentations = model.forward(&input)?;
println!(
"segmentation result shape {:?} which should match [1, num_labels, height/4, width/4]",
segmentations.shape()
);
Ok(())
}
1 change: 1 addition & 0 deletions candle-transformers/src/models/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ pub mod quantized_stable_lm;
pub mod quantized_t5;
pub mod repvgg;
pub mod resnet;
pub mod segformer;
pub mod segment_anything;
pub mod stable_diffusion;
pub mod stable_lm;
Expand Down
Loading

0 comments on commit 0418879

Please sign in to comment.