Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,22 @@ class SpatialTransformer : public GGMLBlock {
int64_t context_dim = 768; // hidden_size, 1024 for VERSION_SD2
bool use_linear = false;

void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") {
auto iter = tensor_storage_map.find(prefix + "proj_out.weight");
if (iter != tensor_storage_map.end()) {
int64_t inner_dim = n_head * d_head;
if (iter->second.n_dims == 4 && use_linear) {
use_linear = false;
blocks["proj_in"] = std::make_shared<Conv2d>(in_channels, inner_dim, std::pair{1, 1});
blocks["proj_out"] = std::make_shared<Conv2d>(inner_dim, in_channels, std::pair{1, 1});
} else if (iter->second.n_dims == 2 && !use_linear) {
use_linear = true;
blocks["proj_in"] = std::make_shared<Linear>(in_channels, inner_dim);
blocks["proj_out"] = std::make_shared<Linear>(inner_dim, in_channels);
}
}
}

public:
SpatialTransformer(int64_t in_channels,
int64_t n_head,
Expand Down
2 changes: 1 addition & 1 deletion ggml_extend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1926,8 +1926,8 @@ class GGMLBlock {
if (prefix.size() > 0) {
prefix = prefix + ".";
}
init_blocks(ctx, tensor_storage_map, prefix);
init_params(ctx, tensor_storage_map, prefix);
init_blocks(ctx, tensor_storage_map, prefix);
}

size_t get_params_num() {
Expand Down
19 changes: 19 additions & 0 deletions vae.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,25 @@ class AttnBlock : public UnaryBlock {
int64_t in_channels;
bool use_linear;

void init_params(struct ggml_context* ctx, const String2TensorStorage& tensor_storage_map = {}, const std::string prefix = "") {
auto iter = tensor_storage_map.find(prefix + "proj_out.weight");
if (iter != tensor_storage_map.end()) {
if (iter->second.n_dims == 4 && use_linear) {
use_linear = false;
blocks["q"] = std::make_shared<Conv2d>(in_channels, in_channels, std::pair{1, 1});
blocks["k"] = std::make_shared<Conv2d>(in_channels, in_channels, std::pair{1, 1});
blocks["v"] = std::make_shared<Conv2d>(in_channels, in_channels, std::pair{1, 1});
blocks["proj_out"] = std::make_shared<Conv2d>(in_channels, in_channels, std::pair{1, 1});
} else if (iter->second.n_dims == 2 && !use_linear) {
use_linear = true;
blocks["q"] = std::make_shared<Linear>(in_channels, in_channels);
blocks["k"] = std::make_shared<Linear>(in_channels, in_channels);
blocks["v"] = std::make_shared<Linear>(in_channels, in_channels);
blocks["proj_out"] = std::make_shared<Linear>(in_channels, in_channels);
}
}
}

public:
AttnBlock(int64_t in_channels, bool use_linear)
: in_channels(in_channels), use_linear(use_linear) {
Expand Down
Loading