In [37]:
import torch
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset

class SemanticICLFramework:
    def __init__(self, model_name, dataset_name):
        """
        Initialize the framework with a model and dataset.

        Parameters:
            model_name (str): The Hugging Face model name.
            dataset_name (str): The dataset name (from Hugging Face).
        """
        self.model_name = model_name
        self.dataset_name = dataset_name

        # Load model and tokenizer
        self.model = AutoModelForCausalLM.from_pretrained(model_name, output_attentions=True)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.tokenizer.pad_token = self.tokenizer.eos_token

        # Load dataset
        self.dataset = load_dataset(dataset_name)

        # Preprocess dataset
        self.processed_data = self.preprocess_dataset()

    def preprocess_dataset(self):
        """
        Preprocess the dataset by tokenizing and extracting triplets.

        Returns:
            dict: Preprocessed train and test datasets.
        """
        processed = {}
        for split in self.dataset:
            processed[split] = [self.preprocess_example(example) for example in self.dataset[split]]
        return processed

    def preprocess_example(self, example):
        """
        Preprocess a single example by tokenizing text and extracting triplets.

        Parameters:
            example (dict): A single example from the dataset.

        Returns:
            dict: Preprocessed example with tokens and triplets.
        """
        text = self.extract_text(example)
        tokens = self.tokenizer(text, truncation=True, padding=True)
        triplets = self.extract_triplets(example)
        return {"text": text, "tokens": tokens, "triplets": triplets}

    def extract_text(self, example):
        """
        Extract text from the dataset example.

        Parameters:
            example (dict): A single example from the dataset.

        Returns:
            str: Extracted text.
        """
        if self.dataset_name == "thunlp/few_rel":
            return " ".join(example.get("tokens", []))
        elif self.dataset_name == "Babelscape/rebel-dataset":
            return example.get("context", "")
        elif self.dataset_name == "thu-coai/kd_conv_with_kb":
            return example.get("content", "")
        else:
            raise ValueError("Unsupported dataset for text extraction.")

    def extract_triplets(self, example):
        """
        Extract triplets (head, relation, tail) from the example.

        Parameters:
            example (dict): A single example from the dataset.

        Returns:
            list: List of triplets (head, relation, tail).
        """
        if self.dataset_name == "thunlp/few_rel":
            head = example.get("head", {}).get("text", "")
            tail = example.get("tail", {}).get("text", "")
            relation = example.get("relation", "")
            return [(head, relation, tail)]
        elif self.dataset_name == "Babelscape/rebel-dataset": # x
            return example.get("triplets", [])
        elif self.dataset_name == "thu-coai/kd_conv_with_kb": # x
            head = example.get("name", "")
            relation = example.get("attrname", "")
            tail = example.get("attrvalue", "")
            return [(head, relation, tail)]
        else:
            return []

    def analyze_attention_heads(self, input_ids, attention_mask):
        """
        Analyze attention heads for Query-Key and Output-Value circuits across various model architectures.

        Parameters:
            input_ids (torch.Tensor): Tokenized input IDs.
            attention_mask (torch.Tensor): Attention mask for padding.

        Returns:
            tuple: QK circuits, OV circuits, and raw attentions.
        """
        outputs = self.model(input_ids, attention_mask=attention_mask, output_attentions=True)
        attentions = outputs.attentions  # List of attention matrices from each layer

        qk_circuits, ov_circuits = [], []

        for layer_idx, layer in enumerate(self.get_layers(self.model)):
            if hasattr(layer, "attn") and hasattr(layer.attn, "c_attn"):  # GPT-like models
                attn = layer.attn
                qkv_proj = attn.c_attn.weight  # Combined QKV projections
                q_proj = qkv_proj[:, :self.model.config.n_embd]
                k_proj = qkv_proj[:, self.model.config.n_embd:2*self.model.config.n_embd]
                v_proj = qkv_proj[:, 2*self.model.config.n_embd:]
                o_proj = attn.c_proj.weight
            elif hasattr(layer, "attention"):  # BERT-like models
                attn = layer.attention.self
                q_proj = attn.query.weight
                k_proj = attn.key.weight
                v_proj = attn.value.weight
                o_proj = layer.attention.output.dense.weight
            elif hasattr(layer, "self_attn"):  # T5/BART-like models
                attn = layer.self_attn
                q_proj = attn.q_proj.weight
                k_proj = attn.k_proj.weight
                v_proj = attn.v_proj.weight
                o_proj = attn.out_proj.weight
            else:
                raise ValueError(f"Unsupported layer structure at layer {layer_idx}")

            # Ensure dimensions match for matrix multiplication
            if v_proj.shape[1] != o_proj.shape[0]:
                raise ValueError(f"Matrix dimension mismatch in layer {layer_idx}: v_proj {v_proj.shape}, o_proj {o_proj.shape}")

            # Compute QK and OV circuits
            qk = q_proj @ k_proj.T
            ov = v_proj @ o_proj.T

            qk_circuits.append(qk)
            ov_circuits.append(ov)

        return qk_circuits, ov_circuits, attentions


    def get_layers(self, model):
        """
        Dynamically retrieve the layers of the model based on its architecture.

        Parameters:
            model (transformers.PreTrainedModel): The Hugging Face model.

        Returns:
            list: List of model layers.
        """
        if hasattr(model, "transformer") and hasattr(model.transformer, "h"):  # GPT-like models
            return model.transformer.h
        elif hasattr(model, "encoder") and hasattr(model.encoder, "layer"):  # Encoder-decoder models
            return model.encoder.layer
        elif hasattr(model, "encoder") and hasattr(model.encoder, "layers"):  # Some T5 variants
            return model.encoder.layers
        else:
            raise ValueError("Unsupported model architecture")

    


    def compute_relation_index(self, attentions, triplets, input_ids):
        """
        Compute relation indices for attention heads.

        Parameters:
            attentions: Raw attention weights from the model.
            triplets: Extracted triplets (head, relation, tail).
            input_ids: Tokenized input IDs for the input sequence.

        Returns:
            float: Average relation index for attention heads.
        """
        relation_indices = []
        token_to_idx = {self.tokenizer.decode([id_]): idx for idx, id_ in enumerate(input_ids[0])}

        for head_attn in attentions:  # Shape: (num_heads, seq_len, seq_len)
            for triplet in triplets:
                head, _, tail = triplet
                head_idx = token_to_idx.get(head, None)
                tail_idx = token_to_idx.get(tail, None)

                if head_idx is None or tail_idx is None:
                    # Skip triplets where tokens are not found
                    continue

                # Compute the attention score between the head and tail tokens
                relation_score = head_attn[:, head_idx, tail_idx].mean().item()
                relation_indices.append(relation_score)

        return np.mean(relation_indices) if relation_indices else 0.0


    def monitor_icl(self, dataset, steps=100):
        """
        Monitor In-Context Learning (ICL) abilities over time.

        Parameters:
            dataset: Preprocessed dataset.
            steps (int): Number of steps to evaluate.

        Returns:
            dict: Loss reduction, format compliance, and pattern discovery metrics.
        """
        loss_reduction, format_compliance, pattern_discovery = [], [], []

        for step, example in enumerate(dataset[:steps]):
            tokens = self.tokenizer(example["text"], return_tensors="pt", truncation=True, padding=True)
            labels = tokens["input_ids"]
            outputs = self.model(**tokens, labels=labels)

            # Loss reduction
            loss = outputs.loss.item()
            loss_reduction.append(loss)

            # Format compliance (correct structure of output)
            predictions = torch.argmax(outputs.logits, dim=-1)
            compliance = (predictions == labels).float().mean().item()
            format_compliance.append(compliance)

            # Pattern discovery (accurate prediction)
            correct_predictions = (predictions == labels).float().sum().item()
            pattern_discovery.append(correct_predictions / len(labels))

        return {
            "loss_reduction": loss_reduction,
            "format_compliance": format_compliance,
            "pattern_discovery": pattern_discovery,
        }

    def correlate_attention_with_icl(self, relation_indices, icl_metrics):
        """
        Compute correlation between attention head behaviors and ICL metrics.

        Parameters:
            relation_indices: Relation index values for attention heads.
            icl_metrics: ICL metrics (e.g., loss reduction).

        Returns:
            float: Correlation coefficient.
        """
        correlation = np.corrcoef(relation_indices, icl_metrics)
        return correlation


