-
Notifications
You must be signed in to change notification settings - Fork 255
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
llm-factor: migrate to candle
#2755
Conversation
Given this is a breaking change, I'd suggest adding the 3.0 label. |
@radu-matei I do not believe I can add labels in this repository. |
8a822de
to
77b2aaf
Compare
77b2aaf
to
927679c
Compare
611f2b2
to
5abb8ca
Compare
The test failure does not seem to be related? |
crates/llm-local/src/utils.rs
Outdated
let json: serde_json::Value = | ||
serde_json::from_reader(&json_file).map_err(candle::Error::wrap)?; | ||
let weight_map = match json.get("weight_map") { | ||
None => candle::bail!("no weight map in {json_file:?}"), | ||
Some(serde_json::Value::Object(map)) => map, | ||
Some(_) => candle::bail!("weight map in {json_file:?} is not a map"), | ||
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can replace this with:
#[derive(Deserialize)]
struct SafeTensorsJson {
weight_map: HashMap<String, String>
}
let json: SafeTensorsJson = serde_json::from_reader(&json_file).map_err(candle::Error::wrap)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I reverted this change and the other because it was leading to some off error, where the returned vector was a duplicate of the same thing repeated several times which meant the same files were being loaded over and over which led to consuming large amounts of memory.
crates/llm-local/src/utils.rs
Outdated
for value in weight_map.values() { | ||
if let Some(file) = value.as_str() { | ||
safetensors_files.insert(file.to_string()); | ||
} | ||
} | ||
let safetensors_files = safetensors_files | ||
.iter() | ||
.map(|v| model_dir.join(v)) | ||
.collect::<Vec<_>>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for value in weight_map.values() { | |
if let Some(file) = value.as_str() { | |
safetensors_files.insert(file.to_string()); | |
} | |
} | |
let safetensors_files = safetensors_files | |
.iter() | |
.map(|v| model_dir.join(v)) | |
.collect::<Vec<_>>(); | |
safetensors_files.extend(weight_map.values().map(|v| model_dir.join(v)) |
This assumes no need to call as_str
because of the suggested change above.
crates/llm-local/src/lib.rs
Outdated
} | ||
|
||
#[async_trait] | ||
trait CachedInferencingModel: Send + Sync { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we document this trait? What about it makes it Cached
? Are implementors required to cache results or does it just happen that the current implementors do?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm fine with keeping the name, but I personally find the name CachedInferencingModel
confusing when implementors aren't required to cache anything. InferencingModel
seems like a more appropriate name.
08f1611
to
3352e7e
Compare
match self.tokenizer.decode(tokens, true) { | ||
Ok(str) => Ok(str), | ||
Err(err) => anyhow::bail!("cannot decode: {err}"), | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
match self.tokenizer.decode(tokens, true) { | |
Ok(str) => Ok(str), | |
Err(err) => anyhow::bail!("cannot decode: {err}"), | |
} | |
self.tokenizer.decode(tokens, true).context("failed to decode token stream") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It does look like I cannot do this because tokenizer.decode
returns a Result<String, Box<dyn Error + Send + Sync>>
which does not seem to be suitable to use context on(?)
}; | ||
self.tokens.push(token); | ||
let text = self.decode(&self.tokens[self.prev_index..])?; | ||
if text.len() > prev_text.len() && text.chars().last().unwrap().is_alphanumeric() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't fully understand what this check is supposed to be doing. Why do we care about the length of the next text vs the previous, and why do we care whether the last character is alphanumeric?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The length check is to see if we have any new tokens. The alphanumeric check is supposed to be to check if we have a valid token to decode. That is what I gather from the python function the docs link to
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The python code is dealing with unfinished utf-8 byte sequences which is not possible at this point in the Rust code. Rust char
s are guaranteed to be valid utf-8. The check for alphanumeric chars is checking that the character is A-Z | a-z | 0-9
which does seem to be what we want.
The Tokenizer::decode
function returns String
s so I'm guessing somehow the tokenizer
crate is taking care of byte sequences that aren't valid utf-8?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here is the relevant rust version from where this is borrowed.
https://github.com/huggingface/candle/blob/6eea45a761fc1636b5e8012d02bdaa93321652ca/candle-examples/src/token_output_stream.rs#L43
6afbcfb
to
3db74cc
Compare
crates/llm-local/src/lib.rs
Outdated
} | ||
|
||
#[async_trait] | ||
trait CachedInferencingModel: Send + Sync { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm fine with keeping the name, but I personally find the name CachedInferencingModel
confusing when implementors aren't required to cache anything. InferencingModel
seems like a more appropriate name.
}; | ||
self.tokens.push(token); | ||
let text = self.decode(&self.tokens[self.prev_index..])?; | ||
if text.len() > prev_text.len() && text.chars().last().unwrap().is_alphanumeric() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The python code is dealing with unfinished utf-8 byte sequences which is not possible at this point in the Rust code. Rust char
s are guaranteed to be valid utf-8. The check for alphanumeric chars is checking that the character is A-Z | a-z | 0-9
which does seem to be what we want.
The Tokenizer::decode
function returns String
s so I'm guessing somehow the tokenizer
crate is taking care of byte sequences that aren't valid utf-8?
d4c0a2e
to
9c61749
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🎉
1a20e61
to
279c58c
Compare
Signed-off-by: karthik2804 <karthik.ganeshram@fermyon.com>
279c58c
to
4e40481
Compare
This PR replaces the dependency on
rustformers/llm
tohuggingface/candle
. This allows us to run newer models like Llama 3(.1). This now requires the models to be of thesafetensors
format.This PR also removes the concept of well-known models. This ensures a consistent directory structure for all models. The rationale is that, with this change, the only group of models initially supported is the Llama family.
Closes #2735