Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 32 additions & 32 deletions model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1892,24 +1892,25 @@ SDVersion ModelLoader::get_sd_version() {
return VERSION_COUNT;
}

ggml_type ModelLoader::get_sd_wtype() {
std::map<ggml_type, uint32_t> ModelLoader::get_wtype_stat() {
std::map<ggml_type, uint32_t> wtype_stat;
for (auto& tensor_storage : tensor_storages) {
if (is_unused_tensor(tensor_storage.name)) {
continue;
}

if (ggml_is_quantized(tensor_storage.type)) {
return tensor_storage.type;
}

if (tensor_should_be_converted(tensor_storage, GGML_TYPE_Q4_K)) {
return tensor_storage.type;
auto iter = wtype_stat.find(tensor_storage.type);
if (iter != wtype_stat.end()) {
iter->second++;
} else {
wtype_stat[tensor_storage.type] = 1;
}
}
return GGML_TYPE_COUNT;
return wtype_stat;
}

ggml_type ModelLoader::get_conditioner_wtype() {
std::map<ggml_type, uint32_t> ModelLoader::get_conditioner_wtype_stat() {
std::map<ggml_type, uint32_t> wtype_stat;
for (auto& tensor_storage : tensor_storages) {
if (is_unused_tensor(tensor_storage.name)) {
continue;
Expand All @@ -1922,18 +1923,18 @@ ggml_type ModelLoader::get_conditioner_wtype() {
continue;
}

if (ggml_is_quantized(tensor_storage.type)) {
return tensor_storage.type;
}

if (tensor_should_be_converted(tensor_storage, GGML_TYPE_Q4_K)) {
return tensor_storage.type;
auto iter = wtype_stat.find(tensor_storage.type);
if (iter != wtype_stat.end()) {
iter->second++;
} else {
wtype_stat[tensor_storage.type] = 1;
}
}
return GGML_TYPE_COUNT;
return wtype_stat;
}

ggml_type ModelLoader::get_diffusion_model_wtype() {
std::map<ggml_type, uint32_t> ModelLoader::get_diffusion_model_wtype_stat() {
std::map<ggml_type, uint32_t> wtype_stat;
for (auto& tensor_storage : tensor_storages) {
if (is_unused_tensor(tensor_storage.name)) {
continue;
Expand All @@ -1943,18 +1944,18 @@ ggml_type ModelLoader::get_diffusion_model_wtype() {
continue;
}

if (ggml_is_quantized(tensor_storage.type)) {
return tensor_storage.type;
}

if (tensor_should_be_converted(tensor_storage, GGML_TYPE_Q4_K)) {
return tensor_storage.type;
auto iter = wtype_stat.find(tensor_storage.type);
if (iter != wtype_stat.end()) {
iter->second++;
} else {
wtype_stat[tensor_storage.type] = 1;
}
}
return GGML_TYPE_COUNT;
return wtype_stat;
}

ggml_type ModelLoader::get_vae_wtype() {
std::map<ggml_type, uint32_t> ModelLoader::get_vae_wtype_stat() {
std::map<ggml_type, uint32_t> wtype_stat;
for (auto& tensor_storage : tensor_storages) {
if (is_unused_tensor(tensor_storage.name)) {
continue;
Expand All @@ -1965,15 +1966,14 @@ ggml_type ModelLoader::get_vae_wtype() {
continue;
}

if (ggml_is_quantized(tensor_storage.type)) {
return tensor_storage.type;
}

if (tensor_should_be_converted(tensor_storage, GGML_TYPE_Q4_K)) {
return tensor_storage.type;
auto iter = wtype_stat.find(tensor_storage.type);
if (iter != wtype_stat.end()) {
iter->second++;
} else {
wtype_stat[tensor_storage.type] = 1;
}
}
return GGML_TYPE_COUNT;
return wtype_stat;
}

void ModelLoader::set_wtype_override(ggml_type wtype, std::string prefix) {
Expand Down
8 changes: 4 additions & 4 deletions model.h
Original file line number Diff line number Diff line change
Expand Up @@ -259,10 +259,10 @@ class ModelLoader {
bool init_from_file(const std::string& file_path, const std::string& prefix = "");
bool model_is_unet();
SDVersion get_sd_version();
ggml_type get_sd_wtype();
ggml_type get_conditioner_wtype();
ggml_type get_diffusion_model_wtype();
ggml_type get_vae_wtype();
std::map<ggml_type, uint32_t> get_wtype_stat();
std::map<ggml_type, uint32_t> get_conditioner_wtype_stat();
std::map<ggml_type, uint32_t> get_diffusion_model_wtype_stat();
std::map<ggml_type, uint32_t> get_vae_wtype_stat();
void set_wtype_override(ggml_type wtype, std::string prefix = "");
bool load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_threads = 0);
bool load_tensors(std::map<std::string, struct ggml_tensor*>& tensors,
Expand Down
59 changes: 24 additions & 35 deletions stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,6 @@ class StableDiffusionGGML {
ggml_backend_t clip_backend = NULL;
ggml_backend_t control_net_backend = NULL;
ggml_backend_t vae_backend = NULL;
ggml_type model_wtype = GGML_TYPE_COUNT;
ggml_type conditioner_wtype = GGML_TYPE_COUNT;
ggml_type diffusion_model_wtype = GGML_TYPE_COUNT;
ggml_type vae_wtype = GGML_TYPE_COUNT;

SDVersion version;
bool vae_decode_only = false;
Expand Down Expand Up @@ -294,37 +290,33 @@ class StableDiffusionGGML {
ggml_type wtype = (int)sd_ctx_params->wtype < std::min<int>(SD_TYPE_COUNT, GGML_TYPE_COUNT)
? (ggml_type)sd_ctx_params->wtype
: GGML_TYPE_COUNT;
if (wtype == GGML_TYPE_COUNT) {
model_wtype = model_loader.get_sd_wtype();
if (model_wtype == GGML_TYPE_COUNT) {
model_wtype = GGML_TYPE_F32;
LOG_WARN("can not get mode wtype frome weight, use f32");
}
conditioner_wtype = model_loader.get_conditioner_wtype();
if (conditioner_wtype == GGML_TYPE_COUNT) {
conditioner_wtype = wtype;
}
diffusion_model_wtype = model_loader.get_diffusion_model_wtype();
if (diffusion_model_wtype == GGML_TYPE_COUNT) {
diffusion_model_wtype = wtype;
}
vae_wtype = model_loader.get_vae_wtype();

if (vae_wtype == GGML_TYPE_COUNT) {
vae_wtype = wtype;
}
} else {
model_wtype = wtype;
conditioner_wtype = wtype;
diffusion_model_wtype = wtype;
vae_wtype = wtype;
if (wtype != GGML_TYPE_COUNT) {
model_loader.set_wtype_override(wtype);
}

LOG_INFO("Weight type: %s", ggml_type_name(model_wtype));
LOG_INFO("Conditioner weight type: %s", ggml_type_name(conditioner_wtype));
LOG_INFO("Diffusion model weight type: %s", ggml_type_name(diffusion_model_wtype));
LOG_INFO("VAE weight type: %s", ggml_type_name(vae_wtype));
std::map<ggml_type, uint32_t> wtype_stat = model_loader.get_wtype_stat();
std::map<ggml_type, uint32_t> conditioner_wtype_stat = model_loader.get_conditioner_wtype_stat();
std::map<ggml_type, uint32_t> diffusion_model_wtype_stat = model_loader.get_diffusion_model_wtype_stat();
std::map<ggml_type, uint32_t> vae_wtype_stat = model_loader.get_vae_wtype_stat();

auto wtype_stat_to_str = [](const std::map<ggml_type, uint32_t>& m, int key_width = 8, int value_width = 5) -> std::string {
std::ostringstream oss;
bool first = true;
for (const auto& [type, count] : m) {
if (!first)
oss << "|";
first = false;
oss << std::right << std::setw(key_width) << ggml_type_name(type)
<< ": "
<< std::left << std::setw(value_width) << count;
}
return oss.str();
};

LOG_INFO("Weight type stat: %s", wtype_stat_to_str(wtype_stat).c_str());
LOG_INFO("Conditioner weight type stat: %s", wtype_stat_to_str(conditioner_wtype_stat).c_str());
LOG_INFO("Diffusion model weight type stat: %s", wtype_stat_to_str(diffusion_model_wtype_stat).c_str());
LOG_INFO("VAE weight type stat: %s", wtype_stat_to_str(vae_wtype_stat).c_str());

LOG_DEBUG("ggml tensor size = %d bytes", (int)sizeof(ggml_tensor));

Expand Down Expand Up @@ -938,9 +930,6 @@ class StableDiffusionGGML {
}

void apply_loras(const std::unordered_map<std::string, float>& lora_state) {
if (lora_state.size() > 0 && model_wtype != GGML_TYPE_F16 && model_wtype != GGML_TYPE_F32) {
LOG_WARN("In quantized models when applying LoRA, the images have poor quality.");
}
std::unordered_map<std::string, float> lora_state_diff;
for (auto& kv : lora_state) {
const std::string& lora_name = kv.first;
Expand Down
Loading