# Example Usage
framework = SemanticICLFramework(model_name="gpt2", dataset_name="thunlp/few_rel")
processed_data = framework.processed_data["train"]

relation_indices = []
for example in processed_data[:100]:
    tokens = framework.tokenizer(example["text"], return_tensors="pt", truncation=True, padding=True)
    input_ids, attention_mask = tokens["input_ids"], tokens["attention_mask"]
    qk_circuits, ov_circuits, attentions = framework.analyze_attention_heads(input_ids, attention_mask)
    
    # Pass input_ids to map tokens to indices
    relation_idx = framework.compute_relation_index(attentions, example["triplets"], input_ids)
    relation_indices.append(relation_idx)


# Correlate with ICL
icl_metrics = framework.monitor_icl(processed_data, steps=100)
correlation = framework.correlate_attention_with_icl(relation_indices, icl_metrics["loss_reduction"])
print(f"Correlation between relation indices and loss reduction: {correlation}")


In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset

class SemanticICLFramework:
    def __init__(self, model_name, dataset_name):
        """
        Initialize the framework with a model and dataset.

        Parameters:
            model_name (str): The Hugging Face model name.
            dataset_name (str): The dataset name (from Hugging Face).
        """
        self.model_name = model_name
        self.dataset_name = dataset_name

        # Load model and tokenizer
        self.model = AutoModelForCausalLM.from_pretrained(model_name, output_attentions=True)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.tokenizer.pad_token = self.tokenizer.eos_token

        # Load dataset
        self.dataset = load_dataset(dataset_name)

        # Preprocess dataset
        self.processed_data = self.preprocess_dataset()

    def preprocess_dataset(self):
        """
        Preprocess the dataset by tokenizing and extracting triplets.
        """
        processed = {}
        for split in self.dataset:
            processed[split] = [self.preprocess_example(example) for example in self.dataset[split]]
        return processed

    def preprocess_example(self, example):
        """
        Preprocess a single example by tokenizing text and extracting triplets.
        """
        text = self.extract_text(example)
        tokens = self.tokenizer(text, truncation=True, padding=True)
        triplets = self.extract_triplets(example)
        return {"text": text, "tokens": tokens, "triplets": triplets}

    def extract_text(self, example):
        """
        Extract text from the dataset example.
        """
        if self.dataset_name == "thunlp/few_rel":
            return " ".join(example.get("tokens", []))
        elif self.dataset_name == "Babelscape/rebel-dataset":
            return example.get("context", "")
        elif self.dataset_name == "thu-coai/kd_conv_with_kb":
            return example.get("content", "")
        else:
            raise ValueError("Unsupported dataset for text extraction.")

    def extract_triplets(self, example):
        """
        Extract triplets (head, relation, tail) from the example.
        """
        if self.dataset_name == "thunlp/few_rel":
            head = example.get("head", {}).get("text", "")
            tail = example.get("tail", {}).get("text", "")
            relation = example.get("relation", "")
            return [(head, relation, tail)]
        elif self.dataset_name == "Babelscape/rebel-dataset":
            return example.get("triplets", [])
        elif self.dataset_name == "thu-coai/kd_conv_with_kb":
            head = example.get("name", "")
            relation = example.get("attrname", "")
            tail = example.get("attrvalue", "")
            return [(head, relation, tail)]
        else:
            return []

    def analyze_attention_heads(self, input_ids, attention_mask):
        """
        Analyze attention heads for Query-Key and Output-Value circuits across various model architectures.
        """
        outputs = self.model(input_ids, attention_mask=attention_mask, output_attentions=True)
        attentions = outputs.attentions  # List of attention matrices from each layer
        return attentions

    def analyze_qk_ov_circuits(self, attentions):
        """
        Analyze QK and OV circuits to understand token relationships and output transformations.
        """
        qk_insights, ov_insights = {}, {}
        for layer_idx, attn_layer in enumerate(attentions):
            qk_matrix = np.mean(attn_layer[:, :, :, :], axis=0)  # Average over heads
            ov_matrix = np.mean(attn_layer[:, :, :, :], axis=0)  # Average over heads

            qk_insights[layer_idx] = qk_matrix
            ov_insights[layer_idx] = ov_matrix

        return {"QK": qk_insights, "OV": ov_insights}

    def layer_head_analysis(self, attentions):
        """
        Perform layer-wise and head-specific analysis of attention matrices.
        """
        layer_head_patterns = {}
        for layer_idx, layer_attn in enumerate(attentions):
            head_patterns = []
            for head_idx in range(layer_attn.shape[1]):  # Iterate over attention heads
                avg_attn = layer_attn[:, head_idx, :, :].mean(axis=0)
                head_patterns.append(avg_attn)

            layer_head_patterns[layer_idx] = head_patterns
        return layer_head_patterns

    def attribute_attention(self, attentions, triplets, input_ids):
        """
        Attribute attention scores to triplet components.
        """
        attribution_scores = []
        tokens = self.tokenizer.convert_ids_to_tokens(input_ids[0])

        for triplet in triplets:
            head, relation, tail = triplet
            head_idx = tokens.index(head) if head in tokens else None
            tail_idx = tokens.index(tail) if tail in tokens else None

            if head_idx is not None and tail_idx is not None:
                avg_attn = np.mean([attn[:, head_idx, tail_idx].mean() for attn in attentions])
                attribution_scores.append((triplet, avg_attn))

        return attribution_scores

    def visualize_layer_dynamics(self, layer_head_patterns):
        """
        Visualize attention dynamics across layers and heads.
        """
        for layer_idx, heads in layer_head_patterns.items():
            plt.figure(figsize=(10, 6))
            for head_idx, head_attn in enumerate(heads):
                plt.plot(head_attn.mean(axis=0), label=f"Head {head_idx}")
            plt.title(f"Layer {layer_idx} Attention Dynamics")
            plt.xlabel("Tokens")
            plt.ylabel("Attention Score")
            plt.legend()
            plt.show()

    def correlate_attention_with_metrics(self, relation_indices, icl_metrics):
        """
        Compute and visualize correlation between attention and ICL metrics.
        """
        correlation = np.corrcoef(relation_indices, icl_metrics["loss_reduction"])[0, 1]

        plt.figure(figsize=(8, 6))
        plt.scatter(relation_indices, icl_metrics["loss_reduction"], alpha=0.7)
        plt.title("Correlation Between Attention and Loss Reduction")
        plt.xlabel("Relation Index")
        plt.ylabel("Loss Reduction")
        plt.grid(True)
        plt.show()

        return correlation


