In [1]:
from datasets import load_from_disk
import torch
import torch.nn as nn

In [8]:
import torch

In [7]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3.5-MoE-instruct", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("microsoft/Phi-3.5-MoE-instruct", trust_remote_code=True)

Encountered exception while importing flash_attn: No module named 'flash_attn'


ImportError: This modeling file requires the following packages that were not found in your environment: flash_attn. Run `pip install flash_attn`

In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import os

llama_1b = AutoModelForCausalLM.from_pretrained("models/meta-llama/llama-3.2-1B")
llama_3b = AutoModelForCausalLM.from_pretrained("models/meta-llama/llama-3.2-3B")
llama_tokenizer = AutoTokenizer.from_pretrained("tokenizers/meta-llama/llama-3.2-1B")

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [3]:
wikitext2 = load_from_disk('data/wikitext-2')

In [4]:
llama_1b_truncated = AutoModelForCausalLM.from_pretrained("models/meta-llama/llama-3.2-1B")
llama_1b_truncated.model.layers = llama_1b_truncated.model.layers[10:]

In [5]:

class Encoder(nn.Module):
    def __init__(self, input_dim: int, hidden_dims: list[int], latent_dim: int):
        """
        Args:
            input_dim: Dimension of input vector
            hidden_dims: List of hidden layer dimensions
            latent_dim: Dimension of latent space (encoded vector)
        """
        super().__init__()

        # Build encoder layers
        layers = []
        prev_dim = input_dim

        # Add hidden layers
        for hidden_dim in hidden_dims:
            layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                nn.ReLU(),
                nn.BatchNorm1d(hidden_dim)
            ])
            prev_dim = hidden_dim

        # Add final layer to latent dimension
        layers.append(nn.Linear(prev_dim, latent_dim))

        self.encoder = nn.Sequential(*layers)

    def forward(self, x):
        return self.encoder(x)


class Decoder(nn.Module):
    def __init__(self, latent_dim: int, hidden_dims: list[int], output_dim: int):
        """
        Args:
            latent_dim: Dimension of latent space (encoded vector)
            hidden_dims: List of hidden layer dimensions (in reverse order of encoder)
            output_dim: Dimension of output vector
        """
        super().__init__()

        # Build decoder layers
        layers = []
        prev_dim = latent_dim

        # Add hidden layers
        for hidden_dim in hidden_dims:
            layers.extend([
                nn.Linear(prev_dim, hidden_dim),
                nn.ReLU(),
                nn.BatchNorm1d(hidden_dim)
            ])
            prev_dim = hidden_dim

        # Add final layer to output dimension
        layers.append(nn.Linear(prev_dim, output_dim))

        self.decoder = nn.Sequential(*layers)

    def forward(self, x):
        return self.decoder(x)

class VectorEncoderDecoder(nn.Module):
    def __init__(self, input_dim: int, hidden_dims: list[int], latent_dim: int, output_dim: int):
        """
        Args:
            input_dim: Dimension of input vector
            hidden_dims: List of hidden layer dimensions
            latent_dim: Dimension of latent space (encoded vector)
            output_dim: Dimension of output vector
        """
        super().__init__()

        self.encoder = Encoder(input_dim, hidden_dims, latent_dim)
        self.decoder = Decoder(latent_dim, hidden_dims[::-1], output_dim)

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

In [6]:
class inner_translation_model(nn.Module):
    def __init__(self, src_model, tgt_model, translation_model, tgt_layer, src_layer):
        super().__init__()
        tgt_model.model.layers = tgt_model.model.layers[tgt_layer:]

        self.src_model = src_model
        self.tgt_model = tgt_model
        self.translation_model = translation_model
        self.src_layer = src_layer
        # freeze source, target models
        for param in self.src_model.parameters():
            param.requires_grad = False
        for param in self.tgt_model.parameters():
            param.requires_grad = False

    def forward(self, x):
        model_outs = self.src_model(x, output_hidden_states=True)
        model_outs_src_layer = model_outs.hidden_states[self.src_layer]
        model_outs_src_layer = model_outs_src_layer.squeeze()
        #model_outs_src_layer = model_outs_src_layer.permute(1,0,2)
        print(model_outs_src_layer.shape)
        translation_outs = self.translation_model(model_outs_src_layer)
        translation_outs = translation_outs.unsqueeze(0)
        logits = self.tgt_model(inputs_embeds = translation_outs, use_cache = False).logits
        return translation_outs, logits





In [7]:
ved = VectorEncoderDecoder(3072, [1024, 512], 256, 2048)
itm = inner_translation_model(llama_3b, llama_1b, ved, 10, 18)

In [8]:
import pickle
with open ('data/chunked_wikitext2/train.pkl', 'rb') as f:
    train_data = pickle.load(f)


In [9]:
from transformers import AutoTokenizer
llama_tokenizer = AutoTokenizer.from_pretrained("tokenizers/meta-llama/llama-3.2-1B")
input = llama_tokenizer(train_data[0], return_tensors="pt")

