In [13]:
from typing import List, Tuple, Dict, Optional
import torch
import numpy as np
from transformers import GPT2Tokenizer, GPT2ForSequenceClassification
import torch.nn.functional as F
import matplotlib.pyplot as plt
from dataclasses import dataclass
from tqdm import tqdm
import random
import logging
import os
from nltk.corpus import stopwords
import nltk
nltk.download('stopwords')
# Setup logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

def clean_token(token: str) -> str:
    """Clean special tokens from GPT-2 tokenizer output."""
    # Remove 'Ġ' character that represents spaces
    return token.replace('Ġ', '')

@dataclass
class Sample:
    """Class for keeping track of a sample with its metadata."""
    text: str
    label: str
    score: Optional[float] = None
    explanation: Optional[List[str]] = None

class ProxyModel:
    """GPT-2 based proxy model as used in the paper."""
    def __init__(self, num_labels: int, device: str = "cuda" if torch.cuda.is_available() else "cpu"):
        logger.info(f"Initializing ProxyModel with {num_labels} labels on {device}")
        self.device = device
        self.num_labels = num_labels
        
        # Initialize tokenizer and model
        try:
            self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
            self.model = GPT2ForSequenceClassification.from_pretrained('gpt2', num_labels=num_labels)
            
            # Add padding token if not present
            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token
                self.model.config.pad_token_id = self.tokenizer.eos_token_id
            
            self.model.to(device)
            self.model.eval()
            
        except Exception as e:
            logger.error(f"Error initializing model: {str(e)}")
            raise

    def predict(self, text: str) -> Dict[str, float]:
        """Generate prediction scores."""
        try:
            inputs = self.tokenizer(
                text,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=512
            ).to(self.device)
            
            with torch.no_grad():
                outputs = self.model(**inputs)
            
            probs = F.softmax(outputs.logits, dim=-1)[0].cpu()
            return {str(i): float(prob) for i, prob in enumerate(probs)}
            
        except Exception as e:
            logger.error(f"Error in prediction: {str(e)}")
            return {str(i): 0.0 for i in range(self.num_labels)}

class ExplanationGenerator:
    def __init__(self, proxy_model: ProxyModel):
        self.proxy_model = proxy_model

    def vanilla_gradients(self, text: str, label: str) -> np.ndarray:
        """Generate explanations using vanilla gradients."""
        try:
            self.proxy_model.model.zero_grad()
            
            inputs = self.proxy_model.tokenizer(
                text,
                return_tensors="pt",
                padding=True,
                truncation=True
            ).to(self.proxy_model.device)
            
            embeddings = self.proxy_model.model.transformer.wte(inputs["input_ids"])
            embeddings.retain_grad()
            
            outputs = self.proxy_model.model(inputs_embeds=embeddings)
            target_class = int(label)
            
            # Just use gradients without multiplying by input
            outputs.logits[0, target_class].backward()
            
            gradients = embeddings.grad
            if gradients is None:
                return np.zeros(len(inputs["input_ids"][0]))
            
            # Take L2 norm of gradients
            token_attributions = torch.norm(gradients[0], p=2, dim=-1).cpu().numpy()
            tokens = [clean_token(token) for token in self.proxy_model.tokenizer.tokenize(text)]
            token_attributions = [attr for attr, token in zip(token_attributions, tokens) if token and token.lower() not in stopwords.words('english')]
            
            return np.array(token_attributions)
            
            return token_attributions
            
        except Exception as e:
            logger.error(f"Error in vanilla gradients: {str(e)}")
            return np.zeros(1)

    def contrastive_gradients(self, text: str, label: str) -> np.ndarray:
        """Generate contrastive explanations."""
        try:
            self.proxy_model.model.zero_grad()
            
            inputs = self.proxy_model.tokenizer(
                text,
                return_tensors="pt",
                padding=True,
                truncation=True
            ).to(self.proxy_model.device)
            
            # Get predictions first
            with torch.no_grad():
                outputs = self.proxy_model.model(**inputs)
                predicted_class = outputs.logits.argmax(dim=-1).item()
            
            # Get gradients for true class
            embeddings = self.proxy_model.model.transformer.wte(inputs["input_ids"])
            embeddings.retain_grad()
            
            outputs = self.proxy_model.model(inputs_embeds=embeddings)
            target_class = int(label)
            
            # First backward pass with retain_graph=True
            outputs.logits[0, target_class].backward(retain_graph=True)
            true_gradients = embeddings.grad.clone()
            
            # Reset gradients
            self.proxy_model.model.zero_grad()
            embeddings.grad = None
            
            # Second backward pass for predicted class
            outputs = self.proxy_model.model(inputs_embeds=embeddings)
            outputs.logits[0, predicted_class].backward()
            pred_gradients = embeddings.grad.clone()
            
            if true_gradients is None or pred_gradients is None:
                return np.zeros(len(inputs["input_ids"][0]))
            
            # Compute contrastive gradients
            contrastive = true_gradients - pred_gradients
            tokens = [clean_token(token) for token in self.proxy_model.tokenizer.tokenize(text)]
            true_gradients = [attr for attr, token in zip(true_gradients[0], tokens) if token and token.lower() not in stopwords.words('english')]
            pred_gradients = [attr for attr, token in zip(pred_gradients[0], tokens) if token and token.lower() not in stopwords.words('english')]
            
            contrastive = np.array(true_gradients) - np.array(pred_gradients)
            return np.linalg.norm(contrastive, ord=2, axis=-1)
            
        except Exception as e:
            logger.error(f"Error in contrastive gradients: {str(e)}")
            return np.zeros(1)

    def generate_explanation(self, text: str, label: str, method: str = "vanilla") -> np.ndarray:
        """Generate explanations using specified method."""
        if method == "vanilla":
            return self.vanilla_gradients(text, label)
        elif method == "contrastive":
            return self.contrastive_gradients(text, label)
        else:
            logger.warning(f"Unknown method {method}, falling back to vanilla gradients")
            return self.vanilla_gradients(text, label)

