In [21]:
import torch
import numpy as np
from typing import List, Dict, Optional, Tuple
from dataclasses import dataclass
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch.nn.functional as F
from captum.attr import IntegratedGradients
import logging

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

@dataclass
class Sample:
    input_text: str
    label: str
    confidence_score: Optional[float] = None
    mcs_score: Optional[float] = None
    important_words: Optional[List[str]] = None

class ProxyModel:
    def __init__(self, model_name: str = "cardiffnlp/twitter-roberta-base-sentiment"):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
        self.model.eval()
        
        # Set up label mappings to match the model's labels
        self.model.config.label2id = {'Negative': 0, 'Neutral': 1, 'Positive': 2}
        self.model.config.id2label = {0: 'Negative', 1: 'Neutral', 2: 'Positive'}
    
    def compute_probabilities(self, sample: Sample) -> Tuple[float, float]:
        if sample.label not in self.model.config.label2id:
            raise ValueError(f"Label '{sample.label}' not found in model's label2id mapping.")
        
        inputs = self.tokenizer(
            sample.input_text,
            return_tensors="pt",
            truncation=True,
            padding=True,
            max_length=512
        )
        
        with torch.no_grad():
            outputs = self.model(**inputs)
            logits = outputs.logits
            probs = F.softmax(logits, dim=1)[0]
            
            label_idx = self.model.config.label2id[sample.label]
            true_prob = probs[label_idx].item()
            pred_prob = probs.max().item()
        
        return pred_prob, true_prob

    def compute_mcs(self, sample: Sample) -> float:
        pred_prob, true_prob = self.compute_probabilities(sample)
        return pred_prob - true_prob

    def get_attribution_scores(self, sample: Sample) -> Dict[str, float]:
        if sample.label not in self.model.config.label2id:
            raise ValueError(f"Label '{sample.label}' not found in model's label2id mapping.")
        
        inputs = self.tokenizer(
            sample.input_text,
            return_tensors="pt",
            truncation=True,
            padding=True,
            max_length=512,
            add_special_tokens=True
        )
        
        input_ids = inputs['input_ids']
        attention_mask = inputs['attention_mask']
        
        label_idx = self.model.config.label2id[sample.label]
        
        # Get embeddings and set requires_grad
        embeddings = self.model.get_input_embeddings()(input_ids)
        embeddings.requires_grad_()

        # Use Integrated Gradients for attribution
        ig = IntegratedGradients(self.forward_func)
        
        attributions, delta = ig.attribute(
            embeddings,
            baselines=torch.zeros_like(embeddings),
            target=label_idx,
            additional_forward_args=attention_mask,
            return_convergence_delta=True
        )
        
        attributions = attributions.sum(dim=-1).squeeze(0)
        tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0])
        scores = attributions.detach().cpu().numpy()
        
        return dict(zip(tokens, scores))

    def forward_func(self, embeddings, attention_mask):
        outputs = self.model(inputs_embeds=embeddings, attention_mask=attention_mask)
        return outputs.logits

class SampleSelector:
    def __init__(self, strategy: str = "h-mcs"):
        self.strategy = strategy

    def select_samples(self, samples: List[Sample], k: int, proxy_model: ProxyModel) -> List[Sample]:
        if len(samples) < k:
            logger.warning("Number of samples is less than k.")
            return samples
        
        # Compute MCS scores
        for sample in samples:
            if sample.mcs_score is None:
                sample.mcs_score = proxy_model.compute_mcs(sample)
        
        if self.strategy == "random":
            return list(np.random.choice(samples, k, replace=False))
        elif self.strategy == "h-mcs":
            return sorted(samples, key=lambda x: x.mcs_score, reverse=True)[:k]
        else:
            return sorted(samples, key=lambda x: x.mcs_score)[:k]

class ExplanationGenerator:
    def __init__(self, top_k: int = 5):
        self.top_k = top_k

    def generate_explanation(self, sample: Sample, proxy_model: ProxyModel) -> List[str]:
        scores = proxy_model.get_attribution_scores(sample)
        
        # Remove special tokens and clean token representations
        filtered_scores = {}
        for token, score in scores.items():
            if token in proxy_model.tokenizer.all_special_tokens:
                continue
            clean_token = token.strip()
            if clean_token:
                filtered_scores[clean_token] = score
        
        # Sort words by absolute attribution score
        sorted_words = sorted(filtered_scores.items(), key=lambda x: abs(x[1]), reverse=True)
        # Remove duplicates while preserving order
        seen = set()
        important_words = []
        for word, _ in sorted_words:
            if word not in seen:
                seen.add(word)
                important_words.append(word)
            if len(important_words) >= self.top_k:
                break
        return important_words

