In [1]:
from transformers import AutoTokenizer
from transformers_gemma2.modeling_gemma2 import CustomGemma2ForCausalLM
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [1]:
import torch
import torch.nn.functional as F
from transformers import AutoConfig
from transformers_gemma2.modeling_gemma2 import Gemma2ForAutoEncoding  # 先ほど実装したクラス

def test_generate_with_initial_embedding():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # モデル設定のロード（例として "google/gemma2-7b"）
    config = AutoConfig.from_pretrained("google/gemma-2-2b")
    model = Gemma2ForAutoEncoding(config, tokenizer_name_or_path="google/gemma-2-2b")
    model.to(device)
    model.eval()
    
    batch_size = 1
    hidden_size = config.hidden_size
    # ダミーの初期埋め込み（[B, 1, hidden_size]）
    dummy_initial_embedding = torch.randn(batch_size, 1, hidden_size, device=device)
    
    instructions = (
        "この文章の内容を要約してください。",  # instruction_prompt0
        "以上です。"                        # instruction_prompt1
    )
    
    # --- 以下は generate_with_initial_embedding の内部処理例 ---
    # 1. prefix（条件部）のトークン列を作成
    tokenizer = model.tokenizer
    seq0 = tokenizer.encode(instructions[0], add_special_tokens=False) if instructions[0] else []
    seq1 = tokenizer.encode(instructions[1], add_special_tokens=False) if instructions[1] else []
    dummy_token = model.pad_id  # 仮のダミーとして pad_id を利用
    prefix_tokens = seq0 + [dummy_token] + seq1
    prefix_len = len(prefix_tokens)
    output_len = 50
    total_seq_len = prefix_len + output_len
    
    # 2. 生成テンソル（すべて pad で初期化）
    generated = torch.full((batch_size, total_seq_len), model.pad_id, dtype=torch.long, device=device)
    prefix_tensor = torch.tensor(prefix_tokens, dtype=torch.long, device=device)
    generated[:, :prefix_len] = prefix_tensor.unsqueeze(0).expand(batch_size, -1)
    
    normalizer = config.hidden_size ** 0.5
    # 3. prefix 部分の埋め込み取得（dummy の位置は後で初期埋め込みで上書き）
    with torch.no_grad():
        prefix_embeds = model.model.embed_tokens(generated[:, :prefix_len])
        dummy_index = len(seq0)
        prefix_embeds[:, dummy_index, :] = dummy_initial_embedding.squeeze(1)
        prefix_embeds = prefix_embeds * normalizer
        
        # 4. prefix 部分からキャッシュを構築
        outputs = model.model(
            inputs_embeds=prefix_embeds,
            use_cache=True,
        )
        past_key_values = outputs.past_key_values
        cur_len = prefix_len

        # 5. 自己回帰生成ループ
        while cur_len < total_seq_len:
            last_token_ids = generated[:, cur_len - 1].unsqueeze(1)
            last_embeds = model.model.embed_tokens(last_token_ids) * normalizer
            # ここでバッチサイズ分の position_ids を作成
            position_ids = torch.full((batch_size, 1), cur_len, device=last_embeds.device)
            
            outputs = model.model(
                inputs_embeds=last_embeds,
                past_key_values=past_key_values,
                use_cache=True,
                position_ids=position_ids,  # 各サンプルに対して正しい位置を与える
            )
            logits = model.lm_head(outputs[0])  # [B, 1, vocab_size]
            logits = logits[:, -1, :]  # [B, vocab_size]
            logits = logits / 0.8  # ここでは温度=0.8と仮定
            probs = F.softmax(logits, dim=-1)
            # Top-p / Top-k の簡易サンプリング（必要に応じて調整）
            next_tokens = torch.multinomial(probs, num_samples=1)
            generated[:, cur_len] = next_tokens.squeeze(-1)
            cur_len += 1
            past_key_values = outputs.past_key_values
        my_custom_module
        # 6. EOS があれば切り取り、テキストにデコード
        outputs_text = []
        for i in range(batch_size):
            gen_tokens = generated[i, prefix_len:].tolist()
            if model.eos_id in gen_tokens:
                gen_tokens = gen_tokens[: gen_tokens.index(model.eos_id)]
            text = tokenizer.decode(gen_tokens, skip_special_tokens=True)
            outputs_text.append(text)
    
    print("=== Generated Text ===")
    print(outputs_text)

if __name__ == "__main__":
    test_generate_with_initial_embedding()


  from .autonotebook import tqdm as notebook_tqdm
The 'batch_size' argument of HybridCache is deprecated and will be removed in v4.49. Use the more precisely named 'max_batch_size' argument instead.
The 'batch_size' attribute of HybridCache is deprecated and will be removed in v4.49. Use the more precisely named 'self.max_batch_size' attribute instead.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [1,0,0], thread: [64,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [1,0,0], thread: [65,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [1,0,0], thread: [66,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:93: operator(): block: [1,0,0], thread: [67,0,0] Assertion `-sizes[i] 

RuntimeError: CUDA error: CUBLAS_STATUS_EXECUTION_FAILED when calling `cublasSgemm( handle, opa, opb, m, n, k, &alpha, a, lda, b, ldb, &beta, c, ldc)`