In [None]:
"""
This contains code to test orthogonality of expert specialization.
"""
None

In [None]:
"""
Imports
"""
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import pandas as pd
import numpy as np
import scipy
import cupy
import cuml

import importlib
import gc
import pickle

from tqdm import tqdm
from termcolor import colored
import plotly.express as px

from utils.memory import check_memory, clear_all_cuda_memory
from utils.quantize import compare_bf16_fp16_batched

main_device = 'cuda:0'
seed = 1234
clear_all_cuda_memory()
check_memory()

## Load base model

In [None]:
"""
Load the base tokenizer/model
"""
model_id = 'allenai/OLMoE-1B-7B-0125-Instruct'
model_prefix = 'olmoe'
tokenizer = AutoTokenizer.from_pretrained(model_id, add_eos_token = False, add_bos_token = False, padding_side = 'left', trust_remote_code = True)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype = torch.bfloat16, trust_remote_code = True).cuda().eval()

In [None]:
"""
Load dataset
"""
def load_data(model_prefix):
    all_pre_mlp_hs = torch.load(f'data/{model_prefix}/all-pre-mlp-hidden-states.pt')
    all_expert_outputs = torch.load(f'data/{model_prefix}/all-expert-outputs.pt')
    with open(f'data/{model_prefix}/metadata.pkl', 'rb') as f:
        metadata = pickle.load(f)
    
    return all_pre_mlp_hs, all_expert_outputs, metadata['sample_df'], metadata['topk_df'], metadata['all_pre_mlp_hidden_states_layers'], metadata['all_expert_outputs_layers']

all_pre_mlp_hs_import, all_expert_outputs_import, sample_df_import, topk_df_import, act_map, expert_map = load_data(model_prefix)

In [None]:
"""
Let's clean up the mappings here. We'll get everything to a sample_ix level first.
"""
sample_df_raw =\
    sample_df_import\
    .assign(sample_ix = lambda df: df.groupby(['batch_ix', 'sequence_ix', 'token_ix']).ngroup())\
    .assign(seq_id = lambda df: df.groupby(['batch_ix', 'sequence_ix']).ngroup())\
    .reset_index()

topk_df =\
    topk_df_import\
    .merge(sample_df_raw[['sample_ix', 'batch_ix', 'sequence_ix', 'token_ix']], how = 'inner', on = ['sequence_ix', 'token_ix', 'batch_ix'])\
    .drop(columns = ['sequence_ix', 'token_ix', 'batch_ix'])

sample_df =\
    sample_df_raw\
    .drop(columns = ['batch_ix', 'sequence_ix'])

del sample_df_import, topk_df_import
display(topk_df)
display(sample_df)

In [None]:
"""
Convert activations to fp16 (for compatability) + dict
"""
all_pre_mlp_hs = all_pre_mlp_hs_import.to(torch.float16)
compare_bf16_fp16_batched(all_pre_mlp_hs_import, all_pre_mlp_hs)
del all_pre_mlp_hs_import
all_pre_mlp_hs = {layer_ix: all_pre_mlp_hs[:, save_ix, :] for save_ix, layer_ix in enumerate(act_map)}

all_expert_outputs = all_expert_outputs_import.to(torch.float16)
compare_bf16_fp16_batched(all_expert_outputs_import, all_expert_outputs)
del all_expert_outputs_import
all_expert_outputs = {layer_ix: all_expert_outputs[:, save_ix, :, :] for save_ix, layer_ix in enumerate(expert_map)}

gc.collect()

## Compare activation clusters of single expert

In [None]:
"""
Visualize clusters
"""

layer_ix = 10
expert_id = 24

topk_df\
    .pipe(lambda df: df[(df['layer_ix'] == layer_ix) & (df[df['expert_id'] == expert_id])])

## Compare outputs of single expert