In [None]:
"""
This is similar to `./../path-analysis/store-pretrained-model-paths.ipynb`, but also exports hidden states. 
- Use for clustering analysis to compare clusters of hidden states versus expert IDs.
- Compares dense models to MoE clusters.
- Due to the huge dataset cost of storing activations, it's better to simply create them at runtime.
"""
None

In [None]:
"""
Imports
"""
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.loss.loss_utils import ForCausalLMLoss # Cross-entropy loss that handles label shifting
from datasets import load_dataset
import pandas as pd
import numpy as np
from tqdm import tqdm
from termcolor import colored
import importlib 

from utils.memory import check_memory, clear_all_cuda_memory
from utils.store_topk import convert_topk_to_df
from utils.store_outputs import convert_outputs_to_df
from utils import pretrained_models

# https://docs.rapids.ai/install/
import umap
import cupy
import cudf
import cuml

import plotly.express as px

main_device = 'cuda:0'
seed = 1234
clear_all_cuda_memory()
check_memory()

## Load base model

In [None]:
"""
Load the base tokenizer/model

Architectures supported currently:
- OlMoE architecture, includes OLMoE-1B-7B-0125-Instruct (1B/7B)
- Qwen2MoE architecture, inclues Qwen1.5-MoE-A2.7B-Chat (2.7B/14.3B), Qwen2-57B-A14B (14B/57B)
- Deepseek v2 architecture, includes Deepseek-v2-Lite (2.4B/15.7B), Deepseek-v2 (21B/236B)
- Deepseek v3 architecture, includes Deepseek-v3 (37B/671B), Deepseek-R1 (37B/671B), Moonlight-16B-A3B (3B/16B)
"""
selected_model_index = 1

def get_model(index):
    model = [
        ('allenai/OLMoE-1B-7B-0125-Instruct', 'olmoe', 'olmoe'),
        ('Qwen/Qwen1.5-MoE-A2.7B-Chat', 'qwen1.5moe', 'qwen2moe'),
        ('deepseek-ai/DeepSeek-V2-Lite', 'dsv2', 'dsv2'),
        ('moonshotai/Moonlight-16B-A3B', 'moonlight', 'dsv3')
    ][index]

    return model[0], model[1], model[2]

model_id, model_prefix, model_architecture = get_model(selected_model_index)
tokenizer = AutoTokenizer.from_pretrained(model_id, add_eos_token = False, add_bos_token = False, padding_side = 'left', trust_remote_code = True)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype = torch.bfloat16, trust_remote_code = True).cuda().eval()

In [None]:
"""
Load reverse-engineered forward pass functions that return topk expert IDs and weights
"""
model_module = importlib.import_module(f"utils.pretrained_models.{model_architecture}")
run_model_return_topk = getattr(model_module, f"run_{model_architecture}_return_topk")

def test_custom_forward_pass(model, pad_token_id):
    inputs = tokenizer(['Hi! I am a dog and I like to bark', 'Vegetables are good for'], return_tensors = 'pt', padding = 'max_length', truncation = True, max_length = 512).to(model.device)
    original_results = model(**inputs)
    custom_results = run_model_return_topk(model, inputs['input_ids'], inputs['attention_mask'], return_hidden_states = True)
    assert torch.equal(original_results.logits, custom_results['logits']), 'Error in custom forward'
    assert len(custom_results['all_topk_experts']) == len(custom_results['all_topk_weights']), 'Length of topk IDs and weights not equal'
    print(f"Length of topk: {len(custom_results['all_topk_experts'])}")
    print(f"Topk size: {custom_results['all_topk_experts'][0].shape}")
    print(f"First token topk IDs: {custom_results['all_topk_experts'][0][1,]}")
    print(f"First token topk weights: {custom_results['all_topk_weights'][0][1,]}")
    loss = ForCausalLMLoss(custom_results['logits'], torch.where(inputs['input_ids'] == pad_token_id, torch.tensor(-100), inputs['input_ids']), model.config.vocab_size).detach().cpu().item()
    print(f"LM loss: {loss}")
    print(f"Hidden states layers (pre-mlp, post-layer): {len(custom_results['all_pre_mlp_hidden_states'])} | {len(custom_results['all_hidden_states'])}")
    print(f"Hidden state size (pre-mlp, post-layer): {(custom_results['all_pre_mlp_hidden_states'][0].shape)} | {(custom_results['all_hidden_states'][0].shape)}")

