In [21]:
import torch
import numpy as np
from typing import List, Dict, Optional, Tuple
from dataclasses import dataclass
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch.nn.functional as F
from captum.attr import IntegratedGradients
import logging

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

@dataclass
class Sample:
    input_text: str
    label: str
    confidence_score: Optional[float] = None
    mcs_score: Optional[float] = None
    important_words: Optional[List[str]] = None

class ProxyModel:
    def __init__(self, model_name: str = "cardiffnlp/twitter-roberta-base-sentiment"):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
        self.model.eval()
        
        # Set up label mappings to match the model's labels
        self.model.config.label2id = {'Negative': 0, 'Neutral': 1, 'Positive': 2}
        self.model.config.id2label = {0: 'Negative', 1: 'Neutral', 2: 'Positive'}
    
    def compute_probabilities(self, sample: Sample) -> Tuple[float, float]:
        if sample.label not in self.model.config.label2id:
            raise ValueError(f"Label '{sample.label}' not found in model's label2id mapping.")
        
        inputs = self.tokenizer(
            sample.input_text,
            return_tensors="pt",
            truncation=True,
            padding=True,
            max_length=512
        )
        
        with torch.no_grad():
            outputs = self.model(**inputs)
            logits = outputs.logits
            probs = F.softmax(logits, dim=1)[0]
            
            label_idx = self.model.config.label2id[sample.label]
            true_prob = probs[label_idx].item()
            pred_prob = probs.max().item()
        
        return pred_prob, true_prob

    def compute_mcs(self, sample: Sample) -> float:
        pred_prob, true_prob = self.compute_probabilities(sample)
        return pred_prob - true_prob

    def get_attribution_scores(self, sample: Sample) -> Dict[str, float]:
        if sample.label not in self.model.config.label2id:
            raise ValueError(f"Label '{sample.label}' not found in model's label2id mapping.")
        
        inputs = self.tokenizer(
            sample.input_text,
            return_tensors="pt",
            truncation=True,
            padding=True,
            max_length=512,
            add_special_tokens=True
        )
        
        input_ids = inputs['input_ids']
        attention_mask = inputs['attention_mask']
        
        label_idx = self.model.config.label2id[sample.label]
        
        # Get embeddings and set requires_grad
        embeddings = self.model.get_input_embeddings()(input_ids)
        embeddings.requires_grad_()

        # Use Integrated Gradients for attribution
        ig = IntegratedGradients(self.forward_func)
        
        attributions, delta = ig.attribute(
            embeddings,
            baselines=torch.zeros_like(embeddings),
            target=label_idx,
            additional_forward_args=attention_mask,
            return_convergence_delta=True
        )
        
        attributions = attributions.sum(dim=-1).squeeze(0)
        tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0])
        scores = attributions.detach().cpu().numpy()
        
        return dict(zip(tokens, scores))

    def forward_func(self, embeddings, attention_mask):
        outputs = self.model(inputs_embeds=embeddings, attention_mask=attention_mask)
        return outputs.logits

class SampleSelector:
    def __init__(self, strategy: str = "h-mcs"):
        self.strategy = strategy

    def select_samples(self, samples: List[Sample], k: int, proxy_model: ProxyModel) -> List[Sample]:
        if len(samples) < k:
            logger.warning("Number of samples is less than k.")
            return samples
        
        # Compute MCS scores
        for sample in samples:
            if sample.mcs_score is None:
                sample.mcs_score = proxy_model.compute_mcs(sample)
        
        if self.strategy == "random":
            return list(np.random.choice(samples, k, replace=False))
        elif self.strategy == "h-mcs":
            return sorted(samples, key=lambda x: x.mcs_score, reverse=True)[:k]
        else:
            return sorted(samples, key=lambda x: x.mcs_score)[:k]

class ExplanationGenerator:
    def __init__(self, top_k: int = 5):
        self.top_k = top_k

    def generate_explanation(self, sample: Sample, proxy_model: ProxyModel) -> List[str]:
        scores = proxy_model.get_attribution_scores(sample)
        
        # Remove special tokens and clean token representations
        filtered_scores = {}
        for token, score in scores.items():
            if token in proxy_model.tokenizer.all_special_tokens:
                continue
            clean_token = token.strip()
            if clean_token:
                filtered_scores[clean_token] = score
        
        # Sort words by absolute attribution score
        sorted_words = sorted(filtered_scores.items(), key=lambda x: abs(x[1]), reverse=True)
        # Remove duplicates while preserving order
        seen = set()
        important_words = []
        for word, _ in sorted_words:
            if word not in seen:
                seen.add(word)
                important_words.append(word)
            if len(important_words) >= self.top_k:
                break
        return important_words