class SampleSelector:
    """Selects samples based on Misclassification Confidence Score."""
    def __init__(self, proxy_model: ProxyModel, top_k: int):
        self.proxy_model = proxy_model
        self.top_k = top_k

    def compute_mcs(self, text: str, true_label: str) -> float:
        """Compute Misclassification Confidence Score."""
        try:
            predictions = self.proxy_model.predict(text)
            true_score = predictions.get(true_label, 0)
            max_false_score = max(
                (score for label, score in predictions.items() if label != true_label),
                default=0
            )
            return max_false_score - true_score
        except Exception as e:
            logger.error(f"Error computing MCS: {str(e)}")
            return 0.0

    def select_samples(self, samples: List[Tuple[str, str]]) -> List[Sample]:
        """Select top-k samples based on MCS."""
        scored_samples = []
        
        for text, label in tqdm(samples, desc="Computing MCS scores"):
            try:
                mcs = self.compute_mcs(text, label)
                scored_samples.append(Sample(text=text, label=label, score=mcs))
            except Exception as e:
                logger.error(f"Error processing sample: {str(e)}")
                continue
        
        scored_samples.sort(key=lambda x: x.score if x.score is not None else float('-inf'), reverse=True)
        return scored_samples[:self.top_k]

class AMPLIFY:
    """Main AMPLIFY framework implementation."""
    def __init__(
        self,
        num_labels: int,
        top_k_samples: int = 3,
        top_k_keywords: int = 5,
        device: str = "cuda" if torch.cuda.is_available() else "cpu"
    ):
        logger.info("Initializing AMPLIFY framework")
        self.num_labels = num_labels
        self.device = device
        self.top_k_samples = top_k_samples
        self.top_k_keywords = top_k_keywords
        
        # Initialize components
        self.proxy_model = ProxyModel(num_labels=num_labels, device=device)
        self.explanation_generator = ExplanationGenerator(self.proxy_model)
        self.sample_selector = SampleSelector(self.proxy_model, top_k_samples)

    def visualize_scores(self, samples: List[Sample], save_path: Optional[str] = None):
        """Visualize MCS distribution."""
        scores = [s.score for s in samples if s.score is not None]
        if not scores:
            logger.warning("No scores to visualize")
            return

        plt.figure(figsize=(10, 6))
        plt.hist(scores, bins=20, color='blue', alpha=0.7)
        plt.title("Misclassification Confidence Score Distribution")
        plt.xlabel("MCS")
        plt.ylabel("Frequency")
        plt.grid(True)
        
        if save_path:
            plt.savefig(save_path)
            logger.info(f"Plot saved to {save_path}")
        else:
            plt.show()
        plt.close()

    def generate_prompt(self, samples: List[Sample]) -> str:
        """Generate the complete prompt with explanations."""
        prompt_parts = []
        
        for sample in tqdm(samples, desc="Generating explanations"):
            try:
                # Generate attributions
                attributions = self.explanation_generator.generate_explanation(
                    sample.text,
                    sample.label
                )
                
                # Get tokens and match lengths
                tokens = self.proxy_model.tokenizer.tokenize(sample.text)
                min_len = min(len(tokens), len(attributions))
                tokens = tokens[:min_len]
                attributions = attributions[:min_len]
                
                # Get top keywords and clean them
                token_scores = list(zip(tokens, attributions))
                token_scores.sort(key=lambda x: x[1], reverse=True)
                top_keywords = [clean_token(token) for token, _ in token_scores[:self.top_k_keywords]]
                
                # Store explanation
                sample.explanation = top_keywords
                
                # Format prompt
                rationale = (
                    f"The key words: {', '.join(top_keywords)} are crucial clues "
                    f"for predicting {sample.label} as the correct answer."
                )
                prompt_part = (
                    f"Input: {sample.text}\n"
                    f"Rationale: {rationale}\n"
                    f"Label: {sample.label}\n"
                )
                prompt_parts.append(prompt_part)
                
            except Exception as e:
                logger.error(f"Error generating prompt part: {str(e)}")
                continue
        
        return "\n".join(prompt_parts)

    def run(
        self,
        validation_set: List[Tuple[str, str]],
        visualize: bool = True,
        save_plot: Optional[str] = None
    ) -> str:
        """Execute the complete AMPLIFY pipeline."""
        try:
            logger.info("Starting AMPLIFY pipeline")
            
            # Step 1: Select samples
            logger.info("Selecting samples using MCS")
            selected_samples = self.sample_selector.select_samples(validation_set)
            
            # Step 2: Visualize if requested
            if visualize:
                logger.info("Generating visualization")
                self.visualize_scores(selected_samples, save_plot)
            
            # Step 3: Generate prompt
            logger.info("Generating final prompt")
            prompt = self.generate_prompt(selected_samples)
            
            return prompt
            
        except Exception as e:
            logger.error(f"Error in AMPLIFY pipeline: {str(e)}")
            raise


