In [1]:
import torch
import torch.nn as nn
from torch.onnx.symbolic_opset9 import cosine_similarity

In [2]:
class TransformerEmbedder(nn.Module):
    def __init__(self, input_dim, output_dim, layers):
        super(TransformerEmbedder, self).__init__()
        self.encoder_layer = nn.TransformerEncoderLayer(d_model=input_dim, nhead=8, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=layers)
        self.activation = nn.GELU()
        self.fc = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        mask = nn.Transformer().generate_square_subsequent_mask(x.shape[1])
        #print(mask.shape)
        x = self.transformer_encoder(x, mask=mask, is_causal=True)
        x = self.activation(x)
        x = self.fc(x)

        return x

In [3]:
device = 'mps'

In [4]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

class LlamaIntermediateLayerExtractor:
    def __init__(self, model_1b_path, model_3b_path):
        """
        Initialize extractor for Llama 3.2 1b and 3b models

        Args:
            model_1b_path (str): Path or HuggingFace model ID for 1b model
            model_3b_path (str): Path or HuggingFace model ID for 3b model
        """
        # Load models and tokenizers
        self.model_1b = AutoModelForCausalLM.from_pretrained(model_1b_path).to(device)
        self.model_3b = AutoModelForCausalLM.from_pretrained(model_3b_path).to(device)

        self.tokenizer_1b = AutoTokenizer.from_pretrained("tokenizers/meta-llama/llama-3.2-1B")
        self.tokenizer_3b = AutoTokenizer.from_pretrained("tokenizers/meta-llama/llama-3.2-3B")

        # Set models to evaluation mode
        self.model_1b.eval()
        self.model_3b.eval()

        # Hooks to capture intermediate layer outputs
        self.intermediate_output_1b = None
        self.intermediate_output_3b = None

    def _register_1b_hook(self):
        """Register hook for 1b model's 10th layer"""
        def hook(module, input, output):
            self.intermediate_output_1b = output[0]  # Typically, first element is hidden states

        # Assuming transformer layers are in model.model.layers or similar
        # You might need to adjust this path based on your specific model structure
        target_layer = self.model_1b.model.layers[9]  # 0-indexed, so 10th layer is index 9
        self.hook_1b = target_layer.register_forward_hook(hook)

    def _register_3b_hook(self):
        """Register hook for 3b model's 18th layer"""
        def hook(module, input, output):
            self.intermediate_output_3b = output[0]

        # Adjust this path based on your specific model structure
        target_layer = self.model_3b.model.layers[17]  # 0-indexed, so 18th layer is index 17
        self.hook_3b = target_layer.register_forward_hook(hook)

    def extract_intermediate_representations(self, text_chunks, max_length=512):
        """
        Extract intermediate representations for given text chunks

        Args:
            text_chunks (list): List of text chunks to process
            max_length (int): Maximum token length to process

        Returns:
            tuple: (intermediate representations for 1b, intermediate representations for 3b)
        """
        # Reset intermediate outputs
        self.intermediate_output_1b = None
        self.intermediate_output_3b = None

        # Register hooks
        self._register_1b_hook()
        self._register_3b_hook()

        # Prepare to collect representations
        repr_1b_list = []
        repr_3b_list = []

        try:
            for chunk in text_chunks:
                # Tokenize and process 1b model
                inputs_1b = self.tokenizer_1b(
                    chunk,
                    return_tensors='pt',
                    truncation=True,
                    max_length=max_length
                ).to(device)

                # Tokenize and process 3b model
                inputs_3b = self.tokenizer_3b(
                    chunk,
                    return_tensors='pt',
                    truncation=True,
                    max_length=max_length
                ).to(device)

                # Forward pass to trigger hooks
                with torch.no_grad():
                    _ = self.model_1b(**inputs_1b)
                    _ = self.model_3b(**inputs_3b)

                # Store intermediate representations
                if self.intermediate_output_1b is not None:
                    repr_1b_list.append(self.intermediate_output_1b.detach())

                if self.intermediate_output_3b is not None:
                    repr_3b_list.append(self.intermediate_output_3b.detach())

        finally:
            # Remove hooks to prevent memory leaks
            self.hook_1b.remove()
            self.hook_3b.remove()

        return repr_1b_list, repr_3b_list



