Skip to content

Commit

Permalink
Support for the new Qwen2 models. (#2257)
Browse files Browse the repository at this point in the history
* Support for the new Qwen2 models.

* Add more models.
  • Loading branch information
LaurentMazare committed Jun 7, 2024
1 parent b9fac7e commit 54ff971
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 12 deletions.
36 changes: 26 additions & 10 deletions candle-examples/examples/qwen/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,14 @@ enum WhichModel {
W72b,
#[value(name = "moe-a2.7b")]
MoeA27b,
#[value(name = "2-0.5b")]
W2_0_5b,
#[value(name = "2-1.5b")]
W2_1_5b,
#[value(name = "2-7b")]
W2_7b,
#[value(name = "2-72b")]
W2_72b,
}

#[derive(Parser, Debug)]
Expand Down Expand Up @@ -234,16 +242,20 @@ fn main() -> Result<()> {
let model_id = match args.model_id {
Some(model_id) => model_id,
None => {
let size = match args.model {
WhichModel::W0_5b => "0.5B",
WhichModel::W1_8b => "1.8B",
WhichModel::W4b => "4B",
WhichModel::W7b => "7B",
WhichModel::W14b => "14B",
WhichModel::W72b => "72B",
WhichModel::MoeA27b => "MoE-A2.7B",
let (version, size) = match args.model {
WhichModel::W2_0_5b => ("2", "0.5B"),
WhichModel::W2_1_5b => ("2", "1.5B"),
WhichModel::W2_7b => ("2", "7B"),
WhichModel::W2_72b => ("2", "72B"),
WhichModel::W0_5b => ("1.5", "0.5B"),
WhichModel::W1_8b => ("1.5", "1.8B"),
WhichModel::W4b => ("1.5", "4B"),
WhichModel::W7b => ("1.5", "7B"),
WhichModel::W14b => ("1.5", "14B"),
WhichModel::W72b => ("1.5", "72B"),
WhichModel::MoeA27b => ("1.5", "MoE-A2.7B"),
};
format!("Qwen/Qwen1.5-{size}")
format!("Qwen/Qwen{version}-{size}")
}
};
let repo = api.repo(Repo::with_revision(
Expand All @@ -261,11 +273,15 @@ fn main() -> Result<()> {
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => match args.model {
WhichModel::W0_5b | WhichModel::W1_8b => vec![repo.get("model.safetensors")?],
WhichModel::W0_5b | WhichModel::W2_0_5b | WhichModel::W2_1_5b | WhichModel::W1_8b => {
vec![repo.get("model.safetensors")?]
}
WhichModel::W4b
| WhichModel::W7b
| WhichModel::W2_7b
| WhichModel::W14b
| WhichModel::W72b
| WhichModel::W2_72b
| WhichModel::MoeA27b => {
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
}
Expand Down
8 changes: 6 additions & 2 deletions candle-transformers/src/models/qwen2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -360,8 +360,12 @@ pub struct ModelForCausalLM {

impl ModelForCausalLM {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
let base_model = Model::new(cfg, vb)?;
let base_model = Model::new(cfg, vb.clone())?;
let lm_head = if vb.contains_tensor("lm_head") {
linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?
} else {
Linear::from_weights(base_model.embed_tokens.embeddings().clone(), None)
};
Ok(Self {
base_model,
lm_head,
Expand Down

0 comments on commit 54ff971

Please sign in to comment.