In [6]:
%load_ext autoreload
%autoreload 2

# Logit PAG

GenerationMixin.generate 이 좀 어질어질하네.

되도록 건드리지 않는 것을 추구해보자

In [7]:
# define hook
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from pag.utils import get_module_by_name

model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

def perturb_attn_map(module, input, kwargs, output, ):
    layer_past = kwargs['layer_past']
    use_cache = kwargs['use_cache']
    head_mask = kwargs['head_mask']
    attention_mask = kwargs['attention_mask']
    output_attentions = kwargs['output_attentions']

    hidden_states = input[0]
    
    query, key, value = module.c_attn(hidden_states).split(module.split_size, dim=2)

    query = module._split_heads(query, module.num_heads, module.head_dim)
    key = module._split_heads(key, module.num_heads, module.head_dim)
    value = module._split_heads(value, module.num_heads, module.head_dim)
    
    if layer_past is not None:
        past_key, past_value = layer_past
        key = torch.cat((past_key, key), dim=-2)
        value = torch.cat((past_value, value), dim=-2)

    if use_cache is True:
        present = (key, value)
    else:
        present = None

    # make perturbed attn_mask
    l_k, l_q = key.size(2), query.size(2)
    attention_mask = torch.zeros(l_k, l_k) + torch.eye(l_k)
    attention_mask = (1.0 - attention_mask) * torch.finfo(hidden_states.dtype).min
    attention_mask = attention_mask[None, None, :l_q, :l_k] # |b, head, l, l|
    
    if module.reorder_and_upcast_attn:
        attn_output, attn_weights = module._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)
    else:
        attn_output, attn_weights = module._attn(query, key, value, attention_mask, head_mask)

    attn_output = module._merge_heads(attn_output, module.num_heads, module.head_dim)
    attn_output = module.c_proj(attn_output)
    attn_output = module.resid_dropout(attn_output)

    outputs = (attn_output, present)
    if output_attentions:
        outputs += (attn_weights,)

    return outputs  # a, present, (attentions)

In [8]:
# monkey patch forward path
import math
import os
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple, Union

from transformers.modeling_outputs import (
    BaseModelOutputWithPastAndCrossAttentions,
    CausalLMOutputWithCrossAttentions,
)