test_custom_forward_pass(model, tokenizer.pad_token_id)

## Get dataset

In [None]:
"""
Load dataset - C4 mix (en/zh/es)
"""
def load_raw_ds():
   
    ds_en = load_dataset('allenai/c4', 'en', split = 'validation', streaming = True).shuffle(seed = 123, buffer_size = 100_000)
    ds_zh = load_dataset('allenai/c4', 'zh', split = 'validation', streaming = True).shuffle(seed = 123, buffer_size = 100_000)
    ds_es = load_dataset('allenai/c4', 'es', split = 'validation', streaming = True).shuffle(seed = 123, buffer_size = 100_000)
    
    def get_data(ds, n_samples):
        raw_data = []
        ds_iter = iter(ds)
        for _ in range(0, n_samples):
            sample = next(ds_iter, None)
            if sample is None:
                break
            raw_data.append(sample['text'])
        
        return raw_data
    
    return get_data(ds_en, 500) + get_data(ds_zh, 100) + get_data(ds_es, 100)


raw_data = load_raw_ds()

In [None]:
""" 
Load dataset into a dataloader. The dataloader returns the original tokens - this is important for BPE tokenizers as otherwise it's difficult to reconstruct the correct string later!
"""
from torch.utils.data import Dataset, DataLoader

class ReconstructableTextDataset(Dataset):

    def __init__(self, text_dataset, tokenizer, max_length):
        """
        Creates a dataset object that also returns a B x N list of the original tokens in the same position as the input ids.

        Params:
            @text_dataset: A list of B samples of text dataset.
            @tokenizer: A HF tokenizer object.
        """
        tokenized = tokenizer(text_dataset, add_special_tokens = False, max_length = max_length, padding = 'max_length', truncation = True, return_offsets_mapping = True, return_tensors = 'pt')

        self.input_ids = tokenized['input_ids']
        self.attention_mask = tokenized['attention_mask']
        self.offset_mapping = tokenized['offset_mapping']
        self.original_tokens = self.get_original_tokens(text_dataset)

    def get_original_tokens(self, text_dataset):
        """
        Return the original tokens associated with each B x N position. This is important for reconstructing the original text when BPE tokenizers are used.
        
        Params:
            @input_ids: A B x N tensor of input ids.
            @offset_mapping: A B x N x 2 tensor of offset mappings. Get from `tokenizer(..., return_offsets_mapping = True)`.

        Returns:
            A list of length B, each with length N, containing the corresponding original tokens corresponding to the token ID at the same position of input_ids.
        """
        all_token_substrings = []
        for i in range(0, self.input_ids.shape[0]):
            token_substrings = []
            for j in range(self.input_ids.shape[1]): 
                start_char, end_char = self.offset_mapping[i][j].tolist()
                if start_char == 0 and end_char == 0: # When pads, offset_mapping might be [0, 0], so let's store an empty string for those positions.
                    token_substrings.append("")
                else:
                    original_substring = text_dataset[i][start_char:end_char]
                    token_substrings.append(original_substring)
            
            all_token_substrings.append(token_substrings)

        return all_token_substrings

    def __len__(self):
        return len(self.input_ids)
    
    def __getitem__(self, idx):
        return {'input_ids': self.input_ids[idx], 'attention_mask': self.attention_mask[idx], 'original_tokens': self.original_tokens[idx]}
    
def collate_fn(batch):
    """
    Custom collate function; necessary to return original_tokens in the correct shape 
    """
    input_ids = torch.stack([b['input_ids'] for b in batch], dim = 0)
    attention_mask = torch.stack([b['attention_mask'] for b in batch], dim = 0)        
    original_tokens = [b['original_tokens'] for b in batch]
    return {'input_ids': input_ids, 'attention_mask': attention_mask, 'original_tokens': original_tokens}

