Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is it possible to load weights from VarStore even after applying some operations? #838

Open
skyser2003 opened this issue Jan 17, 2024 · 0 comments

Comments

@skyser2003
Copy link

Hello, I'm trying to load weights from pretrained model.safetensors file,
and I have a question(or maybe feature request).

This is part of the code I wrote.

// vs: VarStore

let qkv_weight = Tensor::cat(
    &[
        (vs / "q_proj").var("weight", &embed_shape, nn::Init::Const(0.0)),
        (vs / "k_proj").var("weight", &embed_shape, nn::Init::Const(0.0)),
        (vs / "v_proj").var("weight", &embed_shape, nn::Init::Const(0.0)),
    ],
    0,
);

let qkv = Linear {
    ws: qkv_weight,
    bs: None,
};

let out = Linear {
    ws: (vs / "out_proj").var("weight", &embed_shape, nn::Init::Const(0.0)),
    bs: None,
};

The model has linear weights named q_proj, k_proj, v_proj, and out_proj.
The first three are concatenated as one linear layer, and out_proj is kept alone.

And as you might expect, even after loading the weights from model file,
all values of qkv_weights remains 0.0, while out_proj is loaded successfully.

This may look obvious, because the first three proj vars are merged into a new tensor by Tensor::cat operation,
which means the variable information are all gone.
But I've also implemented the same weight loading code using huggingface's candle framework,
and on candle, qkv_weight successfully loaded values from the model file.

I guess the q_proj, k_proj, v_proj concatenation can be done by first loading the weights,
and then applying the Tensor::cat operation, but that breaks the inconsistency of loading weights,
where other variables can just load the weights by calling vs.load.

So, to sum it up, is there any way to load the weights from VarStore even after applying ops on the variables,
or is there any better way to load the variables from VarStore?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant