Skip to content

Commit

Permalink
add semantic segmentation model
Browse files Browse the repository at this point in the history
  • Loading branch information
jimexist committed Jan 24, 2024
1 parent faea6c6 commit 35874e1
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 3 deletions.
2 changes: 1 addition & 1 deletion candle-examples/examples/segformer/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,6 @@ pub fn main() -> anyhow::Result<()> {
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], candle::DType::F32, &device)? };
let config = Default::default();
let num_labels = 1000;
let model = segformer::ImageClassificationModel::new(&config, num_labels, vb)?;
let model = segformer::SemanticSegmentationModel::new(&config, num_labels, vb)?;
Ok(())
}
113 changes: 111 additions & 2 deletions candle-transformers/src/models/segformer.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::models::with_tracing::{conv2d, linear, Conv2d, Linear};
use candle::{Module, Result, Tensor, D};
use candle_nn::{layer_norm, Activation, Conv2dConfig, Dropout, VarBuilder};
use candle::{Module, ModuleT, Result, Tensor, D};
use candle_nn::{conv2d_no_bias, layer_norm, Activation, Conv2dConfig, Dropout, VarBuilder};

// https://github.com/huggingface/transformers/blob/main/src/transformers/models/segformer/configuration_segformer.py
#[derive(Debug, Clone)]
Expand Down Expand Up @@ -494,3 +494,112 @@ impl Module for ImageClassificationModel {
self.classifier.forward(&hidden_states)
}
}

#[derive(Debug)]
struct SegformerMLP {
proj: Linear,
}

impl SegformerMLP {
fn new(config: &Config, input_dim: usize, vb: VarBuilder) -> Result<Self> {
let proj = linear(input_dim, config.decoder_hidden_size, vb.pp("linear"))?;
Ok(Self { proj })
}
}

impl Module for SegformerMLP {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
self.proj.forward(x)
}
}

#[derive(Debug)]
struct SegformerDecodeHead {
linear_c: Vec<SegformerMLP>,
linear_fuse: candle_nn::Conv2d,
batch_norm: candle_nn::BatchNorm,
dropout: Dropout,
classifier: candle_nn::Conv2d,
}

impl SegformerDecodeHead {
fn new(config: &Config, num_labels: usize, vb: VarBuilder) -> Result<Self> {
let mut linear_c = Vec::with_capacity(config.num_encoder_blocks);
for i in 0..config.num_encoder_blocks {
let hidden_size = config.hidden_sizes[i];
linear_c.push(SegformerMLP::new(
config,
hidden_size,
vb.pp(&format!("mlps.{}", i)),
)?);
}
let linear_fuse = conv2d_no_bias(
config.decoder_hidden_size * config.num_encoder_blocks,
config.decoder_hidden_size,
1,
Conv2dConfig::default(),
vb.pp("linear_fuse"),
)?;
let batch_norm = candle_nn::batch_norm(
config.decoder_hidden_size,
config.layer_norm_eps,
vb.pp("batch_norm"),
)?;
let dropout = Dropout::new(config.classifier_dropout_prob);
let classifier = conv2d_no_bias(
config.decoder_hidden_size,
num_labels,
1,
Conv2dConfig::default(),
vb.pp("classifier"),
)?;
Ok(Self {
linear_c,
linear_fuse,
batch_norm,
dropout,
classifier,
})
}
}

impl Module for SegformerDecodeHead {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let mut hidden_states = Vec::with_capacity(self.linear_c.len());
for i in 0..self.linear_c.len() {
hidden_states.push(self.linear_c[i].forward(&x)?);
}
let hidden_states = Tensor::cat(&hidden_states, D::Minus1)?;
let hidden_states = self.linear_fuse.forward(&hidden_states)?;
let hidden_states = self.batch_norm.forward_t(&hidden_states, false)?;
let hidden_states = hidden_states.relu()?;
let hidden_states = self.dropout.forward(&hidden_states, false)?;
// logits are of shape (batch_size, num_labels, height/4, width/4)
let logits = self.classifier.forward(&hidden_states)?;
Ok(logits)
}
}

#[derive(Debug)]
pub struct SemanticSegmentationModel {
encoder: SegformerEncoder,
decode_head: SegformerDecodeHead,
}

impl SemanticSegmentationModel {
pub fn new(config: &Config, num_labels: usize, vb: VarBuilder) -> Result<Self> {
let encoder = SegformerEncoder::new(config.clone(), vb.pp("encoder"))?;
let decode_head = SegformerDecodeHead::new(config, num_labels, vb.pp("decode_head"))?;
Ok(Self {
encoder,
decode_head,
})
}
}

impl Module for SemanticSegmentationModel {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let hidden_states = self.encoder.forward(x)?;
self.decode_head.forward(&hidden_states)
}
}

0 comments on commit 35874e1

Please sign in to comment.