Skip to content

Commit

Permalink
add segformer
Browse files Browse the repository at this point in the history
  • Loading branch information
jimexist committed Mar 1, 2024
1 parent 6980774 commit cf33bf2
Show file tree
Hide file tree
Showing 5 changed files with 838 additions and 0 deletions.
23 changes: 23 additions & 0 deletions candle-examples/examples/segformer/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# candle-segformer

- [HuggingFace Segformer Model Card][segformer]
- [`mit-b0` - An encoder only pretrained model][encoder]
- [`segformer-b0-finetuned-ade-512-512` - A fine tuned model for segmentation][ade512]

[segformer]: https://huggingface.co/docs/transformers/model_doc/segformer
[encoder]: https://huggingface.co/nvidia/mit-b0
[ade512]: https://huggingface.co/nvidia/segformer-b0-finetuned-ade-512-512

## How to run

```bash
# run image classification task
cargo run --example segformer classify candle-examples/examples/segformer/assets/burger.jpg
```

Example output:

```text
classification logits [3.275261e-5, 0.0008562019, 0.0008868563, 0.9977506, 0.0002465068, 0.0002241473, 2.846596e-6]
label: hamburger
```
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
109 changes: 109 additions & 0 deletions candle-examples/examples/segformer/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
use candle::Device;
use candle::Module;
use candle_nn::VarBuilder;
use candle_transformers::models::segformer::{
Config, ImageClassificationModel, SemanticSegmentationModel,
};
use clap::{Args, Parser, Subcommand};
use std::path::PathBuf;

#[derive(Parser)]
#[clap(about, version, long_about = None)]
struct CliArgs {
#[arg(long, help = "use cpu")]
cpu: bool,
#[command(subcommand)]
command: Commands,
}
#[derive(Args, Debug)]
struct SegmentationArgs {
#[arg(
long,
help = "name of the huggingface hub model",
default_value = "nvidia/segformer-b0-finetuned-ade-512-512"
)]
model_name: String,
#[arg(help = "path to image as input")]
image: PathBuf,
}

#[derive(Args, Debug)]
struct ClassificationArgs {
#[arg(
long,
help = "name of the huggingface hub model",
default_value = "paolinox/segformer-finetuned-food101"
)]
model_name: String,
#[arg(help = "path to image as input")]
image: PathBuf,
}

#[derive(Subcommand, Debug)]
enum Commands {
Segment(SegmentationArgs),
Classify(ClassificationArgs),
}

fn get_vb_and_config(model_name: String, device: &Device) -> anyhow::Result<(VarBuilder, Config)> {
println!("loading model {} via huggingface hub", model_name);
let api = hf_hub::api::sync::Api::new()?;
let api = api.model(model_name.clone());
let model_file = api.get("model.safetensors")?;
println!("model {} downloaded and loaded", model_name);
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], candle::DType::F32, device)? };
let config = std::fs::read_to_string(api.get("config.json")?)?;
let config: Config = serde_json::from_str(&config)?;
println!("{:?}", config);
Ok((vb, config))
}

fn segmentation_task(args: SegmentationArgs, device: &Device) -> anyhow::Result<()> {
let image = candle_examples::imagenet::load_image224(args.image)?
.unsqueeze(0)?
.to_device(device)?;
let (vb, config) = get_vb_and_config(args.model_name, device)?;
let num_labels = 150;
let model = SemanticSegmentationModel::new(&config, num_labels, vb)?;
let segmentations = model.forward(&image)?;
println!(
"segmentation result shape {:?} which should match [1, num_labels, height/4, width/4]",
segmentations.shape()
);
let labels = segmentations.squeeze(0)?.argmax(0)?;
let labels = labels.to_vec2::<u32>()?;
println!("labels {:?}", labels);
Ok(())
}

fn classification_task(args: ClassificationArgs, device: &Device) -> anyhow::Result<()> {
let image = candle_examples::imagenet::load_image224(args.image)?
.unsqueeze(0)?
.to_device(device)?;
let (vb, config) = get_vb_and_config(args.model_name, device)?;
let num_labels = 7;
let model = ImageClassificationModel::new(&config, num_labels, vb)?;
let classification = model.forward(&image)?;
let classification = candle_nn::ops::softmax_last_dim(&classification)?;
let classification = classification.squeeze(0)?;
println!(
"classification logits {:?}",
classification.to_vec1::<f32>()?
);
let label_id = classification.argmax(0)?.to_scalar::<u32>()?;
let label_id = format!("{}", label_id);
println!("label: {}", config.id2label[&label_id]);
Ok(())
}

pub fn main() -> anyhow::Result<()> {
let args = CliArgs::parse();
let device = candle_examples::device(args.cpu)?;
if let Commands::Segment(args) = args.command {
segmentation_task(args, &device)?
} else if let Commands::Classify(args) = args.command {
classification_task(args, &device)?
}
Ok(())
}
1 change: 1 addition & 0 deletions candle-transformers/src/models/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ pub mod repvgg;
pub mod resnet;
pub mod rwkv_v5;
pub mod rwkv_v6;
pub mod segformer;
pub mod segment_anything;
pub mod stable_diffusion;
pub mod stable_lm;
Expand Down
Loading

0 comments on commit cf33bf2

Please sign in to comment.