test_dl = DataLoader(
    ReconstructableTextDataset(raw_data, tokenizer, max_length = 1024),
    batch_size = 8,
    shuffle = False,
    collate_fn = collate_fn
)

## Get expert selections + export

In [None]:
""" 
Run forward passes + export data

Note the bulk of compute time will be spent on exporting the CSV, not handling the forward passes.
"""

@torch.no_grad()
def run_and_export_topk(model, dl: ReconstructableTextDataset, layers_to_keep: list[int], max_batches: None | int = None):
    """
    Run forward passes on given model and store the intermediate hidden layers as well as topks

    Params:
        @model: The model to run forward passes on. Should return a dict with keys `logits`, `all_topk_experts`, `all_topk_weights`, and
          `all_pre_mlp_hidden_states`.
        @dl: The dataloader which returns `input_Ids`, `attention_mask`, and `original_tokens`.
        @layers_to_keep: A list of layers for which to filter `topk_df` and `all_pre_mlp_hidden_states` (see returned object description).
        @max_batches: The max number of batches to run.

    Returns:
        A dict with keys:
        - `sample_df`: A sample (token)-level dataframe with corresponding input token ID, output token ID, and input token text (removes masked tokens)
        - `topk_df`: A sample (token) x layer_ix x topk_ix level dataframe that gives the expert ID selected at each sample-layer-topk (removes masked_tokens)
        - `all_pre_mlp_hidden_states`: A tensor of size n_samples x layers_to_keep x D return the hidden state for each retained layers. Each 
            n_sample corresponds to a row of sample_df.
    """
    b_count = 0
    all_pre_mlp_hidden_states = []
    sample_dfs = []
    topk_dfs = []

    for batch_ix, batch in tqdm(enumerate(dl), total = len(dl)):

        input_ids = batch['input_ids'].to(main_device)
        attention_mask = batch['attention_mask'].to(main_device)
        original_tokens = batch['original_tokens']

        output = run_model_return_topk(model, input_ids, attention_mask, return_hidden_states = True)

        # Check no bugs by validating output/perplexity
        if batch_ix == 0:
            loss = ForCausalLMLoss(output['logits'], torch.where(input_ids == tokenizer.pad_token_id, torch.tensor(-100), input_ids), model.config.vocab_size).detach().cpu().item()
            for i in range(min(2, input_ids.size(0))):
                decoded_input = tokenizer.decode(input_ids[i, :attention_mask[i].sum()], skip_special_tokens = True)
                next_token_id = torch.argmax(output['logits'][i, -1, :]).item()
                print(decoded_input + colored(tokenizer.decode([next_token_id], skip_special_tokens = True), 'green'))
            print(f"PPL:", torch.exp(torch.tensor(loss)).item())
        
        original_tokens_df = pd.DataFrame(
            [(seq_i, tok_i, tok) for seq_i, tokens in enumerate(original_tokens) for tok_i, tok in enumerate(tokens)], 
            columns = ['sequence_ix', 'token_ix', 'token']
        )

        # Create sample (token) level dataframe
        sample_df =\
            convert_outputs_to_df(input_ids, attention_mask, output['logits'])\
            .merge(original_tokens_df, how = 'left', on = ['token_ix', 'sequence_ix'])\
            .assign(batch_ix = batch_ix)

        # Create topk x layer_ix x sample level dataframe
        topk_df =\
            convert_topk_to_df(input_ids, attention_mask, output['all_topk_experts'], output['all_topk_weights'])\
            .assign(batch_ix = batch_ix, weight = lambda df: df['weight'])\
            .drop(columns = 'token_id')\
            .pipe(lambda df: df[df['layer_ix'].isin(layers_to_keep)])
        
        sample_dfs.append(sample_df)
        topk_dfs.append(topk_df)

        # Store pre-MLP hidden states - the fwd pass as n_layers list as BN x D, collapse to BN x n_layers x D, with BN filtering out masked items
        valid_pos = torch.where(attention_mask.cpu().view(-1) == 1) # Valid (BN, ) positions
        all_pre_mlp_hidden_states.append(torch.stack(output['all_pre_mlp_hidden_states'], dim = 1)[valid_pos][:, layers_to_keep, :])

        b_count += 1
        if max_batches is not None and b_count >= max_batches:
            break

    return {'sample_df': pd.concat(sample_dfs), 'topk_df': pd.concat(topk_dfs), 'all_pre_mlp_hidden_states': torch.cat(all_pre_mlp_hidden_states, dim = 0)}

