In [None]:
"""
Imports
"""
import torch
from utils.memory import check_memory, profile_memory, clear_all_cuda_memory
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
import pandas as pd
import numpy as np
from scipy import stats
import plotly
import plotly.express as px 
import plotly.graph_objects as go
from tqdm import tqdm
from utils.store_topk import convert_topk_to_df

main_device = 'cuda:0'
seed = 1234
clear_all_cuda_memory()
check_memory()

## Load base model

In [None]:
"""
Load the base model
"""
hf_model_id = 'allenai/OLMoE-1B-7B-0125-Instruct'

tokenizer = AutoTokenizer.from_pretrained(hf_model_id, add_eos_token = False, add_bos_token = False, padding_side = 'left')
model = AutoModelForCausalLM.from_pretrained(hf_model_id, torch_dtype = torch.bfloat16, trust_remote_code = True).cuda().eval()

## Get dataset

In [None]:
ds = load_dataset("allenai/c4", 'en', split = 'validation', streaming = True).shuffle(seed = 42, buffer_size = 10_000_000)
ds_iter = iter(ds)

c4_raw = []
for _ in range(0, 1_000_000):
    sample = next(ds_iter, None)
    if sample is None:
        break
    c4_raw.append(sample['text'])

In [None]:
c4_raw

In [None]:
from torch.utils.data import Dataset, DataLoader

class TextDataset(Dataset):
    def __init__(self, tokenizer_output):
        self.input_ids = tokenizer_output['input_ids']
        self.attention_mask = tokenizer_output['attention_mask']
        self.tokens = tokenizer.batch_decode(self.input_ids)

    def __len__(self):
        return len(self.input_ids)
    
    def __getitem__(self, idx):
        return {
            'tokens': self.tokens[idx],
            'input_ids': self.input_ids[idx],
            'attention_mask': self.attention_mask[idx]
        }

res = tokenizer(c4_raw, add_special_tokens = False, max_length = 512, padding = 'max_length', truncation = True, return_tensors = 'pt')
c4_dl = DataLoader(TextDataset(res), batch_size = 8, shuffle = False)

In [None]:
next(iter(c4_dl))

In [None]:
@torch.no_grad()
def run_model_return_topk(input_ids, attention_mask):
    input_embeds = model.model.embed_tokens(input_ids)
    
    cache_position = torch.arange(0, input_embeds.shape[1], device = input_embeds.device)
    position_ids = cache_position.unsqueeze(0)
    causal_mask = model.model._update_causal_mask(attention_mask, input_embeds, cache_position, None, None)

    hidden_state = input_embeds
    position_embeddings = model.model.rotary_emb(hidden_state, position_ids)

    all_topk_experts = []
    all_topk_weights = []
    for layer in model.model.layers:
        # SA
        residual = hidden_state
        hidden_state = layer.input_layernorm(hidden_state)
        hidden_state, _, _ = layer.self_attn(hidden_states = hidden_state, attention_mask = causal_mask, position_ids = position_ids, position_embeddings = position_embeddings)
        hidden_state = residual + hidden_state
        residual = hidden_state
        hidden_state = layer.post_attention_layernorm(hidden_state)

        # MoE
        ####### OlMoESparseMoeBlock - below code replaces hidden_state = layer.mlp(hidden_state)
        batch_size, sequence_length, hidden_dim = hidden_state.shape
        moe_hidden_state = hidden_state.view(-1, hidden_dim)
        # router_logits: (batch * sequence_length, n_experts)
        router_logits = layer.mlp.gate(moe_hidden_state)

        routing_weights = torch.nn.functional.softmax(router_logits, dim=1, dtype=torch.float)
        routing_weights, selected_experts = torch.topk(routing_weights, layer.mlp.top_k, dim=-1, sorted = True)
        routing_weights = routing_weights.to(moe_hidden_state.dtype)
        final_hidden_states = torch.zeros(
            (batch_size * sequence_length, hidden_dim), dtype=hidden_state.dtype, device=hidden_state.device
        )
        expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes = layer.mlp.num_experts).permute(2, 1, 0)

        # Loop over all available experts in the model and perform the computation on each expert
        for expert_idx in range(layer.mlp.num_experts):
            expert_layer = layer.mlp.experts[expert_idx]
            idx, top_x = torch.where(expert_mask[expert_idx])

            current_state = moe_hidden_state[None, top_x].reshape(-1, hidden_dim)
            current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]

            final_hidden_states.index_add_(0, top_x, current_hidden_states.to(moe_hidden_state.dtype))

        final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
        #######

        hidden_state = final_hidden_states
        hidden_state = residual + hidden_state

        all_topk_experts.append(selected_experts.detach().cpu())
        all_topk_weights.append(routing_weights.detach().cpu().to(torch.float32))

    hidden_state = model.model.norm(hidden_state)
    logits = model.lm_head(hidden_state)
    return {'logits': logits, 'all_topk_experts': all_topk_experts, 'all_topk_weights': all_topk_weights}


topk_dfs = []

b_count = 0

for batch_ix, batch in tqdm(enumerate(c4_dl)):

    input_ids = batch['input_ids'].to(main_device)
    attention_mask = batch['attention_mask'].to(main_device)
    
    output = run_model_return_topk(input_ids, attention_mask)

    topk_df = convert_topk_to_df(input_ids, output['all_topk_experts'], output['all_topk_weights'])
    topk_df =\
        topk_df[topk_df['token_id'] != tokenizer.pad_token_id]\
        .assign(weight = lambda df: df['weight'].round(3))

    topk_dfs.append(topk_df)

    b_count += 1
    if b_count >= 200:
        break


In [None]:
pd.concat(topk_dfs).to_csv('olmoe_clustering.csv', index = False)

In [None]:
from utils.vocab import export_vocab_as_csv

export_vocab_as_csv(tokenizer, 'olmoe_vocab.csv')

In [None]:
z

In [None]:
chunk = []
for _ in tqdm(range(0, 1)):
    sample = next(ds_iter)
    chunk.append(sample['text'])

chunk