In [10]:
print(input["input_ids"].shape)

torch.Size([1, 4446])


In [11]:
out = itm(input["input_ids"])

torch.Size([4446, 3072])


In [13]:
out[0].shape

torch.Size([1, 4446, 2048])

In [26]:
def loss_fn(intermediate_pred, intermediate_tgt, logits_pred, logits_tgt):
    cosine_loss = nn.CosineEmbeddingLoss()
    cosine_loss_out = cosine_loss(
        intermediate_pred.squeeze(),
        intermediate_tgt.squeeze(),
        torch.ones(intermediate_pred.shape[1]).to(device)
    )

    kl_div = nn.KLDivLoss()
    kl_div_out = kl_div(
        logits_pred.log_softmax(dim=-1),
        logits_tgt.softmax(dim=-1)
    )
    return cosine_loss_out + kl_div_out



In [15]:
import torch.optim as optim
optimizer = optim.Adam(itm.parameters(), lr=0.001)
device = "mps"
epochs = 2

In [16]:
llama_1b = AutoModelForCausalLM.from_pretrained("models/meta-llama/llama-3.2-1B")
llama_1b.to(device)
itm.to(device)

inner_translation_model(
  (src_model): LlamaForCausalLM(
    (model): LlamaModel(
      (embed_tokens): Embedding(128256, 3072)
      (layers): ModuleList(
        (0-27): 28 x LlamaDecoderLayer(
          (self_attn): LlamaSdpaAttention(
            (q_proj): Linear(in_features=3072, out_features=3072, bias=False)
            (k_proj): Linear(in_features=3072, out_features=1024, bias=False)
            (v_proj): Linear(in_features=3072, out_features=1024, bias=False)
            (o_proj): Linear(in_features=3072, out_features=3072, bias=False)
            (rotary_emb): LlamaRotaryEmbedding()
          )
          (mlp): LlamaMLP(
            (gate_proj): Linear(in_features=3072, out_features=8192, bias=False)
            (up_proj): Linear(in_features=3072, out_features=8192, bias=False)
            (down_proj): Linear(in_features=8192, out_features=3072, bias=False)
            (act_fn): SiLU()
          )
          (input_layernorm): LlamaRMSNorm((3072,), eps=1e-05)
          (post_

In [30]:
max_len = 1000
loss_sum = 0
for epoch in range(epochs):
    for i, text in enumerate(train_data):
        input = llama_tokenizer(text, max_length=max_len, return_tensors="pt", truncation=True)

        #input = llama_tokenizer(text, return_tensors="pt")
        input.to(device)
        optimizer.zero_grad()
        intermediate_pred, logits_pred = itm(input["input_ids"])
        true_out = llama_1b(input["input_ids"], output_hidden_states=True)
        intermediate_tgt = true_out.hidden_states[10]
        logits_tgt = true_out.logits
        loss = loss_fn(intermediate_pred, intermediate_tgt, logits_pred, logits_tgt)
        loss_sum += loss.item()

        loss.backward()
        optimizer.step()
        if i % 10 == 0:
            print(f"Epoch: {epoch}, Loss: {loss_sum / 10}")
            loss_sum = 0

torch.Size([1000, 3072])
Epoch: 0, Loss: 0.0358019083738327
torch.Size([1000, 3072])
torch.Size([1000, 3072])
torch.Size([833, 3072])
torch.Size([1000, 3072])
torch.Size([1000, 3072])
torch.Size([13, 3072])
torch.Size([13, 3072])
torch.Size([33, 3072])
torch.Size([18, 3072])
torch.Size([12, 3072])
Epoch: 0, Loss: 0.3406906008720398
torch.Size([27, 3072])
torch.Size([573, 3072])
torch.Size([1000, 3072])
torch.Size([1000, 3072])
torch.Size([782, 3072])
torch.Size([1000, 3072])
torch.Size([1000, 3072])
torch.Size([1000, 3072])
torch.Size([1000, 3072])
torch.Size([1000, 3072])
Epoch: 0, Loss: 0.3396680772304535
torch.Size([1000, 3072])
torch.Size([1000, 3072])
torch.Size([1000, 3072])
torch.Size([1000, 3072])
torch.Size([1000, 3072])
torch.Size([1000, 3072])
torch.Size([1000, 3072])
torch.Size([1000, 3072])
torch.Size([1000, 3072])
torch.Size([1000, 3072])
Epoch: 0, Loss: 0.33113611936569215
torch.Size([1000, 3072])
torch.Size([1000, 3072])
torch.Size([1000, 3072])
torch.Size([1000, 3072])

RuntimeError: MPS backend out of memory (MPS allocated: 36.92 GB, other allocations: 23.83 GB, max allowed: 61.20 GB). Tried to allocate 489.26 MB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).

In [31]:
translator = itm.translation_model
torch.save(translator, 'models/translator_later.pth')