from IPython.display import HTML, display
import json

def create_token_visualization(samples_with_scores):
    """Create an interactive HTML visualization for token scores."""
    
    html_template = '''
    <div style="font-family: Arial, sans-serif; max-width: 900px; margin: 20px auto;">
        <style>
            .token-container {
                margin: 10px 0;
                padding: 15px;
                border: 1px solid #ddd;
                border-radius: 8px;
            }
            .token {
                display: inline-block;
                margin: 2px;
                padding: 5px 10px;
                border-radius: 4px;
                position: relative;
                cursor: pointer;
            }
            .token:hover .tooltip {
                display: block;
            }
            .tooltip {
                display: none;
                position: absolute;
                bottom: 100%;
                left: 50%;
                transform: translateX(-50%);
                background-color: #333;
                color: white;
                padding: 4px 8px;
                border-radius: 4px;
                font-size: 12px;
                white-space: nowrap;
                z-index: 100;
            }
            .prompt-container {
                margin-top: 20px;
                padding: 15px;
                background-color: #f8f9fa;
                border-radius: 8px;
            }
            .section-title {
                font-weight: bold;
                margin: 10px 0;
                color: #2c5282;
            }
        </style>
        
        <div class="section-title">Token Explanation Score Visualization</div>
        
        <div id="visualization-container">
            {% for sample in samples %}
            <div class="token-container">
                <div style="margin-bottom: 10px;"><b>Sample {{ loop.index }}:</b> Label = {{ sample.label }}</div>
                <div>
                    {% for token in sample.token_scores %}
                    <span class="token" 
                          style="background-color: rgb({{ 255 - token.score * 255 }}, {{ 255 - token.score * 255 }}, 255);">
                        {{ token.token }}
                        <span class="tooltip">Score: {{ "%.3f"|format(token.score) }}</span>
                    </span>
                    {% endfor %}
                </div>
            </div>
            {% endfor %}
        </div>

        <div class="section-title">Generated Training Prompt</div>
        <div class="prompt-container">
            <pre style="margin: 0; white-space: pre-wrap;">{{ prompt }}</pre>
        </div>
    </div>
    '''
    
    # Generate prompt
    prompt_parts = []
    for sample in samples_with_scores:
        # Get top 5 tokens by score
        top_tokens = sorted(sample['token_scores'], key=lambda x: x['score'], reverse=True)[:5]
        top_token_texts = [t['token'] for t in top_tokens]
        
        prompt_part = (
            f"Input: {sample['text']}\n"
            f"Rationale: The key words: {', '.join(top_token_texts)} are crucial clues "
            f"for predicting {sample['label']} as the correct answer.\n"
            f"Label: {sample['label']}"
        )
        prompt_parts.append(prompt_part)
    
    final_prompt = "\n\n".join(prompt_parts)
    
    # Convert scores to 0-1 range if needed
    for sample in samples_with_scores:
        scores = [t['score'] for t in sample['token_scores']]
        max_score = max(abs(min(scores)), abs(max(scores)))
        for token in sample['token_scores']:
            token['score'] = abs(token['score']) / max_score if max_score != 0 else 0

    # Replace template variables
    from jinja2 import Template
    template = Template(html_template)
    html_content = template.render(
        samples=samples_with_scores,
        prompt=final_prompt
    )
    
    return HTML(html_content)

