diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 7984834188c..01ec22aa3cc 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2331,32 +2331,177 @@ def set_gguf_parameters(self): class SNACDecModel(Model): model_arch = gguf.MODEL_ARCH.SNAC_DEC - def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[Tuple[str, Tensor]]: - del bid # unused + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._dummy_added = False + + def modify_tensors(self, data_torch: torch.Tensor, name: str, bid: int | None) -> Iterable[Tuple[str, torch.Tensor]]: + """Convert nested PyTorch tensor names to a flat GGUF naming scheme for decoder tensors.""" + del bid # Unused + + # Add dummy token_embd.weight only once + if not self._dummy_added: + import torch + dummy_tok_embd = torch.zeros((4096, 8), dtype=torch.float16) + dummy_tok_embd = dummy_tok_embd.view(4096, 8) + logger.info(f"Adding dummy tensor: token_embd.weight, shape: {list(dummy_tok_embd.shape)}") + yield ("token_embd.weight", dummy_tok_embd) + self._dummy_added = True # Mark as added + + original_name = name + + if name.startswith("quantizer.quantizers."): + match = re.match(r"quantizer\.quantizers\.(\d+)\.(codebook\.weight|out_proj\.bias|out_proj\.parametrizations\.weight\.original[0-1])", name) + if match: + q_idx = int(match.group(1)) + tensor_type = match.group(2) + if tensor_type == "codebook.weight": + new_name = f"quantizer.{q_idx}.codebook" + elif tensor_type == "out_proj.parametrizations.weight.original0": + new_name = f"quantizer.{q_idx}.out_proj.scale" + elif tensor_type == "out_proj.parametrizations.weight.original1": + new_name = f"quantizer.{q_idx}.out_proj.weight" + elif tensor_type == "out_proj.bias": + new_name = f"quantizer.{q_idx}.out_proj.bias" + + logger.info(f"Mapping {original_name} -> {new_name}, shape: {list(data_torch.shape)}") + yield (new_name, data_torch) + else: + logger.warning(f"Could not parse quantizer tensor from: {original_name}") + return - logger.debug(f"Processing tensor: {name}") + # Skip non-decoder tensors (except quantizers, which were handled above) + if not name.startswith("decoder."): + logger.debug(f"Skipping non-decoder tensor: {original_name}") + return - if (name.startswith("decoder.") or - re.match(r"quantizer\.quantizers\.\d+\.codebook\.weight", name) or - re.match(r"quantizer\.quantizers\.\d+\.out_proj\..*", name)): - logger.info(f"{name} -> {data_torch.shape}") - return [(name, data_torch)] - else: - logger.debug(f"Skipping {name!r}") - return [] + base = name[8:] # Remove 'decoder.' + parts = base.split(".") + + if base.startswith("model.0."): + logger.info(f"Skipping incompatible decoder layer 0 tensor: {original_name}") + return # Explicitly skip this layer + + # Layer 1: Second Conv + if base.startswith("model.1."): + if "bias" in name and "parametrizations" not in name: + new_name = "decoder.1.conv2.bias" + elif "parametrizations.weight.original0" in name: + new_name = "decoder.1.conv2.scale" + elif "parametrizations.weight.original1" in name: + new_name = "decoder.1.conv2.weight" + else: + logger.warning(f"Unhandled layer 1 tensor: {original_name}") + return + logger.info(f"Mapping {original_name} -> {new_name}, shape: {list(data_torch.shape)}") + yield (new_name, data_torch) + return + + # Layers 2–5: DecoderBlocks + if "model." in base and "block" in base: + try: + layer_idx = int(parts[1]) # e.g., '2' from 'model.2' + if layer_idx not in {2, 3, 4, 5}: + logger.debug(f"Skipping non-DecoderBlock layer {layer_idx}: {original_name}") + return + block_idx = int(parts[3]) # e.g., '1' from 'block.1' + new_base = f"decoder.{layer_idx}.block.{block_idx}" + + if block_idx == 0: # Snake1d + if "alpha" in name: + new_name = f"{new_base}.alpha" + else: + logger.error(f"Expected 'alpha' in {original_name}") + return + elif block_idx == 1: # Transpose Conv + if "bias" in name and "parametrizations" not in name: + new_name = f"{new_base}.trans.bias" + elif "parametrizations.weight.original0" in name: + new_name = f"{new_base}.trans.scale" + elif "parametrizations.weight.original1" in name: + new_name = f"{new_base}.trans.weight" + else: + logger.error(f"Unhandled tensor in block 1: {original_name}") + return + elif block_idx == 2: # Noise Block + if "linear.parametrizations.weight.original0" in name: + new_name = f"{new_base}.noise.scale" + elif "linear.parametrizations.weight.original1" in name: + new_name = f"{new_base}.noise.weight" + else: + logger.error(f"Unhandled tensor in block 2: {original_name}") + return + elif block_idx in {3, 4, 5}: # Residual Units + res_base = f"{new_base}.res" + if "block.0.alpha" in name: + new_name = f"{res_base}.snake1.alpha" + elif "block.1.bias" in name: + new_name = f"{res_base}.conv1.bias" + elif "block.1.parametrizations.weight.original0" in name: + new_name = f"{res_base}.conv1.scale" + elif "block.1.parametrizations.weight.original1" in name: + new_name = f"{res_base}.conv1.weight" + elif "block.2.alpha" in name: + new_name = f"{res_base}.snake2.alpha" + elif "block.3.bias" in name: + new_name = f"{res_base}.conv2.bias" + elif "block.3.parametrizations.weight.original0" in name: + new_name = f"{res_base}.conv2.scale" + elif "block.3.parametrizations.weight.original1" in name: + new_name = f"{res_base}.conv2.weight" + else: + logger.error(f"Unhandled tensor in residual unit: {original_name}") + return + else: + logger.error(f"Unhandled block index {block_idx} in layer {layer_idx}: {original_name}") + return + + logger.info(f"Mapping {original_name} -> {new_name}, shape: {list(data_torch.shape)}") + yield (new_name, data_torch) + return + + except (IndexError, ValueError) as e: + logger.error(f"Failed to parse tensor {original_name}: {e}") + return + + # Layer 6: Snake1d + if base == "model.6.alpha": + new_name = "decoder.6.alpha" + logger.info(f"Mapping {original_name} -> {new_name}, shape: {list(data_torch.shape)}") + yield (new_name, data_torch) + return + + # Layer 7: Final Conv + if base.startswith("model.7."): + if "bias" in name and "parametrizations" not in name: + new_name = "decoder.7.conv.bias" + elif "parametrizations.weight.original0" in name: + new_name = "decoder.7.conv.scale" + elif "parametrizations.weight.original1" in name: + new_name = "decoder.7.conv.weight" + else: + logger.warning(f"Unhandled layer 7 tensor: {original_name}") + return + logger.info(f"Mapping {original_name} -> {new_name}, shape: {list(data_torch.shape)}") + yield (new_name, data_torch) + return + + logger.warning(f"Tensor {original_name} not mapped to any layer") + return def set_vocab(self): self._set_vocab_none() def set_gguf_parameters(self): super().set_gguf_parameters() - self.gguf_writer.add_vocab_size(self.hparams["codebook_size"]) - self.gguf_writer.add_quantizer_count(len(self.hparams["vq_strides"])) - self.gguf_writer.add_features_length(self.hparams["codebook_dim"]) - self.gguf_writer.add_quantizer_strides(self.hparams["vq_strides"]) - self.gguf_writer.add_embedding_length(self.hparams["decoder_dim"]) - self.gguf_writer.add_decoder_upsample_rates(self.hparams["decoder_rates"]) - self.gguf_writer.add_decoder_channel_dims(self.hparams["decoder_channel_dims"]) + self.gguf_writer.add_vocab_size (4096) # TODO: Fix + self.gguf_writer.add_uint32("snac.quantizer.codebook_size", self.hparams["codebook_size"]) + self.gguf_writer.add_uint32("snac.quantizer.codebook_dim", self.hparams["codebook_dim"]) + self.gguf_writer.add_embedding_length(self.hparams["decoder_dim"]) # 1024 + self.gguf_writer.add_decoder_upsample_rates(self.hparams["decoder_rates"]) # [8, 8, 4, 2] + self.gguf_writer.add_uint32("n_layers", 8) + self.gguf_writer.add_array("decoder_channel_dims", [768, 1024, 512, 256, 128, 64, 1]) + self.gguf_writer.add_array("vq_strides", self.hparams["vq_strides"]) @Model.register("Qwen2MoeForCausalLM") class Qwen2MoeModel(Model): diff --git a/examples/tts/CMakeLists.txt b/examples/tts/CMakeLists.txt index c72bd814c3b..42f95df7387 100644 --- a/examples/tts/CMakeLists.txt +++ b/examples/tts/CMakeLists.txt @@ -3,3 +3,9 @@ add_executable(${TARGET} tts.cpp) install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT}) target_compile_features(${TARGET} PRIVATE cxx_std_17) + +set(TARGET llama-orpheus-tts) +add_executable(${TARGET} orpheus-tts.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/tts/orpheus-tts.cpp b/examples/tts/orpheus-tts.cpp new file mode 100644 index 00000000000..622ec46fde0 --- /dev/null +++ b/examples/tts/orpheus-tts.cpp @@ -0,0 +1,344 @@ +#include "common.h" +#include "llama.h" +#include "llama-impl.h" +#include "log.h" +#include "arg.h" +#include "sampling.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +std::vector redistribute_codes(const std::vector& raw_codes) { + std::vector snac_codes; + for (size_t i = 0; i < raw_codes.size(); i += 7) { + // Ensure we have a full frame (7 codes) + if (i + 6 >= raw_codes.size()) break; + + // Frame offsets (per notebook) + snac_codes.push_back(raw_codes[i]); // Codebook 0 (no offset) + snac_codes.push_back(raw_codes[i+1] - 4096); // Codebook 1 + snac_codes.push_back(raw_codes[i+2] - 8192); // Codebook 2 + snac_codes.push_back(raw_codes[i+3] - 12288); // Codebook 2 + snac_codes.push_back(raw_codes[i+4] - 16384); // Codebook 1 + snac_codes.push_back(raw_codes[i+5] - 20480); // Codebook 2 + snac_codes.push_back(raw_codes[i+6] - 24576); // Codebook 2 + } + return snac_codes; +} + +static std::vector embd_to_audio( + const float * embd, + const int n_codes, + const int n_embd, + const int n_thread); +static bool save_wav16(const std::string & fname, const std::vector & data, int sample_rate); +static void fill_hann_window(int length, bool periodic, float * output); +static void irfft(int n, const float * inp_cplx, float * out_real); +static void fold(const std::vector & data, int64_t n_out, int64_t n_win, int64_t n_hop, int64_t n_pad, std::vector & output); + +static void print_usage(int /*argc*/, char **argv) { + LOG("\nexample usage:\n"); + LOG("\n %s -m model.gguf -mv vocoder.gguf -p \"Hello world\"\n", argv[0]); + LOG("\n"); +} + +static void prompt_add(std::vector &prompt, const llama_vocab *vocab, const std::string &txt, bool add_special, bool parse_special) { + auto tmp = common_tokenize(vocab, txt, add_special, parse_special); + prompt.insert(prompt.end(), tmp.begin(), tmp.end()); +} + + +// // Include embd_to_audio and save_wav16 from tts.cpp (for now) +static std::vector embd_to_audio( + const float * embd, + const int n_codes, + const int n_embd, + const int n_thread) { + const int n_fft = 1280; + const int n_hop = 320; + const int n_win = 1280; + const int n_pad = (n_win - n_hop)/2; + const int n_out = (n_codes - 1)*n_hop + n_win; + + std::vector hann(n_fft); + fill_hann_window(hann.size(), true, hann.data()); + + int n_spec = n_embd*n_codes; + + std::vector E (n_spec); + std::vector S (n_spec); + std::vector ST(n_spec); + + for (int l = 0; l < n_codes; ++l) { + for (int k = 0; k < n_embd; ++k) { + E[k*n_codes + l] = embd[l*n_embd + k]; + } + } + + for (int k = 0; k < n_embd/2; ++k) { + for (int l = 0; l < n_codes; ++l) { + float mag = E[(k )*n_codes + l]; + float phi = E[(k + n_embd/2)*n_codes + l]; + mag = exp(mag); + if (mag > 1e2) { + mag = 1e2; + } + S[2*(k*n_codes + l) + 0] = mag*cosf(phi); + S[2*(k*n_codes + l) + 1] = mag*sinf(phi); + } + } + + for (int l = 0; l < n_codes; ++l) { + for (int k = 0; k < n_embd/2; ++k) { + ST[l*n_embd + 2*k + 0] = S[2*(k*n_codes + l) + 0]; + ST[l*n_embd + 2*k + 1] = S[2*(k*n_codes + l) + 1]; + } + } + + std::vector res (n_codes*n_fft); + std::vector hann2(n_codes*n_fft); + + std::vector workers(n_thread); + for (int i = 0; i < n_thread; ++i) { + workers[i] = std::thread([&, i]() { + for (int l = i; l < n_codes; l += n_thread) { + irfft(n_fft, ST.data() + l*n_embd, res.data() + l*n_fft); + for (int j = 0; j < n_fft; ++j) { + res [l*n_fft + j] *= hann[j]; + hann2[l*n_fft + j] = hann[j] * hann[j]; + } + } + }); + } + for (int i = 0; i < n_thread; ++i) { + workers[i].join(); + } + + std::vector audio; + std::vector env; + + fold(res, n_out, n_win, n_hop, n_pad, audio); + fold(hann2, n_out, n_win, n_hop, n_pad, env); + + for (size_t i = 0; i < audio.size(); ++i) { + audio[i] /= env[i]; + } + + return audio; +} + +static bool save_wav16(const std::string & fname, const std::vector & data, int sample_rate) { + std::ofstream file(fname, std::ios::binary); + if (!file) { + LOG_ERR("%s: Failed to open file '%s' for writing.\n", __func__, fname.c_str()); + return false; + } + + struct wav_header { + char riff[4] = {'R', 'I', 'F', 'F'}; + uint32_t chunk_size; + char wave[4] = {'W', 'A', 'V', 'E'}; + char fmt[4] = {'f', 'm', 't', ' '}; + uint32_t fmt_chunk_size = 16; + uint16_t audio_format = 1; // PCM + uint16_t num_channels = 1; // Mono + uint32_t sample_rate; + uint32_t byte_rate; + uint16_t block_align; + uint16_t bits_per_sample = 16; + char data[4] = {'d', 'a', 't', 'a'}; + uint32_t data_size; + } header; + + header.sample_rate = sample_rate; + header.byte_rate = header.sample_rate * header.num_channels * (header.bits_per_sample / 8); + header.block_align = header.num_channels * (header.bits_per_sample / 8); + header.data_size = data.size() * (header.bits_per_sample / 8); + header.chunk_size = 36 + header.data_size; + + file.write(reinterpret_cast(&header), sizeof(header)); + + for (const auto & sample : data) { + int16_t pcm_sample = static_cast(std::clamp(sample * 32767.0, -32768.0, 32767.0)); + file.write(reinterpret_cast(&pcm_sample), sizeof(pcm_sample)); + } + + return file.good(); +} + +// Supporting functions from tts.cpp (for embd_to_audio) +static void fill_hann_window(int length, bool periodic, float * output) { + int offset = -1; + if (periodic) { + offset = 0; + } + for (int i = 0; i < length; i++) { + output[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset))); + } +} + +static void twiddle(float * real, float * imag, int k, int N) { + float angle = 2 * M_PI * k / N; + *real = cos(angle); + *imag = sin(angle); +} + +static void irfft(int n, const float * inp_cplx, float * out_real) { + int N = n / 2 + 1; + + std::vector real_input(N); + std::vector imag_input(N); + for (int i = 0; i < N; ++i) { + real_input[i] = inp_cplx[2 * i]; + imag_input[i] = inp_cplx[2 * i + 1]; + } + + std::vector real_output(n); + std::vector imag_output(n); + + for (int k = 0; k < n; ++k) { + real_output[k] = 0.0f; + imag_output[k] = 0.0f; + for (int m = 0; m < N; ++m) { + float twiddle_real; + float twiddle_imag; + + twiddle(&twiddle_real, &twiddle_imag, k * m, n); + + real_output[k] += real_input[m] * twiddle_real - imag_input[m] * twiddle_imag; + imag_output[k] += real_input[m] * twiddle_imag + imag_input[m] * twiddle_real; + } + } + + for (int i = 0; i < n; ++i) { + out_real[i] = real_output[i] / N; + } +} + +static void fold(const std::vector & data, int64_t n_out, int64_t n_win, int64_t n_hop, int64_t n_pad, std::vector & output) { + int64_t output_height = n_out; + int64_t kernel_w = n_win; + int64_t stride_w = n_hop; + int64_t width = n_out; + + output.resize(width, 0.0f); + + int64_t col_idx = 0; + for (int64_t w_col = 0; w_col < width; ++w_col) { + int64_t start = w_col * stride_w - n_pad; + int64_t end = start + kernel_w; + + for (int64_t w_im = start; w_im < end; ++w_im) { + if (w_im >= 0 && w_im < output_height && col_idx < (int64_t) data.size()) { + output[w_im] += data[col_idx]; + } + col_idx++; + } + } + + output.resize(n_out - 2 * n_pad); +} + +int main(int argc, char **argv) { + common_params params; + + params.model = "models/orpheus-3b-0.1-ft-q4_k_m.gguf"; + params.vocoder.model = "models/snac-vocab.gguf"; + params.out_file = "output.wav"; + + params.n_predict = 1200; + params.sampling.top_k = 4; + params.sampling.samplers = { COMMON_SAMPLER_TYPE_TOP_K }; + params.n_batch = 4096; + + common_init(); + llama_backend_init(); + llama_numa_init(params.numa); + + common_init_result orpheus_init_ttc = common_init_from_params(params); + + llama_model * model_ttc = NULL; + llama_context * ctx_ttc = NULL; + + model_ttc = orpheus_init_ttc.model.get(); + ctx_ttc = orpheus_init_ttc.context.get(); + + const llama_vocab *vocab = llama_model_get_vocab(model_ttc); + + common_sampler *sampler = common_sampler_init(model_ttc, params.sampling); + if (!sampler) { + LOG_ERR("Failed to initialize sampler\n"); + return 1; + } + + // Construct prompt: <|startofhuman|> tara: [prompt] <|eot_id|> <|endofhuman|> + std::vector tokens; + tokens.push_back(128259); // <|startofhuman|> + prompt_add(tokens, vocab, "tara: ", false, true); // Voice prefix + prompt_add(tokens, vocab, params.prompt, false, true); // User prompt + prompt_add(tokens, vocab, "", false, true); // Emotion tag + tokens.push_back(128009); // <|eot_id|> + tokens.push_back(128260); // <|endofhuman|> + + + llama_model * model_cts = NULL; + llama_context * ctx_cts = NULL; + + params.model = params.vocoder.model; + params.n_batch = 2; + + params.embedding = true + // disable warmup, SNAC doesn't care about BOS or EOS tokens; + params.warmup = false; + + common_init_result snac_init_cts = common_init_from_params(params); + LOG_INF("SNAC model loaded: %s\n", params.model.c_str()); + + model_cts = snac_init_cts.model.get(); + ctx_cts = snac_init_cts.context.get(); + + std::vector speech_codes = {100, 4200, 8500, 12500, 16500, 21000, 25000, + 200, 4300, 8600, 12600, 16600, 21111, 25100}; + + std::vector snac_codes = redistribute_codes(speech_codes); + + const int n_codes = speech_codes.size(); + const int batch_size = n_codes; + + llama_batch batch = llama_batch_init(batch_size, 0, 1); + + for (size_t i = 0; i < n_codes; ++i) { + common_batch_add(batch, snac_codes[i], i, {0}, true); + } + + LOG_INF("Batch before decode: n_tokens = %d\n", batch.n_tokens); + if (llama_decode(ctx_cts, batch) != 0) { /* error */ } + + if (llama_decode(ctx_cts, batch) != 0) { /* error */ } + GGML_ASSERT(batch.n_tokens == n_codes); + + batch.logits[batch.n_tokens - 1] = true; + + if (llama_decode(ctx_cts, batch) != 0) { + LOG_ERR("Failed to decode SNAC batch\n"); + return 1; + } + llama_synchronize(ctx_cts); + + LOG_INF("SNAC decode completed\n"); + + llama_batch_free(batch); + llama_backend_free(); + return 0; +} diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 15e86da8960..19fb2319f85 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -174,9 +174,7 @@ class ConvNext: BLOCK_COUNT = "{arch}.convnext.block_count" class AudioCodec: - QUANTIZER_COUNT = "{arch}.audio_codec.quantizer_count" - CODEBOOK_DIM = "{arch}.audio_codec.codebook_dim" - QUANTIZER_STRIDES = "{arch}.audio_codec.quantizer_strides" + #CODEBOOK_DIM = "{arch}.audio_codec.codebook_dim" DECODER_UPSAMPLE_RATES = "{arch}.audio_codec.decoder_upsample_rates" DECODER_CHANNEL_DIMS = "{arch}.audio_codec.decoder_channel_dims" diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 9cf3a55702a..f1924358d9c 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -887,12 +887,6 @@ def add_remove_extra_whitespaces(self, value: bool) -> None: def add_precompiled_charsmap(self, charsmap: Sequence[bytes]) -> None: self.add_array(Keys.Tokenizer.PRECOMPILED_CHARSMAP, charsmap) - def add_quantizer_count(self, count: int) -> None: - self.add_uint32(Keys.AudioCodec.QUANTIZER_COUNT.format(arch=self.arch), count) - - def add_quantizer_strides(self, strides: Sequence[int]) -> None: - self.add_array(Keys.AudioCodec.QUANTIZER_STRIDES.format(arch=self.arch), strides) - def add_decoder_upsample_rates(self, rates: Sequence[int]) -> None: self.add_array(Keys.AudioCodec.DECODER_UPSAMPLE_RATES.format(arch=self.arch), rates) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 9debb56cc80..b90214d2555 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -65,6 +65,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_GRANITE_MOE, "granitemoe" }, { LLM_ARCH_CHAMELEON, "chameleon" }, { LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" }, + { LLM_ARCH_SNAC_DEC, "snac-dec" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -1391,6 +1392,55 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_POS_NET_ATTN_OUT, "posnet.%d.attn_output" }, }, }, + { + LLM_ARCH_SNAC_DEC, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_CODEBOOK, "quantizer.%d.codebook" }, + { LLM_TENSOR_CODEBOOK_PROJ_B, "quantizer.%d.out_proj.bias" }, + { LLM_TENSOR_CODEBOOK_PROJ_S, "quantizer.%d.out_proj.scale" }, + { LLM_TENSOR_CODEBOOK_PROJ_W, "quantizer.%d.out_proj.weight" }, + + { LLM_TENSOR_CONV_W2, "decoder.1.conv2.weight" }, + { LLM_TENSOR_CONV_S2, "decoder.1.conv2.scale" }, + { LLM_TENSOR_CONV_B2, "decoder.1.conv2.bias" }, + { LLM_TENSOR_BLOCK_ALPHA, "decoder.%d.block.0.alpha" }, + { LLM_TENSOR_TRANS_W, "decoder.%d.block.1.trans.weight" }, + { LLM_TENSOR_TRANS_S, "decoder.%d.block.1.trans.scale" }, + { LLM_TENSOR_TRANS_B, "decoder.%d.block.1.trans.bias" }, + { LLM_TENSOR_NOISE_W, "decoder.%d.block.2.noise.weight" }, + { LLM_TENSOR_NOISE_S, "decoder.%d.block.2.noise.scale" }, + // Residual Units + { LLM_TENSOR_RES_SNAKE1_A, "decoder.%d.block.3.res.snake1.alpha" }, + { LLM_TENSOR_RES_CONV1_W, "decoder.%d.block.3.res.conv1.weight" }, + { LLM_TENSOR_RES_CONV1_S, "decoder.%d.block.3.res.conv1.scale" }, + { LLM_TENSOR_RES_CONV1_B, "decoder.%d.block.3.res.conv1.bias" }, + { LLM_TENSOR_RES_SNAKE2_A, "decoder.%d.block.3.res.snake2.alpha" }, + { LLM_TENSOR_RES_CONV2_W, "decoder.%d.block.3.res.conv2.weight" }, + { LLM_TENSOR_RES_CONV2_S, "decoder.%d.block.3.res.conv2.scale" }, + { LLM_TENSOR_RES_CONV2_B, "decoder.%d.block.3.res.conv2.bias" }, + { LLM_TENSOR_RES_SNAKE1_A_B4, "decoder.%d.block.4.res.snake1.alpha" }, + { LLM_TENSOR_RES_CONV1_W_B4, "decoder.%d.block.4.res.conv1.weight" }, + { LLM_TENSOR_RES_CONV1_S_B4, "decoder.%d.block.4.res.conv1.scale" }, + { LLM_TENSOR_RES_CONV1_B_B4, "decoder.%d.block.4.res.conv1.bias" }, + { LLM_TENSOR_RES_SNAKE2_A_B4, "decoder.%d.block.4.res.snake2.alpha" }, + { LLM_TENSOR_RES_CONV2_W_B4, "decoder.%d.block.4.res.conv2.weight" }, + { LLM_TENSOR_RES_CONV2_S_B4, "decoder.%d.block.4.res.conv2.scale" }, + { LLM_TENSOR_RES_CONV2_B_B4, "decoder.%d.block.4.res.conv2.bias" }, + { LLM_TENSOR_RES_SNAKE1_A_B5, "decoder.%d.block.5.res.snake1.alpha" }, + { LLM_TENSOR_RES_CONV1_W_B5, "decoder.%d.block.5.res.conv1.weight" }, + { LLM_TENSOR_RES_CONV1_S_B5, "decoder.%d.block.5.res.conv1.scale" }, + { LLM_TENSOR_RES_CONV1_B_B5, "decoder.%d.block.5.res.conv1.bias" }, + { LLM_TENSOR_RES_SNAKE2_A_B5, "decoder.%d.block.5.res.snake2.alpha" }, + { LLM_TENSOR_RES_CONV2_W_B5, "decoder.%d.block.5.res.conv2.weight" }, + { LLM_TENSOR_RES_CONV2_S_B5, "decoder.%d.block.5.res.conv2.scale" }, + { LLM_TENSOR_RES_CONV2_B_B5, "decoder.%d.block.5.res.conv2.bias" }, + { LLM_TENSOR_ALPHA, "decoder.6.alpha" }, + { LLM_TENSOR_CONV_W7, "decoder.7.conv.weight" }, + { LLM_TENSOR_CONV_S7, "decoder.7.conv.scale" }, + { LLM_TENSOR_CONV_B7, "decoder.7.conv.bias" }, + }, + }, { LLM_ARCH_UNKNOWN, { @@ -1552,8 +1602,53 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_CONVNEXT_PW1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CONVNEXT_PW2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CONVNEXT_GAMMA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + + { LLM_TENSOR_CONV_B2, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD } }, + { LLM_TENSOR_CONV_S2, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL } }, + { LLM_TENSOR_CONV_W2, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_IM2COL } }, + { LLM_TENSOR_BLOCK_ALPHA, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL } }, + { LLM_TENSOR_TRANS_B, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD } }, + { LLM_TENSOR_TRANS_S, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL } }, + { LLM_TENSOR_TRANS_W, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_IM2COL } }, + { LLM_TENSOR_NOISE_S, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL } }, + { LLM_TENSOR_NOISE_W, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT } }, + { LLM_TENSOR_RES_SNAKE1_A, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL } }, + { LLM_TENSOR_RES_CONV1_B, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD } }, + { LLM_TENSOR_RES_CONV1_S, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL } }, + { LLM_TENSOR_RES_CONV1_W, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_IM2COL } }, + { LLM_TENSOR_RES_SNAKE2_A, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL } }, + { LLM_TENSOR_RES_CONV2_B, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD } }, + { LLM_TENSOR_RES_CONV2_S, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL } }, + { LLM_TENSOR_RES_CONV2_W, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_IM2COL } }, + { LLM_TENSOR_RES_SNAKE1_A_B4, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL } }, + { LLM_TENSOR_RES_CONV1_B_B4, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD } }, + { LLM_TENSOR_RES_CONV1_S_B4, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL } }, + { LLM_TENSOR_RES_CONV1_W_B4, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_IM2COL } }, + { LLM_TENSOR_RES_SNAKE2_A_B4, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL } }, + { LLM_TENSOR_RES_CONV2_B_B4, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD } }, + { LLM_TENSOR_RES_CONV2_S_B4, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL } }, + { LLM_TENSOR_RES_CONV2_W_B4, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_IM2COL } }, + { LLM_TENSOR_RES_SNAKE1_A_B5, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL } }, + { LLM_TENSOR_RES_CONV1_B_B5, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD } }, + { LLM_TENSOR_RES_CONV1_S_B5, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL } }, + { LLM_TENSOR_RES_CONV1_W_B5, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_IM2COL } }, + { LLM_TENSOR_RES_SNAKE2_A_B5, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL } }, + { LLM_TENSOR_RES_CONV2_B_B5, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD } }, + { LLM_TENSOR_RES_CONV2_S_B5, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL } }, + { LLM_TENSOR_RES_CONV2_W_B5, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_IM2COL } }, + { LLM_TENSOR_ALPHA, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL } }, + { LLM_TENSOR_CONV_B7, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD } }, + { LLM_TENSOR_CONV_S7, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL } }, + { LLM_TENSOR_CONV_W7, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_IM2COL } }, + + { LLM_TENSOR_CODEBOOK, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS } }, + { LLM_TENSOR_CODEBOOK_PROJ_B, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD } }, + { LLM_TENSOR_CODEBOOK_PROJ_S, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL } }, + { LLM_TENSOR_CODEBOOK_PROJ_W, { LLM_TENSOR_LAYER_REPEATING, GGML_OP_IM2COL } }, }; + + LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {} std::string LLM_KV::operator()(llm_kv kv) const { @@ -1563,6 +1658,7 @@ std::string LLM_KV::operator()(llm_kv kv) const { std::string LLM_TN_IMPL::str() const { if (LLM_TENSOR_NAMES.at(arch).find(tensor) == LLM_TENSOR_NAMES.at(arch).end()) { + fprintf(stderr, "LLM_TN_IMPL::str: tensor enum %d not found in map for arch %d\n", (int)tensor, (int)arch); return "__missing__"; } diff --git a/src/llama-arch.h b/src/llama-arch.h index a28815d8a14..5d649d045cc 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -69,6 +69,7 @@ enum llm_arch { LLM_ARCH_GRANITE_MOE, LLM_ARCH_CHAMELEON, LLM_ARCH_WAVTOKENIZER_DEC, + LLM_ARCH_SNAC_DEC, LLM_ARCH_UNKNOWN, }; @@ -201,6 +202,11 @@ enum llm_kv { LLM_KV_CONVNEXT_EMBEDDING_LENGTH, LLM_KV_CONVNEXT_BLOCK_COUNT, + LLM_KV_QUANTIZER_COUNT, + LLM_KV_QUANTIZER_STRIDES, + LLM_KV_DECODER_UPSAMPLE_RATES, + LLM_KV_DECODER_CHANNEL_DIMS, + // deprecated: LLM_KV_TOKENIZER_PREFIX_ID, LLM_KV_TOKENIZER_SUFFIX_ID, @@ -346,6 +352,52 @@ enum llm_tensor { LLM_TENSOR_POS_NET_ATTN_K, LLM_TENSOR_POS_NET_ATTN_V, LLM_TENSOR_POS_NET_ATTN_OUT, + + LLM_TENSOR_CONV_B1, + LLM_TENSOR_CONV_S1, + LLM_TENSOR_CONV_W1, + LLM_TENSOR_CONV_B2, + LLM_TENSOR_CONV_S2, + LLM_TENSOR_CONV_W2, + LLM_TENSOR_BLOCK_ALPHA, + LLM_TENSOR_TRANS_B, + LLM_TENSOR_TRANS_S, + LLM_TENSOR_TRANS_W, + LLM_TENSOR_NOISE_S, + LLM_TENSOR_NOISE_W, + LLM_TENSOR_RES_SNAKE1_A, + LLM_TENSOR_RES_CONV1_B, + LLM_TENSOR_RES_CONV1_S, + LLM_TENSOR_RES_CONV1_W, + LLM_TENSOR_RES_SNAKE2_A, + LLM_TENSOR_RES_CONV2_B, + LLM_TENSOR_RES_CONV2_S, + LLM_TENSOR_RES_CONV2_W, + LLM_TENSOR_RES_SNAKE1_A_B4, + LLM_TENSOR_RES_CONV1_B_B4, + LLM_TENSOR_RES_CONV1_S_B4, + LLM_TENSOR_RES_CONV1_W_B4, + LLM_TENSOR_RES_SNAKE2_A_B4, + LLM_TENSOR_RES_CONV2_B_B4, + LLM_TENSOR_RES_CONV2_S_B4, + LLM_TENSOR_RES_CONV2_W_B4, + LLM_TENSOR_RES_SNAKE1_A_B5, + LLM_TENSOR_RES_CONV1_B_B5, + LLM_TENSOR_RES_CONV1_S_B5, + LLM_TENSOR_RES_CONV1_W_B5, + LLM_TENSOR_RES_SNAKE2_A_B5, + LLM_TENSOR_RES_CONV2_B_B5, + LLM_TENSOR_RES_CONV2_S_B5, + LLM_TENSOR_RES_CONV2_W_B5, + LLM_TENSOR_ALPHA, + LLM_TENSOR_CONV_B7, + LLM_TENSOR_CONV_S7, + LLM_TENSOR_CONV_W7, + + LLM_TENSOR_CODEBOOK, + LLM_TENSOR_CODEBOOK_PROJ_B, + LLM_TENSOR_CODEBOOK_PROJ_S, + LLM_TENSOR_CODEBOOK_PROJ_W, }; enum llm_tensor_layer { diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 5bec63e2e79..ca4adaa781c 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -312,7 +312,9 @@ llama_context::llama_context( // reserve pp graph first so that buffers are only allocated once { + LLAMA_LOG_DEBUG("here 3\n"); llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; + auto * gf = graph_init(); graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT); if (!ggml_backend_sched_reserve(sched.get(), gf)) { diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 0bd40174438..f0a8b1071dc 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -985,6 +985,8 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const { } } else { inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, ubatch.n_tokens); + LLAMA_LOG_DEBUG("build_inp_embd: inp->embd shape = [%ld, %ld, %ld, %ld]\n", + inp->embd->ne[0], inp->embd->ne[1], inp->embd->ne[2], inp->embd->ne[3]); ggml_set_input(inp->embd); cur = inp->embd; diff --git a/src/llama-hparams.h b/src/llama-hparams.h index bb17ba86dc2..ed0f9ab98d8 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -70,6 +70,10 @@ struct llama_hparams { float f_attn_logit_softcapping = 50.0f; float f_final_logit_softcapping = 30.0f; + // for SNAC vocoder + std::array upsample_rates; + std::array n_channels; + // for RWKV uint32_t rescale_every_n_layers = 0; uint32_t time_mix_extra_dim = 0; diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp index 05d58ad90eb..4730b8c6a5d 100644 --- a/src/llama-model-loader.cpp +++ b/src/llama-model-loader.cpp @@ -745,6 +745,8 @@ const struct ggml_tensor * llama_model_loader::check_tensor_dims(const std::stri } } if (!is_ok) { + fprintf(stderr, "check_tensor_dims: name=%s, expected=%s, got=%s\n", + name.c_str(), llama_format_tensor_shape(ne).c_str(), llama_format_tensor_shape(cur).c_str()); throw std::runtime_error( format("%s: tensor '%s' has wrong shape; expected %s, got %s", __func__, name.c_str(), diff --git a/src/llama-model.cpp b/src/llama-model.cpp index cd7e0a0c4db..e711b286848 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1317,6 +1317,21 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_GROUPNORM_GROUPS, hparams.n_norm_groups); ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); } break; + case LLM_ARCH_SNAC_DEC: + { + hparams.n_channels = {768, 1024, 512, 256, 128, 64, 1}; // From decoder_channel_dims + hparams.upsample_rates = {8, 8, 4, 2}; + hparams.n_embd = 768; + hparams.n_layer = 8; + + // Dummy KV cache params to satisfy llama.cpp + for (uint32_t i = 0; i < 7; ++i) { // n_total_layers = 8 + hparams.n_head_arr[i] = 1; + hparams.n_head_kv_arr[i] = 1; + } + hparams.n_embd_head_k = 1; + hparams.n_embd_head_v = 1; + } break; default: throw std::runtime_error("unsupported model architecture"); } @@ -1473,13 +1488,15 @@ bool llama_model::load_tensors(llama_model_loader & ml) { ggml_backend_buffer_type_t first_moved_to_buft = nullptr; auto create_tensor = [&](const LLM_TN_IMPL & tn, const std::initializer_list & ne, int flags) -> ggml_tensor * { - ggml_tensor * t_meta = ml.get_tensor_meta(tn.str().c_str()); + std::string tn_str = tn.str(); + ggml_tensor * t_meta = ml.get_tensor_meta(tn_str.c_str()); if (!t_meta) { if (flags & TENSOR_NOT_REQUIRED) { return nullptr; } - throw std::runtime_error(format("missing tensor '%s'", tn.str().c_str())); + return nullptr; + //throw std::runtime_error(format("missing tensor '%s'", tn.str().c_str())); } // some models use the token embedding tensor as the output, but since these are used in different layers and with different ops @@ -1574,6 +1591,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { return t; } } + fprintf(stderr, "create_tensor: Creating '%s' with ne=[%ld, %ld, %ld]\n", + tn_str.c_str(), ne.begin()[0], ne.begin()[1], ne.begin()[2]); return ml.create_tensor(ctx, tn, ne, flags); }; @@ -3686,7 +3705,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.gamma = create_tensor(tn(LLM_TENSOR_CONVNEXT_GAMMA, "weight", i), {n_embd}, 0); } - // output output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); } @@ -3694,6 +3712,138 @@ bool llama_model::load_tensors(llama_model_loader & ml) { output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hparams.convnext.n_embd, n_embd}, 0); output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {n_embd}, 0); } break; + case LLM_ARCH_SNAC_DEC: + { + // TODO: Magic numbers everwhere + const int64_t n_total_layers = hparams.n_layer; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {8, 4096, 1}, 0); + + hparams.n_channels = {768, 1024, 512, 256, 128, 64, 1}; + + // Quantizer projection tensors (0, 1, 2) + for (int qid = 0; qid < 3; ++qid) { + fprintf(stderr, "%s: Loading quantizer %d tensors\n", __func__, qid); + // Bias: [768, 1, 1, 1] -> {768} + codebook_proj_b[qid] = create_tensor(tn(LLM_TENSOR_CODEBOOK_PROJ_B, qid, -1), {768, 1, 1}, 0); + // Scale: [1, 1, 768, 1] -> {768} + codebook_proj_s[qid] = create_tensor(tn(LLM_TENSOR_CODEBOOK_PROJ_S, qid, -1), {1, 1, 768}, 0); + // Weight: [1, 8, 768, 1] -> {8, 768} + codebook_proj_w[qid] = create_tensor(tn(LLM_TENSOR_CODEBOOK_PROJ_W, qid, -1), {1, 8, 768}, 0); + + codebook[qid] = create_tensor(tn(LLM_TENSOR_CODEBOOK, qid, -1), {8, 4096, 1, 1}, 0); + } + + // Decoder tensors + for (int i = 1; i < n_total_layers; ++i) { // Loop from i = 0 to 7 + auto & layer = layers[i]; + + // Calculate n_in and n_out for the current layer i + const int64_t n_in = (i == 0) ? 1 : ((i == 7) ? hparams.n_channels[i-2] /* 64 */ : hparams.n_channels[i-1]); + const int64_t n_out = (i == 7) ? hparams.n_channels[i-1] /* 1 */ : hparams.n_channels[i]; + + fprintf(stderr, "%s: Layer %d: Starting (n_in=%lld, n_out=%lld)\n", __func__, i, n_in, n_out); + + if (i == 1) { // --- Layer 1: Conv2 --- + layer.conv_w = create_tensor(tn(LLM_TENSOR_CONV_W2, i, -1), {1, n_in, n_out}, 0); + layer.conv_s = create_tensor(tn(LLM_TENSOR_CONV_S2, i, -1), {1, 1, n_out}, 0); + layer.conv_b = create_tensor(tn(LLM_TENSOR_CONV_B2, i, -1), {n_out}, 0); + } + else if (i >= 2 && i <= 5) { // --- Layers 2-5: Blocks --- + const int n_blocks = 6; + layer.decoder_blocks.resize(n_blocks); + + for (int bid = 0; bid < n_blocks; ++bid) { + LLAMA_LOG_DEBUG("%s: Layer %d, Block %d: Starting\n", __func__, i, bid); + + switch (bid) { + case 0: // Block 0: Alpha + layer.decoder_blocks[bid].alpha = create_tensor(tn(LLM_TENSOR_BLOCK_ALPHA, i, bid), {1, n_in, 1}, 0); + break; + case 1: // Block 1: Transition + { + int64_t trans_dim; + if (i == 2) trans_dim = 16; + else if (i == 3) trans_dim = 16; + else if (i == 4) trans_dim = 8; + else trans_dim = 4; // Assumed for i == 5 + LLAMA_LOG_DEBUG("%s: Layer %d, Block %d: Using trans_dim = %lld\n", __func__, i, bid, trans_dim); + layer.decoder_blocks[bid].up_weight = create_tensor(tn(LLM_TENSOR_TRANS_W, i, bid), {trans_dim, n_out, n_in}, 0); + layer.decoder_blocks[bid].up_scale = create_tensor(tn(LLM_TENSOR_TRANS_S, i, bid), {1, 1, n_in}, 0); + layer.decoder_blocks[bid].up_bias = create_tensor(tn(LLM_TENSOR_TRANS_B, i, bid), {n_out}, 0); + if (!layer.decoder_blocks[bid].up_bias) { + LLAMA_LOG_DEBUG("Failed to create decoder.%d.block.%d.trans.bias\n", i, bid); + } + } + break; + case 2: + { + LLAMA_LOG_DEBUG("%s: Layer %d, Block %d: Loading noise tensors\n", __func__, i, bid); + layer.decoder_blocks[bid].noise_w = create_tensor(tn(LLM_TENSOR_NOISE_W, i, bid), {1, n_out, n_out}, 0); + layer.decoder_blocks[bid].noise_s = create_tensor(tn(LLM_TENSOR_NOISE_S, i, bid), {1, 1, n_out}, 0); + } + break; + case 3: // Block 3: Residual Unit 1 + { + int res_unit_idx = 0; auto & res_unit = layer.decoder_blocks[bid].res_units[res_unit_idx]; + res_unit.alpha1 = create_tensor(tn(LLM_TENSOR_RES_SNAKE1_A, i, bid), {1, n_out, 1}, 0); + res_unit.conv1_w = create_tensor(tn(LLM_TENSOR_RES_CONV1_W, i, bid), {7, 1, n_out}, 0); + res_unit.conv1_s = create_tensor(tn(LLM_TENSOR_RES_CONV1_S, i, bid), {1, 1, n_out}, 0); + res_unit.conv1_b = create_tensor(tn(LLM_TENSOR_RES_CONV1_B, i, bid), {n_out}, 0); + res_unit.alpha2 = create_tensor(tn(LLM_TENSOR_RES_SNAKE2_A, i, bid), {1, n_out, 1}, 0); + res_unit.conv2_w = create_tensor(tn(LLM_TENSOR_RES_CONV2_W, i, bid), {1, n_out, n_out}, 0); + res_unit.conv2_s = create_tensor(tn(LLM_TENSOR_RES_CONV2_S, i, bid), {1, 1, n_out}, 0); + res_unit.conv2_b = create_tensor(tn(LLM_TENSOR_RES_CONV2_B, i, bid), {n_out}, 0); + } + break; + case 4: // Block 4: Residual Unit 2 + { + int res_unit_idx = 1; auto & res_unit = layer.decoder_blocks[bid].res_units[res_unit_idx]; + res_unit.alpha1 = create_tensor(tn(LLM_TENSOR_RES_SNAKE1_A_B4, i, bid), {1, n_out, 1}, 0); + res_unit.conv1_w = create_tensor(tn(LLM_TENSOR_RES_CONV1_W_B4, i, bid), {7, 1, n_out}, 0); + res_unit.conv1_s = create_tensor(tn(LLM_TENSOR_RES_CONV1_S_B4, i, bid), {1, 1, n_out}, 0); + res_unit.conv1_b = create_tensor(tn(LLM_TENSOR_RES_CONV1_B_B4, i, bid), {n_out}, 0); + res_unit.alpha2 = create_tensor(tn(LLM_TENSOR_RES_SNAKE2_A_B4, i, bid), {1, n_out, 1}, 0); + res_unit.conv2_w = create_tensor(tn(LLM_TENSOR_RES_CONV2_W_B4, i, bid), {1, n_out, n_out}, 0); + res_unit.conv2_s = create_tensor(tn(LLM_TENSOR_RES_CONV2_S_B4, i, bid), {1, 1, n_out}, 0); + res_unit.conv2_b = create_tensor(tn(LLM_TENSOR_RES_CONV2_B_B4, i, bid), {n_out}, 0); + } + break; + case 5: // Block 5: Residual Unit 3 + { + int res_unit_idx = 2; auto & res_unit = layer.decoder_blocks[bid].res_units[res_unit_idx]; + res_unit.alpha1 = create_tensor(tn(LLM_TENSOR_RES_SNAKE1_A_B5, i, bid), {1, n_out, 1}, 0); + res_unit.conv1_w = create_tensor(tn(LLM_TENSOR_RES_CONV1_W_B5, i, bid), {7, 1, n_out}, 0); + res_unit.conv1_s = create_tensor(tn(LLM_TENSOR_RES_CONV1_S_B5, i, bid), {1, 1, n_out}, 0); + res_unit.conv1_b = create_tensor(tn(LLM_TENSOR_RES_CONV1_B_B5, i, bid), {n_out}, 0); + res_unit.alpha2 = create_tensor(tn(LLM_TENSOR_RES_SNAKE2_A_B5, i, bid), {1, n_out, 1}, 0); + res_unit.conv2_w = create_tensor(tn(LLM_TENSOR_RES_CONV2_W_B5, i, bid), {1, n_out, n_out}, 0); + res_unit.conv2_s = create_tensor(tn(LLM_TENSOR_RES_CONV2_S_B5, i, bid), {1, 1, n_out}, 0); + res_unit.conv2_b = create_tensor(tn(LLM_TENSOR_RES_CONV2_B_B5, i, bid), {n_out}, 0); + } + break; + default: + fprintf(stderr, "%s: ERROR: Unexpected block id %d in layer %d\n", __func__, bid, i); + return false; // Or handle error appropriately + } + fprintf(stderr, "%s: Layer %d, Block %d: Finished\n", __func__, i, bid); + } // End block loop + } + else if (i == 6) { // --- Layer 6: Alpha --- + layer.alpha = create_tensor(tn(LLM_TENSOR_ALPHA, i, -1), {1, n_in, 1}, 0); + } + else if (i == 7) { // --- Layer 7: Final Conv --- + layer.conv_w = create_tensor(tn(LLM_TENSOR_CONV_W7, i, -1), {7, n_in, n_out}, 0); + layer.conv_s = create_tensor(tn(LLM_TENSOR_CONV_S7, i, -1), {1, 1, n_out}, 0); + layer.conv_b = create_tensor(tn(LLM_TENSOR_CONV_B7, i, -1), {n_out}, 0); + } + else { // Should not happen + fprintf(stderr, "%s: ERROR: Unexpected layer index %d\n", __func__, i); + return false; // Or handle error appropriately + } + fprintf(stderr, "%s: Layer %d: Finished\n", __func__, i); + } + } break; default: throw std::runtime_error("unknown architecture"); } @@ -11597,6 +11747,288 @@ struct llm_build_wavtokenizer_dec : public llm_graph_context { } }; +// struct llm_build_snac_dec : public llm_graph_context { + +// llm_build_snac_dec(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { +// LLAMA_LOG_INFO("Raw ubatch.n_tokens = %d\n", ubatch.n_tokens); +// for (int i = 0; i < std::min(20, (int)ubatch.n_tokens); ++i) { +// LLAMA_LOG_INFO("%d ", ubatch.token[i]); +// } +// LLAMA_LOG("\n"); +// LLAMA_LOG_DEBUG("%s: Entering constructor, model.layers.size() = %zu\n", __func__, model.layers.size()); +// ggml_tensor * cur; +// ggml_tensor * inpL; + +// // TODO: probalby just get raw codes +// //cur = build_inp_embd(model.tok_embd); +// //LLAMA_LOG_INFO("After build_inp_embd: shape = [%ld, %ld, %ld, %ld]\n", cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); + +// // hack, hardcode expected SNAC input at first conv layer +// cur = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, 768, 64, 1, 1); // [channels, seq_len, 1, 1] +// ggml_set_input(cur); +// LLAMA_LOG_INFO("hardcoded shape = [%ld, %ld, %ld, %ld]\n", cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); + +// // end hack + +// // Log input tokens before processing +// LLAMA_LOG_INFO("%s: ubatch.n_tokens = %u\n", __func__, ubatch.n_tokens); +// LLAMA_LOG_WARN("%s: Input tokens from ubatch = ", __func__); +// for (uint32_t i = 0; i < ubatch.n_tokens && i < 20; ++i) { +// LLAMA_LOG_INFO("%d ", ubatch.token[i]); +// } +// if (ubatch.n_tokens > 20) LLAMA_LOG_INFO("..."); +// LLAMA_LOG("\n"); + +// // ggml_tensor * layer_1; +// // ggml_tensor * layer_2; +// // ggml_tensor * layer_3; +// //redistribute_codes(cur, &layer_1, &layer_2, &layer_3); + +// // Log the redistributed layers +// //log_tensor("Layer 1", layer_1); +// //log_tensor("Layer 2", layer_2); +// //log_tensor("Layer 3", layer_3); + +// for (uint32_t il = 1; il < model.layers.size(); ++il) { +// const auto & layer = model.layers[il]; + +// LLAMA_LOG_DEBUG("%s: Layer %u: Starting, cur = %p\n", __func__, il, cur); + +// if (il == 1) { // pointwise +// LLAMA_LOG_INFO("%s: Layer %u: Pointwise conv, conv_w = %p, conv_s = %p, conv_b = %p\n", +// __func__, il, layer.conv_w, layer.conv_s, layer.conv_b); +// LLAMA_LOG_INFO("Before transpose, cur shape = [%ld, %ld, %ld, %ld]\n", cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); +// cur = ggml_transpose(ctx0, cur); // [768, 512] -> [512, 768] +// LLAMA_LOG_INFO("After transpose, cur shape = [%ld, %ld, %ld, %ld]\n", cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); +// cur = apply_conv1d(cur, layer.conv_w, layer.conv_s, layer.conv_b, 1, 0); +// LLAMA_LOG_INFO("%s: Layer %u: After pointwise conv, cur = %p, shape = [%ld, %ld, %ld, %ld]\n", +// __func__, il, cur, cur ? cur->ne[0] : -1, cur ? cur->ne[1] : -1, cur ? cur->ne[2] : -1, cur ? cur->ne[3] : -1); +// } else if (il == model.layers.size() - 1) { +// LLAMA_LOG_INFO("%s: Layer %u: Final layer, alpha = %p, conv_w = %p, conv_s = %p, conv_b = %p\n", +// __func__, il, layer.alpha, layer.conv_w, layer.conv_s, layer.conv_b); +// cur = ggml_snake(ctx0, cur, layer.alpha); +// LLAMA_LOG_INFO("%s: Layer %u: After ggml_snake, cur = %p, shape = [%ld, %ld, %ld, %ld]\n", +// __func__, il, cur, cur ? cur->ne[0] : -1, cur ? cur->ne[1] : -1, cur ? cur->ne[2] : -1, cur ? cur->ne[3] : -1); +// cur = apply_conv1d(cur, layer.conv_w, layer.conv_s, layer.conv_b, 1, 3); +// LLAMA_LOG_INFO("%s: Layer %u: After final conv, cur = %p, shape = [%ld, %ld, %ld, %ld]\n", +// __func__, il, cur, cur ? cur->ne[0] : -1, cur ? cur->ne[1] : -1, cur ? cur->ne[2] : -1, cur ? cur->ne[3] : -1); +// cur = ggml_tanh(ctx0, cur); +// LLAMA_LOG_INFO("%s: Layer %u: After ggml_tanh, cur = %p, shape = [%ld, %ld, %ld, %ld]\n", +// __func__, il, cur, cur ? cur->ne[0] : -1, cur ? cur->ne[1] : -1, cur ? cur->ne[2] : -1, cur ? cur->ne[3] : -1); +// } else { +// // Layers 2-5: Decoder Blocks (1024 -> 512 -> 256 -> 128 -> 64) +// const int stride = hparams.upsample_rates[il - 2]; // 8 for il = 2 +// const int padding = stride; + +// // Block 0: Snake activation +// const auto & block0 = layer.decoder_blocks[0]; +// LLAMA_LOG_DEBUG("%s: Layer %u: Block 0, alpha = %p\n", __func__, il, block0.alpha); +// cur = ggml_snake(ctx0, cur, block0.alpha); +// LLAMA_LOG_DEBUG("%s: Layer %u: After ggml_snake, cur = %p, shape = [%ld, %ld, %ld, %ld]\n", +// __func__, il, cur, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); + +// // Block 1: Transposed convolution +// const auto & block1 = layer.decoder_blocks[1]; +// LLAMA_LOG_DEBUG("%s: Layer %u: Block 1, stride = %d, up_weight = %p, up_scale = %p, up_bias = %p\n", +// __func__, il, stride, block1.up_weight, block1.up_scale, block1.up_bias); + +// cur = apply_conv1d_transpose(cur, block1.up_weight, block1.up_scale, block1.up_bias, stride, padding); +// LLAMA_LOG_DEBUG("%s: Layer %u: After conv1d_transpose, cur = %p, shape = [%ld, %ld, %ld, %ld]\n", +// __func__, il, cur, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); + +// // Residual Units (3 per block) +// for (int j = 0; j < 3; ++j) { +// const auto & ru = block1.res_units[j]; +// ggml_tensor * inpL = cur; +// LLAMA_LOG_DEBUG("%s: Layer %u, ResUnit %d: Starting, inpL = %p, alpha1 = %p, conv1_w = %p, conv1_s = %p, conv1_b = %p\n", +// __func__, il, j, inpL, ru.alpha1, ru.conv1_w, ru.conv1_s, ru.conv1_b); + +// cur = ggml_snake(ctx0, cur, ru.alpha1); +// LLAMA_LOG_DEBUG("%s: Layer %u, ResUnit %d: After ggml_snake (alpha1), cur = %p, shape = [%ld, %ld, %ld, %ld]\n", +// __func__, il, j, cur, cur ? cur->ne[0] : -1, cur ? cur->ne[1] : -1, cur ? cur->ne[2] : -1, cur ? cur->ne[3] : -1); +// int dilation = (j == 0) ? 1 : (j == 1) ? 3 : 9; +// int padding = 3 * dilation; // Kernel 7, dilated padding = (7-1)/2 * dilation +// cur = apply_conv1d(cur, ru.conv1_w, ru.conv1_s, ru.conv1_b, 1, padding); +// LLAMA_LOG_DEBUG("%s: Layer %u, ResUnit %d: After conv1d (conv1), cur = %p, shape = [%ld, %ld, %ld, %ld]\n", +// __func__, il, j, cur, cur ? cur->ne[0] : -1, cur ? cur->ne[1] : -1, cur ? cur->ne[2] : -1, cur ? cur->ne[3] : -1); + +// // pw +// LLAMA_LOG_DEBUG("%s: Layer %u, ResUnit %d: Pointwise, alpha2 = %p, conv2_w = %p, conv2_s = %p, conv2_b = %p\n", +// __func__, il, j, ru.alpha2, ru.conv2_w, ru.conv2_s, ru.conv2_b); +// cur = ggml_snake(ctx0, cur, ru.alpha2); +// LLAMA_LOG_DEBUG("%s: Layer %u, ResUnit %d: After ggml_snake (alpha2), cur = %p, shape = [%ld, %ld, %ld, %ld]\n", +// __func__, il, j, cur, cur ? cur->ne[0] : -1, cur ? cur->ne[1] : -1, cur ? cur->ne[2] : -1, cur ? cur->ne[3] : -1); +// cur = apply_conv1d(cur, ru.conv2_w, ru.conv2_s, ru.conv2_b, 1, 0); +// LLAMA_LOG_DEBUG("%s: Layer %u, ResUnit %d: After conv1d (conv2), cur = %p, shape = [%ld, %ld, %ld, %ld]\n", +// __func__, il, j, cur, cur ? cur->ne[0] : -1, cur ? cur->ne[1] : -1, cur ? cur->ne[2] : -1, cur ? cur->ne[3] : -1); + +// // residual +// cur = ggml_add(ctx0, cur, inpL); +// LLAMA_LOG_DEBUG("%s: Layer %u, ResUnit %d: After ggml_add, cur = %p, shape = [%ld, %ld, %ld, %ld]\n", +// __func__, il, j, cur, cur ? cur->ne[0] : -1, cur ? cur->ne[1] : -1, cur ? cur->ne[2] : -1, cur ? cur->ne[3] : -1); +// } +// } +// LLAMA_LOG_DEBUG("%s: Layer %u: Finished, cur = %p\n", __func__, il, cur); +// } + +// int64_t target_samples = 24000; // TODO: magic number +// LLAMA_LOG_DEBUG("%s: Trimming output, cur = %p, target_samples = %ld, cur->ne[0] = %ld\n", +// __func__, cur, target_samples, cur ? cur->ne[0] : -1); +// if (cur->ne[0] > target_samples) { +// cur = ggml_get_rows(ctx0, cur, ggml_new_i32(ctx0, target_samples)); +// LLAMA_LOG_DEBUG("%s: After ggml_get_rows, cur = %p, shape = [%ld, %ld, %ld, %ld]\n", +// __func__, cur, cur ? cur->ne[0] : -1, cur ? cur->ne[1] : -1, cur ? cur->ne[2] : -1, cur ? cur->ne[3] : -1); +// } + +// LLAMA_LOG_DEBUG("%s: Setting result_embd, cur = %p\n", __func__, cur); +// cb(cur, "result_embd", -1); +// res->t_embd = cur; + +// LLAMA_LOG_DEBUG("%s: Building forward graph, cur = %p\n", __func__, cur); +// ggml_build_forward_expand(gf, cur); +// LLAMA_LOG_DEBUG("%s: Graph build completed\n", __func__); +// } + +// // TODO: move these somewhere else +// private: +// // Helper to log tensor contents +// void log_tensor(const char * name, ggml_tensor * tensor) { +// if (!tensor) { +// LLAMA_LOG_INFO("%s: %s is null\n", __func__, name); +// return; +// } +// LLAMA_LOG_DEBUG("%s: %s shape = [%ld, %ld, %ld, %ld], first 20 elements = ", +// __func__, name, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); +// int n_elements = ggml_nelements(tensor); +// float * data = (float *)tensor->data; +// for (int i = 0; i < std::min(20, n_elements); ++i) { +// LLAMA_LOG_DEBUG("%.2f ", data[i]); +// } +// if (n_elements > 20) LLAMA_LOG_DEBUG("..."); +// LLAMA_LOG_DEBUG("\n"); +// } + +// void redistribute_codes(ggml_tensor * input, ggml_tensor ** layer_1, ggml_tensor ** layer_2, ggml_tensor ** layer_3) { +// int64_t n_codes = input->ne[1]; // Assuming input is [n_embd, n_tokens, 1, 1] +// int64_t n_frames = n_codes / 7; +// if (n_codes % 7 != 0) { +// LLAMA_LOG_ERROR("%s: Input codes length %ld is not a multiple of 7\n", __func__, n_codes); +// *layer_1 = *layer_2 = *layer_3 = nullptr; +// return; +// } + +// int64_t n_layer_1 = n_frames; // 1 code per frame +// int64_t n_layer_2 = n_frames * 2; // 2 codes per frame +// int64_t n_layer_3 = n_frames * 4; // 4 codes per frame + +// // Indices for each layer +// std::vector idx_layer_1(n_layer_1); +// std::vector idx_layer_2(n_layer_2); +// std::vector idx_layer_3(n_layer_3); + +// for (int64_t i = 0; i < n_frames; ++i) { +// int64_t base_idx = i * 7; +// idx_layer_1[i] = base_idx + 0; // No offset +// idx_layer_2[i * 2] = base_idx + 1; // Offset -4096 +// idx_layer_2[i * 2 + 1] = base_idx + 4; // Offset -16384 +// idx_layer_3[i * 4] = base_idx + 2; // Offset -8192 +// idx_layer_3[i * 4 + 1] = base_idx + 3; // Offset -12288 +// idx_layer_3[i * 4 + 2] = base_idx + 5; // Offset -20480 +// idx_layer_3[i * 4 + 3] = base_idx + 6; // Offset -24576 +// } + +// // Create index tensors +// ggml_tensor * idx_1 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_layer_1); +// ggml_tensor * idx_2 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_layer_2); +// ggml_tensor * idx_3 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_layer_3); + +// memcpy(idx_1->data, idx_layer_1.data(), n_layer_1 * sizeof(int32_t)); +// memcpy(idx_2->data, idx_layer_2.data(), n_layer_2 * sizeof(int32_t)); +// memcpy(idx_3->data, idx_layer_3.data(), n_layer_3 * sizeof(int32_t)); + +// // Extract layers using ggml_get_rows +// *layer_1 = ggml_get_rows(ctx0, input, idx_1); +// *layer_2 = ggml_get_rows(ctx0, input, idx_2); +// *layer_3 = ggml_get_rows(ctx0, input, idx_3); + +// // Apply offsets +// *layer_2 = ggml_add(ctx0, *layer_2, ggml_new_f32(ctx0, -4096.0f)); // Simplified; we'll refine offsets later +// *layer_3 = ggml_add(ctx0, *layer_3, ggml_new_f32(ctx0, -8192.0f)); // Simplified for now +// } + +// ggml_tensor * apply_conv1d(ggml_tensor * input, ggml_tensor * conv_w, ggml_tensor * conv_scale, ggml_tensor * conv_b, +// int stride, int padding) { +// ggml_tensor * w_final = normalize_weight(conv_w, conv_scale); +// ggml_tensor * cur = ggml_conv_1d_ph(ctx0, w_final, input, stride, padding); +// if (conv_b) { +// ggml_tensor* bias_reshaped = ggml_reshape_3d(ctx0, conv_b, 1, 1024, 1); +// cur = ggml_add(ctx0, cur, bias_reshaped); +// } +// return cur; +// } + +// ggml_tensor * apply_conv1d_transpose(ggml_tensor * input, ggml_tensor * up_weight, ggml_tensor * up_scale, ggml_tensor * up_bias, int stride, int padding) { +// // Normalize weights (temporary fix for up_scale shape mismatch) +// if (up_scale->ne[2] != up_weight->ne[1]) { // 1024 != 512 +// LLAMA_LOG_WARN("up_scale channels (%ld) don’t match output channels (%ld), expected behavior may vary\n", up_scale->ne[2], up_weight->ne[1]); +// // Ideally reshape up_scale to [1, 1, 512, 1], but no reshape; proceed with warning +// } +// ggml_tensor * w_final = normalize_weight(up_weight, up_scale); +// LLAMA_LOG_INFO("After normalize weight: w_final shape = [%ld, %ld, %ld, %ld]\n", +// w_final->ne[0], w_final->ne[1], w_final->ne[2], w_final->ne[3]); + +// ggml_tensor * cur = ggml_conv_transpose_1d(ctx0, w_final, input, stride, 0, 1); +// LLAMA_LOG_INFO("After ggml_conv_transpose_1d = [%ld, %ld, %ld, %ld]\n", +// cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); + +// if (up_bias) { +// // up_bias is [512, 1, 1, 1]; need [4104, 512, 1, 1] for ggml_add +// LLAMA_LOG_INFO("entering up_bias block. Before ggml_repeat, cur shape = [%ld, %ld, %ld, %ld]\n", cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); +// LLAMA_LOG_INFO("Before ggml_repeat, up_bias shape = [%ld, %ld, %ld, %ld]\n", up_bias->ne[0], up_bias->ne[1], up_bias->ne[2], up_bias->ne[3]); +// ggml_tensor * bias_repeated = ggml_repeat(ctx0, up_bias, cur); +// LLAMA_LOG_DEBUG("Repeated up_bias to shape = [%ld, %ld, %ld, %ld]\n", +// bias_repeated->ne[0], bias_repeated->ne[1], bias_repeated->ne[2], bias_repeated->ne[3]); +// cur = ggml_add(ctx0, cur, bias_repeated); +// LLAMA_LOG_DEBUG("After bias add: cur shape = [%ld, %ld, %ld, %ld]\n", +// cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); +// } +// return cur; +// } + +// // w_final = scale * (w / || w ||) +// ggml_tensor * normalize_weight(ggml_tensor * w, ggml_tensor * scale) { +// ggml_tensor * norm = ggml_norm(ctx0, w, 1e-5f); // 1e-8f ? +// ggml_tensor * w_normalized = ggml_div(ctx0, w, norm); +// ggml_tensor * w_final = ggml_mul(ctx0, w_normalized, scale); +// return w_final; +// } +// }; + +// TODO: Placeholder +struct llm_build_snac_dec : public llm_graph_context { + + llm_build_snac_dec(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + + // TODO: Remove + LLAMA_LOG_INFO("Raw ubatch.n_tokens = %d\n", ubatch.n_tokens); + for (int i = 0; i < std::min(20, (int)ubatch.n_tokens); ++i) { + LLAMA_LOG_INFO("%d ", ubatch.token[i]); + } + LLAMA_LOG("\n"); + ggml_tensor * cur; + + // TODO: Hack. Implement codebook lookups and out_proj + cur = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, 768, 64, 1, 1); + ggml_set_input(cur); + // end hack + + LLAMA_LOG_DEBUG("%s: Setting result_embd, cur = %p\n", __func__, cur); + cb(cur, "result_embd", -1); + res->t_embd = cur; + ggml_build_forward_expand(gf, cur); + } +}; + llama_memory_i * llama_model::create_memory() const { llama_memory_i * res; @@ -11868,6 +12300,10 @@ llm_graph_result_ptr llama_model::build_graph( { llm = std::make_unique(*this, params, gf); } break; + case LLM_ARCH_SNAC_DEC: + { + llm = std::make_unique(*this, params, gf); + } break; default: GGML_ABORT("fatal error"); } @@ -11976,6 +12412,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_RWKV7: case LLM_ARCH_ARWKV7: case LLM_ARCH_WAVTOKENIZER_DEC: + case LLM_ARCH_SNAC_DEC: return LLAMA_ROPE_TYPE_NONE; // use what we call a normal RoPE, operating on pairs of consecutive head values diff --git a/src/llama-model.h b/src/llama-model.h index a9da1215abb..5e636b0b3b3 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -137,6 +137,28 @@ struct llama_layer_convnext { struct ggml_tensor * gamma = nullptr; }; +struct llama_layer_snac_dec_block { + struct ggml_tensor * alpha = nullptr; + + struct ggml_tensor * up_weight = nullptr; + struct ggml_tensor * up_scale = nullptr; + struct ggml_tensor * up_bias = nullptr; + + struct ggml_tensor * noise_w = nullptr; + struct ggml_tensor * noise_s = nullptr; + + struct { + struct ggml_tensor * alpha1 = nullptr; + struct ggml_tensor * conv1_w = nullptr; + struct ggml_tensor * conv1_s = nullptr; + struct ggml_tensor * conv1_b = nullptr; + struct ggml_tensor * alpha2 = nullptr; + struct ggml_tensor * conv2_w = nullptr; + struct ggml_tensor * conv2_s = nullptr; + struct ggml_tensor * conv2_b = nullptr; + } res_units[3]; +}; + struct llama_layer { // normalization struct ggml_tensor * attn_norm = nullptr; @@ -304,6 +326,13 @@ struct llama_layer { struct llama_layer_posnet posnet; struct llama_layer_convnext convnext; + + struct ggml_tensor * conv_w = nullptr; + struct ggml_tensor * conv_s = nullptr; + struct ggml_tensor * conv_b = nullptr; + struct ggml_tensor * alpha = nullptr; + + std::vector decoder_blocks; }; struct llama_model { @@ -336,6 +365,13 @@ struct llama_model { struct ggml_tensor * conv1d = nullptr; struct ggml_tensor * conv1d_b = nullptr; + + // TODO: structify + ggml_tensor * codebook[3]; + ggml_tensor * codebook_proj_b[3]; // Array for quantizer 0, 1, 2 bias + ggml_tensor * codebook_proj_s[3]; // Array for quantizer 0, 1, 2 scale + ggml_tensor * codebook_proj_w[3]; + std::vector layers; llama_model_params params;