class PromptGenerator:
    def generate_prompt(self, sample: Sample) -> str:
        if not sample.important_words:
            return ""
        
        unique_words = list(dict.fromkeys(sample.important_words))  # Remove duplicates while preserving order
        
        if len(unique_words) == 1:
            words_str = unique_words[0]
        else:
            words_str = ", ".join(unique_words[:-1]) + f", and {unique_words[-1]}"
        return (f"The key words '{words_str}' are important clues to predict "
                f"'{sample.label}' as the correct answer.")

class AMPLIFY:
    def __init__(self, proxy_model: ProxyModel, sample_selector: SampleSelector,
                 explanation_generator: ExplanationGenerator, prompt_generator: PromptGenerator):
        self.proxy_model = proxy_model
        self.sample_selector = sample_selector
        self.explanation_generator = explanation_generator
        self.prompt_generator = prompt_generator

    def process_samples(self, samples: List[Sample], k: int) -> List[str]:
        # Select samples based on MCS scores
        selected_samples = self.sample_selector.select_samples(samples, k, self.proxy_model)
        
        # Generate explanations and prompts
        prompts = []
        for sample in selected_samples:
            sample.important_words = self.explanation_generator.generate_explanation(
                sample, self.proxy_model)
            prompt = self.prompt_generator.generate_prompt(sample)
            prompts.append(prompt)
        
        return prompts

def main():
    # Initialize components
    proxy_model = ProxyModel("cardiffnlp/twitter-roberta-base-sentiment")
    sample_selector = SampleSelector(strategy="h-mcs")
    explanation_generator = ExplanationGenerator(top_k=5)
    prompt_generator = PromptGenerator()
    
    # Create AMPLIFY instance
    amplify = AMPLIFY(proxy_model, sample_selector, explanation_generator, prompt_generator)
    
    # Example samples
    samples = [
        Sample("The movie was absolutely fantastic, with amazing performances", "Positive"),
        Sample("The plot made no sense and the acting was terrible", "Negative"),
        Sample("While the visuals were good, the story needed work", "Neutral"),
        Sample("A masterpiece of modern cinema", "Positive"),
        Sample("Complete waste of time and money", "Negative")
    ]
    
    # Generate prompts
    logger.info("Generating prompts...")
    prompts = amplify.process_samples(samples, k=3)
    
    # Print results
    print("\nGenerated Prompts:")
    print("=" * 50)
    for i, prompt in enumerate(prompts, 1):
        print(f"\nPrompt {i}:")
        print(prompt)

if __name__ == "__main__":
    main()


INFO:__main__:Generating prompts...



Generated Prompts:

Prompt 1:
The key words 'Ġfantastic, Ġabsolutely, Ġamazing, Ġwas, and Ġmovie' are important clues to predict 'Positive' as the correct answer.

Prompt 2:
The key words 'Ġterrible, Ġno, Ġsense, Ġwas, and Ġacting' are important clues to predict 'Negative' as the correct answer.

Prompt 3:
The key words 'Ġgood, Ġwere, Ġneeded, Ġvisuals, and ,' are important clues to predict 'Neutral' as the correct answer.


In [19]:
!pip install captum

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


