In [5]:
import torch
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
from IPython.display import display, HTML
import json

class SemanticICLFramework:
    def __init__(self, model_name):
        """Initialize framework with pretrained model."""
        self.model_name = model_name
        self.model = AutoModelForCausalLM.from_pretrained(model_name, output_attentions=True)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.tokenizer.pad_token = self.tokenizer.eos_token

    def get_layer_weight(self, layer_idx, weight_name):
        """Get layer weights in a flexible way."""
        return self.model.get_submodule(f'transformer.h.{layer_idx}.attn').get_parameter(weight_name)

    def analyze_attention_heads(self, input_ids, attention_mask):
        """Analyze attention patterns from model."""
        outputs = self.model(input_ids, attention_mask=attention_mask, output_attentions=True)
        attentions = outputs.attentions
        
        # Get OV circuits for all layers
        ov_circuits = []
        for i in range(len(attentions)):
            ov_circuits.append(self.get_layer_weight(i, 'c_proj.weight'))
            
        return attentions, ov_circuits

    def compute_relation_index(self, attentions, ov_circuits, triplets, input_ids, relation_type, tau=1e-4):
        vocab_matrix = self.model.get_output_embeddings().weight
        num_layers = len(attentions)
        num_heads = attentions[0].size(1) 
        relation_indices = np.zeros((num_layers, num_heads))
        
        tokens = [self.tokenizer.decode(t) for t in input_ids[0].cpu().numpy()]
        
        for layer_idx, layer_attn in enumerate(attentions):
            for head_idx in range(num_heads):
                head_scores = []
                
                for triplet in triplets:
                    head, rel, tail = triplet
                    
                    head_positions = [i for i, t in enumerate(tokens) if head in t]
                    tail_positions = [i for i, t in enumerate(tokens) if tail in t]
                    
                    if not head_positions or not tail_positions:
                        continue
                        
                    # Use average position for multi-token entities
                    head_idx_pos = sum(head_positions) / len(head_positions)
                    tail_idx_pos = sum(tail_positions) / len(tail_positions)

                    attention_weights = layer_attn[0, head_idx, :, :].detach().cpu().numpy()
                    
                    # Apply thresholding
                    if attention_weights[int(tail_idx_pos), int(head_idx_pos)] / np.max(attention_weights[int(tail_idx_pos), :]) <= tau:
                        continue
                    
                    # Get OV influence and normalize by maximum value
                    ov_influence = torch.matmul(ov_circuits[layer_idx][head_idx], vocab_matrix.T)
                    ov_influence = ov_influence / ov_influence.max()
                    
                    # Calculate normalized score 
                    qk_score = attention_weights[int(tail_idx_pos), int(head_idx_pos)]
                    ov_score = ov_influence[int(tail_idx_pos)].item()
                    
                    # Adjust scoring based on relation type
                    if relation_type == 'syntactic':
                        score = qk_score * ov_score
                    elif relation_type == 'semantic':
                        score = qk_score * ov_score * self.get_semantic_factor(rel)
                    else:
                        raise ValueError(f"Unknown relation type: {relation_type}")
                    
                    head_scores.append(score)

                # Take mean across triplets instead of sum
                if head_scores:
                    relation_indices[layer_idx, head_idx] = np.mean(head_scores)

        # Normalize final scores to [0,1] range
        relation_indices = (relation_indices - relation_indices.min()) / (relation_indices.max() - relation_indices.min())
        
        return relation_indices

    def get_semantic_factor(self, relation):
        """Return a factor based on the semantic relation type."""
        semantic_factors = {
            'Part-of': 1.2,
            'Compare': 1.0,
            'Used-for': 1.1,
            'Feature-of': 1.0,
            'Hyponym-of': 1.3,
            'Evaluate-for': 1.0,
            'Conjunction': 0.9
        }
        return semantic_factors.get(relation, 1.0)


    def visualize_relation_indices(framework, relation_indices):
        display(create_heatmap_html(relation_indices))

from IPython.display import HTML, Javascript