# Example usage in notebook:
def visualize_amplify_results(framework, samples):
    """Create visualization from AMPLIFY results."""
    samples_with_scores = []
    
    for text, label in samples:
        # Get attributions
        attributions = framework.explanation_generator.generate_explanation(text, label)
        
        # Get tokens
        tokens = framework.proxy_model.tokenizer.tokenize(text)
        min_len = min(len(tokens), len(attributions))
        
        # Create token scores
        token_scores = [
            {
                "token": token.replace('Ġ', ''),
                "score": float(attributions[i])
            }
            for i, token in enumerate(tokens[:min_len])
        ]
        
        samples_with_scores.append({
            "text": text,
            "label": label,
            "token_scores": token_scores
        })
    
    return create_token_visualization(samples_with_scores)


def main():
    """Example usage of AMPLIFY framework with interactive visualization."""
    # Set random seed for reproducibility
    torch.manual_seed(42)
    random.seed(42)
    np.random.seed(42)

    # Example data (Snarks task from the paper)
    validation_data = [
        ("Is this sarcastic? 'What a great day!' he said while getting soaked in the rain.", "1"),
        ("Nice work showing up an hour late!", "1"),
        ("The weather is beautiful today.", "0"),
        ("I love waiting in long lines!", "1"),
        ("This is the best day ever!", "1"),
    ]

    try:
        # Initialize AMPLIFY
        framework = AMPLIFY(
            num_labels=2,  # Binary classification
            top_k_samples=3,
            top_k_keywords=3
        )

        # Generate prompt
        prompt = framework.run(
            validation_set=validation_data,
            visualize=True,
            save_plot="outputs/mcs_distribution.png"
        )

        # Create and display interactive visualization
        visualization = visualize_amplify_results(framework, validation_data)
        display(visualization)

        # Print raw results
        print("\nRaw Generated Prompt:")
        print("-" * 50)
        print(prompt)

        # Also save the prompt to a file
        with open("outputs/generated_prompt.txt", "w") as f:
            f.write(prompt)
        
        return {
            'framework': framework,
            'prompt': prompt,
            'validation_data': validation_data
        }

    except Exception as e:
        logger.error(f"Error in main: {str(e)}")
        raise

# For Jupyter notebook usage
if __name__ == "__main__":
    results = main()

[nltk_data] Downloading package stopwords to
[nltk_data]     /Users/lihongxuan/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
2024-12-04 21:57:44,906 - INFO - Initializing AMPLIFY framework
2024-12-04 21:57:44,906 - INFO - Initializing ProxyModel with 2 labels on cpu
Some weights of GPT2ForSequenceClassification were not initialized from the model checkpoint at gpt2 and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
2024-12-04 21:57:46,259 - INFO - Starting AMPLIFY pipeline
2024-12-04 21:57:46,259 - INFO - Selecting samples using MCS
Computing MCS scores: 100%|██████████| 5/5 [00:00<00:00, 38.10it/s]
2024-12-04 21:57:46,392 - INFO - Generating visualization
2024-12-04 21:57:46,472 - INFO - Plot saved to outputs/mcs_distribution.png
2024-12-04 21:57:46,472 - INFO - Generating final prompt
Generating explanations: 100%|██████████| 3/3 [00:00<00:00,  4.33it/s]



Raw Generated Prompt:
--------------------------------------------------
Input: Nice work showing up an hour late!
Rationale: The key words: Nice, hour, work are crucial clues for predicting 1 as the correct answer.
Label: 1

Input: I love waiting in long lines!
Rationale: The key words: I, long, love are crucial clues for predicting 1 as the correct answer.
Label: 1

Input: This is the best day ever!
Rationale: The key words: best, is, This are crucial clues for predicting 1 as the correct answer.
Label: 1



[nltk_data] Downloading package stopwords to
[nltk_data]     /Users/lihongxuan/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.


True