In [5]:
MODEL_1B_PATH = "models/meta-llama/llama-3.2-1B"
MODEL_3B_PATH = "models/meta-llama/llama-3.2-3B"
extractor = LlamaIntermediateLayerExtractor(MODEL_1B_PATH, MODEL_3B_PATH)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [6]:
import wandb
from tqdm import tqdm
wandb.init(
    project="seq2seq interior",
    config= {
        "learning_rate": 0.001,
        "architecture": "transformer",
        "dataset": "wikitext-2",
        "epochs": 1000,
        "loss": "cosine"
    }
)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mskimmer[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [7]:
from datasets import load_dataset
def load_data(max_length=512):
    print("Loading dataset from HuggingFace...")
    dataset = load_dataset("wikitext", "wikitext-2-raw-v1")
    train_texts = dataset['train']['text']

    filtered_texts = [text for text in train_texts if text.strip() and len(text.split()) <= max_length]
    print(f"Processed {len(filtered_texts)} samples after filtering")

    return filtered_texts
texts = load_data()

Loading dataset from HuggingFace...
Processed 23758 samples after filtering


In [8]:
import wandb
from tqdm import tqdm
wandb.init(
    project="seq2seq interior",
    config= {
        "learning_rate": 0.001,
        "architecture": "transformer",
        "dataset": "wikitext-2",
        "chunks": 1000,
        "loss": "MSE"
    }
)

In [10]:
embedder = TransformerEmbedder(3072, 2048, 2).to(device)
criterion = nn.MSELoss()
# criterion = nn.CosineEmbeddingLoss()
optimizer = torch.optim.Adam(embedder.parameters(), lr=0.001)
embedder.train()
losses = []

for text in tqdm(texts[:1000]):
    repr_1b, repr_3b = extractor.extract_intermediate_representations([text])
    repr_1b = repr_1b[0].to(device)
    repr_3b = repr_3b[0].to(device)

    optimizer.zero_grad()
    output = embedder(repr_3b)
    loss = criterion(output.squeeze(), repr_1b.squeeze())
    loss.backward()
    optimizer.step()

    wandb.log({"loss": loss.item()})
    losses.append(loss.item())

100%|██████████| 1000/1000 [11:07<00:00,  1.50it/s]


In [11]:
torch.save(embedder, "transformer_embedder_1000.pth")

In [10]:
embedder.eval()
repr_1b, repr_3b = extractor.extract_intermediate_representations([texts[999]])
repr_1b = repr_1b[0].to(device)
repr_3b = repr_3b[0].to(device)
output = embedder(repr_3b)
print(output.shape)

torch.Size([1, 22, 2048])


In [11]:
print(repr_1b.shape)

torch.Size([1, 22, 2048])


In [12]:
print(output)

tensor([[[-1.0502e+00, -2.4207e-01,  2.2764e+00,  ...,  6.0472e-01,
           1.2637e+00,  8.9992e-01],
         [-4.2429e-02, -1.3500e-02, -1.2751e-01,  ..., -4.1418e-02,
          -5.6093e-02, -2.6518e-04],
         [-4.2445e-02, -1.3496e-02, -1.2750e-01,  ..., -4.1411e-02,
          -5.6087e-02, -2.5539e-04],
         ...,
         [-4.2491e-02, -1.3516e-02, -1.2748e-01,  ..., -4.1416e-02,
          -5.6071e-02, -2.3512e-04],
         [-4.2491e-02, -1.3516e-02, -1.2748e-01,  ..., -4.1416e-02,
          -5.6071e-02, -2.3514e-04],
         [-4.2491e-02, -1.3516e-02, -1.2748e-01,  ..., -4.1416e-02,
          -5.6071e-02, -2.3513e-04]]], device='mps:0',
       grad_fn=<LinearBackward0>)


In [13]:
print(repr_1b)

tensor([[[ 0.1367, -0.1160,  0.9267,  ...,  0.4339,  0.7350,  0.5164],
         [-0.1310, -0.1256, -0.2124,  ...,  0.1703, -0.1001,  0.0895],
         [-0.2141, -0.0653, -0.1952,  ..., -0.0515, -0.0457,  0.0152],
         ...,
         [-0.1481, -0.0385, -0.1986,  ..., -0.0626, -0.0716, -0.0598],
         [ 0.0657, -0.0803, -0.0886,  ..., -0.0153, -0.2196, -0.1318],
         [ 0.0767, -0.1609, -0.0287,  ...,  0.0110, -0.0201,  0.0609]]],
       device='mps:0')


In [14]:
print(criterion(output, repr_1b))

tensor(0.0655, device='mps:0', grad_fn=<MseLossBackward0>)


In [15]:
wandb.finish()

0,1
loss,▂▁▂▁▁▁▁▁▁▁▁█▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
loss,0.0772