res = run_and_export_topk(model, test_dl, layers_to_keep = list(range(0, 6)), max_batches = None)

## Prep for Clustering

In [None]:
"""
Let's clean up the mappings here. We'll get everything to a sample_ix level first.
"""
sample_df_raw =\
    res['sample_df']\
    .assign(sample_ix = lambda df: df.groupby(['batch_ix', 'sequence_ix', 'token_ix']).ngroup())\
    .assign(seq_id = lambda df: df.groupby(['batch_ix', 'sequence_ix']).ngroup())\
    .reset_index()

topk_df =\
    res['topk_df']\
    .merge(sample_df_raw[['sample_ix', 'batch_ix', 'sequence_ix', 'token_ix']], how = 'inner', on = ['sequence_ix', 'token_ix', 'batch_ix'])\
    .drop(columns = ['sequence_ix', 'token_ix', 'batch_ix'])

sample_df =\
    sample_df_raw\
    .drop(columns = ['batch_ix', 'sequence_ix'])

display(topk_df)
display(sample_df)

In [None]:
all_pre_mlp_hs = res['all_pre_mlp_hidden_states']
all_pre_mlp_hs.shape 

## Clustering

In [None]:
""" 
Cross-layer Topk = 1 Clusters
"""
topk_wide =\
    topk_df\
    .pipe(lambda df: df[df['topk_ix'] == 1])\
    .merge(sample_df[['sample_ix', 'token']], on = 'sample_ix', how = 'inner')\
    .pivot(index = ['sample_ix', 'token'], columns = 'layer_ix', values = 'expert')\
    .rename(columns = lambda c: f'layer_{c}_id')\
    .reset_index()

display(topk_wide.groupby('layer_1_id', as_index = False).agg(n_samples = ('token', 'size')).sort_values(by = 'n_samples', ascending = False))

topk_wide\
    .groupby(['layer_2_id', 'layer_3_id', 'layer_4_id', 'layer_5_id'], as_index = False)\
    .agg(
        n_samples = ('token', 'size'),
        samples = ('token', lambda s: s.sample(n = min(len(s), 10)).tolist())
    )\
    .pipe(lambda df: df[df['n_samples'] >= 5])\
    .sample(25)

In [None]:
"""
Within layer clusters
"""
# Test single layer
sl_wide =\
    topk_df\
    .pipe(lambda df: df[df['layer_ix'] == 4])\
    .merge(sample_df[['sample_ix', 'token']], on = 'sample_ix', how = 'inner')\
    .pivot(index = ['sample_ix', 'token'], columns = 'topk_ix', values = 'expert')\
    .rename(columns = lambda c: f'topk_l1_{c}_id')\
    .reset_index()

display(sl_wide\
    .groupby(['topk_l1_1_id', 'topk_l1_2_id', 'topk_l1_3_id', 'topk_l1_4_id'], as_index = False)\
    .agg(
        n_samples = ('token', 'size'),
        samples = ('token', lambda s: s.sample(n = min(len(s), 10)).tolist())
    )\
    .pipe(lambda df: df[df['n_samples'] >= 5])\
    .sample(25))

sl2_wide =\
    topk_df\
    .pipe(lambda df: df[df['layer_ix'] == 5])\
    .merge(sample_df[['sample_ix', 'token']], on = 'sample_ix', how = 'inner')\
    .pivot(index = ['sample_ix', 'token'], columns = 'topk_ix', values = 'expert')\
    .rename(columns = lambda c: f'topk_l2_{c}_id')\
    .reset_index()

display(sl2_wide\
    .groupby(['topk_l2_1_id', 'topk_l2_2_id', 'topk_l2_3_id', 'topk_l2_4_id'], as_index = False)\
    .agg(
        n_samples = ('token', 'size'),
        samples = ('token', lambda s: s.sample(n = min(len(s), 10)).tolist())
    )\
    .pipe(lambda df: df[df['n_samples'] >= 5])\
    .sample(25))


