diff --git a/candle-examples/examples/segformer/README.md b/candle-examples/examples/segformer/README.md new file mode 100644 index 000000000..3ea503ee2 --- /dev/null +++ b/candle-examples/examples/segformer/README.md @@ -0,0 +1,28 @@ +# candle-segformer + +- [HuggingFace Segformer Model Card][segformer] +- [`mit-b0` - An encoder only pretrained model][encoder] +- [`segformer-b0-finetuned-ade-512-512` - A fine tuned model for segmentation][ade512] + +## How to run the example + +If you want you can use the example images from this [pull request][pr], download them and supply the path to the image as an argument to the example. + +```bash +# run the image classification task +cargo run --example segformer classify +# run the segmentation task +cargo run --example segformer segment +``` + +Example output for classification: + +```text +classification logits [3.275261e-5, 0.0008562019, 0.0008868563, 0.9977506, 0.0002465068, 0.0002241473, 2.846596e-6] +label: hamburger +``` + +[pr]: https://github.com/huggingface/candle/pull/1617 +[segformer]: https://huggingface.co/docs/transformers/model_doc/segformer +[encoder]: https://huggingface.co/nvidia/mit-b0 +[ade512]: https://huggingface.co/nvidia/segformer-b0-finetuned-ade-512-512 diff --git a/candle-examples/examples/segformer/assets/labels.json b/candle-examples/examples/segformer/assets/labels.json new file mode 100644 index 000000000..645d21526 --- /dev/null +++ b/candle-examples/examples/segformer/assets/labels.json @@ -0,0 +1,752 @@ +[ + { + "index": 1, + "color": "#787878", + "label": "wall" + }, + { + "index": 2, + "color": "#B47878", + "label": "building;edifice" + }, + { + "index": 3, + "color": "#06E6E6", + "label": "sky" + }, + { + "index": 4, + "color": "#503232", + "label": "floor;flooring" + }, + { + "index": 5, + "color": "#04C803", + "label": "tree" + }, + { + "index": 6, + "color": "#787850", + "label": "ceiling" + }, + { + "index": 7, + "color": "#8C8C8C", + "label": "road;route" + }, + { + "index": 8, + "color": "#CC05FF", + "label": "bed" + }, + { + "index": 9, + "color": "#E6E6E6", + "label": "windowpane;window" + }, + { + "index": 10, + "color": "#04FA07", + "label": "grass" + }, + { + "index": 11, + "color": "#E005FF", + "label": "cabinet" + }, + { + "index": 12, + "color": "#EBFF07", + "label": "sidewalk;pavement" + }, + { + "index": 13, + "color": "#96053D", + "label": "person;individual;someone;somebody;mortal;soul" + }, + { + "index": 14, + "color": "#787846", + "label": "earth;ground" + }, + { + "index": 15, + "color": "#08FF33", + "label": "door;double;door" + }, + { + "index": 16, + "color": "#FF0652", + "label": "table" + }, + { + "index": 17, + "color": "#8FFF8C", + "label": "mountain;mount" + }, + { + "index": 18, + "color": "#CCFF04", + "label": "plant;flora;plant;life" + }, + { + "index": 19, + "color": "#FF3307", + "label": "curtain;drape;drapery;mantle;pall" + }, + { + "index": 20, + "color": "#CC4603", + "label": "chair" + }, + { + "index": 21, + "color": "#0066C8", + "label": "car;auto;automobile;machine;motorcar" + }, + { + "index": 22, + "color": "#3DE6FA", + "label": "water" + }, + { + "index": 23, + "color": "#FF0633", + "label": "painting;picture" + }, + { + "index": 24, + "color": "#0B66FF", + "label": "sofa;couch;lounge" + }, + { + "index": 25, + "color": "#FF0747", + "label": "shelf" + }, + { + "index": 26, + "color": "#FF09E0", + "label": "house" + }, + { + "index": 27, + "color": "#0907E6", + "label": "sea" + }, + { + "index": 28, + "color": "#DCDCDC", + "label": "mirror" + }, + { + "index": 29, + "color": "#FF095C", + "label": "rug;carpet;carpeting" + }, + { + "index": 30, + "color": "#7009FF", + "label": "field" + }, + { + "index": 31, + "color": "#08FFD6", + "label": "armchair" + }, + { + "index": 32, + "color": "#07FFE0", + "label": "seat" + }, + { + "index": 33, + "color": "#FFB806", + "label": "fence;fencing" + }, + { + "index": 34, + "color": "#0AFF47", + "label": "desk" + }, + { + "index": 35, + "color": "#FF290A", + "label": "rock;stone" + }, + { + "index": 36, + "color": "#07FFFF", + "label": "wardrobe;closet;press" + }, + { + "index": 37, + "color": "#E0FF08", + "label": "lamp" + }, + { + "index": 38, + "color": "#6608FF", + "label": "bathtub;bathing;tub;bath;tub" + }, + { + "index": 39, + "color": "#FF3D06", + "label": "railing;rail" + }, + { + "index": 40, + "color": "#FFC207", + "label": "cushion" + }, + { + "index": 41, + "color": "#FF7A08", + "label": "base;pedestal;stand" + }, + { + "index": 42, + "color": "#00FF14", + "label": "box" + }, + { + "index": 43, + "color": "#FF0829", + "label": "column;pillar" + }, + { + "index": 44, + "color": "#FF0599", + "label": "signboard;sign" + }, + { + "index": 45, + "color": "#0633FF", + "label": "chest;of;drawers;chest;bureau;dresser" + }, + { + "index": 46, + "color": "#EB0CFF", + "label": "counter" + }, + { + "index": 47, + "color": "#A09614", + "label": "sand" + }, + { + "index": 48, + "color": "#00A3FF", + "label": "sink" + }, + { + "index": 49, + "color": "#8C8C8C", + "label": "skyscraper" + }, + { + "index": 50, + "color": "#FA0A0F", + "label": "fireplace;hearth;open;fireplace" + }, + { + "index": 51, + "color": "#14FF00", + "label": "refrigerator;icebox" + }, + { + "index": 52, + "color": "#1FFF00", + "label": "grandstand;covered;stand" + }, + { + "index": 53, + "color": "#FF1F00", + "label": "path" + }, + { + "index": 54, + "color": "#FFE000", + "label": "stairs;steps" + }, + { + "index": 55, + "color": "#99FF00", + "label": "runway" + }, + { + "index": 56, + "color": "#0000FF", + "label": "case;display;case;showcase;vitrine" + }, + { + "index": 57, + "color": "#FF4700", + "label": "pool;table;billiard;table;snooker;table" + }, + { + "index": 58, + "color": "#00EBFF", + "label": "pillow" + }, + { + "index": 59, + "color": "#00ADFF", + "label": "screen;door;screen" + }, + { + "index": 60, + "color": "#1F00FF", + "label": "stairway;staircase" + }, + { + "index": 61, + "color": "#0BC8C8", + "label": "river" + }, + { + "index": 62, + "color": "#FF5200", + "label": "bridge;span" + }, + { + "index": 63, + "color": "#00FFF5", + "label": "bookcase" + }, + { + "index": 64, + "color": "#003DFF", + "label": "blind;screen" + }, + { + "index": 65, + "color": "#00FF70", + "label": "coffee;table;cocktail;table" + }, + { + "index": 66, + "color": "#00FF85", + "label": "toilet;can;commode;crapper;pot;potty;stool;throne" + }, + { + "index": 67, + "color": "#FF0000", + "label": "flower" + }, + { + "index": 68, + "color": "#FFA300", + "label": "book" + }, + { + "index": 69, + "color": "#FF6600", + "label": "hill" + }, + { + "index": 70, + "color": "#C2FF00", + "label": "bench" + }, + { + "index": 71, + "color": "#008FFF", + "label": "countertop" + }, + { + "index": 72, + "color": "#33FF00", + "label": "stove;kitchen;stove;range;kitchen;range;cooking;stove" + }, + { + "index": 73, + "color": "#0052FF", + "label": "palm;palm;tree" + }, + { + "index": 74, + "color": "#00FF29", + "label": "kitchen;island" + }, + { + "index": 75, + "color": "#00FFAD", + "label": "computer;computing;machine;computing;device;data;processor;electronic;computer;information;processing;system" + }, + { + "index": 76, + "color": "#0A00FF", + "label": "swivel;chair" + }, + { + "index": 77, + "color": "#ADFF00", + "label": "boat" + }, + { + "index": 78, + "color": "#00FF99", + "label": "bar" + }, + { + "index": 79, + "color": "#FF5C00", + "label": "arcade;machine" + }, + { + "index": 80, + "color": "#FF00FF", + "label": "hovel;hut;hutch;shack;shanty" + }, + { + "index": 81, + "color": "#FF00F5", + "label": "bus;autobus;coach;charabanc;double-decker;jitney;motorbus;motorcoach;omnibus;passenger;vehicle" + }, + { + "index": 82, + "color": "#FF0066", + "label": "towel" + }, + { + "index": 83, + "color": "#FFAD00", + "label": "light;light;source" + }, + { + "index": 84, + "color": "#FF0014", + "label": "truck;motortruck" + }, + { + "index": 85, + "color": "#FFB8B8", + "label": "tower" + }, + { + "index": 86, + "color": "#001FFF", + "label": "chandelier;pendant;pendent" + }, + { + "index": 87, + "color": "#00FF3D", + "label": "awning;sunshade;sunblind" + }, + { + "index": 88, + "color": "#0047FF", + "label": "streetlight;street;lamp" + }, + { + "index": 89, + "color": "#FF00CC", + "label": "booth;cubicle;stall;kiosk" + }, + { + "index": 90, + "color": "#00FFC2", + "label": "television;television;receiver;television;set;tv;tv;set;idiot;box;boob;tube;telly;goggle;box" + }, + { + "index": 91, + "color": "#00FF52", + "label": "airplane;aeroplane;plane" + }, + { + "index": 92, + "color": "#000AFF", + "label": "dirt;track" + }, + { + "index": 93, + "color": "#0070FF", + "label": "apparel;wearing;apparel;dress;clothes" + }, + { + "index": 94, + "color": "#3300FF", + "label": "pole" + }, + { + "index": 95, + "color": "#00C2FF", + "label": "land;ground;soil" + }, + { + "index": 96, + "color": "#007AFF", + "label": "bannister;banister;balustrade;balusters;handrail" + }, + { + "index": 97, + "color": "#00FFA3", + "label": "escalator;moving;staircase;moving;stairway" + }, + { + "index": 98, + "color": "#FF9900", + "label": "ottoman;pouf;pouffe;puff;hassock" + }, + { + "index": 99, + "color": "#00FF0A", + "label": "bottle" + }, + { + "index": 100, + "color": "#FF7000", + "label": "buffet;counter;sideboard" + }, + { + "index": 101, + "color": "#8FFF00", + "label": "poster;posting;placard;notice;bill;card" + }, + { + "index": 102, + "color": "#5200FF", + "label": "stage" + }, + { + "index": 103, + "color": "#A3FF00", + "label": "van" + }, + { + "index": 104, + "color": "#FFEB00", + "label": "ship" + }, + { + "index": 105, + "color": "#08B8AA", + "label": "fountain" + }, + { + "index": 106, + "color": "#8500FF", + "label": "conveyer;belt;conveyor;belt;conveyer;conveyor;transporter" + }, + { + "index": 107, + "color": "#00FF5C", + "label": "canopy" + }, + { + "index": 108, + "color": "#B800FF", + "label": "washer;automatic;washer;washing;machine" + }, + { + "index": 109, + "color": "#FF001F", + "label": "plaything;toy" + }, + { + "index": 110, + "color": "#00B8FF", + "label": "swimming;pool;swimming;bath;natatorium" + }, + { + "index": 111, + "color": "#00D6FF", + "label": "stool" + }, + { + "index": 112, + "color": "#FF0070", + "label": "barrel;cask" + }, + { + "index": 113, + "color": "#5CFF00", + "label": "basket;handbasket" + }, + { + "index": 114, + "color": "#00E0FF", + "label": "waterfall;falls" + }, + { + "index": 115, + "color": "#70E0FF", + "label": "tent;collapsible;shelter" + }, + { + "index": 116, + "color": "#46B8A0", + "label": "bag" + }, + { + "index": 117, + "color": "#A300FF", + "label": "minibike;motorbike" + }, + { + "index": 118, + "color": "#9900FF", + "label": "cradle" + }, + { + "index": 119, + "color": "#47FF00", + "label": "oven" + }, + { + "index": 120, + "color": "#FF00A3", + "label": "ball" + }, + { + "index": 121, + "color": "#FFCC00", + "label": "food;solid;food" + }, + { + "index": 122, + "color": "#FF008F", + "label": "step;stair" + }, + { + "index": 123, + "color": "#00FFEB", + "label": "tank;storage;tank" + }, + { + "index": 124, + "color": "#85FF00", + "label": "trade;name;brand;name;brand;marque" + }, + { + "index": 125, + "color": "#FF00EB", + "label": "microwave;microwave;oven" + }, + { + "index": 126, + "color": "#F500FF", + "label": "pot;flowerpot" + }, + { + "index": 127, + "color": "#FF007A", + "label": "animal;animate;being;beast;brute;creature;fauna" + }, + { + "index": 128, + "color": "#FFF500", + "label": "bicycle;bike;wheel;cycle" + }, + { + "index": 129, + "color": "#0ABED4", + "label": "lake" + }, + { + "index": 130, + "color": "#D6FF00", + "label": "dishwasher;dish;washer;dishwashing;machine" + }, + { + "index": 131, + "color": "#00CCFF", + "label": "screen;silver;screen;projection;screen" + }, + { + "index": 132, + "color": "#1400FF", + "label": "blanket;cover" + }, + { + "index": 133, + "color": "#FFFF00", + "label": "sculpture" + }, + { + "index": 134, + "color": "#0099FF", + "label": "hood;exhaust;hood" + }, + { + "index": 135, + "color": "#0029FF", + "label": "sconce" + }, + { + "index": 136, + "color": "#00FFCC", + "label": "vase" + }, + { + "index": 137, + "color": "#2900FF", + "label": "traffic;light;traffic;signal;stoplight" + }, + { + "index": 138, + "color": "#29FF00", + "label": "tray" + }, + { + "index": 139, + "color": "#AD00FF", + "label": "ashcan;trash;can;garbage;can;wastebin;ash;bin;ash-bin;ashbin;dustbin;trash;barrel;trash;bin" + }, + { + "index": 140, + "color": "#00F5FF", + "label": "fan" + }, + { + "index": 141, + "color": "#4700FF", + "label": "pier;wharf;wharfage;dock" + }, + { + "index": 142, + "color": "#7A00FF", + "label": "crt;screen" + }, + { + "index": 143, + "color": "#00FFB8", + "label": "plate" + }, + { + "index": 144, + "color": "#005CFF", + "label": "monitor;monitoring;device" + }, + { + "index": 145, + "color": "#B8FF00", + "label": "bulletin;board;notice;board" + }, + { + "index": 146, + "color": "#0085FF", + "label": "shower" + }, + { + "index": 147, + "color": "#FFD600", + "label": "radiator" + }, + { + "index": 148, + "color": "#19C2C2", + "label": "glass;drinking;glass" + }, + { + "index": 149, + "color": "#66FF00", + "label": "clock" + }, + { + "index": 150, + "color": "#5C00FF", + "label": "flag" + } +] diff --git a/candle-examples/examples/segformer/main.rs b/candle-examples/examples/segformer/main.rs new file mode 100644 index 000000000..76c9f30e3 --- /dev/null +++ b/candle-examples/examples/segformer/main.rs @@ -0,0 +1,155 @@ +use candle::Device; +use candle::Module; +use candle_nn::VarBuilder; +use candle_transformers::models::segformer::{ + Config, ImageClassificationModel, SemanticSegmentationModel, +}; +use clap::{Args, Parser, Subcommand}; +use image::Rgb; +use imageproc::integral_image::ArrayData; +use std::collections::HashMap; +use std::path::PathBuf; + +#[derive(Parser)] +#[clap(about, version, long_about = None)] +struct CliArgs { + #[arg(long, help = "use cpu")] + cpu: bool, + #[command(subcommand)] + command: Commands, +} +#[derive(Args, Debug)] +struct SegmentationArgs { + #[arg( + long, + help = "name of the huggingface hub model", + default_value = "nvidia/segformer-b0-finetuned-ade-512-512" + )] + model_name: String, + #[arg( + long, + help = "path to the label file in json format", + default_value = "candle-examples/examples/segformer/assets/labels.json" + )] + label_path: PathBuf, + #[arg(long, help = "path to for the output mask image")] + output_path: PathBuf, + #[arg(help = "path to image as input")] + image: PathBuf, +} + +#[derive(Args, Debug)] +struct ClassificationArgs { + #[arg( + long, + help = "name of the huggingface hub model", + default_value = "paolinox/segformer-finetuned-food101" + )] + model_name: String, + #[arg(help = "path to image as input")] + image: PathBuf, +} + +#[derive(Subcommand, Debug)] +enum Commands { + Segment(SegmentationArgs), + Classify(ClassificationArgs), +} + +fn get_vb_and_config(model_name: String, device: &Device) -> anyhow::Result<(VarBuilder, Config)> { + println!("loading model {} via huggingface hub", model_name); + let api = hf_hub::api::sync::Api::new()?; + let api = api.model(model_name.clone()); + let model_file = api.get("model.safetensors")?; + println!("model {} downloaded and loaded", model_name); + let vb = + unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], candle::DType::F32, device)? }; + let config = std::fs::read_to_string(api.get("config.json")?)?; + let config: Config = serde_json::from_str(&config)?; + println!("{:?}", config); + Ok((vb, config)) +} + +#[derive(Debug, serde::Deserialize)] +struct LabelItem { + index: u32, + color: String, +} + +fn segmentation_task(args: SegmentationArgs, device: &Device) -> anyhow::Result<()> { + let label_file = std::fs::read_to_string(&args.label_path)?; + let label_items: Vec = serde_json::from_str(&label_file)?; + let label_colors: HashMap> = label_items + .iter() + .map(|x| { + (x.index - 1, { + let color = x.color.trim_start_matches('#'); + let r = u8::from_str_radix(&color[0..2], 16).unwrap(); + let g = u8::from_str_radix(&color[2..4], 16).unwrap(); + let b = u8::from_str_radix(&color[4..6], 16).unwrap(); + Rgb([r, g, b]) + }) + }) + .collect(); + + let image = candle_examples::imagenet::load_image224(args.image)? + .unsqueeze(0)? + .to_device(device)?; + let (vb, config) = get_vb_and_config(args.model_name, device)?; + let num_labels = label_items.len(); + + let model = SemanticSegmentationModel::new(&config, num_labels, vb)?; + let segmentations = model.forward(&image)?; + + // generate a mask image + let mask = &segmentations.squeeze(0)?.argmax(0)?; + let (h, w) = mask.dims2()?; + let mask = mask.flatten_all()?.to_vec1::()?; + let mask = mask + .iter() + .flat_map(|x| label_colors[x].data()) + .collect::>(); + let mask: image::ImageBuffer, Vec> = + image::ImageBuffer::from_raw(w as u32, h as u32, mask).unwrap(); + // resize + let mask = image::DynamicImage::from(mask); + let mask = mask.resize_to_fill( + w as u32 * 4, + h as u32 * 4, + image::imageops::FilterType::CatmullRom, + ); + mask.save(args.output_path.clone())?; + println!("mask image saved to {:?}", args.output_path); + Ok(()) +} + +fn classification_task(args: ClassificationArgs, device: &Device) -> anyhow::Result<()> { + let image = candle_examples::imagenet::load_image224(args.image)? + .unsqueeze(0)? + .to_device(device)?; + let (vb, config) = get_vb_and_config(args.model_name, device)?; + let num_labels = 7; + let model = ImageClassificationModel::new(&config, num_labels, vb)?; + let classification = model.forward(&image)?; + let classification = candle_nn::ops::softmax_last_dim(&classification)?; + let classification = classification.squeeze(0)?; + println!( + "classification logits {:?}", + classification.to_vec1::()? + ); + let label_id = classification.argmax(0)?.to_scalar::()?; + let label_id = format!("{}", label_id); + println!("label: {}", config.id2label[&label_id]); + Ok(()) +} + +pub fn main() -> anyhow::Result<()> { + let args = CliArgs::parse(); + let device = candle_examples::device(args.cpu)?; + if let Commands::Segment(args) = args.command { + segmentation_task(args, &device)? + } else if let Commands::Classify(args) = args.command { + classification_task(args, &device)? + } + Ok(()) +} diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index a5f03059b..6833bab0f 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -41,6 +41,7 @@ pub mod repvgg; pub mod resnet; pub mod rwkv_v5; pub mod rwkv_v6; +pub mod segformer; pub mod segment_anything; pub mod stable_diffusion; pub mod stable_lm; diff --git a/candle-transformers/src/models/segformer.rs b/candle-transformers/src/models/segformer.rs new file mode 100644 index 000000000..3727e0042 --- /dev/null +++ b/candle-transformers/src/models/segformer.rs @@ -0,0 +1,705 @@ +use crate::models::with_tracing::{conv2d, linear, Conv2d, Linear}; +use candle::{Module, ModuleT, Result, Tensor, D}; +use candle_nn::{conv2d_no_bias, layer_norm, Activation, Conv2dConfig, VarBuilder}; +use serde::Deserialize; +use std::collections::HashMap; + +// https://github.com/huggingface/transformers/blob/main/src/transformers/models/segformer/configuration_segformer.py +#[derive(Debug, Clone, PartialEq, Deserialize)] +pub struct Config { + #[serde(default)] + pub id2label: HashMap, + pub num_channels: usize, + pub num_encoder_blocks: usize, + pub depths: Vec, + pub sr_ratios: Vec, + pub hidden_sizes: Vec, + pub patch_sizes: Vec, + pub strides: Vec, + pub num_attention_heads: Vec, + pub mlp_ratios: Vec, + pub hidden_act: candle_nn::Activation, + pub layer_norm_eps: f64, + pub decoder_hidden_size: usize, +} + +#[derive(Debug, Clone)] +struct SegformerOverlapPatchEmbeddings { + projection: Conv2d, + layer_norm: candle_nn::LayerNorm, +} + +impl SegformerOverlapPatchEmbeddings { + fn new( + config: &Config, + patch_size: usize, + stride: usize, + num_channels: usize, + hidden_size: usize, + vb: VarBuilder, + ) -> Result { + let projection = conv2d( + num_channels, + hidden_size, + patch_size, + Conv2dConfig { + stride, + padding: patch_size / 2, + ..Default::default() + }, + vb.pp("proj"), + )?; + let layer_norm = + candle_nn::layer_norm(hidden_size, config.layer_norm_eps, vb.pp("layer_norm"))?; + Ok(Self { + projection, + layer_norm, + }) + } +} + +impl Module for SegformerOverlapPatchEmbeddings { + fn forward(&self, x: &Tensor) -> Result { + let embeddings = self.projection.forward(x)?; + let shape = embeddings.shape(); + // [B, C, H, W] -> [B, H * W, C] + let embeddings = embeddings.flatten_from(2)?.transpose(1, 2)?; + let embeddings = self.layer_norm.forward(&embeddings)?; + // [B, H * W, C] -> [B, C, H, W] + let embeddings = embeddings.transpose(1, 2)?.reshape(shape)?; + Ok(embeddings) + } +} + +#[derive(Debug, Clone)] +struct SegformerEfficientSelfAttention { + num_attention_heads: usize, + attention_head_size: usize, + query: Linear, + key: Linear, + value: Linear, + sr: Option, + layer_norm: Option, +} + +impl SegformerEfficientSelfAttention { + fn new( + config: &Config, + hidden_size: usize, + num_attention_heads: usize, + sequence_reduction_ratio: usize, + vb: VarBuilder, + ) -> Result { + if hidden_size % num_attention_heads != 0 { + candle::bail!( + "The hidden size {} is not a multiple of the number of attention heads {}", + hidden_size, + num_attention_heads + ) + } + let attention_head_size = hidden_size / num_attention_heads; + let all_head_size = num_attention_heads * attention_head_size; + let query = linear(hidden_size, all_head_size, vb.pp("query"))?; + let key = linear(hidden_size, all_head_size, vb.pp("key"))?; + let value = linear(hidden_size, all_head_size, vb.pp("value"))?; + let (sr, layer_norm) = if sequence_reduction_ratio > 1 { + ( + Some(conv2d( + hidden_size, + hidden_size, + sequence_reduction_ratio, + Conv2dConfig { + stride: sequence_reduction_ratio, + ..Default::default() + }, + vb.pp("sr"), + )?), + Some(candle_nn::layer_norm( + hidden_size, + config.layer_norm_eps, + vb.pp("layer_norm"), + )?), + ) + } else { + (None, None) + }; + Ok(Self { + num_attention_heads, + attention_head_size, + query, + key, + value, + sr, + layer_norm, + }) + } + + fn transpose_for_scores(&self, hidden_states: Tensor) -> Result { + let (batch, seq_length, _) = hidden_states.shape().dims3()?; + let new_shape = &[ + batch, + seq_length, + self.num_attention_heads, + self.attention_head_size, + ]; + let hidden_states = hidden_states.reshape(new_shape)?; + let hidden_states = hidden_states.permute((0, 2, 1, 3))?; + Ok(hidden_states) + } +} + +impl Module for SegformerEfficientSelfAttention { + fn forward(&self, x: &Tensor) -> Result { + // [B, C, H, W] -> [B, H * W, C] + let hidden_states = x.flatten_from(2)?.permute((0, 2, 1))?; + let query = self + .transpose_for_scores(self.query.forward(&hidden_states)?)? + .contiguous()?; + let hidden_states = if let (Some(sr), Some(layer_norm)) = (&self.sr, &self.layer_norm) { + let hidden_states = sr.forward(x)?; + // [B, C, H, W] -> [B, H * W, C] + let hidden_states = hidden_states.flatten_from(2)?.permute((0, 2, 1))?; + layer_norm.forward(&hidden_states)? + } else { + // already [B, H * W, C] + hidden_states + }; + // standard self-attention + let key = self + .transpose_for_scores(self.key.forward(&hidden_states)?)? + .contiguous()?; + let value = self + .transpose_for_scores(self.value.forward(&hidden_states)?)? + .contiguous()?; + let attention_scores = + (query.matmul(&key.t()?)? / f64::sqrt(self.attention_head_size as f64))?; + let attention_scores = candle_nn::ops::softmax_last_dim(&attention_scores)?; + let result = attention_scores.matmul(&value)?; + let result = result.permute((0, 2, 1, 3))?.contiguous()?; + result.flatten_from(D::Minus2) + } +} + +#[derive(Debug, Clone)] +struct SegformerSelfOutput { + dense: Linear, +} + +impl SegformerSelfOutput { + fn new(hidden_size: usize, vb: VarBuilder) -> Result { + let dense = linear(hidden_size, hidden_size, vb.pp("dense"))?; + Ok(Self { dense }) + } +} + +impl Module for SegformerSelfOutput { + fn forward(&self, x: &Tensor) -> Result { + self.dense.forward(x) + } +} + +#[derive(Debug, Clone)] +struct SegformerAttention { + attention: SegformerEfficientSelfAttention, + output: SegformerSelfOutput, +} + +impl SegformerAttention { + fn new( + config: &Config, + hidden_size: usize, + num_attention_heads: usize, + sequence_reduction_ratio: usize, + vb: VarBuilder, + ) -> Result { + let attention = SegformerEfficientSelfAttention::new( + config, + hidden_size, + num_attention_heads, + sequence_reduction_ratio, + vb.pp("self"), + )?; + let output = SegformerSelfOutput::new(hidden_size, vb.pp("output"))?; + Ok(Self { attention, output }) + } +} + +impl Module for SegformerAttention { + fn forward(&self, x: &Tensor) -> Result { + let attention_output = self.attention.forward(x)?; + self.output.forward(&attention_output) + } +} + +#[derive(Debug, Clone)] +struct SegformerDWConv { + dw_conv: Conv2d, +} + +impl SegformerDWConv { + fn new(dim: usize, vb: VarBuilder) -> Result { + let dw_conv = conv2d( + dim, + dim, + 3, + Conv2dConfig { + stride: 1, + padding: 1, + groups: dim, + ..Default::default() + }, + vb.pp("dwconv"), + )?; + Ok(Self { dw_conv }) + } +} + +impl Module for SegformerDWConv { + fn forward(&self, x: &Tensor) -> Result { + self.dw_conv.forward(x) + } +} + +#[derive(Debug, Clone)] +struct SegformerMixFFN { + dense1: Linear, + dw_conv: SegformerDWConv, + act: Activation, + dense2: Linear, +} + +impl SegformerMixFFN { + fn new( + config: &Config, + in_features: usize, + hidden_features: usize, + out_features: usize, + vb: VarBuilder, + ) -> Result { + let dense1 = linear(in_features, hidden_features, vb.pp("dense1"))?; + let dw_conv = SegformerDWConv::new(hidden_features, vb.pp("dwconv"))?; + let act = config.hidden_act; + let dense2 = linear(hidden_features, out_features, vb.pp("dense2"))?; + Ok(Self { + dense1, + dw_conv, + act, + dense2, + }) + } +} + +impl Module for SegformerMixFFN { + fn forward(&self, x: &Tensor) -> Result { + let (batch, _, height, width) = x.shape().dims4()?; + let hidden_states = self + .dense1 + .forward(&x.flatten_from(2)?.permute((0, 2, 1))?)?; + let channels = hidden_states.dim(2)?; + let hidden_states = self.dw_conv.forward( + &hidden_states + .permute((0, 2, 1))? + .reshape((batch, channels, height, width))?, + )?; + let hidden_states = self.act.forward(&hidden_states)?; + let hidden_states = self + .dense2 + .forward(&hidden_states.flatten_from(2)?.permute((0, 2, 1))?)?; + let channels = hidden_states.dim(2)?; + hidden_states + .permute((0, 2, 1))? + .reshape((batch, channels, height, width)) + } +} + +#[derive(Debug, Clone)] +struct SegformerLayer { + layer_norm_1: candle_nn::LayerNorm, + attention: SegformerAttention, + layer_norm_2: candle_nn::LayerNorm, + mlp: SegformerMixFFN, +} + +impl SegformerLayer { + fn new( + config: &Config, + hidden_size: usize, + num_attention_heads: usize, + sequence_reduction_ratio: usize, + mlp_ratio: usize, + vb: VarBuilder, + ) -> Result { + let layer_norm_1 = layer_norm(hidden_size, config.layer_norm_eps, vb.pp("layer_norm_1"))?; + let attention = SegformerAttention::new( + config, + hidden_size, + num_attention_heads, + sequence_reduction_ratio, + vb.pp("attention"), + )?; + let layer_norm_2 = layer_norm(hidden_size, config.layer_norm_eps, vb.pp("layer_norm_2"))?; + let mlp = SegformerMixFFN::new( + config, + hidden_size, + hidden_size * mlp_ratio, + hidden_size, + vb.pp("mlp"), + )?; + Ok(Self { + layer_norm_1, + attention, + layer_norm_2, + mlp, + }) + } +} + +impl Module for SegformerLayer { + fn forward(&self, x: &Tensor) -> Result { + let shape = x.shape().dims4()?; + // [B, C, H, W] -> [B, H * W, C] + let hidden_states = x.flatten_from(2)?.permute((0, 2, 1))?; + let layer_norm_output = self.layer_norm_1.forward(&hidden_states)?; + let layer_norm_output = layer_norm_output.permute((0, 2, 1))?.reshape(shape)?; + // attention takes in [B, C, H, W] in order to properly do conv2d (and output [B, H * W, C]) + let attention_output = self.attention.forward(&layer_norm_output)?; + let hidden_states = (attention_output + hidden_states)?; + let layer_norm_output = self.layer_norm_2.forward(&hidden_states)?; + let mlp_output = self + .mlp + .forward(&layer_norm_output.permute((0, 2, 1))?.reshape(shape)?)?; + hidden_states.permute((0, 2, 1))?.reshape(shape)? + mlp_output + } +} + +#[derive(Debug, Clone)] +struct SegformerEncoder { + /// config file + config: Config, + /// a list of embeddings + patch_embeddings: Vec, + /// a list of attention blocks, each consisting of layers + blocks: Vec>, + /// a final list of layer norms + layer_norms: Vec, +} + +impl SegformerEncoder { + fn new(config: Config, vb: VarBuilder) -> Result { + let mut patch_embeddings = Vec::with_capacity(config.num_encoder_blocks); + let mut blocks = Vec::with_capacity(config.num_encoder_blocks); + let mut layer_norms = Vec::with_capacity(config.num_encoder_blocks); + for i in 0..config.num_encoder_blocks { + let patch_size = config.patch_sizes[i]; + let stride = config.strides[i]; + let hidden_size = config.hidden_sizes[i]; + let num_channels = if i == 0 { + config.num_channels + } else { + config.hidden_sizes[i - 1] + }; + patch_embeddings.push(SegformerOverlapPatchEmbeddings::new( + &config, + patch_size, + stride, + num_channels, + hidden_size, + vb.pp(&format!("patch_embeddings.{}", i)), + )?); + let mut layers = Vec::with_capacity(config.depths[i]); + for j in 0..config.depths[i] { + let sequence_reduction_ratio = config.sr_ratios[i]; + let num_attention_heads = config.num_attention_heads[i]; + let mlp_ratio = config.mlp_ratios[i]; + layers.push(SegformerLayer::new( + &config, + hidden_size, + num_attention_heads, + sequence_reduction_ratio, + mlp_ratio, + vb.pp(&format!("block.{}.{}", i, j)), + )?); + } + blocks.push(layers); + layer_norms.push(layer_norm( + hidden_size, + config.layer_norm_eps, + vb.pp(&format!("layer_norm.{}", i)), + )?); + } + Ok(Self { + config, + patch_embeddings, + blocks, + layer_norms, + }) + } +} + +impl ModuleWithHiddenStates for SegformerEncoder { + fn forward(&self, x: &Tensor) -> Result> { + let mut all_hidden_states = Vec::with_capacity(self.config.num_encoder_blocks); + let mut hidden_states = x.clone(); + for i in 0..self.config.num_encoder_blocks { + hidden_states = self.patch_embeddings[i].forward(&hidden_states)?; + for layer in &self.blocks[i] { + hidden_states = layer.forward(&hidden_states)?; + } + let shape = hidden_states.shape().dims4()?; + hidden_states = + self.layer_norms[i].forward(&hidden_states.flatten_from(2)?.permute((0, 2, 1))?)?; + hidden_states = hidden_states.permute((0, 2, 1))?.reshape(shape)?; + all_hidden_states.push(hidden_states.clone()); + } + Ok(all_hidden_states) + } +} + +#[derive(Debug, Clone)] +struct SegformerModel { + encoder: SegformerEncoder, +} + +impl SegformerModel { + fn new(config: &Config, vb: VarBuilder) -> Result { + let encoder = SegformerEncoder::new(config.clone(), vb.pp("encoder"))?; + Ok(Self { encoder }) + } +} + +impl ModuleWithHiddenStates for SegformerModel { + fn forward(&self, x: &Tensor) -> Result> { + self.encoder.forward(x) + } +} + +#[derive(Debug, Clone)] +struct SegformerMLP { + proj: Linear, +} + +impl SegformerMLP { + fn new(config: &Config, input_dim: usize, vb: VarBuilder) -> Result { + let proj = linear(input_dim, config.decoder_hidden_size, vb.pp("proj"))?; + Ok(Self { proj }) + } +} + +impl Module for SegformerMLP { + fn forward(&self, x: &Tensor) -> Result { + self.proj.forward(x) + } +} + +#[derive(Debug, Clone)] +struct SegformerDecodeHead { + linear_c: Vec, + linear_fuse: candle_nn::Conv2d, + batch_norm: candle_nn::BatchNorm, + classifier: candle_nn::Conv2d, +} + +impl SegformerDecodeHead { + fn new(config: &Config, num_labels: usize, vb: VarBuilder) -> Result { + let mut linear_c = Vec::with_capacity(config.num_encoder_blocks); + for i in 0..config.num_encoder_blocks { + let hidden_size = config.hidden_sizes[i]; + linear_c.push(SegformerMLP::new( + config, + hidden_size, + vb.pp(&format!("linear_c.{}", i)), + )?); + } + let linear_fuse = conv2d_no_bias( + config.decoder_hidden_size * config.num_encoder_blocks, + config.decoder_hidden_size, + 1, + Conv2dConfig::default(), + vb.pp("linear_fuse"), + )?; + let batch_norm = candle_nn::batch_norm( + config.decoder_hidden_size, + config.layer_norm_eps, + vb.pp("batch_norm"), + )?; + let classifier = conv2d_no_bias( + config.decoder_hidden_size, + num_labels, + 1, + Conv2dConfig::default(), + vb.pp("classifier"), + )?; + Ok(Self { + linear_c, + linear_fuse, + batch_norm, + classifier, + }) + } + + fn forward(&self, encoder_hidden_states: &[Tensor]) -> Result { + if encoder_hidden_states.len() != self.linear_c.len() { + candle::bail!( + "The number of encoder hidden states {} is not equal to the number of linear layers {}", + encoder_hidden_states.len(), + self.linear_c.len() + ) + } + // most fine layer + let (_, _, upsample_height, upsample_width) = encoder_hidden_states[0].shape().dims4()?; + let mut hidden_states = Vec::with_capacity(self.linear_c.len()); + for (hidden_state, mlp) in encoder_hidden_states.iter().zip(&self.linear_c) { + let (batch, _, height, width) = hidden_state.shape().dims4()?; + let hidden_state = mlp.forward(&hidden_state.flatten_from(2)?.permute((0, 2, 1))?)?; + let hidden_state = hidden_state.permute((0, 2, 1))?.reshape(( + batch, + hidden_state.dim(2)?, + height, + width, + ))?; + let hidden_state = hidden_state.upsample_nearest2d(upsample_height, upsample_width)?; + hidden_states.push(hidden_state); + } + hidden_states.reverse(); + let hidden_states = Tensor::cat(&hidden_states, 1)?; + let hidden_states = self.linear_fuse.forward(&hidden_states)?; + let hidden_states = self.batch_norm.forward_t(&hidden_states, false)?; + let hidden_states = hidden_states.relu()?; + self.classifier.forward(&hidden_states) + } +} + +trait ModuleWithHiddenStates { + fn forward(&self, xs: &Tensor) -> Result>; +} + +#[derive(Debug, Clone)] +pub struct SemanticSegmentationModel { + segformer: SegformerModel, + decode_head: SegformerDecodeHead, +} + +impl SemanticSegmentationModel { + pub fn new(config: &Config, num_labels: usize, vb: VarBuilder) -> Result { + let segformer = SegformerModel::new(config, vb.pp("segformer"))?; + let decode_head = SegformerDecodeHead::new(config, num_labels, vb.pp("decode_head"))?; + Ok(Self { + segformer, + decode_head, + }) + } +} + +impl Module for SemanticSegmentationModel { + fn forward(&self, x: &Tensor) -> Result { + let hidden_states = self.segformer.forward(x)?; + self.decode_head.forward(&hidden_states) + } +} + +#[derive(Debug, Clone)] +pub struct ImageClassificationModel { + segformer: SegformerModel, + classifier: Linear, +} + +impl ImageClassificationModel { + pub fn new(config: &Config, num_labels: usize, vb: VarBuilder) -> Result { + let segformer = SegformerModel::new(config, vb.pp("segformer"))?; + let classifier = linear(config.decoder_hidden_size, num_labels, vb.pp("classifier"))?; + Ok(Self { + segformer, + classifier, + }) + } +} + +impl Module for ImageClassificationModel { + fn forward(&self, x: &Tensor) -> Result { + let all_hidden_states = self.segformer.forward(x)?; + let hidden_states = all_hidden_states.last().unwrap(); + let hidden_states = hidden_states.flatten_from(2)?.permute((0, 2, 1))?; + let mean = hidden_states.mean(1)?; + self.classifier.forward(&mean) + } +} + +#[cfg(test)] +mod tests { + + use super::*; + + #[test] + fn test_config_json_load() { + let raw_json = r#"{ + "architectures": [ + "SegformerForImageClassification" + ], + "attention_probs_dropout_prob": 0.0, + "classifier_dropout_prob": 0.1, + "decoder_hidden_size": 256, + "depths": [ + 2, + 2, + 2, + 2 + ], + "downsampling_rates": [ + 1, + 4, + 8, + 16 + ], + "drop_path_rate": 0.1, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.0, + "hidden_sizes": [ + 32, + 64, + 160, + 256 + ], + "image_size": 224, + "initializer_range": 0.02, + "layer_norm_eps": 1e-06, + "mlp_ratios": [ + 4, + 4, + 4, + 4 + ], + "model_type": "segformer", + "num_attention_heads": [ + 1, + 2, + 5, + 8 + ], + "num_channels": 3, + "num_encoder_blocks": 4, + "patch_sizes": [ + 7, + 3, + 3, + 3 + ], + "sr_ratios": [ + 8, + 4, + 2, + 1 + ], + "strides": [ + 4, + 2, + 2, + 2 + ], + "torch_dtype": "float32", + "transformers_version": "4.12.0.dev0" + }"#; + let config: Config = serde_json::from_str(raw_json).unwrap(); + assert_eq!(vec![4, 2, 2, 2], config.strides); + assert_eq!(1e-6, config.layer_norm_eps); + } +}