In [1]:
!pip install transformers



In [2]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import numpy as np
from typing import Dict, List, Tuple
import matplotlib.pyplot as plt

In [3]:
class SuperWeightFinder:
    def __init__(self, model_name: str, device: str = "cuda"):
        """
        Initialize the finder with a specified model

        Args:
            model_name: Name or path of the model
            device: Device to load the model on ("cuda", "cpu", or specific CUDA device like "cuda:0")
        """
        self.device = device
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            device_map={"": torch.device(device)},
            torch_dtype=torch.float16 if "cuda" in device else torch.float32
        )
        self.model.eval()

    def register_hooks(self):
        """Register forward hooks to capture activations"""
        self.down_proj_inputs = {}
        self.down_proj_outputs = {}

        def hook_fn(layer_idx, is_input):
            def hook(module, input_tensors, output_tensors):
                if is_input:
                    self.down_proj_inputs[layer_idx] = input_tensors[0].detach()
                else:
                    self.down_proj_outputs[layer_idx] = output_tensors.detach()
            return hook

        for idx, layer in enumerate(self.model.model.layers):
            layer.mlp.down_proj.register_forward_hook(hook_fn(idx, True))
            layer.mlp.down_proj.register_forward_hook(hook_fn(idx, False))

#    def verify_super_weight(self, candidate: Dict, test_prompt: str = "My favorite condiment is") -> Dict:
    def verify_super_weight(self, candidate: Dict, test_prompt: str = "Q: トマトは何色ですか？\nA: ") -> Dict:
        """Super weight候補の重要性を検証"""
        layer = self.model.model.layers[candidate['layer']]
        original_weight = float(layer.mlp.down_proj.weight[candidate['row'], candidate['col']])

        inputs = self.tokenizer(test_prompt, return_tensors="pt").to(self.device)

        with torch.no_grad():
            original_output = self.model.generate(
                inputs.input_ids,
                max_length=50,
                num_return_sequences=1,
                pad_token_id=self.tokenizer.eos_token_id
            )
        original_text = self.tokenizer.decode(original_output[0], skip_special_tokens=True)

        with torch.no_grad():
            layer.mlp.down_proj.weight[candidate['row'], candidate['col']] = 0.0

        with torch.no_grad():
            modified_output = self.model.generate(
                inputs.input_ids,
                max_length=50,
                num_return_sequences=1,
                pad_token_id=self.tokenizer.eos_token_id
            )
        modified_text = self.tokenizer.decode(modified_output[0], skip_special_tokens=True)

        with torch.no_grad():
            layer.mlp.down_proj.weight[candidate['row'], candidate['col']] = original_weight

        return {
            'candidate': candidate,
            'original_output': original_text,
            'modified_output': modified_text,
            'impact': original_text != modified_text
        }
    def find_super_weights(self, sample_text: str, magnitude_threshold: float = 50.0) -> List[Dict]:
        """
        論文の手法に基づいてSuper weightを特定する（改良版）

        Args:
            sample_text: 分析に使用するサンプルテキスト
            magnitude_threshold: 活性化値の大きさの閾値
        """
        self.register_hooks()
        inputs = self.tokenizer(sample_text, return_tensors="pt").to(self.device)

        with torch.no_grad():
            self.model(**inputs)

        # レイヤーごとの最大活性化値を記録
        layer_activations = {}
        for layer_idx in self.down_proj_inputs.keys():
            input_tensor = self.down_proj_inputs[layer_idx]  # [batch, seq_len, hidden_dim]
            output_tensor = self.down_proj_outputs[layer_idx]  # [batch, seq_len, output_dim]

            # 各位置での最大活性化値を計算
            input_max = torch.amax(torch.abs(input_tensor), dim=(0, 1))  # [hidden_dim]
            output_max = torch.amax(torch.abs(output_tensor), dim=(0, 1))  # [output_dim]

            layer_activations[layer_idx] = {
                'input_max': input_max,
                'output_max': output_max,
                'input_tensor': input_tensor,
                'output_tensor': output_tensor
            }

        # Super weightの候補を見つける
        super_weight_candidates = []
        for layer_idx, activations in layer_activations.items():
            layer = self.model.model.layers[layer_idx]
            weight_matrix = layer.mlp.down_proj.weight  # [output_dim, hidden_dim]

            # 入力の大きな活性化値を持つチャネルを特定
            input_peaks = torch.where(activations['input_max'] > magnitude_threshold)[0]

            # 出力の大きな活性化値を持つチャネルを特定
            output_peaks = torch.where(activations['output_max'] > magnitude_threshold)[0]

            if len(input_peaks) > 0 and len(output_peaks) > 0:
                for out_idx in output_peaks:
                    for in_idx in input_peaks:
                        weight_value = float(weight_matrix[out_idx, in_idx])

                        # Super weightの候補を記録
                        candidate = {
                            'layer': layer_idx,
                            'row': int(out_idx),
                            'col': int(in_idx),
                            'weight_value': weight_value,
                            'input_magnitude': float(activations['input_max'][in_idx]),
                            'output_magnitude': float(activations['output_max'][out_idx])
                        }

                        # 後続のレイヤーでのsuper activationの持続を確認
                        subsequent_magnitudes = []
                        for subsequent_layer in range(layer_idx + 1, len(self.model.model.layers)):
                            if subsequent_layer in layer_activations:
                                out_magnitude = float(layer_activations[subsequent_layer]['output_max'][out_idx])
                                subsequent_magnitudes.append(out_magnitude)

                        candidate['subsequent_magnitudes'] = subsequent_magnitudes
                        candidate['persistence_score'] = np.mean(subsequent_magnitudes) if subsequent_magnitudes else 0

                        super_weight_candidates.append(candidate)

        # 候補をスコアでソート
        # 1. 活性化値の大きさ
        # 2. 後続レイヤーでの持続性
        super_weight_candidates.sort(
            key=lambda x: (x['input_magnitude'] * x['output_magnitude'] * x['persistence_score']),
            reverse=True
        )

        return super_weight_candidates

    def calculate_perplexity(self, text: str, candidate: Dict = None) -> float:
        """
        テキストのperplexityを計算する。candidateが指定された場合は、そのSuper weightを0にした状態で計算

        Args:
            text: 評価するテキスト
            candidate: Super weightの候補（Noneの場合は元のモデルでperplexityを計算）

        Returns:
            float: 計算されたperplexity
        """
        # 入力をトークナイズ
        inputs = self.tokenizer(text, return_tensors="pt").to(self.device)

        original_weight = None
        if candidate is not None:
            # Super weightの値を一時的に保存して0に設定
            layer = self.model.model.layers[candidate['layer']]
            original_weight = float(layer.mlp.down_proj.weight[candidate['row'], candidate['col']])
            with torch.no_grad():
                layer.mlp.down_proj.weight[candidate['row'], candidate['col']] = 0.0

        # perplexityの計算
        with torch.no_grad():
            outputs = self.model(**inputs)
            logits = outputs.logits

            # シフトしたターゲットを作成
            target_ids = inputs.input_ids[..., 1:]
            shift_logits = logits[..., :-1, :]

            # loss計算
            loss_fct = torch.nn.CrossEntropyLoss()
            shift_logits = shift_logits.reshape(-1, shift_logits.size(-1))
            target_ids = target_ids.reshape(-1)
            loss = loss_fct(shift_logits, target_ids)

            perplexity = torch.exp(loss).item()

        # Super weightを元に戻す
        if original_weight is not None:
            with torch.no_grad():
                layer.mlp.down_proj.weight[candidate['row'], candidate['col']] = original_weight

        return perplexity

    def evaluate_candidates(self, candidates: List[Dict], eval_text: str) -> List[Dict]:
        """
        候補のSuper weightを評価する

        Args:
            candidates: Super weightの候補リスト
            eval_text: 評価に使用するテキスト

        Returns:
            List[Dict]: 評価結果を含む候補リスト
        """
        # 元のモデルのperplexityを計算
        base_perplexity = self.calculate_perplexity(eval_text)
        print(f"Base perplexity: {base_perplexity:.2f}")

        evaluated_candidates = []
        for i, candidate in enumerate(candidates):
            # Super weightを0にした状態でのperplexityを計算
            modified_perplexity = self.calculate_perplexity(eval_text, candidate)

            # perplexityの変化率を計算
            perplexity_change = ((modified_perplexity - base_perplexity) / base_perplexity) * 100

            # 評価結果を追加
            candidate_with_eval = {
                **candidate,
                'base_perplexity': base_perplexity,
                'modified_perplexity': modified_perplexity,
                'perplexity_change_percent': perplexity_change
            }
            evaluated_candidates.append(candidate_with_eval)

            print(f"Candidate {i+1}/{len(candidates)}:")
            print(f"Layer: {candidate['layer']}, Position: ({candidate['row']}, {candidate['col']})")
            print(f"Modified perplexity: {modified_perplexity:.2f}")
            print(f"Perplexity change: {perplexity_change:+.2f}%")
            print("-" * 50)

        # perplexityの変化が大きい順にソート
        evaluated_candidates.sort(key=lambda x: abs(x['perplexity_change_percent']), reverse=True)

        return evaluated_candidates