def forward(
    self,
    input_ids: Optional[torch.LongTensor] = None,
    past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
    attention_mask: Optional[torch.FloatTensor] = None,
    token_type_ids: Optional[torch.LongTensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    head_mask: Optional[torch.FloatTensor] = None,
    inputs_embeds: Optional[torch.FloatTensor] = None,
    encoder_hidden_states: Optional[torch.Tensor] = None,
    encoder_attention_mask: Optional[torch.FloatTensor] = None,
    labels: Optional[torch.LongTensor] = None,
    use_cache: Optional[bool] = False,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
    r"""
    labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
        Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
        `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
        are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
    """
    global guidance_traj, hook_layer_list, guidance_scale
    return_dict = return_dict if return_dict is not None else self.config.use_return_dict

    ########
    # base #
    ########
    transformer_outputs = self.transformer(
        input_ids,
        past_key_values=past_key_values,
        attention_mask=attention_mask,
        token_type_ids=token_type_ids,
        position_ids=position_ids,
        head_mask=head_mask,
        inputs_embeds=inputs_embeds,
        encoder_hidden_states=encoder_hidden_states,
        encoder_attention_mask=encoder_attention_mask,
        use_cache=use_cache,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
    )
    hidden_states_o = transformer_outputs[0]

    # Set device for model parallelism
    if self.model_parallel:
        torch.cuda.set_device(self.transformer.first_device)
        hidden_states_o = hidden_states_o.to(self.lm_head.weight.device)

    lm_logits_o = self.lm_head(hidden_states_o)

    #######
    # PAG #
    #######
    # 1. hook (make attn map identity)
    handles = []
    for layer_name in hook_layer_list:
        module = get_module_by_name(self, layer_name)
        handle = module.register_forward_hook(perturb_attn_map, with_kwargs=True)
        handles.append(handle)

    # 2. compute logit 
    transformer_outputs_p = self.transformer(
        input_ids,
        past_key_values=past_key_values,
        attention_mask=attention_mask,
        token_type_ids=token_type_ids,
        position_ids=position_ids,
        head_mask=head_mask,
        inputs_embeds=inputs_embeds,
        encoder_hidden_states=encoder_hidden_states,
        encoder_attention_mask=encoder_attention_mask,
        use_cache=use_cache,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
    )
    hidden_states_p = transformer_outputs_p[0]

    if self.model_parallel:
        torch.cuda.set_device(self.transformer.first_device)
        hidden_states_p = hidden_states_p.to(self.lm_head.weight.device)

    lm_logits_p = self.lm_head(hidden_states_p)

    # 3. remove hook
    for handle in handles:
        handle.remove()

    # 4. guidance
    lm_logits = lm_logits_o + guidance_scale * (lm_logits_o - lm_logits_p)
    guidance_traj.append(lm_logits_o - lm_logits_p)

    if not return_dict:
        output = (lm_logits,) + transformer_outputs[1:]
        return ((loss,) + output) if loss is not None else output

    return CausalLMOutputWithCrossAttentions(
        loss=None,
        logits=lm_logits,
        past_key_values=transformer_outputs.past_key_values,
        hidden_states=transformer_outputs.hidden_states,
        attentions=transformer_outputs.attentions,
        cross_attentions=transformer_outputs.cross_attentions,
    )

# monkey patch (ref;  https://discuss.pytorch.org/t/monkey-patching-the-forward-pass-of-an-nn-module/176095)
# old_forward = model.forward
import types
model.forward = types.MethodType(forward, model)

### PAG 

In [9]:
hook_layer_list = [
    # 'transformer.h.0.attn', 
    # 'transformer.h.1.attn', 
    'transformer.h.2.attn', 
    'transformer.h.3.attn', 
    'transformer.h.4.attn', 
    'transformer.h.5.attn', 
    # 'transformer.h.6.attn', 
    # 'transformer.h.7.attn', 
    # 'transformer.h.8.attn', 
    # 'transformer.h.9.attn', 
    # 'transformer.h.10.attn', 
    # 'transformer.h.11.attn', 
]

tokenizer.pad_token_id = tokenizer.eos_token_id
for guidance_scale in range(15):
    guidance_traj = []

    inputs = tokenizer(["Today is"], return_tensors="pt")
    outputs = model.generate(**inputs, max_new_tokens=15, return_dict_in_generate=True, output_scores=True)

    # generated result
    print()
    print(f"### Generated text - guidance_scale {guidance_scale} ###")
    print(f'{tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)}')

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



### Generated text - guidance_scale 0 ###
Today is the day when we can all be proud of our country and our values.


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



### Generated text - guidance_scale 1 ###
Today is Election Day and the Republican Party is in a state of flux. The party


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



### Generated text - guidance_scale 2 ###
Today is Election Day and the Republican Party is in a state of flux. The party


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



### Generated text - guidance_scale 3 ###
Today is Election Day and Donald Trump is the presumptive nominee. What will happen to jobs


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



### Generated text - guidance_scale 4 ###
Today is Election Day and Donald Trump is the presumptive nominee. Do not underestimate him.


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



### Generated text - guidance_scale 5 ###
Today is Election Day and Donald Trump is the presumptive nominee. Do not ever confuse good


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



### Generated text - guidance_scale 6 ###
Today is Election Day and Mitt and I both lost because of it————" exclaimed Mitt


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



### Generated text - guidance_scale 7 ###
Today is Election Day and Mitt and I both lost because people thought I was crazy for


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



### Generated text - guidance_scale 8 ###
Today is Election Day and Mitt and I both lost because people saw me winning because people


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



### Generated text - guidance_scale 9 ###
Today is Election Day again and again throughout Southern andcentral America voting peacefully intoulsicular


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



### Generated text - guidance_scale 10 ###
Today is Election Eve and Jill continue her historic opposition to unaccountability in politics today afternoon


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



### Generated text - guidance_scale 11 ###
Today is Election Eve and Jill continue her historic opposition to unaccountability in politics today afternoon


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



### Generated text - guidance_scale 12 ###
Today is Election Season when integrity advocates stockpile stinkiolDOSDOSEmpty(), stinkivia


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



### Generated text - guidance_scale 13 ###
Today is Election Season when integrity advocates stockpile stinkiolDOS executable's withoutifying Unifiedavis

### Generated text - guidance_scale 14 ###
Today is Election Season when integrity advocates stockpile stinkiolDOS executable's withoutifying Unifiedavis


In [10]:
# Perplexity 계산 함수
# def calculate_perplexity(model_original, model_modified, tokenizer, text, guidance_scale):
#     encoded_text = tokenizer.encode(text, return_tensors='pt')
#     attention_mask = torch.ones_like(encoded_text)
    
#     with torch.no_grad():
#         outputs_original = model_original(encoded_text, attention_mask=attention_mask, labels=encoded_text)
#         loss_original = outputs_original.loss
        
#         outputs_modified = model_modified(encoded_text, attention_mask=attention_mask)
#         logits_modified = outputs_modified.logits
        
#         logits = outputs_original.logits + guidance_scale * (outputs_original.logits - logits_modified)
#         loss_modified = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), encoded_text.view(-1))
        
#         perplexity = torch.exp((loss_original + loss_modified) / 2)
    
#     return perplexity.item()

In [11]:
# Perplexity 계산 및 출력
# perplexity = calculate_perplexity(model_original, model_modified, tokenizer, generated_text, guidance_scale)
# print(f"Perplexity: {perplexity:.2f}")