Skip to content
Merged
61 changes: 37 additions & 24 deletions lora.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef __LORA_HPP__
#define __LORA_HPP__

#include <mutex>
#include "ggml_extend.hpp"

#define LORA_GRAPH_BASE_SIZE 10240
Expand Down Expand Up @@ -115,49 +116,61 @@ struct LoraModel : public GGMLRunner {
return "lora";
}

bool load_from_file(bool filter_tensor = false) {
bool load_from_file(bool filter_tensor = false, int n_threads = 0) {
LOG_INFO("loading LoRA from '%s'", file_path.c_str());

if (load_failed) {
LOG_ERROR("init lora model loader from file failed: '%s'", file_path.c_str());
return false;
}

std::unordered_map<std::string, TensorStorage> tensors_to_create;
std::mutex lora_mutex;
bool dry_run = true;
auto on_new_tensor_cb = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) -> bool {
const std::string& name = tensor_storage.name;
if (dry_run) {
const std::string& name = tensor_storage.name;

if (filter_tensor && !contains(name, "lora")) {
// LOG_INFO("skipping LoRA tesnor '%s'", name.c_str());
return true;
}
// LOG_INFO("lora_tensor %s", name.c_str());
for (int i = 0; i < LORA_TYPE_COUNT; i++) {
if (name.find(type_fingerprints[i]) != std::string::npos) {
type = (lora_t)i;
break;
if (filter_tensor && !contains(name, "lora")) {
return true;
}
}

if (dry_run) {
struct ggml_tensor* real = ggml_new_tensor(params_ctx,
tensor_storage.type,
tensor_storage.n_dims,
tensor_storage.ne);
lora_tensors[name] = real;
{
std::lock_guard<std::mutex> lock(lora_mutex);
for (int i = 0; i < LORA_TYPE_COUNT; i++) {
if (name.find(type_fingerprints[i]) != std::string::npos) {
type = (lora_t)i;
break;
}
}
tensors_to_create[name] = tensor_storage;
}
} else {
auto real = lora_tensors[name];
*dst_tensor = real;
const std::string& name = tensor_storage.name;
auto iter = lora_tensors.find(name);
if (iter != lora_tensors.end()) {
*dst_tensor = iter->second;
}
}

return true;
};

model_loader.load_tensors(on_new_tensor_cb);
model_loader.load_tensors(on_new_tensor_cb, n_threads);

for (const auto& pair : tensors_to_create) {
const auto& name = pair.first;
const auto& ts = pair.second;
struct ggml_tensor* real = ggml_new_tensor(params_ctx,
ts.type,
ts.n_dims,
ts.ne);
lora_tensors[name] = real;
}

alloc_params_buffer();
// exit(0);

dry_run = false;
model_loader.load_tensors(on_new_tensor_cb);
model_loader.load_tensors(on_new_tensor_cb, n_threads);

LOG_DEBUG("lora type: \"%s\"/\"%s\"", lora_downs[type].c_str(), lora_ups[type].c_str());

Expand Down
Loading
Loading