In [4]:
# テスターのインスタンスを作成
finder = SuperWeightFinder("llm-jp/llm-jp-3-3.7b",  device="cuda")
#finder = SuperWeightFinder("mistralai/Mistral-7B-v0.1",  device="cuda")
#finder = SuperWeightFinder("meta-llama/Llama-2-7b-hf",  device="cuda")
#finder = SuperWeightFinder("meta-llama/Llama-3.2-3B",  device="cuda")
#finder = SuperWeightFinder("meta-llama/Llama-3.1-8B",  device="cuda")
#finder = SuperWeightFinder("sbintuitions/sarashina2-7b",  device="cuda")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [5]:
# テストテキストを使ってSuper weightを探す
test_text = "富士山は山梨県と静岡県に跨る活火山である。標高3776.12 m、日本最高峰の独立峰で、その優美な風貌は日本国外でも日本の象徴として広く知られている。"
#test_text = "The highest mountain in Japan is"
candidates = finder.find_super_weights(test_text)

In [6]:
# 評価用のテキスト
eval_text = """ドメイン特化のLLMが一つのトレンドで、医療や金融では盛んに開発が行われていることもあり、モデルやデータセット、ベンチマークなどは充実しています。"""

# 候補を評価
evaluated_candidates = finder.evaluate_candidates(candidates, eval_text)

