Skip to content

Commit

Permalink
add segformer
Browse files Browse the repository at this point in the history
  • Loading branch information
jimexist committed Feb 3, 2024
1 parent 9e824ec commit 513ae54
Show file tree
Hide file tree
Showing 4 changed files with 848 additions and 0 deletions.
18 changes: 18 additions & 0 deletions candle-examples/examples/segformer/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# candle-segformer

- [HuggingFace Segformer Model Card][segformer]
- [`mit-b0` - An encoder only pretrained model][encoder]
- [`segformer-b0-finetuned-ade-512-512` - A fine tuned model for segmentation][ade512]

[segformer]: https://huggingface.co/docs/transformers/model_doc/segformer
[encoder]: https://huggingface.co/nvidia/mit-b0
[ade512]: https://huggingface.co/nvidia/segformer-b0-finetuned-ade-512-512

## How to run

```bash
# run image classification task
cargo run --example segformer -- classify <image_file>
# run semantic segmentation task
cargo run --example segformer -- segment <image_file>
```
101 changes: 101 additions & 0 deletions candle-examples/examples/segformer/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
use candle::Device;
use candle::Module;
use candle_nn::VarBuilder;
use candle_transformers::models::segformer::{
Config, ImageClassificationModel, SemanticSegmentationModel,
};
use clap::{Args, Parser, Subcommand};
use std::path::PathBuf;

#[derive(Parser)]
#[clap(about, version, long_about = None)]
struct CliArgs {
#[arg(long, help = "use cpu")]
cpu: bool,
#[command(subcommand)]
command: Commands,
}
#[derive(Args, Debug)]
struct SegmentationArgs {
#[arg(
long,
help = "name of the huggingface hub model",
default_value = "nvidia/segformer-b0-finetuned-ade-512-512"
)]
model_name: String,
#[arg(help = "path to image as input")]
image: PathBuf,
}

#[derive(Args, Debug)]
struct ClassificationArgs {
#[arg(
long,
help = "name of the huggingface hub model",
default_value = "paolinox/segformer-finetuned-food101"
)]
model_name: String,
#[arg(help = "path to image as input")]
image: PathBuf,
}

#[derive(Subcommand, Debug)]
enum Commands {
Segment(SegmentationArgs),
Classify(ClassificationArgs),
}

fn get_vb_and_config(model_name: String, device: &Device) -> anyhow::Result<(VarBuilder, Config)> {
println!("loading model {} via huggingface hub", model_name);
let api = hf_hub::api::sync::Api::new()?;
let api = api.model(model_name.clone());
let model_file = api.get("model.safetensors")?;
println!("model {} downloaded and loaded", model_name);
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], candle::DType::F32, device)? };
let config = std::fs::read_to_string(api.get("config.json")?)?;
let config: Config = serde_json::from_str(&config)?;
println!("{:?}", config);
Ok((vb, config))
}

fn segmentation_task(args: SegmentationArgs, device: &Device) -> anyhow::Result<()> {
let image = candle_examples::imagenet::load_image224(args.image)?
.unsqueeze(0)?
.to_device(device)?;
let (vb, config) = get_vb_and_config(args.model_name, device)?;
let num_labels = 150;
let model = SemanticSegmentationModel::new(&config, num_labels, vb)?;
let segmentations = model.forward(&image)?;
println!(
"segmentation result shape {:?} which should match [1, num_labels, height/4, width/4]",
segmentations.shape()
);
// let image = segmentations.squeeze(0)?.i(2)?.to_vec2::<f32>()?;
// println!("image {:?}", image);
Ok(())
}

fn classification_task(args: ClassificationArgs, device: &Device) -> anyhow::Result<()> {
let image = candle_examples::imagenet::load_image224(args.image)?
.unsqueeze(0)?
.to_device(device)?;
let (vb, config) = get_vb_and_config(args.model_name, device)?;
let num_labels = 7;
let model = ImageClassificationModel::new(&config, num_labels, vb)?;
let classification = model.forward(&image)?;
let classification = candle_nn::ops::softmax_last_dim(&classification)?;
println!("classification {:?}", classification.to_vec2::<f32>()?);
Ok(())
}

pub fn main() -> anyhow::Result<()> {
let args = CliArgs::parse();
let device = candle_examples::device(args.cpu)?;
if let Commands::Segment(args) = args.command {
segmentation_task(args, &device)?
} else if let Commands::Classify(args) = args.command {
classification_task(args, &device)?
}
Ok(())
}
1 change: 1 addition & 0 deletions candle-transformers/src/models/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ pub mod quantized_stable_lm;
pub mod quantized_t5;
pub mod repvgg;
pub mod resnet;
pub mod segformer;
pub mod segment_anything;
pub mod stable_diffusion;
pub mod stable_lm;
Expand Down
Loading

0 comments on commit 513ae54

Please sign in to comment.