In [1]:
import torch
import torch.nn as nn

class Seq2SeqEncoder(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int, num_layers: int):
        """
        Encoder for seq2seq model.
        Args:
            input_dim: Dimensionality of input vectors
            hidden_dim: Hidden state size of the LSTM
            num_layers: Number of LSTM layers
        """
        super().__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)

    def forward(self, x):
        """
        Args:
            x: Input tensor of shape (batch_size, seq_len, input_dim)
        Returns:
            encoder_outputs: Outputs for each time step (batch_size, seq_len, hidden_dim)
            hidden: Tuple of (h_n, c_n) (last hidden and cell states)
        """
        encoder_outputs, hidden = self.lstm(x)
        return encoder_outputs, hidden


class Seq2SeqDecoder(nn.Module):
    def __init__(self, hidden_dim: int, output_dim: int, num_layers: int):
        """
        Decoder for seq2seq model.
        Args:
            hidden_dim: Dimensionality of the encoded hidden state
            output_dim: Dimensionality of output vectors (same as input_dim)
            num_layers: Number of LSTM layers (should match encoder)
        """
        super().__init__()
        self.lstm = nn.LSTM(hidden_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x, hidden):
        """
        Args:
            x: Input tensor to the decoder (batch_size, seq_len, hidden_dim)
            hidden: Tuple of (h_n, c_n) from the encoder
        Returns:
            outputs: Decoded sequence (batch_size, seq_len, output_dim)
        """
        lstm_out, hidden = self.lstm(x, hidden)
        outputs = self.fc(lstm_out)
        return outputs, hidden


class Seq2SeqModel(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int, num_layers: int):
        """
        Combines encoder and decoder into a seq2seq model.
        Args:
            input_dim: Dimensionality of input vectors
            hidden_dim: Hidden state size
            num_layers: Number of LSTM layers
        """
        super().__init__()
        self.encoder = Seq2SeqEncoder(input_dim, hidden_dim, num_layers)
        self.decoder = Seq2SeqDecoder(hidden_dim, input_dim, num_layers)

    def forward(self, x):
        """
        Args:
            x: Input sequence of shape (batch_size, seq_len, input_dim)
        Returns:
            output: Reconstructed sequence of shape (batch_size, seq_len, input_dim)
        """
        # Encoder
        encoder_outputs, hidden = self.encoder(x)

        # Decoder (Use encoder outputs as initial input)
        output, _ = self.decoder(encoder_outputs, hidden)
        return output

In [2]:
from datasets import load_dataset
def load_data(max_length=512):
    print("Loading dataset from HuggingFace...")
    dataset = load_dataset("wikitext", "wikitext-2-raw-v1")
    train_texts = dataset['train']['text']

    filtered_texts = [text for text in train_texts if text.strip() and len(text.split()) <= max_length]
    print(f"Processed {len(filtered_texts)} samples after filtering")

    return filtered_texts

In [3]:
texts = load_data()

Loading dataset from HuggingFace...
Processed 23758 samples after filtering


In [7]:
print(len(texts))
print(texts[1000])

23758
 After Deconstruction and Ghost , Townsend announced a new album , Casualties of Cool , with which he started to work after the release of Epicloud . The album features Ché Aimee Dorval ( from Ki ) on vocals and Morgan Ågren on drums . Townsend described the album sounds like " haunted Johnny Cash songs " and " late night music " , highlighting it will be different than anything he has done before . Townsend referred the music of the album to be " closest to his heart " at this point of his life , and that it is an important and satisfying project he doesn 't want to rush . 



In [10]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