# 結果を表示
print("\nTop 3 most influential Super weights:")
for i, candidate in enumerate(evaluated_candidates[:3]):
    print(f"\nCandidate {i+1}:")
    print(f"Layer: {candidate['layer']}")
    print(f"Position: ({candidate['row']}, {candidate['col']})")
    print(f"Weight value: {candidate['weight_value']:.4f}")
    print(f"Perplexity impact: {candidate['perplexity_change_percent']:+.2f}%")
    print(f"Base perplexity: {candidate['base_perplexity']:.2f}")
    print(f"Modified perplexity: {candidate['modified_perplexity']:.2f}")

Base perplexity: 63.09
Candidate 1/265:
Layer: 26, Position: (1161, 2736)
Modified perplexity: 68.19
Perplexity change: +8.07%
--------------------------------------------------
Candidate 2/265:
Layer: 26, Position: (1264, 2736)
Modified perplexity: 63.09
Perplexity change: +0.00%
--------------------------------------------------
Candidate 3/265:
Layer: 7, Position: (1161, 546)
Modified perplexity: 75.50
Perplexity change: +19.66%
--------------------------------------------------
Candidate 4/265:
Layer: 7, Position: (1264, 546)
Modified perplexity: 72.06
Perplexity change: +14.21%
--------------------------------------------------
Candidate 5/265:
Layer: 26, Position: (1161, 3464)
Modified perplexity: 62.84
Perplexity change: -0.40%
--------------------------------------------------
Candidate 6/265:
Layer: 7, Position: (1161, 5535)
Modified perplexity: 59.50
Perplexity change: -5.70%
--------------------------------------------------
Candidate 7/265:
Layer: 26, Position: (1264, 3464)

In [7]:
def infer(model, tokenizer, text):
    inputs = tokenizer(text, return_tensors="pt").to(model.device)
    with torch.no_grad():
        outputs = model.generate(
            inputs.input_ids,
            max_length=32,
            num_return_sequences=1,
            pad_token_id=tokenizer.eos_token_id
        )
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

In [8]:
def verify_super_weight(model, tokenizer, layer_idx, row, col, text):
    original_output = infer(model, tokenizer, text)
    layer = model.model.layers[layer_idx]
    original_weight = float(layer.mlp.down_proj.weight[row, col])
    with torch.no_grad():
        layer.mlp.down_proj.weight[row, col] = 0.0
    sw_output = infer(model, tokenizer, text)
    print("Original output:", original_output)
    print("Modified output:", sw_output)
    with torch.no_grad():
        layer.mlp.down_proj.weight[row, col] = original_weight

In [11]:
verify_super_weight(finder.model, finder.tokenizer, 1, 2938, 7633, "日本で一番高い山は、")

Original output: 日本で一番高い山は、富士山です。

富士山は標高3776mで、日本で一番高い山です。

富士山は日本で一番高い
Modified output: 日本で一番高い山は、日本で一番高い山は、日本で一番高い山は、日本で一番高い山は、日本で一番高い山は、日本


In [12]:
verify_super_weight(finder.model, finder.tokenizer, 1, 1161, 546, "日本で一番高い山は、")

Original output: 日本で一番高い山は、富士山です。

富士山は標高3776mで、日本で一番高い山です。

富士山は日本で一番高い
Modified output: 日本で一番高い山は、富士山です。

富士山は標高3776mで、日本で一番高い山です。

富士山は日本で一番高い