def create_heatmap_html(relation_indices, title="Semantic Relationship Analysis"):
    data = relation_indices.tolist()
    
    html = f"""
    <div id="heatmap-container" style="width: 800px; font-family: Arial, sans-serif; position: relative;">
        <div style="margin-bottom: 20px;">
            <h2 style="color: #333;">{title}</h2>
        </div>
        
        <div style="display: flex; gap: 20px;">
            <!-- Heatmap -->
            <div style="flex: 2; position: relative;">
                <div id="heatmap" style="position: relative; height: 400px;">
                    <!-- Heatmap will be rendered here -->
                </div>
                <div style="position: absolute; bottom: -40px; left: 0; right: 0; text-align: center; color: #333; font-size: 16px; font-weight: bold;">
                    Attention Heads
                </div>
                <div style="position: absolute; top: 50%; left: -40px; transform: rotate(-90deg) translateX(-50%); transform-origin: left; color: #333; font-size: 16px; font-weight: bold;">
                    Layers
                </div>
            </div>
            
            <!-- Explanation Panel -->
            <div style="flex: 1; background: #f5f5f5; padding: 15px; border-radius: 5px;">
                <h3 style="margin-top: 0;">How to Read This Visualization</h3>
                <ul style="padding-left: 20px; color: #444;">
                    <li>Each cell represents an attention head</li>
                    <li>Darker blue indicates stronger relationship encoding</li>
                    <li>Rows represent layers in the transformer</li>
                    <li>Columns show attention heads within each layer</li>
                </ul>
                
                <h3>Key Patterns in In-Context Learning</h3>
                <ul style="padding-left: 20px; color: #444;">
                    <li>Early layers often show basic pattern recognition</li>
                    <li>Middle layers demonstrate more complex feature extraction</li>
                    <li>Later layers may reveal task-specific adaptations</li>
                    <li>Some heads might specialize in particular relationships or tasks</li>
                    <li>Overall pattern typically progresses from general to specific understanding</li>
                </ul>
            </div>
        </div>
    </div>

    <script>
    function createHeatmap() {{
        const data = {json.dumps(data)};
        const container = document.getElementById('heatmap');
        if (!container) {{
            console.error('Heatmap container not found');
            return;
        }}
        
        data.forEach((row, i) => {{
            const rowDiv = document.createElement('div');
            rowDiv.style.display = 'flex';
            rowDiv.style.height = (400 / data.length) + 'px';
            
            row.forEach((value, j) => {{
                const cell = document.createElement('div');
                cell.style.width = (100 / row.length) + '%';
                cell.style.height = '100%';
                cell.style.backgroundColor = `rgb(${{Math.floor(255 * (1 - value))}}, ${{Math.floor(255 * (1 - value))}}, 255)`;
                cell.style.border = '1px solid white';
                cell.title = `Layer ${{i+1}}, Head ${{j+1}}: ${{value.toFixed(3)}}`;
                rowDiv.appendChild(cell);
            }});
            
            container.appendChild(rowDiv);
        }});
    }}

    // Run after a short delay to ensure DOM is ready
    setTimeout(createHeatmap, 100);
    </script>
    """

    return HTML(html)

# def create_heatmap_html(relation_indices, title="Semantic Relationship Analysis"):
#     data = relation_indices.tolist()
    
#     html = f"""
#     <div id="heatmap-container" style="width: 800px; font-family: Arial, sans-serif;">
#         <div style="margin-bottom: 20px;">
#             <h2 style="color: #333;">{title}</h2>
#             <p style="color: #666;">Analyzing how attention heads encode semantic relationships</p>
#         </div>
        
#         <div style="display: flex; gap: 20px;">
#             <!-- Heatmap -->
#             <div style="flex: 2;">
#                 <div id="heatmap" style="position: relative;">
#                     <!-- Heatmap will be rendered here -->
#                 </div>
#                 <div style="margin-top: 10px; text-align: center; color: #666;">
#                     <div>Attention Heads</div>
#                 </div>
#                 <div style="position: absolute; left: -30px; top: 50%; transform: rotate(-90deg); color: #666;">
#                     Layers
#                 </div>
#             </div>
            
#             <!-- Explanation Panel -->
#             <div style="flex: 1; background: #f5f5f5; padding: 15px; border-radius: 5px;">
#                 <h3 style="margin-top: 0;">How to Read This Visualization</h3>
#                 <ul style="padding-left: 20px; color: #444;">
#                     <li>Each cell represents an attention head</li>
#                     <li>Darker blue indicates stronger relationship encoding</li>
#                     <li>Rows represent layers in the transformer</li>
#                     <li>Columns show attention heads within each layer</li>
#                 </ul>
                
#                 <h3>Key Patterns</h3>
#                 <ul style="padding-left: 20px; color: #444;">
#                     <li>Middle layers often encode more semantic relationships</li>
#                     <li>Some heads specialize in specific relationships</li>
#                     <li>Earlier layers capture simpler relationships</li>
#                     <li>Later layers show more complex patterns</li>
#                 </ul>
#             </div>
#         </div>
#     </div>
#     """