display(sl_wide.merge(sl2_wide, on = 'sample_ix', how = 'inner')\
    .groupby(['topk_l1_1_id', 'topk_l1_2_id', 'topk_l2_1_id', 'topk_l2_2_id'], as_index = False)\
    .agg(
        n_samples = ('token_x', 'size'),
        samples = ('token_x', lambda s: s.sample(n = min(len(s), 10)).tolist())
    )\
    .pipe(lambda df: df[df['n_samples'] >= 5])\
    .sample(25))

In [None]:
"""
K-Means (note - returns imbalanced clusters)
""" 
def cluster_kmeans(layer_hs: torch.Tensor, n_clusters = 64):
    """
    Params:
        @layer_hs: A n_token_samples x D tensor for a single layer
        @n_clusters: The number of clusters to return

    Returns:
        A list of length n_token_samples of cluster ids
    """
    hs_cupy = cupy.asarray(layer_hs.to(torch.float32))
    kmeans_model = cuml.cluster.KMeans(
        n_clusters = n_clusters,
        max_iter = 1000,
        random_state = 123,
        verbose = True
    )
    kmeans_model.fit(hs_cupy)
    cluster_labels = kmeans_model.labels_ # shape = (n_samples,)
    # cluster_centers = kmeans_model.cluster_centers_ # shape = (num_clusters, D)
    return cluster_labels.tolist()

kmeans_res = [
    {'cluster_method': 'kmeans', 'layer_ix': layer_ix, 'cluster_ids': cluster_kmeans(layer_hs, 64)}
    for layer_ix, layer_hs in tqdm(enumerate(all_pre_mlp_hs.unbind(dim = 1)))
]

joined_df =\
    pd.concat([pd.DataFrame({'layer_' + str(x['layer_ix']) + '_id': x['cluster_ids']}) for x in kmeans_res], axis = 1)\
    .pipe(lambda df: pd.concat([df, sample_df], axis = 1))

display(joined_df.groupby('layer_1_id', as_index = False).agg(n_samples = ('token', 'size')).sort_values(by = 'n_samples', ascending = False))
display(joined_df.groupby('layer_3_id', as_index = False).agg(n_samples = ('token', 'size')).sort_values(by = 'n_samples', ascending = False))

routes =\
    joined_df\
    .groupby(['layer_2_id', 'layer_3_id', 'layer_4_id', 'layer_5_id'], as_index = False)\
    .agg(n_samples = ('token', 'size'), samples = ('token', lambda s: s.sample(n = min(len(s), 10)).tolist()))\
    .pipe(lambda df: df[df['n_samples'] >= 5])

display(routes.head(25))

In [None]:
""" 
Test decomp methods
"""
def reduce_pca(layer_hs: torch.Tensor, n_components = 2, fit_samples: None | int = 10_000):
    # https://docs.rapids.ai/api/cuml/stable/api/#principal-component-analysis
    hs_cupy = cupy.asarray(layer_hs.to(torch.float32))
    if fit_samples:
        subset_indices = np.random.default_rng(123).choice(hs_cupy.shape[0], min(hs_cupy.shape[0], fit_samples), replace = False)
    else:
        subset_indices = list(range(0, hs_cupy.shape[0]))

    model = cuml.PCA(
        iterated_power = 20,
        n_components = n_components,
        verbose = True
    )
    model.fit(hs_cupy[subset_indices, :])
    # print(f'Explained variance ratio: {model.explained_variance_ratio_}')
    print(f'Cumulative variance ratio: {np.cumsum(model.explained_variance_ratio_)[-1]}')
    # print(f'Means by feature: {model.mean_}')
    print(f'Max feature mean: {np.max(model.mean_)} | Min feature mean: {np.min(model.mean_)}')
    pred = model.transform(hs_cupy)
    
    return cupy.asnumpy(pred)

