In [6]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# GPT-2 모델과 토크나이저 로드
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)

# 원본 모델 로드
model_original = AutoModelForCausalLM.from_pretrained(model_name)

# Attention map 수정을 위한 모델 복사
model_modified = AutoModelForCausalLM.from_pretrained(model_name)

# Attention map을 수정할 layer 번호 설정
modify_layers = [2, 4]

# Attention map을 identity matrix로 설정하는 함수
def modify_attention(module, input, output):
    # module이 MultiheadAttention인 경우에만 수정
    if isinstance(module, torch.nn.MultiheadAttention):
        # layer 번호 가져오기
        layer_num = int(module.layer_num)
        # 수정할 layer인 경우
        if layer_num in modify_layers:
            attention_probs = output[1]
            batch_size, num_heads, seq_length, _ = attention_probs.shape
            print(f"Layer {layer_num} - Attention shape: {attention_probs.shape}")
            # Attention shape 확인
            assert attention_probs.shape[-1] == seq_length, f"Attention shape mismatch: {attention_probs.shape}"
            # Attention map을 identity matrix로 설정
            attention_probs.data = torch.eye(seq_length).expand(batch_size, num_heads, -1, -1)
    return output

# Hook 등록 함수
def register_hooks(model):
    # 모든 module에 대해 iterate
    for name, module in model.named_modules():
        # module이 MultiheadAttention인 경우 hook 등록
        if isinstance(module, torch.nn.MultiheadAttention):
            layer_num = name.split(".")[2]
            module.layer_num = layer_num
            module.register_forward_hook(modify_attention)

# 수정된 모델에 hook 등록
register_hooks(model_modified)

In [7]:
# 텍스트 생성 함수
def generate_text(input_text, guidance_scale):
    # 입력 텍스트 인코딩
    input_ids = tokenizer.encode(input_text, return_tensors="pt")
    attention_mask = torch.ones_like(input_ids)
    
    # Attention map 수정 후 logit_p 계산
    with torch.no_grad():
        outputs = model_modified(input_ids, attention_mask=attention_mask)
        logit_p = outputs.logits
        
    # 원래 attention으로 logit_o 계산
    with torch.no_grad():
        outputs = model_original(input_ids, attention_mask=attention_mask)
        logit_o = outputs.logits
        
    # logit 계산
    logit = logit_o + guidance_scale * (logit_o - logit_p)
    print(f"logit shape: {logit.shape}")
    print(f"logit: {logit}")
    
    # 텍스트 생성
    output = model_original.generate(
        input_ids,
        max_length=50,
        num_return_sequences=1,
        pad_token_id=tokenizer.eos_token_id,
        attention_mask=attention_mask,
        logits_processor=[
            lambda input_ids, scores: scores + guidance_scale * (scores - logit[:, -1, :])
        ],
        output_scores=True,
        return_dict_in_generate=True,
        top_p=0.9
    )
    
    print(f"Generated text: {tokenizer.decode(output.sequences[0], skip_special_tokens=True)}")
    print(f"Processed logits: {output.scores[0].shape}")
    print(f"Processed logits: {output.scores[0]}")
    
    return tokenizer.decode(output.sequences[0], skip_special_tokens=True)

In [8]:
# 텍스트 생성 및 출력
input_text = "Hello, how are you? I'm doing well."
guidance_scale = 6.0
generated_text = generate_text(input_text, guidance_scale)
print("Generated Text:")
print(generated_text)

logit shape: torch.Size([1, 11, 50257])
logit: tensor([[[ -35.2362,  -35.3266,  -38.9754,  ...,  -44.4645,  -43.9975,
           -36.4580],
         [-112.6171, -114.5832, -116.5724,  ..., -119.0128, -118.8059,
          -111.6917],
         [-116.7137, -117.5931, -123.1624,  ..., -125.6588, -125.2527,
          -119.3150],
         ...,
         [-107.5247, -109.3616, -113.5464,  ..., -115.4737, -118.1396,
          -112.0031],
         [ -82.1815,  -84.9542,  -93.1059,  ...,  -98.1118,  -98.1967,
           -88.8256],
         [-147.8117, -146.8510, -149.7308,  ..., -160.3841, -160.5273,
          -143.6174]]])




Generated text: Hello, how are you? I'm doing well. I intend tremend tremend tremend tremend tremend tremend tremend tremend tremend tremend tremend tremend tremend tremend tremend tremend tremend tremend tremend tremend tremend tremend tremend tremend tremend tremend tremend tremend tremend tremend tremend tremend tremend tremend tremend tremend tremend
Processed logits: torch.Size([1, 50257])
Processed logits: tensor([[-147.8117, -146.8510, -149.7308,  ..., -160.3841, -160.5273,
         -143.6174]])
Generated Text:
Hello, how are you? I'm doing well. I intend tremend tremend tremend tremend tremend tremend tremend tremend tremend tremend tremend tremend tremend tremend tremend tremend tremend tremend tremend tremend tremend tremend tremend tremend tremend tremend tremend tremend tremend tremend tremend tremend tremend tremend tremend tremend tremend


In [9]:
# Perplexity 계산 함수
def calculate_perplexity(model_original, model_modified, tokenizer, text, guidance_scale):
    encoded_text = tokenizer.encode(text, return_tensors='pt')
    attention_mask = torch.ones_like(encoded_text)
    
    with torch.no_grad():
        outputs_original = model_original(encoded_text, attention_mask=attention_mask, labels=encoded_text)
        loss_original = outputs_original.loss
        
        outputs_modified = model_modified(encoded_text, attention_mask=attention_mask)
        logits_modified = outputs_modified.logits
        
        logits = outputs_original.logits + guidance_scale * (outputs_original.logits - logits_modified)
        loss_modified = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), encoded_text.view(-1))
        
        perplexity = torch.exp((loss_original + loss_modified) / 2)
    
    return perplexity.item()

In [10]:
# Perplexity 계산 및 출력
perplexity = calculate_perplexity(model_original, model_modified, tokenizer, generated_text, guidance_scale)
print(f"Perplexity: {perplexity:.2f}")

Perplexity: 535.28