#     js = f"""
#     function createHeatmap() {{
#         const data = {json.dumps(data)};
#         const container = document.getElementById('heatmap');
#         if (!container) {{
#             console.error('Heatmap container not found');
#             return;
#         }}
        
#         data.forEach((row, i) => {{
#             const rowDiv = document.createElement('div');
#             rowDiv.style.display = 'flex';
#             rowDiv.style.height = '20px';
            
#             row.forEach((value, j) => {{
#                 const cell = document.createElement('div');
#                 cell.style.width = '20px';
#                 cell.style.height = '100%';
#                 cell.style.backgroundColor = `rgb(${{Math.floor(255 * (1 - value))}}, ${{Math.floor(255 * (1 - value))}}, 255)`;
#                 cell.style.border = '1px solid white';
#                 cell.title = `Layer ${{i+1}}, Head ${{j+1}}: ${{value.toFixed(3)}}`;
#                 rowDiv.appendChild(cell);
#             }});
            
#             container.appendChild(rowDiv);
#         }});
#     }}

#     // Run after a short delay to ensure DOM is ready
#     setTimeout(createHeatmap, 100);
#     """

#     return HTML(html), Javascript(js)





framework = SemanticICLFramework("gpt2")

# Example semantic relationship triplets
examples = [
    {
        "text": "The pen is used for writing.",
        "triplets": [("pen", "Used-for", "writing")]
    },
    {
        "text": "A cat chases a mouse.", 
        "triplets": [("cat", "Chases", "mouse")]
    }
]

relation_indices_list = []

# Process examples
for example in examples:
    # Tokenize input
    tokens = framework.tokenizer(example["text"], 
                                return_tensors="pt",
                                truncation=True, 
                                padding=True)
    
    # Get attention patterns
    attentions, ov_circuits = framework.analyze_attention_heads(
        tokens["input_ids"], 
        tokens["attention_mask"]
    )
    
    # Compute relation indices
    relation_indices = framework.compute_relation_index(
        attentions,
        ov_circuits, 
        example["triplets"],
        tokens["input_ids"],
        relation_type="semantic"
    )
    
    relation_indices_list.append(relation_indices)

# Average results and visualize
mean_relation_indices = np.mean(relation_indices_list, axis=0)
framework.visualize_relation_indices(mean_relation_indices)

In [3]:
import torch
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
from IPython.display import display, HTML
import json

class SemanticICLFramework:
    def __init__(self, model_name):
        """Initialize framework with pretrained model."""
        self.model_name = model_name
        self.model = AutoModelForCausalLM.from_pretrained(model_name, output_attentions=True)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.tokenizer.pad_token = self.tokenizer.eos_token

    def get_layer_weight(self, layer_idx, weight_name):
        """Get layer weights in a flexible way."""
        return self.model.get_submodule(f'transformer.h.{layer_idx}.attn').get_parameter(weight_name)

    def analyze_attention_heads(self, input_ids, attention_mask):
        """Analyze attention patterns from model."""
        outputs = self.model(input_ids, attention_mask=attention_mask, output_attentions=True)
        attentions = outputs.attentions
        
        # Get OV circuits for all layers
        ov_circuits = []
        for i in range(len(attentions)):
            ov_circuits.append(self.get_layer_weight(i, 'c_proj.weight'))
            
        return attentions, ov_circuits

    def compute_relation_index(self, attentions, ov_circuits, triplets, input_ids, relation_type, tau=1e-4):
        vocab_matrix = self.model.get_output_embeddings().weight
        num_layers = len(attentions)
        num_heads = attentions[0].size(1) 
        relation_indices = np.zeros((num_layers, num_heads))
        
        tokens = [self.tokenizer.decode(t) for t in input_ids[0].cpu().numpy()]
        
        for layer_idx, layer_attn in enumerate(attentions):
            for head_idx in range(num_heads):
                head_scores = []
                
                for triplet in triplets:
                    head, rel, tail = triplet
                    
                    head_positions = [i for i, t in enumerate(tokens) if head in t]
                    tail_positions = [i for i, t in enumerate(tokens) if tail in t]
                    
                    if not head_positions or not tail_positions:
                        continue
                        
                    # Use average position for multi-token entities
                    head_idx_pos = sum(head_positions) / len(head_positions)
                    tail_idx_pos = sum(tail_positions) / len(tail_positions)

                    attention_weights = layer_attn[0, head_idx, :, :].detach().cpu().numpy()
                    
                    # Apply thresholding
                    if attention_weights[int(tail_idx_pos), int(head_idx_pos)] / np.max(attention_weights[int(tail_idx_pos), :]) <= tau:
                        continue
                    
                    # Get OV influence and normalize by maximum value
                    ov_influence = torch.matmul(ov_circuits[layer_idx][head_idx], vocab_matrix.T)
                    ov_influence = ov_influence / ov_influence.max()
                    
                    # Calculate normalized score 
                    qk_score = attention_weights[int(tail_idx_pos), int(head_idx_pos)]
                    ov_score = ov_influence[int(tail_idx_pos)].item()
                    
                    # Adjust scoring based on relation type
                    if relation_type == 'syntactic':
                        score = qk_score * ov_score
                    elif relation_type == 'semantic':
                        score = qk_score * ov_score * self.get_semantic_factor(rel)
                    else:
                        raise ValueError(f"Unknown relation type: {relation_type}")
                    
                    head_scores.append(score)

                # Take mean across triplets instead of sum
                if head_scores:
                    relation_indices[layer_idx, head_idx] = np.mean(head_scores)

        # Normalize final scores to [0,1] range
        relation_indices = (relation_indices - relation_indices.min()) / (relation_indices.max() - relation_indices.min())
        
        return relation_indices

    def get_semantic_factor(self, relation):
        """Return a factor based on the semantic relation type."""
        semantic_factors = {
            'Part-of': 1.2,
            'Compare': 1.0,
            'Used-for': 1.1,
            'Feature-of': 1.0,
            'Hyponym-of': 1.3,
            'Evaluate-for': 1.0,
            'Conjunction': 0.9
        }
        return semantic_factors.get(relation, 1.0)


    def visualize_relation_indices(framework, relation_indices):
        html, js = create_heatmap_html(relation_indices)
        display(html, js)

# def create_heatmap_html(relation_indices, title="Semantic Relationship Analysis"):
#     data = relation_indices.tolist()
    
#     html = f"""
#     <div id="heatmap-container" style="width: 800px; font-family: Arial, sans-serif;">
#         <div style="margin-bottom: 20px;">
#             <h2 style="color: #333;">{title}</h2>
#             <p style="color: #666;">Analyzing how attention heads encode semantic relationships</p>
#         </div>
        
#         <div style="display: flex; gap: 20px;">
#             <div style="flex: 2;">
#                 <div id="heatmap" style="position: relative;">
#                     <!-- Heatmap will be rendered here -->
#                 </div>
#                 <div style="margin-top: 10px; text-align: center; color: #666;">
#                     <div>Attention Heads</div>
#                 </div>
#                 <div style="position: absolute; left: -30px; top: 50%; transform: rotate(-90deg); color: #666;">
#                     Layers
#                 </div>
#             </div>
            
#             <div style="flex: 1; background: #f5f5f5; padding: 15px; border-radius: 5px;">
#                 <h3 style="margin-top: 0;">How to Read This Visualization</h3>
#                 <ul style="padding-left: 20px; color: #444;">
#                     <li>Each cell represents an attention head</li>
#                     <li>Darker blue indicates stronger relationship encoding</li>
#                     <li>Rows represent layers in the transformer</li>
#                     <li>Columns show attention heads within each layer</li>
#                 </ul>
                
#                 <h3>Key Patterns</h3>
#                 <ul style="padding-left: 20px; color: #444;">
#                     <li>Middle layers often encode more semantic relationships</li>
#                     <li>Some heads specialize in specific relationships</li>
#                     <li>Earlier layers capture simpler relationships</li>
#                     <li>Later layers show more complex patterns</li>
#                 </ul>
#             </div>
#         </div>
#     </div>
#     """

#     js = f"""
#     function createHeatmap() {{
#         const data = {json.dumps(data)};
#         const container = document.getElementById('heatmap');
#         if (!container) {{
#             console.error('Heatmap container not found');
#             return;
#         }}
        
#         data.forEach((row, i) => {{
#             const rowDiv = document.createElement('div');
#             rowDiv.style.display = 'flex';
#             rowDiv.style.height = '20px';
            
#             row.forEach((value, j) => {{
#                 const cell = document.createElement('div');
#                 cell.style.width = '20px';
#                 cell.style.height = '100%';
#                 cell.style.backgroundColor = `rgb(${{Math.floor(255 * (1 - value))}}, ${{Math.floor(255 * (1 - value))}}, 255)`;
#                 cell.style.border = '1px solid white';
#                 cell.title = `Layer ${{i+1}}, Head ${{j+1}}: ${{value.toFixed(3)}}`;
#                 rowDiv.appendChild(cell);
#             }});
            