pca_test = reduce_pca(all_pre_mlp_hs.unbind(dim = 1)[0], 100, 500_000)
px.scatter(
    pd.concat([pd.DataFrame({'d1': pca_test[:, 0], 'd2': pca_test[:, 1]}), sample_df.head(pca_test.shape[0])], axis = 1)\
        .sample(5000)
        .assign(is_of = lambda df: np.where(df['token'] == ' of', 1, 0)),
    x = 'd1', y = 'd2', color = 'is_of', hover_data = ['token']
).show()
clear_all_cuda_memory()

pca_100 = [reduce_pca(layer_hs, 100, 500_000) for layer_hs in tqdm(all_pre_mlp_hs.unbind(dim = 1))]

In [None]:
def reduce_umap(layer_hs: torch.Tensor, n_components = 2, metric = 'euclidean', fit_samples: None | int = 10_000):
    # https://docs.rapids.ai/api/cuml/stable/api/#umap
    hs_cupy = cupy.asarray(layer_hs.to(torch.float32))
    if fit_samples:
        subset_indices = np.random.default_rng(123).choice(hs_cupy.shape[0], min(hs_cupy.shape[0], fit_samples), replace = False)
    else:
        subset_indices = list(range(0, hs_cupy.shape[0]))

    model = cuml.UMAP(
        n_components = n_components, 
        n_neighbors = 15, # 15 for default, smaller = more local data preserved [2 - 100]
        metric = metric, # euclidean, cosine, manhattan, l2, hamming
        min_dist = 0.1, # 0.1 by default, effective distance between embedded points
        n_epochs = 250, # 200 by default for large datasets
        random_state = None, # Allow parallelism
        verbose = True
    )
    model.fit(hs_cupy[subset_indices, :])
    pred = model.transform(hs_cupy)
    
    return cupy.asnumpy(pred)

umap_test = reduce_umap(all_pre_mlp_hs.unbind(dim = 1)[0], 100, 'euclidean', 500_000) # 300k = 2min
px.scatter(
    pd.concat([pd.DataFrame({'d1': umap_test[:, 0], 'd2': umap_test[:, 1]}), sample_df.head(umap_test.shape[0])], axis = 1)\
        .sample(5000)
        .assign(is_of = lambda df: np.where(df['token'] == ' of', 1, 0)),
    x = 'd1', y = 'd2', color = 'is_of', hover_data = ['token']
).show()
clear_all_cuda_memory()

umap_euc_100 = [reduce_umap(layer_hs, 100, 'euclidean', 500_000) for layer_hs in tqdm(all_pre_mlp_hs.unbind(dim = 1))]
umap_cos_100 = [reduce_umap(layer_hs, 100, 'cosine', 500_000) for layer_hs in tqdm(all_pre_mlp_hs.unbind(dim = 1))]

In [None]:
"""
DBScan
"""
def cluster_dbscan(layer_hs: torch.Tensor, fit_samples: None | int = 10_000):
    # https://docs.rapids.ai/api/cuml/stable/api/#dbscan
    hs_cupy = cupy.asarray(layer_hs.to(torch.float32))

    if fit_samples:
        subset_indices = np.random.randint(0, hs_cupy.shape[0], size = max(hs_cupy.shape[0], fit_samples)) # 50k = 3min
    else:
        subset_indices = list(range(0, hs_cupy.shape[0]))

    dbscan_model = cuml.cluster.DBSCAN(
        metric = 'euclidean', # Or cosine
        min_samples = 5, # Number of samples st the group can be considered a core point
        verbose = True
    )
    dbscan_model.fit(hs_cupy[subset_indices, :])
    pred = dbscan_model.transform(hs_cupy)
    cluster_labels = pred.labels_ # shape = (n_samples,)
    
    print(f"Values unassigned to clusters: {len([l for l in cluster_labels.tolist() if l == -1])}/{len(cluster_labels)}")

    return cluster_labels.tolist()

cluster_dbscan(all_pre_mlp_hs.unbind(dim = 1)[0], 100)

In [None]:
asdgasdg

In [None]:
"""
UMAP Testing
"""

