Skip to content

Commit

Permalink
Enable the new layer-norm.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed May 24, 2024
1 parent 1df2bdd commit 5c44e7a
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
5 changes: 5 additions & 0 deletions candle-nn/src/layer_norm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,11 @@ impl LayerNorm {

impl Module for LayerNorm {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
if x.is_contiguous() && self.remove_mean {
if let Some(bias) = self.bias.as_ref() {
return crate::ops::layer_norm(x, &self.weight, bias, self.eps as f32);
}
}
let x_dtype = x.dtype();
let internal_dtype = match x_dtype {
DType::F16 | DType::BF16 => DType::F32,
Expand Down
12 changes: 4 additions & 8 deletions candle-transformers/src/models/phi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,24 +56,20 @@ impl RotaryEmbedding {
.to_dtype(DType::F32)?
.reshape((cfg.max_position_embeddings, 1))?;
let freqs = t.matmul(&inv_freq)?;
let emb = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
Ok(Self {
dim,
sin: emb.sin()?,
cos: emb.cos()?,
sin: freqs.sin()?,
cos: freqs.cos()?,
})
}

fn apply_rotary_emb(&self, xs: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
let (_b_size, _num_heads, seq_len, _headdim) = xs.dims4()?;
let xs_rot = xs.i((.., .., .., ..self.dim))?;
let xs_rot = xs.i((.., .., .., ..self.dim))?.contiguous()?;
let xs_pass = xs.i((.., .., .., self.dim..))?;
let xs12 = xs_rot.chunk(2, D::Minus1)?;
let (xs1, xs2) = (&xs12[0], &xs12[1]);
let c = self.cos.narrow(0, seqlen_offset, seq_len)?;
let s = self.sin.narrow(0, seqlen_offset, seq_len)?;
let rotate_half = Tensor::cat(&[&xs2.neg()?, xs1], D::Minus1)?;
let xs_rot = (xs_rot.broadcast_mul(&c)? + rotate_half.broadcast_mul(&s)?)?;
let xs_rot = candle_nn::rotary_emb::rope(&xs_rot, &c, &s)?;
Tensor::cat(&[&xs_rot, &xs_pass], D::Minus1)
}
}
Expand Down

0 comments on commit 5c44e7a

Please sign in to comment.