class LlamaIntermediateLayerExtractor:
    def __init__(self, model_1b_path, model_3b_path):
        """
        Initialize extractor for Llama 3.2 1b and 3b models

        Args:
            model_1b_path (str): Path or HuggingFace model ID for 1b model
            model_3b_path (str): Path or HuggingFace model ID for 3b model
        """
        # Load models and tokenizers
        self.model_1b = AutoModelForCausalLM.from_pretrained(model_1b_path)
        self.model_3b = AutoModelForCausalLM.from_pretrained(model_3b_path)

        self.tokenizer_1b = AutoTokenizer.from_pretrained("tokenizers/meta-llama/llama-3.2-1B")
        self.tokenizer_3b = AutoTokenizer.from_pretrained("tokenizers/meta-llama/llama-3.2-3B")

        # Set models to evaluation mode
        self.model_1b.eval()
        self.model_3b.eval()

        # Hooks to capture intermediate layer outputs
        self.intermediate_output_1b = None
        self.intermediate_output_3b = None

    def _register_1b_hook(self):
        """Register hook for 1b model's 10th layer"""
        def hook(module, input, output):
            self.intermediate_output_1b = output[0]  # Typically, first element is hidden states

        # Assuming transformer layers are in model.model.layers or similar
        # You might need to adjust this path based on your specific model structure
        target_layer = self.model_1b.model.layers[9]  # 0-indexed, so 10th layer is index 9
        self.hook_1b = target_layer.register_forward_hook(hook)

    def _register_3b_hook(self):
        """Register hook for 3b model's 18th layer"""
        def hook(module, input, output):
            self.intermediate_output_3b = output[0]

        # Adjust this path based on your specific model structure
        target_layer = self.model_3b.model.layers[17]  # 0-indexed, so 18th layer is index 17
        self.hook_3b = target_layer.register_forward_hook(hook)

    def extract_intermediate_representations(self, text_chunks, max_length=512):
        """
        Extract intermediate representations for given text chunks

        Args:
            text_chunks (list): List of text chunks to process
            max_length (int): Maximum token length to process

        Returns:
            tuple: (intermediate representations for 1b, intermediate representations for 3b)
        """
        # Reset intermediate outputs
        self.intermediate_output_1b = None
        self.intermediate_output_3b = None

        # Register hooks
        self._register_1b_hook()
        self._register_3b_hook()

        # Prepare to collect representations
        repr_1b_list = []
        repr_3b_list = []

        try:
            for chunk in text_chunks:
                # Tokenize and process 1b model
                inputs_1b = self.tokenizer_1b(
                    chunk,
                    return_tensors='pt',
                    truncation=True,
                    max_length=max_length
                )

                # Tokenize and process 3b model
                inputs_3b = self.tokenizer_3b(
                    chunk,
                    return_tensors='pt',
                    truncation=True,
                    max_length=max_length
                )

                # Forward pass to trigger hooks
                with torch.no_grad():
                    _ = self.model_1b(**inputs_1b)
                    _ = self.model_3b(**inputs_3b)

                # Store intermediate representations
                if self.intermediate_output_1b is not None:
                    repr_1b_list.append(self.intermediate_output_1b.detach())

                if self.intermediate_output_3b is not None:
                    repr_3b_list.append(self.intermediate_output_3b.detach())

        finally:
            # Remove hooks to prevent memory leaks
            self.hook_1b.remove()
            self.hook_3b.remove()

        return repr_1b_list, repr_3b_list

    def align_representations(self, repr_1b, repr_3b):
        """
        Align intermediate representations

        Args:
            repr_1b (list): Intermediate representations from 1b model
            repr_3b (list): Intermediate representations from 3b model

        Returns:
            tuple: Aligned and processed representations
        """
        # Basic alignment strategy:
        # 1. Ensure consistent dimensionality
        # 2. Potential dimensionality reduction
        # 3. Normalize representations

        # Check representation compatibility
        assert len(repr_1b) == len(repr_3b), "Representations must be paired"

        aligned_repr_1b = []
        aligned_repr_3b = []

        for r1, r3 in zip(repr_1b, repr_3b):
            # Example simple alignment (you may need more sophisticated method)
            # Potential strategies:
            # - Linear projection
            # - PCA
            # - Embedding matching

            # Basic normalization
            r1_norm = (r1 - r1.mean()) / r1.std()
            r3_norm = (r3 - r3.mean()) / r3.std()

            aligned_repr_1b.append(r1_norm)
            aligned_repr_3b.append(r3_norm)

        return aligned_repr_1b, aligned_repr_3b