#             container.appendChild(rowDiv);
#         }});
#     }}

#     createHeatmap();
#     """

#     return HTML(html), Javascript(js)
# # Usage example

from IPython.display import HTML, Javascript

def create_heatmap_html(relation_indices, title="Semantic Relationship Analysis"):
    data = relation_indices.tolist()
    
    html = f"""
    <div id="heatmap-container" style="width: 800px; font-family: Arial, sans-serif;">
        <div style="margin-bottom: 20px;">
            <h2 style="color: #333;">{title}</h2>
            <p style="color: #666;">Analyzing how attention heads encode semantic relationships</p>
        </div>
        
        <div style="display: flex; gap: 20px;">
            <!-- Heatmap -->
            <div style="flex: 2;">
                <div id="heatmap" style="position: relative;">
                    <!-- Heatmap will be rendered here -->
                </div>
                <div style="margin-top: 10px; text-align: center; color: #666;">
                    <div>Attention Heads</div>
                </div>
                <div style="position: absolute; left: -30px; top: 50%; transform: rotate(-90deg); color: #666;">
                    Layers
                </div>
            </div>
            
            <!-- Explanation Panel -->
            <div style="flex: 1; background: #f5f5f5; padding: 15px; border-radius: 5px;">
                <h3 style="margin-top: 0;">How to Read This Visualization</h3>
                <ul style="padding-left: 20px; color: #444;">
                    <li>Each cell represents an attention head</li>
                    <li>Darker blue indicates stronger relationship encoding</li>
                    <li>Rows represent layers in the transformer</li>
                    <li>Columns show attention heads within each layer</li>
                </ul>
                
                <h3>Key Patterns</h3>
                <ul style="padding-left: 20px; color: #444;">
                    <li>Middle layers often encode more semantic relationships</li>
                    <li>Some heads specialize in specific relationships</li>
                    <li>Earlier layers capture simpler relationships</li>
                    <li>Later layers show more complex patterns</li>
                </ul>
            </div>
        </div>
    </div>
    """

    js = f"""
    function createHeatmap() {{
        const data = {json.dumps(data)};
        const container = document.getElementById('heatmap');
        if (!container) {{
            console.error('Heatmap container not found');
            return;
        }}
        
        data.forEach((row, i) => {{
            const rowDiv = document.createElement('div');
            rowDiv.style.display = 'flex';
            rowDiv.style.height = '20px';
            
            row.forEach((value, j) => {{
                const cell = document.createElement('div');
                cell.style.width = '20px';
                cell.style.height = '100%';
                cell.style.backgroundColor = `rgb(${{Math.floor(255 * (1 - value))}}, ${{Math.floor(255 * (1 - value))}}, 255)`;
                cell.style.border = '1px solid white';
                cell.title = `Layer ${{i+1}}, Head ${{j+1}}: ${{value.toFixed(3)}}`;
                rowDiv.appendChild(cell);
            }});
            
            container.appendChild(rowDiv);
        }});
    }}

    // Run after a short delay to ensure DOM is ready
    setTimeout(createHeatmap, 100);
    """

    return HTML(html), Javascript(js)





framework = SemanticICLFramework("gpt2")

# Example semantic relationship triplets
examples = [
    {
        "text": "The pen is used for writing.",
        "triplets": [("pen", "Used-for", "writing")]
    },
    {
        "text": "A cat chases a mouse.", 
        "triplets": [("cat", "Chases", "mouse")]
    }
]

relation_indices_list = []

# Process examples
for example in examples:
    # Tokenize input
    tokens = framework.tokenizer(example["text"], 
                                return_tensors="pt",
                                truncation=True, 
                                padding=True)
    
    # Get attention patterns
    attentions, ov_circuits = framework.analyze_attention_heads(
        tokens["input_ids"], 
        tokens["attention_mask"]
    )
    
    # Compute relation indices
    relation_indices = framework.compute_relation_index(
        attentions,
        ov_circuits, 
        example["triplets"],
        tokens["input_ids"],
        relation_type="semantic"
    )
    
    relation_indices_list.append(relation_indices)

# Average results and visualize
mean_relation_indices = np.mean(relation_indices_list, axis=0)
framework.visualize_relation_indices(mean_relation_indices)

<IPython.core.display.Javascript object>