def reduce_umap(layer_hs: torch.Tensor)
    # https://docs.rapids.ai/api/cuml/stable/api/#umap
    hs_cupy = cupy.asarray(layer_hs.to(torch.float32))
    subset_indicies = np.random.choice(hs_cupy.shape[0], size = 50_000, replace = False) # 50k = 3min
    reducer = umap.UMAP(
        n_components = 2, 
        n_neighbors = 15, # 15 for default, smaller = more local data preserved [2 - 100]
        metric = 'euclidean', # cosine, manhattan, l2, hamming
        min_dist = 0.1, # 0.1 by default, effective distance between embedded points
        n_epochs = 200, # 200 by default for large datasets
        random_state = None # Allow parallelism
        )

    embed = reducer.fit(hs_layer_np[subset_indicies])  # shape = (n_samples, n_components)


px.scatter(
    pd.concat(
        [pd.DataFrame({'d1': embed[:, 0], 'd2': embed[:, 1]}), sample_df.head(embed.shape[0])],
        axis = 1
        )\
        .sample(5000)
        .assign(
            is_of = lambda df: np.where(df['token'] == ' of', 1, 0)
        ),
    x = 'd1',
    y = 'd2',
    color = 'is_of',
    hover_data = ['token']
)

In [None]:
"""
UMAP -> 100 + HDBSCAN
"""
import umap

def cluster_umap_to_hdbscan(layer_hs: torch.Tensor, umap_dim: int = 100, fit_samples: None | int = 10_000):
    # https://docs.rapids.ai/api/cuml/stable/api/#hdbscan
    hs_cupy = cupy.asarray(layer_hs.to(torch.float32))

    if fit_samples:
        subset_indices = np.random.randint(0, hs_cupy.shape[0], size = max(hs_cupy.shape[0], fit_samples)) # 50k = 3min
    else:
        subset_indices = list(range(0, hs_cupy.shape[0]))

    hdbscan_model = cuml.cluster.HDBSCAN(
        min_cluster_size = len(hs_cupy) // (64 * 100), # Min 1/20 of the uniform dist value
        max_cluster_size = len(hs_cupy) // (64 * 1/100), # Max 20x the uniform dist values 
        metric = 'euclidean',
        min_samples = 1,
        verbose = True
    )
    hdbscan_model.fit(hs_cupy[subset_indices, :])
    pred = hdbscan_model.transform(hs_cupy)
    cluster_labels = pred.labels_ # shape = (n_samples,)
    
    print(f"Values unassigned to clusters: {len([l for l in cluster_labels.tolist() if l == -1])}/{len(cluster_labels)}")

    return cluster_labels.tolist()

cluster_hdbscan(all_pre_mlp_hs.unbind(dim = 1)[0])

In [None]:
from cuml.manifold.umap import UMAP

hs_for_layer = hs_by_layer[0]

hs_layer_np = hs_for_layer.to(torch.float32).numpy()
subset_indicies = np.random.choice(list(range(0, hs_layer_np.shape[0])), size = 50_000, replace = False) # 50k = 3min
reducer = umap.UMAP(
    n_components = 10, 
    n_neighbors = 15, 
    min_dist = 0.1, 
    n_epochs = 100,
    data_on_host = True
    )

embed = reducer.fit(hs_layer_np[subset_indicies])  # shape = (n_samples, n_components)
embed_all = reducer.transform(hs_layer_np)


In [None]:
clusterer = hdbscan.HDBSCAN(
    min_cluster_size = 500,
    min_samples = 100,
    metric = 'euclidean' # cosine
    ) # https://stackoverflow.com/questions/67898039/hdbscan-difference-between-parameters
labels = clusterer.fit_predict(embed_all)


In [None]:
import plotly.express as px

px.scatter(
    pd.concat(
        [
            pd.DataFrame({'d1': embedding[:, 0], 'd2': embedding[:, 1]}),
            sample_df.head(embedding.shape[0])
        ],
        axis = 1
    ).sample(5000)
    .assign(
        is_of = lambda df: np.where(df['token'] == ' of', 1, 0)
    ),
    x = 'd1',
    y = 'd2',
    color = 'is_of',
    hover_data = ['token']
)

In [None]:
pd.concat(
    [
        pd.DataFrame({'d1': embedding[:, 0], 'd2': embedding[:, 1]}),
        sample_df.head(embedding.shape[0])
    ],
    axis = 1
    )

## Compare to Dense Models