@@ -2331,32 +2331,177 @@ def set_gguf_parameters(self):
23312331class SNACDecModel (Model ):
23322332 model_arch = gguf .MODEL_ARCH .SNAC_DEC
23332333
2334- def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [Tuple [str , Tensor ]]:
2335- del bid # unused
2334+ def __init__ (self , * args , ** kwargs ):
2335+ super ().__init__ (* args , ** kwargs )
2336+ self ._dummy_added = False
2337+
2338+ def modify_tensors (self , data_torch : torch .Tensor , name : str , bid : int | None ) -> Iterable [Tuple [str , torch .Tensor ]]:
2339+ """Convert nested PyTorch tensor names to a flat GGUF naming scheme for decoder tensors."""
2340+ del bid # Unused
2341+
2342+ # Add dummy token_embd.weight only once
2343+ if not self ._dummy_added :
2344+ import torch
2345+ dummy_tok_embd = torch .zeros ((4096 , 8 ), dtype = torch .float16 )
2346+ dummy_tok_embd = dummy_tok_embd .view (4096 , 8 )
2347+ logger .info (f"Adding dummy tensor: token_embd.weight, shape: { list (dummy_tok_embd .shape )} " )
2348+ yield ("token_embd.weight" , dummy_tok_embd )
2349+ self ._dummy_added = True # Mark as added
2350+
2351+ original_name = name
2352+
2353+ if name .startswith ("quantizer.quantizers." ):
2354+ match = re .match (r"quantizer\.quantizers\.(\d+)\.(codebook\.weight|out_proj\.bias|out_proj\.parametrizations\.weight\.original[0-1])" , name )
2355+ if match :
2356+ q_idx = int (match .group (1 ))
2357+ tensor_type = match .group (2 )
2358+ if tensor_type == "codebook.weight" :
2359+ new_name = f"quantizer.{ q_idx } .codebook"
2360+ elif tensor_type == "out_proj.parametrizations.weight.original0" :
2361+ new_name = f"quantizer.{ q_idx } .out_proj.scale"
2362+ elif tensor_type == "out_proj.parametrizations.weight.original1" :
2363+ new_name = f"quantizer.{ q_idx } .out_proj.weight"
2364+ elif tensor_type == "out_proj.bias" :
2365+ new_name = f"quantizer.{ q_idx } .out_proj.bias"
2366+
2367+ logger .info (f"Mapping { original_name } -> { new_name } , shape: { list (data_torch .shape )} " )
2368+ yield (new_name , data_torch )
2369+ else :
2370+ logger .warning (f"Could not parse quantizer tensor from: { original_name } " )
2371+ return
23362372
2337- logger .debug (f"Processing tensor: { name } " )
2373+ # Skip non-decoder tensors (except quantizers, which were handled above)
2374+ if not name .startswith ("decoder." ):
2375+ logger .debug (f"Skipping non-decoder tensor: { original_name } " )
2376+ return
23382377
2339- if (name .startswith ("decoder." ) or
2340- re .match (r"quantizer\.quantizers\.\d+\.codebook\.weight" , name ) or
2341- re .match (r"quantizer\.quantizers\.\d+\.out_proj\..*" , name )):
2342- logger .info (f"{ name } -> { data_torch .shape } " )
2343- return [(name , data_torch )]
2344- else :
2345- logger .debug (f"Skipping { name !r} " )
2346- return []
2378+ base = name [8 :] # Remove 'decoder.'
2379+ parts = base .split ("." )
2380+
2381+ if base .startswith ("model.0." ):
2382+ logger .info (f"Skipping incompatible decoder layer 0 tensor: { original_name } " )
2383+ return # Explicitly skip this layer
2384+
2385+ # Layer 1: Second Conv
2386+ if base .startswith ("model.1." ):
2387+ if "bias" in name and "parametrizations" not in name :
2388+ new_name = "decoder.1.conv2.bias"
2389+ elif "parametrizations.weight.original0" in name :
2390+ new_name = "decoder.1.conv2.scale"
2391+ elif "parametrizations.weight.original1" in name :
2392+ new_name = "decoder.1.conv2.weight"
2393+ else :
2394+ logger .warning (f"Unhandled layer 1 tensor: { original_name } " )
2395+ return
2396+ logger .info (f"Mapping { original_name } -> { new_name } , shape: { list (data_torch .shape )} " )
2397+ yield (new_name , data_torch )
2398+ return
2399+
2400+ # Layers 2–5: DecoderBlocks
2401+ if "model." in base and "block" in base :
2402+ try :
2403+ layer_idx = int (parts [1 ]) # e.g., '2' from 'model.2'
2404+ if layer_idx not in {2 , 3 , 4 , 5 }:
2405+ logger .debug (f"Skipping non-DecoderBlock layer { layer_idx } : { original_name } " )
2406+ return
2407+ block_idx = int (parts [3 ]) # e.g., '1' from 'block.1'
2408+ new_base = f"decoder.{ layer_idx } .block.{ block_idx } "
2409+
2410+ if block_idx == 0 : # Snake1d
2411+ if "alpha" in name :
2412+ new_name = f"{ new_base } .alpha"
2413+ else :
2414+ logger .error (f"Expected 'alpha' in { original_name } " )
2415+ return
2416+ elif block_idx == 1 : # Transpose Conv
2417+ if "bias" in name and "parametrizations" not in name :
2418+ new_name = f"{ new_base } .trans.bias"
2419+ elif "parametrizations.weight.original0" in name :
2420+ new_name = f"{ new_base } .trans.scale"
2421+ elif "parametrizations.weight.original1" in name :
2422+ new_name = f"{ new_base } .trans.weight"
2423+ else :
2424+ logger .error (f"Unhandled tensor in block 1: { original_name } " )
2425+ return
2426+ elif block_idx == 2 : # Noise Block
2427+ if "linear.parametrizations.weight.original0" in name :
2428+ new_name = f"{ new_base } .noise.scale"
2429+ elif "linear.parametrizations.weight.original1" in name :
2430+ new_name = f"{ new_base } .noise.weight"
2431+ else :
2432+ logger .error (f"Unhandled tensor in block 2: { original_name } " )
2433+ return
2434+ elif block_idx in {3 , 4 , 5 }: # Residual Units
2435+ res_base = f"{ new_base } .res"
2436+ if "block.0.alpha" in name :
2437+ new_name = f"{ res_base } .snake1.alpha"
2438+ elif "block.1.bias" in name :
2439+ new_name = f"{ res_base } .conv1.bias"
2440+ elif "block.1.parametrizations.weight.original0" in name :
2441+ new_name = f"{ res_base } .conv1.scale"
2442+ elif "block.1.parametrizations.weight.original1" in name :
2443+ new_name = f"{ res_base } .conv1.weight"
2444+ elif "block.2.alpha" in name :
2445+ new_name = f"{ res_base } .snake2.alpha"
2446+ elif "block.3.bias" in name :
2447+ new_name = f"{ res_base } .conv2.bias"
2448+ elif "block.3.parametrizations.weight.original0" in name :
2449+ new_name = f"{ res_base } .conv2.scale"
2450+ elif "block.3.parametrizations.weight.original1" in name :
2451+ new_name = f"{ res_base } .conv2.weight"
2452+ else :
2453+ logger .error (f"Unhandled tensor in residual unit: { original_name } " )
2454+ return
2455+ else :
2456+ logger .error (f"Unhandled block index { block_idx } in layer { layer_idx } : { original_name } " )
2457+ return
2458+
2459+ logger .info (f"Mapping { original_name } -> { new_name } , shape: { list (data_torch .shape )} " )
2460+ yield (new_name , data_torch )
2461+ return
2462+
2463+ except (IndexError , ValueError ) as e :
2464+ logger .error (f"Failed to parse tensor { original_name } : { e } " )
2465+ return
2466+
2467+ # Layer 6: Snake1d
2468+ if base == "model.6.alpha" :
2469+ new_name = "decoder.6.alpha"
2470+ logger .info (f"Mapping { original_name } -> { new_name } , shape: { list (data_torch .shape )} " )
2471+ yield (new_name , data_torch )
2472+ return
2473+
2474+ # Layer 7: Final Conv
2475+ if base .startswith ("model.7." ):
2476+ if "bias" in name and "parametrizations" not in name :
2477+ new_name = "decoder.7.conv.bias"
2478+ elif "parametrizations.weight.original0" in name :
2479+ new_name = "decoder.7.conv.scale"
2480+ elif "parametrizations.weight.original1" in name :
2481+ new_name = "decoder.7.conv.weight"
2482+ else :
2483+ logger .warning (f"Unhandled layer 7 tensor: { original_name } " )
2484+ return
2485+ logger .info (f"Mapping { original_name } -> { new_name } , shape: { list (data_torch .shape )} " )
2486+ yield (new_name , data_torch )
2487+ return
2488+
2489+ logger .warning (f"Tensor { original_name } not mapped to any layer" )
2490+ return
23472491
23482492 def set_vocab (self ):
23492493 self ._set_vocab_none ()
23502494
23512495 def set_gguf_parameters (self ):
23522496 super ().set_gguf_parameters ()
2353- self .gguf_writer .add_vocab_size (self .hparams ["codebook_size" ])
2354- self .gguf_writer .add_quantizer_count (len (self .hparams ["vq_strides" ]))
2355- self .gguf_writer .add_features_length (self .hparams ["codebook_dim" ])
2356- self .gguf_writer .add_quantizer_strides (self .hparams ["vq_strides" ])
2357- self .gguf_writer .add_embedding_length (self .hparams ["decoder_dim" ])
2358- self .gguf_writer .add_decoder_upsample_rates (self .hparams ["decoder_rates" ])
2359- self .gguf_writer .add_decoder_channel_dims (self .hparams ["decoder_channel_dims" ])
2497+ self .gguf_writer .add_vocab_size (4096 ) # TODO: Fix
2498+ self .gguf_writer .add_uint32 ("snac.quantizer.codebook_size" , self .hparams ["codebook_size" ])
2499+ self .gguf_writer .add_uint32 ("snac.quantizer.codebook_dim" , self .hparams ["codebook_dim" ])
2500+ self .gguf_writer .add_embedding_length (self .hparams ["decoder_dim" ]) # 1024
2501+ self .gguf_writer .add_decoder_upsample_rates (self .hparams ["decoder_rates" ]) # [8, 8, 4, 2]
2502+ self .gguf_writer .add_uint32 ("n_layers" , 8 )
2503+ self .gguf_writer .add_array ("decoder_channel_dims" , [768 , 1024 , 512 , 256 , 128 , 64 , 1 ])
2504+ self .gguf_writer .add_array ("vq_strides" , self .hparams ["vq_strides" ])
23602505
23612506@Model .register ("Qwen2MoeForCausalLM" )
23622507class Qwen2MoeModel (Model ):
0 commit comments