# Example Usage
framework = SemanticICLFramework(model_name="gpt2", dataset_name="Babelscape/rebel-dataset")
processed_data = framework.processed_data["train"]

relation_indices = []
for example in processed_data[:5]:
    tokens = framework.tokenizer(example["text"], return_tensors="pt", truncation=True, padding=True)
    input_ids, attention_mask = tokens["input_ids"], tokens["attention_mask"]
    attentions = framework.analyze_attention_heads(input_ids, attention_mask)

    # Visualize attention for the first example
    framework.visualize_layer_dynamics(framework.layer_head_analysis(attentions))

    # Compute triplet attribution
    attribution = framework.attribute_attention(attentions, example["triplets"], input_ids)
    print(f"Attention Attribution: {attribution}")

# Example Correlation Visualization
icl_metrics = {"loss_reduction": [0.1, 0.2, 0.15, 0.18, 0.12]}
correlation = framework.correlate_attention_with_metrics(relation_indices, icl_metrics)
print(f"Correlation: {correlation}")


ModuleNotFoundError: No module named 'datasets'

In [2]:
!pip install datasets

Collecting datasets
  Using cached datasets-3.1.0-py3-none-any.whl.metadata (20 kB)
Collecting pyarrow>=15.0.0 (from datasets)
  Downloading pyarrow-18.1.0-cp312-cp312-macosx_12_0_arm64.whl.metadata (3.3 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp312-cp312-macosx_11_0_arm64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py312-none-any.whl.metadata (7.2 kB)
Using cached datasets-3.1.0-py3-none-any.whl (480 kB)
Downloading multiprocess-0.70.16-py312-none-any.whl (146 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m146.7/146.7 kB[0m [31m901.8 kB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hDownloading pyarrow-18.1.0-cp312-cp312-macosx_12_0_arm64.whl (29.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m29.5/29.5 MB[0m [31m34.6 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hDownloading xxhash-3.5.0-cp312-cp312-macosx_11_0_arm64.whl (30 kB)
Installing collecte