# Example usage
def main():
    # Replace with your actual model paths
    MODEL_1B_PATH = "models/meta-llama/llama-3.2-1B"
    MODEL_3B_PATH = "models/meta-llama/llama-3.2-3B"

    # Text chunks to process
    text_chunks = [
        "This is the first text chunk.",
        "Another interesting piece of text goes here.",
        "And a third chunk for good measure."
    ]

    # Create extractor
    extractor = LlamaIntermediateLayerExtractor(MODEL_1B_PATH, MODEL_3B_PATH)

    # Extract representations
    repr_1b, repr_3b = extractor.extract_intermediate_representations(text_chunks)

    # Align representations
    aligned_1b, aligned_3b = extractor.align_representations(repr_1b, repr_3b)

    # Print some basic information
    print(f"Number of 1B representations: {len(aligned_1b)}")
    print(f"Number of 3B representations: {len(aligned_3b)}")
    print(f"1B representation shape: {aligned_1b[0].shape}")
    print(f"3B representation shape: {aligned_3b[0].shape}")

if __name__ == "__main__":
    main()

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Number of 1B representations: 3
Number of 3B representations: 3
1B representation shape: torch.Size([1, 8, 2048])
3B representation shape: torch.Size([1, 8, 3072])


In [11]:
q = "What is the meaning of life?"
extractor = LlamaIntermediateLayerExtractor("meta-llama/llama-3.2-1B", "meta-llama/llama-3.2-3B")
repr_1b, repr_3b = extractor.extract_intermediate_representations([q])

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [13]:
llama_1b_s = AutoModelForCausalLM.from_pretrained("meta-llama/llama-3.2-1B")
llama_1b_s.model.layers

ModuleList(
  (0-15): 16 x LlamaDecoderLayer(
    (self_attn): LlamaSdpaAttention(
      (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
      (k_proj): Linear(in_features=2048, out_features=512, bias=False)
      (v_proj): Linear(in_features=2048, out_features=512, bias=False)
      (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
      (rotary_emb): LlamaRotaryEmbedding()
    )
    (mlp): LlamaMLP(
      (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
      (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
      (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
      (act_fn): SiLU()
    )
    (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
    (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
  )
)

In [14]:
llama_1b_s.model.layers = llama_1b_s.model.layers[10:]

In [15]:
llama_1b = AutoModelForCausalLM.from_pretrained("meta-llama/llama-3.2-1B")

In [19]:
tokenizer = AutoTokenizer.from_pretrained("meta-llama/llama-3.2-1B")
tokenized = tokenizer.tokenize(q)
out = llama_1b(**tokenizer(q, return_tensors="pt"))

In [22]:
out.logits.shape

torch.Size([1, 8, 128256])

In [29]:
out.logits

tensor([[[ 7.0544,  9.0268, 13.3233,  ..., -3.7595, -3.7596, -3.7597],
         [11.6842,  8.4082,  7.3702,  ..., -0.2775, -0.2776, -0.2776],
         [ 8.2786,  9.0717,  6.5711,  ..., -0.1754, -0.1750, -0.1752],
         ...,
         [ 7.0296, 10.0401,  5.8869,  ..., -0.8870, -0.8870, -0.8872],
         [15.6388, 12.5399, 11.7565,  ...,  0.9495,  0.9491,  0.9493],
         [10.9356,  7.8670,  8.9850,  ..., -0.4730, -0.4730, -0.4729]]],
       grad_fn=<UnsafeViewBackward0>)

In [26]:
inter_out = llama_1b_s(inputs_embeds = repr_1b[0])

In [28]:
inter_out.logits

tensor([[[ 7.0544,  9.0268, 13.3233,  ..., -3.7595, -3.7596, -3.7597],
         [11.6842,  8.4082,  7.3702,  ..., -0.2775, -0.2776, -0.2776],
         [ 8.2786,  9.0717,  6.5711,  ..., -0.1754, -0.1750, -0.1752],
         ...,
         [ 7.0296, 10.0401,  5.8869,  ..., -0.8870, -0.8870, -0.8872],
         [15.6388, 12.5399, 11.7565,  ...,  0.9495,  0.9491,  0.9493],
         [10.9356,  7.8670,  8.9850,  ..., -0.4730, -0.4730, -0.4729]]],
       grad_fn=<UnsafeViewBackward0>)

In [30]:
seq2seqmodel = Seq2SeqModel(3072, 2048, 1)

In [None]:
for sentence in texts:
