In [1]:
import torch
from torch import nn
from torch.nn import functional as F

from typing import Dict, List, Optional, Tuple, Union
import random
import math
import re
import time

from transformer_with_hidden import *

## Lens

In [2]:
class Lens:
    """
    Implementation of the logit and tuned lens method for visualizing intermediate layer predictions
    in transformer models.
    
    The logit and tuned lens allows us to decode hidden states at each layer of a transformer
    using the unembedding matrix to observe how predictions evolve through the network.
    """
    
    def __init__(self, model: nn.Module):
        """
        Initialize the LogitLens.
        
        Args:
            model: The transformer model
        """
        self.model = model
        self.unembed_weight = self.model.decoder.weight # weight of the Linear module in the implementation of the course
    
    def logit_lens(self, hidden_state: torch.Tensor) -> torch.Tensor:
        """
        Apply the standard logit lens to a hidden state.
        
        Args:
            hidden_state: Hidden state from an intermediate layer
            
        Returns:
            Logits obtained by projecting the hidden state through the unembedding matrix
        """
        # Project through the unembedding matrix
        logits = F.linear(hidden_state, self.unembed_weight)
        return logits
    
    def tuned_lens(self, 
                   hidden_state: torch.Tensor, 
                   translator: nn.Module) -> torch.Tensor:
        """
        Apply the tuned logit lens to a hidden state using a learned translator.
        
        Args:
            hidden_state: Hidden state from an intermediate layer
            translator: A learned affine transformation module
            
        Returns:
            Logits obtained by applying the translator and then projecting through
            the unembedding matrix
        """
        # Apply the translator
        translated_state = translator(hidden_state)
        
        # Project through the unembedding matrix
        logits = F.linear(translated_state, self.unembed_weight)
        return logits

In [3]:
class TranslatorModule(nn.Module):
    """A learned affine transformation for the tuned lens."""
    
    def __init__(self, hidden_size: int):
        """
        Initialize the translator module.
        
        Args:
            hidden_size: Size of the hidden state
        """
        super().__init__()
        self.linear = nn.Linear(hidden_size, hidden_size)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Apply the translator to a hidden state."""
        return self.linear(x)

In [4]:
class TransformerWithLens(nn.Module):
    """
    A wrapper for a transformer model that captures intermediate hidden states
    and applies the lens method.
    """
    
    def __init__(self, 
                 transformer_model: nn.Module, 
                 num_layers: int,
                 hidden_size: int,
                 
                 use_tuned_lens: bool = False):
        """
        Initialize the wrapper.
        
        Args:
            transformer_model: The transformer model to wrap
            num_layers: Number of layers in the transformer
            hidden_size: Dimensionality of the hidden states
            use_tuned_lens: Whether to use the tuned lens with learned translators
        """
        super().__init__()
        self.transformer = transformer_model
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.use_tuned_lens = use_tuned_lens
            
        # Initialize the logit lens
        self.logit_lens = Lens(transformer_model)
        
        # Initialize one translator per layer for tuned lens if needed
        if use_tuned_lens:
            self.translators = nn.ModuleList([
                TranslatorModule(hidden_size) for _ in range(num_layers)
            ])
    
    def forward(self, 
                inputs: Dict[str, torch.Tensor]
                ) -> Dict[str, Union[torch.Tensor, List[torch.Tensor]]]:
        """
        Forward pass that includes logit lens visualizations.
        
        Args:
            inputs: Input tensors for the transformer model
            
        Returns:
            A dictionary containing:
                - 'output': The original model output
                - 'logit_lens_outputs': List of logit lens outputs for each layer
                - 'tuned_lens_outputs': List of tuned lens outputs for each layer (if enabled)
        """
        # Get the original model output and the hidden states
        outputs, hidden_states = self.transformer(inputs)
                
        # Apply logit lens to each hidden state
        logit_lens_outputs = [
            self.logit_lens.logit_lens(hidden_state)
            for hidden_state in hidden_states
        ]
        
        # Apply tuned lens if enabled
        tuned_lens_outputs = None
        if self.use_tuned_lens:
            tuned_lens_outputs = [
                self.logit_lens.tuned_lens(hidden_states[i], self.translators[i])
                for i in range(self.num_layers)
            ]
        
        return {
            'output': outputs,
            'logit_lens_outputs': logit_lens_outputs,
            'tuned_lens_outputs': tuned_lens_outputs
        }

In [5]:
def visualize_predictions(logits: torch.Tensor, 
                          tokenizer, 
                          top_k: int = 5) -> List[Tuple[str, float]]:
    """
    Visualize top-k predictions from logits.
    
    Args:
        logits: The logits tensor
        tokenizer: The tokenizer for converting token IDs to strings
        top_k: Number of top predictions to return
        
    Returns:
        List of (token, probability) tuples for the top-k predictions
    """
    # Get probabilities
    probs = F.softmax(logits, dim=-1)
    
    # Get top-k predictions
    values, indices = torch.topk(probs, top_k)
    indices, values = indices.flatten().tolist(), values.flatten().tolist()
    
    # Convert to list of (token, probability) tuples
    predictions = []
    for i, idx in enumerate(indices):
        token = tokenizer.decode([idx])
        probability = values[i]
        predictions.append((token, probability))
    
    return predictions

In [6]:
def demonstrate_logit_lens(model, tokenizer, input_text: str):
    """
    Demonstrate the logit lens by visualizing predictions for all intermediate layers.
    
    Args:
        model: The transformer model
        tokenizer: The tokenizer
        input_text: The input text to process
    """
    # Prepare inputs
    inputs = torch.tensor(tokenizer.encode(input_text)).view((-1,1))
    
    # Create the wrapper
    wrapper = TransformerWithLens(model, num_layers=len(model.encoder.layers), hidden_size=model.ninp)
    
    # Forward pass
    outputs = wrapper(inputs)
    
    # Get logit lens outputs
    logit_lens_outputs = outputs['logit_lens_outputs']

    # Visualize predictions for all layers and the final prediction
    for i, layer_outputs in enumerate(logit_lens_outputs):
        predictions = visualize_predictions(layer_outputs[-1, :, :].squeeze(1), tokenizer)
        print(f"Layer {i+1} predictions:")
        for token, prob in predictions:
            print(f"  {token}: {prob:.4f}")
            
    predictions = visualize_predictions(outputs['output'][-1, :, :].squeeze(1), tokenizer)
    print(f"Final prediction:")
    for token, prob in predictions:
        print(f"  {token}: {prob:.4f}")

In [15]:
def demonstrate_tuned_lens(model, tokenizer, input_text: str):
    """
    Demonstrate the tuned lens by visualizing predictions for all intermediate layers.
    
    Args:
        model: The transformer model
        tokenizer: The tokenizer
        input_text: The input text to process
    """
    # Prepare inputs
    inputs = torch.tensor(tokenizer.encode(input_text)).view((-1,1))
    
    # Create the wrapper, attention pour les tuned lens ça ne maarchera pas comme ça, 
    # il faudra en entraîner un
    wrapper = TransformerWithLens(model, num_layers=len(model.encoder.layers), hidden_size=model.ninp, use_tuned_lens=True)
    
    # Forward pass
    outputs = wrapper(inputs)
    
    # Get tuned lens outputs
    tuned_lens_outputs = outputs['tuned_lens_outputs']

    # Visualize predictions for all layers and the final prediction
    for i, layer_outputs in enumerate(tuned_lens_outputs):
        predictions = visualize_predictions(layer_outputs[-1, :, :].squeeze(1), tokenizer)
        print(f"Layer {i+1} predictions:")
        for token, prob in predictions:
            print(f"  {token}: {prob:.4f}")
            
    predictions = visualize_predictions(outputs['output'][-1, :, :].squeeze(1), tokenizer)
    print(f"Final prediction:")
    for token, prob in predictions:
        print(f"  {token}: {prob:.4f}")

## Utility functions for tokenization

In [7]:
pad_token="[PAD]"
eos_token="[EOS]"

In [8]:
class character_level_tokenizer:
    """
    character-level
    """
    def __init__(self):
        self.vocab = [str(x) for x in range(10)] + ["+", "="] + [pad_token, eos_token]
        self.token_to_id = {v : k for k, v in enumerate(self.vocab)}
        self.id_to_token = {k : v for k, v in enumerate(self.vocab)}
        self.ntokens = len(self.vocab)
        self.pattern = f"[^{re.escape(''.join(self.vocab))}]"
    
    def clean(self, text):
        """
        removes all characters not in the vocabulary
        """
        out = re.sub(self.pattern, "", text)
        return out

    def pre_tokenization(self, text):
        """
        character-level
        """
        return [c for c in text]

    def encode(self, text):
        text_list = self.pre_tokenization(self.clean(text))
        return [self.token_to_id[c] for c in text_list]

    def decode(self, token_list):
        return "".join([self.id_to_token[x] for x in token_list])

In [9]:
tokenizer = character_level_tokenizer()

## Example usage

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


In [11]:
model = torch.load('arithmetic.pt', weights_only=False, map_location='cpu')
model.to(device)
model.eval()

TransformerModelWithHidden(
  (encoder): TransformerEncoderWithHidden(
    (layers): ModuleList(
      (0-7): 8 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
        )
        (linear1): Linear(in_features=128, out_features=64, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=64, out_features=128, bias=True)
        (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
    (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  )
  (decoder): Linear(in_features=128, out_features=14, bias=True)
  (input_emb): Embedding(14, 128)
  (pos_encoder): PositionalEncoding(
    (dropout): Dropout(p=0.5, inplace=False)
  )
)

In [26]:
demonstrate_logit_lens(model, tokenizer, input_text='123+182=')

Layer 1 predictions:
  [PAD]: 0.2270
  7: 0.1506
  8: 0.0996
  [EOS]: 0.0832
  9: 0.0672
Layer 2 predictions:
  7: 0.3376
  8: 0.1695
  [PAD]: 0.0855
  6: 0.0848
  1: 0.0807
Layer 3 predictions:
  7: 0.2038
  8: 0.1282
  6: 0.1246
  9: 0.1180
  1: 0.1065
Layer 4 predictions:
  7: 0.1626
  8: 0.1480
  6: 0.1340
  9: 0.1031
  3: 0.0791
Layer 5 predictions:
  4: 0.2591
  3: 0.1762
  2: 0.0974
  6: 0.0768
  5: 0.0687
Layer 6 predictions:
  3: 0.4060
  2: 0.2325
  4: 0.1261
  1: 0.0725
  9: 0.0508
Layer 7 predictions:
  2: 0.6465
  3: 0.3199
  1: 0.0130
  4: 0.0059
  9: 0.0046
Layer 8 predictions:
  2: 0.4921
  3: 0.4891
  1: 0.0089
  4: 0.0070
  9: 0.0008
Final prediction:
  3: 0.4946
  2: 0.4900
  1: 0.0082
  4: 0.0055
  9: 0.0005


In [30]:
# untrained attempt of tuned lens, just to wheck it works
demonstrate_tuned_lens(model, tokenizer, input_text='123+182=')

Layer 1 predictions:
  7: 0.1528
  =: 0.1143
  6: 0.1049
  5: 0.1047
  8: 0.0851
Layer 2 predictions:
  1: 0.1070
  7: 0.1036
  +: 0.0983
  0: 0.0941
  2: 0.0925
Layer 3 predictions:
  [PAD]: 0.1261
  5: 0.1106
  6: 0.1099
  4: 0.0989
  0: 0.0835
Layer 4 predictions:
  9: 0.1484
  8: 0.1175
  7: 0.1028
  6: 0.0923
  5: 0.0858
Layer 5 predictions:
  =: 0.1911
  +: 0.1415
  [EOS]: 0.0878
  7: 0.0740
  9: 0.0735
Layer 6 predictions:
  9: 0.1438
  [PAD]: 0.1322
  =: 0.1047
  8: 0.1047
  +: 0.0992
Layer 7 predictions:
  =: 0.1707
  +: 0.1541
  3: 0.1203
  [EOS]: 0.1080
  2: 0.0884
Layer 8 predictions:
  [EOS]: 0.2078
  2: 0.1442
  1: 0.1148
  0: 0.0867
  3: 0.0844
Final prediction:
  3: 0.4946
  2: 0.4900
  1: 0.0082
  4: 0.0055
  9: 0.0005


In [31]:
wrapper = TransformerWithLens(model, num_layers=len(model.encoder.layers), hidden_size=model.ninp, use_tuned_lens=True)

In [33]:
wrapper.translators

ModuleList(
  (0-7): 8 x TranslatorModule(
    (linear): Linear(in_features=128, out_features=128, bias=True)
  )
)