[33mDEPRECATION: Loading egg at /opt/anaconda3/lib/python3.12/site-packages/ninja-1.11.1.1-py3.12-macosx-11.1-arm64.egg is deprecated. pip 24.3 will enforce this behaviour change. A possible replacement is to use pip for package installation.. Discussion can be found at https://github.com/pypa/pip/issues/12330[0m[33m
[0m[33mDEPRECATION: Loading egg at /opt/anaconda3/lib/python3.12/site-packages/torchdrug-0.2.1-py3.12.egg is deprecated. pip 24.3 will enforce this behaviour change. A possible replacement is to use pip for package installation.. Discussion can be found at https://github.com/pypa/pip/issues/12330[0m[33m
[0m[33mDEPRECATION: Loading egg at /opt/anaconda3/lib/python3.12/site-packages/fair_esm-2.0.0-py3.12.egg is deprecated. pip 24.3 will enforce this behaviour change. A possible replacement is to use pip for package installation.. Discussion can be found at https://github.com/pypa/pip/issues/12330[0m[33m
[0mCollecting captum
  Downloading captum-0.7.0-py3-none-any

In [25]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Dict, Optional
from dataclasses import dataclass
from tqdm import tqdm


import torch
import torch.nn as nn
import numpy as np
from dataclasses import dataclass
from typing import List, Dict, Tuple

@dataclass
class Triplet:
    head_token: str
    relation: str
    tail_token: str
    head_pos: int
    tail_pos: int

class SemanticInductionAnalyzer:
    def __init__(self, model: nn.Module, tokenizer, tau: float = 2.2):
        self.model = model
        self.tokenizer = tokenizer
        self.tau = tau
        self.hidden_size = model.config.hidden_size
        self.num_heads = model.config.num_attention_heads
        self.num_layers = model.config.num_hidden_layers

    def extract_attention_patterns(self, input_ids: torch.Tensor) -> torch.Tensor:
        """Extract attention patterns from model's all layers and heads."""
        outputs = self.model(input_ids, output_attentions=True)
        return torch.stack(outputs.attentions, dim=0)  # [layers, batch, heads, seq, seq]

    def compute_relation_index(self, triplet: Triplet, input_ids: torch.Tensor) -> Dict[Tuple[int, int], float]:
        """Compute relation index for each attention head given a triplet."""
        attention_patterns = self.extract_attention_patterns(input_ids)
        head_indices = {}

        for layer in range(self.num_layers):
            for head in range(self.num_heads):
                # Get attention scores for current position
                attn = attention_patterns[layer, 0, head]  # [seq_len, seq_len]
                
                # Check if head attends to head token strongly
                if torch.argmax(attn[triplet.tail_pos]) == triplet.head_pos:
                    attention_ratio = attn[triplet.tail_pos, triplet.head_pos] / \
                                    attn[triplet.tail_pos].max()
                    
                    if attention_ratio > self.tau:
                        # Compute OV circuit influence
                        logits = self._compute_output_value_influence(layer, head, triplet)
                        head_indices[(layer, head)] = self._compute_index_from_logits(logits, triplet)

        return head_indices

    def _compute_output_value_influence(self, layer: int, head: int, triplet: Triplet) -> torch.Tensor:
        """Compute the influence of OV circuit on vocabulary."""
        Wv = self.model.layers[layer].attention.value
        Wo = self.model.layers[layer].attention.output
        
        OV = torch.matmul(Wv[head], Wo[head])
        return torch.matmul(OV, self.model.embeddings.weight.T)

    def _compute_index_from_logits(self, logits: torch.Tensor, triplet: Triplet) -> float:
        """Compute relation index from logits."""
        tail_id = self.tokenizer.convert_tokens_to_ids(triplet.tail_token)
        
        # Normalize logits and get score for tail token
        probs = torch.softmax(logits, dim=-1)
        tail_prob = probs[tail_id].item()
        
        # Compare to mean probability
        mean_prob = probs.mean().item()
        return max(0, tail_prob - mean_prob)

class RelationAnalysis:
    def __init__(self, analyzer: SemanticInductionAnalyzer):
        self.analyzer = analyzer
        
    def analyze_dependencies(self, text: str, dependencies: List[Tuple[str, str, str]]) -> Dict[str, Dict[Tuple[int, int], float]]:
        """Analyze syntactic dependencies in text."""
        input_ids = self.analyzer.tokenizer(text, return_tensors="pt").input_ids
        results = {}
        
        for head, rel, tail in dependencies:
            triplet = self._create_triplet(text, head, rel, tail)
            if triplet:
                results[rel] = self.analyzer.compute_relation_index(triplet, input_ids)
                
        return results
    
    def _create_triplet(self, text: str, head: str, rel: str, tail: str) -> Triplet:
        """Create triplet with token positions."""
        tokens = self.analyzer.tokenizer.tokenize(text)
        
        # Find positions of head and tail tokens
        head_pos = self._find_token_position(tokens, head)
        tail_pos = self._find_token_position(tokens, tail)
        
        if head_pos is not None and tail_pos is not None:
            return Triplet(head, rel, tail, head_pos, tail_pos)
        return None
    
    def _find_token_position(self, tokens: List[str], target: str) -> int:
        """Find position of target token in sequence."""
        target_tokens = self.analyzer.tokenizer.tokenize(target)
        for i in range(len(tokens) - len(target_tokens) + 1):
            if tokens[i:i+len(target_tokens)] == target_tokens:
                return i
        return None

@dataclass
class ExperimentConfig:
    num_shots: int
    batch_size: int 
    num_epochs: int
    checkpoints: List[int]
    tasks: List[str]

class InContextLearningExperiments:
    def __init__(self, model, tokenizer, config: ExperimentConfig):
        self.model = model
        self.tokenizer = tokenizer
        self.config = config
        
    def evaluate_format_compliance(self, checkpoint: int) -> Dict[str, float]:
        """Evaluate format compliance for different tasks at checkpoint."""
        results = {}
        for task in self.config.tasks:
            correct = 0
            total = 0
            
            for _ in range(self.config.num_epochs):
                prompts = self._generate_prompts(task, self.config.num_shots)
                outputs = self._get_model_outputs(prompts)
                
                for output in outputs:
                    if self._check_format(output, task):
                        correct += 1
                    total += 1
                    
            results[task] = correct / total
        return results
    
    def evaluate_pattern_discovery(self, checkpoint: int) -> Dict[str, float]:
        """Evaluate pattern discovery ability at checkpoint."""
        results = {}
        for task in self.config.tasks:
            correct = 0
            total = 0
            
            for _ in range(self.config.num_epochs):
                prompts = self._generate_prompts(task, self.config.num_shots)
                outputs = self._get_model_outputs(prompts)
                labels = self._get_labels(prompts)
                
                for output, label in zip(outputs, labels):
                    if output == label:
                        correct += 1
                    total += 1
                    
            results[task] = correct / total
        return results

    def analyze_loss_reduction(self, texts: List[str]) -> np.ndarray:
        """Analyze loss reduction over sequence length."""
        losses = []
        
        for text in texts:
            input_ids = self.tokenizer(text, return_tensors="pt").input_ids
            outputs = self.model(input_ids, labels=input_ids)
            
            # Compute token-wise loss
            loss = outputs.loss.view(-1).detach().numpy()
            losses.append(loss)
            
        return np.mean(losses, axis=0)

    def _generate_prompts(self, task: str, num_shots: int) -> List[str]:
        """Generate few-shot prompts for given task."""
        if task == "binary_classification":
            examples = [
                ("apple", "January", "0"),
                ("desk", "teacher", "1"),
                ("orange", "March", "0")
            ]
        elif task == "relation_justification":
            examples = [
                ("cat", "chases", "mouse", "true"),
                ("book", "reads", "student", "true"),
                ("sky", "eats", "cloud", "false")
            ]
            
        # Select random examples for few-shot prompt
        selected = np.random.choice(examples, num_shots, replace=False)
        prompts = []
        
        for example in selected:
            if task == "binary_classification":
                prompts.append(f"{example[0]}, {example[1]}: {example[2]}")
            elif task == "relation_justification":
                prompts.append(f"{example[0]}, {example[1]}, {example[2]}: {example[3]}")
                
        return prompts

    def _get_model_outputs(self, prompts: List[str]) -> List[str]:
        """Get model outputs for prompts."""
        outputs = []
        for prompt in prompts:
            input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
            output = self.model.generate(
                input_ids,
                max_new_tokens=1,
                pad_token_id=self.tokenizer.pad_token_id
            )
            outputs.append(self.tokenizer.decode(output[0][-1]))
        return outputs

class RelationVisualizer:
    def __init__(self, num_layers: int, num_heads: int):
        self.num_layers = num_layers
        self.num_heads = num_heads
        
    def plot_relation_heatmap(self, relation_indices: Dict[Tuple[int, int], float], 
                            title: str, save_path: Optional[str] = None):
        """Plot heatmap of relation indices across layers and heads."""
        matrix = np.zeros((self.num_layers, self.num_heads))
        
        for (layer, head), value in relation_indices.items():
            matrix[layer, head] = value
            
        plt.figure(figsize=(10, 8))
        sns.heatmap(matrix, cmap='YlOrRd', 
                   xticklabels=range(self.num_heads),
                   yticklabels=range(self.num_layers))
        
        plt.title(title)
        plt.xlabel('Heads')
        plt.ylabel('Layers')
        
        if save_path:
            plt.savefig(save_path)
        plt.close()
        
    def plot_training_progression(self, checkpoint_results: Dict[int, Dict[str, float]], 
                                metric: str, save_path: Optional[str] = None):
        """Plot progression of metrics across training checkpoints."""
        checkpoints = sorted(checkpoint_results.keys())
        tasks = list(checkpoint_results[checkpoints[0]].keys())
        
        plt.figure(figsize=(12, 6))
        for task in tasks:
            values = [checkpoint_results[cp][task] for cp in checkpoints]
            plt.plot(checkpoints, values, label=task, marker='o')
            
        plt.title(f'{metric} Progression During Training')
        plt.xlabel('Training Steps')
        plt.ylabel(metric)
        plt.legend()
        
        if save_path:
            plt.savefig(save_path)
        plt.close()

class ExperimentRunner:
    def __init__(self, 
                 analyzer: SemanticInductionAnalyzer,
                 experimenter: InContextLearningExperiments,
                 visualizer: RelationVisualizer):
        self.analyzer = analyzer
        self.experimenter = experimenter
        self.visualizer = visualizer
        
    def run_complete_analysis(self, texts: List[str], dependencies: List[Tuple[str, str, str]]):
        """Run complete analysis pipeline."""
        # Analyze semantic induction heads
        relation_results = {}
        for text, deps in zip(texts, dependencies):
            analysis = RelationAnalysis(self.analyzer)
            results = analysis.analyze_dependencies(text, deps)
            relation_results.update(results)
            
        # Evaluate in-context learning
        format_results = {}
        pattern_results = {}
        for checkpoint in self.experimenter.config.checkpoints:
            format_results[checkpoint] = self.experimenter.evaluate_format_compliance(checkpoint)
            pattern_results[checkpoint] = self.experimenter.evaluate_pattern_discovery(checkpoint)
            
        # Generate visualizations
        self.visualizer.plot_relation_heatmap(relation_results, "Semantic Relation Indices")
        self.visualizer.plot_training_progression(format_results, "Format Compliance")
        self.visualizer.plot_training_progression(pattern_results, "Pattern Discovery")
        
        return {
            'relation_results': relation_results,
            'format_results': format_results,
            'pattern_results': pattern_results
        }

In [26]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

def main():
    # Initialize model and tokenizer
    model_name = "internlm/internlm2-1.8b"
    model = AutoModelForCausalLM.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    # Configure experiments
    config = ExperimentConfig(
        num_shots=5,
        batch_size=16,
        num_epochs=3,
        checkpoints=[200, 400, 800, 1600, 3200],
        tasks=["binary_classification", "relation_justification"]
    )

    # Initialize analysis components
    analyzer = SemanticInductionAnalyzer(model, tokenizer)
    experimenter = InContextLearningExperiments(model, tokenizer, config)
    visualizer = RelationVisualizer(
        num_layers=model.config.num_hidden_layers,
        num_heads=model.config.num_attention_heads
    )
    runner = ExperimentRunner(analyzer, experimenter, visualizer)

    # Example text and dependencies for analysis
    texts = [
        "The cat chases the mouse in the garden",
        "We present a CRF model for Event Detection"
    ]
    
    dependencies = [
        [("cat", "subj", "chases"), ("mouse", "obj", "chases")],
        [("model", "obj", "present"), ("CRF", "mod", "model")]
    ]

    # Run specific experiments
    def run_format_compliance_experiment():
        print("Running format compliance experiment...")
        results = experimenter.evaluate_format_compliance(checkpoint=1600)
        print("Format compliance results:", results)
        
        # Visualize results
        visualizer.plot_relation_heatmap(
            {(2, 3): 0.8, (3, 1): 0.6, (4, 2): 0.7}, 
            "Format Compliance Heatmap",
            "format_compliance.png"
        )
        return results

    def run_pattern_discovery_experiment():
        print("Running pattern discovery experiment...")
        # Example text for testing pattern discovery
        test_examples = [
            "apple, January: ",  # Expected: 0
            "desk, teacher: ",   # Expected: 1
        ]
        
        results = experimenter.evaluate_pattern_discovery(checkpoint=3200)
        print("Pattern discovery results:", results)
        return results

    def analyze_semantic_relations():
        print("Analyzing semantic relations...")
        relation_analysis = RelationAnalysis(analyzer)
        
        for text, deps in zip(texts, dependencies):
            results = relation_analysis.analyze_dependencies(text, deps)
            print(f"\nResults for text: {text}")
            for rel, indices in results.items():
                print(f"Relation {rel}:", indices)

    # Run complete analysis pipeline
    def run_full_analysis():
        print("Running complete analysis...")
        results = runner.run_complete_analysis(texts, dependencies)
        
        # Print summary
        print("\nAnalysis Summary:")
        print("Number of relations analyzed:", len(results['relation_results']))
        print("Checkpoints evaluated:", len(results['format_results']))
        print("Tasks analyzed:", len(results['pattern_results']))
        
        return results


    # Example usage scenarios

    # Scenario 1: Quick format compliance check
    format_results = run_format_compliance_experiment()

    # Scenario 2: Pattern discovery analysis
    pattern_results = run_pattern_discovery_experiment()

    # Scenario 3: Semantic relation analysis
    analyze_semantic_relations()

    # Scenario 4: Complete analysis pipeline
    full_results = run_full_analysis()

    print("\nAnalysis complete!")

NameError: name 'run_format_compliance_experiment' is not defined