In [None]:
import torch
import pandas as pd
from tqdm import tqdm
from transformers import (
    AutoTokenizer,
    BitsAndBytesConfig,
    MixtralModel,
    MixtralConfig,
    AutoModelForCausalLM
)
from torch import nn
import collections
import time
import os
from einops import rearrange
import torch.nn.functional as F

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1")

In [None]:
layer = 15

W2 = 0
GATE = 1

sequence_length = 128
batch_size = 6
num_experts_per_token = 2
num_experts = 8
expert_ffn_dims = 14336
expert_hidden_dims = 4096

In [None]:
double_quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
)

config = MixtralConfig(
    num_experts_per_tok=num_experts_per_token,
    num_hidden_layers=layer + 1,
)

model = AutoModelForCausalLM.from_pretrained(
    "mistralai/Mixtral-8x7B-v0.1",
    quantization_config=double_quant_config,
    attn_implementation="flash_attention_2",
    config=config,
)

In [None]:
def load_generated_dataset(
    filename, batch_size, dataset_relative_path="../dataset"
):
    dataset = pd.read_csv(f"{dataset_relative_path}/{filename}")
    # print(dataset.head(10))
    dataset_list = dataset["0"].tolist()
    dataloader = torch.utils.data.DataLoader(dataset_list, batch_size=batch_size)
    return dataloader


dataset = load_generated_dataset("pile.csv", batch_size=batch_size)

In [None]:
import torch
from torch import nn
from tqdm import tqdm

class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        # Define your autoencoder structure here
    def forward(self, x):
        # Define forwarding logic here
        return x

# Initialize the autoencoder and move it to the appropriate device
autoencoder = Autoencoder().to(device)
# Assuming you have an optimizer and loss function for the autoencoder
# optimizer = torch.optim.Adam(autoencoder.parameters(), lr=1e-3)
# criterion = nn.MSELoss()

i = 0
moe = model.model.layers[layer].block_sparse_moe
hooks = []

for batch in tqdm(dataset):
    if i > 0:
        break
    i += 1

    mlps = [None] * num_experts
    router_logits = None

    def getActivation(expert_idx, type):
        def hook(model, input, output):
            global router_logits
            if type == W2:
                mlps[expert_idx] = output.detach()
            elif type == GATE:
                router_logits = output.detach()
        return hook

    for expert_idx in range(num_experts):
        w2_hook = moe.experts[expert_idx].w2.register_forward_hook(getActivation(expert_idx, W2))
        gate_hook = moe.gate.register_forward_hook(getActivation(expert_idx, GATE))
        hooks.extend([w2_hook, gate_hook])

    tokenizer.pad_token = tokenizer.eos_token
    batch_tokens = tokenizer(batch, padding="max_length", truncation=True, max_length=sequence_length, return_tensors="pt").to(device)

    try:
        output = model(batch_tokens["input_ids"])
    except Exception as e:
        print("Exception occurred, removing hooks")
        for hook in hooks:
            hook.remove()

    routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
    routing_weights, selected_experts = torch.topk(routing_weights, num_experts_per_token, dim=-1)
    routing_weights /= routing_weights.sum(dim=-1, keepdim=True)

    activations = torch.zeros(num_experts, batch_size * sequence_length, expert_hidden_dims, device=device)

    for expert_idx in range(num_experts):
        idx, top_x = torch.where(expert_mask[expert_idx])
        if top_x.shape[0] == 0:
            continue
        top_x_list = top_x.tolist()
        idx_list = idx.tolist()
        current_hidden_states = mlps[expert_idx] * routing_weights[top_x_list, idx_list, None]
        activations[expert_idx][top_x] = current_hidden_states

    activations = rearrange(activations, 'experts sequences hidden -> sequences (experts hidden)')
    print(activations.shape)

    # Pass activations through the autoencoder
    autoencoder.train()
    optimizer.zero_grad()
    reconstructed_activations = autoencoder(activations)
    loss = criterion(reconstructed_activations, activations)
    loss.backward()
    optimizer.step()

    print(f"Autoencoder loss: {loss.item()}")

    for hook in hooks:
        hook.remove()