class PromptGenerator:
    def generate_prompt(self, sample: Sample) -> str:
        if not sample.important_words:
            return ""
        
        unique_words = list(dict.fromkeys(sample.important_words))  # Remove duplicates while preserving order
        
        if len(unique_words) == 1:
            words_str = unique_words[0]
        else:
            words_str = ", ".join(unique_words[:-1]) + f", and {unique_words[-1]}"
        return (f"The key words '{words_str}' are important clues to predict "
                f"'{sample.label}' as the correct answer.")

class AMPLIFY:
    def __init__(self, proxy_model: ProxyModel, sample_selector: SampleSelector,
                 explanation_generator: ExplanationGenerator, prompt_generator: PromptGenerator):
        self.proxy_model = proxy_model
        self.sample_selector = sample_selector
        self.explanation_generator = explanation_generator
        self.prompt_generator = prompt_generator

    def process_samples(self, samples: List[Sample], k: int) -> List[str]:
        # Select samples based on MCS scores
        selected_samples = self.sample_selector.select_samples(samples, k, self.proxy_model)
        
        # Generate explanations and prompts
        prompts = []
        for sample in selected_samples:
            sample.important_words = self.explanation_generator.generate_explanation(
                sample, self.proxy_model)
            prompt = self.prompt_generator.generate_prompt(sample)
            prompts.append(prompt)
        
        return prompts

def main():
    # Initialize components
    proxy_model = ProxyModel("cardiffnlp/twitter-roberta-base-sentiment")
    sample_selector = SampleSelector(strategy="h-mcs")
    explanation_generator = ExplanationGenerator(top_k=5)
    prompt_generator = PromptGenerator()
    
    # Create AMPLIFY instance
    amplify = AMPLIFY(proxy_model, sample_selector, explanation_generator, prompt_generator)
    
    # Example samples
    samples = [
        Sample("The movie was absolutely fantastic, with amazing performances", "Positive"),
        Sample("The plot made no sense and the acting was terrible", "Negative"),
        Sample("While the visuals were good, the story needed work", "Neutral"),
        Sample("A masterpiece of modern cinema", "Positive"),
        Sample("Complete waste of time and money", "Negative")
    ]
    
    # Generate prompts
    logger.info("Generating prompts...")
    prompts = amplify.process_samples(samples, k=3)
    
    # Print results
    print("\nGenerated Prompts:")
    print("=" * 50)
    for i, prompt in enumerate(prompts, 1):
        print(f"\nPrompt {i}:")
        print(prompt)

if __name__ == "__main__":
    main()


INFO:__main__:Generating prompts...



Generated Prompts:

Prompt 1:
The key words 'Ġfantastic, Ġabsolutely, Ġamazing, Ġwas, and Ġmovie' are important clues to predict 'Positive' as the correct answer.

Prompt 2:
The key words 'Ġterrible, Ġno, Ġsense, Ġwas, and Ġacting' are important clues to predict 'Negative' as the correct answer.

Prompt 3:
The key words 'Ġgood, Ġwere, Ġneeded, Ġvisuals, and ,' are important clues to predict 'Neutral' as the correct answer.


In [19]:
!pip install captum

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


[33mDEPRECATION: Loading egg at /opt/anaconda3/lib/python3.12/site-packages/ninja-1.11.1.1-py3.12-macosx-11.1-arm64.egg is deprecated. pip 24.3 will enforce this behaviour change. A possible replacement is to use pip for package installation.. Discussion can be found at https://github.com/pypa/pip/issues/12330[0m[33m
[0m[33mDEPRECATION: Loading egg at /opt/anaconda3/lib/python3.12/site-packages/torchdrug-0.2.1-py3.12.egg is deprecated. pip 24.3 will enforce this behaviour change. A possible replacement is to use pip for package installation.. Discussion can be found at https://github.com/pypa/pip/issues/12330[0m[33m
[0m[33mDEPRECATION: Loading egg at /opt/anaconda3/lib/python3.12/site-packages/fair_esm-2.0.0-py3.12.egg is deprecated. pip 24.3 will enforce this behaviour change. A possible replacement is to use pip for package installation.. Discussion can be found at https://github.com/pypa/pip/issues/12330[0m[33m
[0mCollecting captum
  Downloading captum-0.7.0-py3-none-any