In [1]:
import torch
import torch.nn as nn

class Seq2SeqEncoder(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int, num_layers: int):
        """
        Encoder for seq2seq model.
        Args:
            input_dim: Dimensionality of input vectors
            hidden_dim: Hidden state size of the LSTM
            num_layers: Number of LSTM layers
        """
        super().__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)

    def forward(self, x):
        """
        Args:
            x: Input tensor of shape (batch_size, seq_len, input_dim)
        Returns:
            encoder_outputs: Outputs for each time step (batch_size, seq_len, hidden_dim)
            hidden: Tuple of (h_n, c_n) (last hidden and cell states)
        """
        encoder_outputs, hidden = self.lstm(x)
        return encoder_outputs, hidden


class Seq2SeqDecoder(nn.Module):
    def __init__(self, hidden_dim: int, output_dim: int, num_layers: int):
        """
        Decoder for seq2seq model.
        Args:
            hidden_dim: Dimensionality of the encoded hidden state
            output_dim: Dimensionality of output vectors (same as input_dim)
            num_layers: Number of LSTM layers (should match encoder)
        """
        super().__init__()
        self.lstm = nn.LSTM(hidden_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x, hidden):
        """
        Args:
            x: Input tensor to the decoder (batch_size, seq_len, hidden_dim)
            hidden: Tuple of (h_n, c_n) from the encoder
        Returns:
            outputs: Decoded sequence (batch_size, seq_len, output_dim)
        """
        lstm_out, hidden = self.lstm(x, hidden)
        outputs = self.fc(lstm_out)
        return outputs, hidden


class Seq2SeqModel(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int):
        """
        Combines encoder and decoder into a seq2seq model.
        Args:
            input_dim: Dimensionality of input vectors
            hidden_dim: Hidden state size
            num_layers: Number of LSTM layers
        """
        super().__init__()
        self.encoder = Seq2SeqEncoder(input_dim, hidden_dim, num_layers)
        self.decoder = Seq2SeqDecoder(hidden_dim, output_dim, num_layers)

    def forward(self, x):
        """
        Args:
            x: Input sequence of shape (batch_size, seq_len, input_dim)
        Returns:
            output: Reconstructed sequence of shape (batch_size, seq_len, input_dim)
        """
        # Encoder
        encoder_outputs, hidden = self.encoder(x)

        # Decoder (Use encoder outputs as initial input)
        output, _ = self.decoder(encoder_outputs, hidden)
        return output

In [2]:
device = 'mps'

In [3]:
from datasets import load_dataset
def load_data(max_length=512):
    print("Loading dataset from HuggingFace...")
    dataset = load_dataset("wikitext", "wikitext-2-raw-v1")
    train_texts = dataset['train']['text']

    filtered_texts = [text for text in train_texts if text.strip() and len(text.split()) <= max_length]
    print(f"Processed {len(filtered_texts)} samples after filtering")

    return filtered_texts

In [4]:
texts = load_data()

Loading dataset from HuggingFace...
Processed 23758 samples after filtering


In [3]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

class LlamaIntermediateLayerExtractor:
    def __init__(self, model_1b_path, model_3b_path):
        """
        Initialize extractor for Llama 3.2 1b and 3b models

        Args:
            model_1b_path (str): Path or HuggingFace model ID for 1b model
            model_3b_path (str): Path or HuggingFace model ID for 3b model
        """
        # Load models and tokenizers
        self.model_1b = AutoModelForCausalLM.from_pretrained(model_1b_path).to(device)
        self.model_3b = AutoModelForCausalLM.from_pretrained(model_3b_path).to(device)

        self.tokenizer_1b = AutoTokenizer.from_pretrained("tokenizers/meta-llama/llama-3.2-1B")
        self.tokenizer_3b = AutoTokenizer.from_pretrained("tokenizers/meta-llama/llama-3.2-3B")

        # Set models to evaluation mode
        self.model_1b.eval()
        self.model_3b.eval()

        # Hooks to capture intermediate layer outputs
        self.intermediate_output_1b = None
        self.intermediate_output_3b = None

    def _register_1b_hook(self):
        """Register hook for 1b model's 10th layer"""
        def hook(module, input, output):
            self.intermediate_output_1b = output[0]  # Typically, first element is hidden states

        # Assuming transformer layers are in model.model.layers or similar
        # You might need to adjust this path based on your specific model structure
        target_layer = self.model_1b.model.layers[9]  # 0-indexed, so 10th layer is index 9
        self.hook_1b = target_layer.register_forward_hook(hook)

    def _register_3b_hook(self):
        """Register hook for 3b model's 18th layer"""
        def hook(module, input, output):
            self.intermediate_output_3b = output[0]

        # Adjust this path based on your specific model structure
        target_layer = self.model_3b.model.layers[17]  # 0-indexed, so 18th layer is index 17
        self.hook_3b = target_layer.register_forward_hook(hook)

    def extract_intermediate_representations(self, text_chunks, max_length=512):
        """
        Extract intermediate representations for given text chunks

        Args:
            text_chunks (list): List of text chunks to process
            max_length (int): Maximum token length to process

        Returns:
            tuple: (intermediate representations for 1b, intermediate representations for 3b)
        """
        # Reset intermediate outputs
        self.intermediate_output_1b = None
        self.intermediate_output_3b = None

        # Register hooks
        self._register_1b_hook()
        self._register_3b_hook()

        # Prepare to collect representations
        repr_1b_list = []
        repr_3b_list = []

        try:
            for chunk in text_chunks:
                # Tokenize and process 1b model
                inputs_1b = self.tokenizer_1b(
                    chunk,
                    return_tensors='pt',
                    truncation=True,
                    max_length=max_length
                ).to(device)

                # Tokenize and process 3b model
                inputs_3b = self.tokenizer_3b(
                    chunk,
                    return_tensors='pt',
                    truncation=True,
                    max_length=max_length
                ).to(device)

                # Forward pass to trigger hooks
                with torch.no_grad():
                    _ = self.model_1b(**inputs_1b)
                    _ = self.model_3b(**inputs_3b)

                # Store intermediate representations
                if self.intermediate_output_1b is not None:
                    repr_1b_list.append(self.intermediate_output_1b.detach())

                if self.intermediate_output_3b is not None:
                    repr_3b_list.append(self.intermediate_output_3b.detach())

        finally:
            # Remove hooks to prevent memory leaks
            self.hook_1b.remove()
            self.hook_3b.remove()

        return repr_1b_list, repr_3b_list



In [6]:
# MODEL_1B_PATH = "models/meta-llama/llama-3.2-1B"
# MODEL_3B_PATH = "models/meta-llama/llama-3.2-3B"
#
# # Text chunks to process
# text_chunks = [
#     "This is the first text chunk.",
#     "Another interesting piece of text goes here.",
#     "And a third chunk for good measure."
# ]
# extractor = LlamaIntermediateLayerExtractor(MODEL_1B_PATH, MODEL_3B_PATH)
#
# # Extract representations
# repr_1b, repr_3b = extractor.extract_intermediate_representations(text_chunks)
# print(len(repr_1b), len(repr_3b))
# print(repr_1b[0].shape, repr_3b[0].shape)

In [7]:
#print(repr_1b[1].shape, repr_3b[1].shape)

In [8]:
from tqdm import tqdm

In [9]:
import wandb
wandb.init(
    project="seq2seq interior",
    config= {
        "learning_rate": 0.001,
        "architecture": "LSTM",
        "dataset": "wikitext-2",
        "epochs": 1000,
        "loss": "CosineEmbeddingLoss"
    }
)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mskimmer[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [4]:
MODEL_1B_PATH = "models/meta-llama/llama-3.2-1B"
MODEL_3B_PATH = "models/meta-llama/llama-3.2-3B"
extractor = LlamaIntermediateLayerExtractor(MODEL_1B_PATH, MODEL_3B_PATH)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [6]:
input = "What is the meaning of life?"
repr_1b, repr_3b = extractor.extract_intermediate_representations([input])
print(repr_3b[0])
#sum
print(repr_3b[0].sum())

tensor([[[-0.0601, -0.0200,  0.8885,  ...,  0.3583,  0.5194,  0.2792],
         [ 0.2183, -0.0816, -0.4049,  ..., -0.2152,  0.1284,  0.4417],
         [-0.1304, -0.2036, -0.3451,  ..., -0.2903,  0.1702, -0.1489],
         ...,
         [ 0.2089, -0.3997, -0.1467,  ..., -0.1077,  0.0199, -0.1851],
         [-0.1151, -0.2551, -0.1698,  ..., -0.0078, -0.0934, -0.1084],
         [ 0.0509, -0.1615,  0.1318,  ..., -0.1788, -0.2703, -0.4363]]],
       device='mps:0')
tensor(-628.8185, device='mps:0')


In [11]:
# Ensure hooks are removed and tensors are detached

seq2seqmodel = Seq2SeqModel(3072, 2048, 2048, 1).to(device)
#criterion = nn.MSELoss()
criterion = nn.CosineEmbeddingLoss()
optimizer = torch.optim.Adam(seq2seqmodel.parameters(), lr=0.001)

MODEL_1B_PATH = "models/meta-llama/llama-3.2-1B"
MODEL_3B_PATH = "models/meta-llama/llama-3.2-3B"
extractor = LlamaIntermediateLayerExtractor(MODEL_1B_PATH, MODEL_3B_PATH)
average_5_loss = 0
for i in tqdm(range(1000)):

    text = [texts[i]]

    # Extract representations with no_grad to save memory
    with torch.no_grad():
        repr_1b, repr_3b = extractor.extract_intermediate_representations(text)

    seq2seqmodel.train()
    optimizer.zero_grad()
    # Detach tensors to free up memory
    repr_1b = [r.detach() for r in repr_1b]
    repr_3b = [r.detach() for r in repr_3b]

    out = seq2seqmodel(repr_3b[0])
    loss = criterion(out.squeeze(), repr_1b[0].squeeze(), torch.ones(1).to(device))
    loss.backward()
    optimizer.step()
    average_5_loss += loss.item()
    if i % 5 == 0:
        wandb.log({"loss": average_5_loss / 5})
        average_5_loss = 0

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

100%|██████████| 1000/1000 [14:52<00:00,  1.12it/s]


In [11]:
seq2seqmodel

Seq2SeqModel(
  (encoder): Seq2SeqEncoder(
    (lstm): LSTM(3072, 2048, batch_first=True)
  )
  (decoder): Seq2SeqDecoder(
    (lstm): LSTM(2048, 2048, batch_first=True)
    (fc): Linear(in_features=2048, out_features=2048, bias=True)
  )
)

In [12]:
# save the model
torch.save(seq2seqmodel, "seq2seqmodel_1000_cosine.pth")

In [13]:
print(texts[722])

 In the Guerrero district of Chihuahua , Pascual Orozco attacked Federal troops and sent dead soldiers ' clothing back to Díaz with the message , " Ahí te van las hojas , mándame más tamales " ( " Here are the wrappers , send me more tamales . " ) He then began operations which threatened Ciudad Juárez . Additionally , political support for Madero 's rebellion came from Abraham González , who accepted the Plan of San Luis Potosí . 



In [14]:
print(texts[967])


 = = Biography = = 



In [15]:
print(texts[914])


 Since no inscriptions on any of the island have been discovered , the ancient history of the island is conjectural , at best . Pandavas , the heroes of the Hindu epic Mahabharata , and Banasura , the demon devotee of Shiva , are both credited with building temples or cut caves to live . Local tradition holds that the caves are not man @-@ made . 



In [16]:
print(texts[915])


 The Elephanta caves are " of unknown date and attribution " . Art historians have dated the caves in the range of late 5th to late 8th century AD . Archaeological excavations have unearthed a few Kshatrapa coins dated to 4th century AD . The known history is traced only to the defeat of Mauryan rulers of Konkan by the Badami Chalukyas emperor Pulakesi II ( 609 – 642 ) in a naval battle , in 635 AD . Elephanta was then called Puri or Purika , and served as the capital of the Konkan Mauryas . Some historians attribute the caves to the Konkan Mauryas , dating them to the mid @-@ 6th century , though others refute this claim saying a relatively small kingdom like the Konkan Mauryas could not undertake " an almost superhuman excavation effort , " which was needed to carve the rock temples from solid rock and could not have the skilled labor to produce such " high quality " sculpture . 



In [17]:
print(texts[916])


 Some other historians attribute the construction to the Kalachuris ( late 5th to 6th century ) , who may have had a feudal relationship with the Konkan Mauryas . In an era where polytheism was prevalent , the Elephanta main cave dedicates the monotheism of the Pashupata Shaivism sect , a sect to which Kalachuris as well as Konkan Mauryas belonged . 



In [18]:
print(texts[917])


 The Chalukyas , who defeated the Kalachuris as well as the Konkan Mauryas , are also believed by some to be creators of the main cave , in the mid @-@ 7th century . The Rashtrakutas are the last claimants to the creation of the main cave , approximated to the early 7th to late 8th century . The Elephanta Shiva cave resembles in some aspects the 8th @-@ century Rashtrakuta rock @-@ temple Kailash at Ellora . The Trimurti of Elephanta showing the three faces of Shiva is akin to the Trimurti of Brahma , Vishnu and Mahesh ( Shiva ) , which was the royal insignia of the Rashtrakutas . The Nataraja and Ardhanarishvara sculptures are also attributed to the Rashtrakutas . 

