In [37]:

import pandas as pd
from transformers import MambaForCausalLM
import plotly.express as px

import numpy as np
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA

from tqdm import tqdm
import matplotlib.pyplot as plt

from scipy.stats import skew, kurtosis
import plotly.express as px


In [4]:
from scripts.evaluate_model import get_tokenizer_and_model
from scripts.plot_a_vals_distr import collect_and_stack_A_logs

In [5]:
_, model = get_tokenizer_and_model("mamba", '2.8B')
model.eval();

The fast path is not available because one of `(selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)` is None. Falling back to the sequential implementation of Mamba, as use_mambapy is set to False. To install follow https://github.com/state-spaces/mamba/#installation and https://github.com/Dao-AILab/causal-conv1d. For the mamba.py backend, follow https://github.com/alxndrTL/mamba.py.
Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00,  5.50it/s]


In [6]:
stacked_A_logs, layer_indices, position_indices = collect_and_stack_A_logs(model)


In [7]:
def compute_features(data, features_to_add, feature_dict=None):
    if feature_dict is None:
        feature_dict = {}
    features_to_add = [f for f in features_to_add if f not in feature_dict]
    if 'original' in features_to_add:
        feature_dict['original'] = data
    if 'L1_norm' in features_to_add:
        feature_dict['L1_norm'] = np.linalg.norm(data, ord=1, axis=1, keepdims=True)
    if 'L_infinity_norm' in features_to_add:
        feature_dict['L_infinity_norm'] = np.linalg.norm(data, ord=np.inf, axis=1, keepdims=True)
    if 'top_2_pca' in features_to_add:
        pca = PCA(n_components=min(data.shape[1], 5))  # Compute more PCA components
        pca_features = pca.fit_transform(data)
        top_2_indices = np.argsort(-np.abs(pca_features), axis=1)[:, :2]  # Get indices of top 2 absolute values per row
        top_2_pca_features = np.take_along_axis(pca_features, top_2_indices, axis=1)  # Select the top 2 PCA components
        feature_dict['top_2_pca'] = top_2_pca_features
        feature_dict['pca1'] = pca_features[:, 0].reshape(-1, 1)
        feature_dict['pca2'] = pca_features[:, 1].reshape(-1, 1)
    if 'skewness' in features_to_add:
        feature_dict['skewness'] = skew(data, axis=1).reshape(-1, 1)
    if 'kurtosis' in features_to_add:
        feature_dict['kurtosis'] = kurtosis(data, axis=1).reshape(-1, 1)
    return feature_dict

In [8]:
enriched_features = {}

In [9]:
enriched_features = compute_features(stacked_A_logs, features_to_add=['original', 'L1_norm', 'L_infinity_norm', 'top_2_pca', 'skewness', 'kurtosis'], feature_dict=enriched_features)

In [34]:
enriched_features_names = ['L1_norm', 'L_infinity_norm', 'skewness', 'kurtosis', 'pca1', 'pca2']
# Create a DataFrame for enriched features
enriched_df = pd.DataFrame({
    **{
        f'{feature}': enriched_features[feature].flatten()
        for feature 
        in enriched_features_names
    },
    'Layer Index': layer_indices,
    'Layer Index str': [f'Layer {i}' for i in layer_indices],
    'Position Index': position_indices
})


In [14]:
enriched_df.to_csv('A_features.csv', index=False)

In [1]:

def plot_feature_interactions(data, feature1, feature2):
    n_layers = len(data['Layer Index'].unique())
    colorscale = px.colors.sample_colorscale(
        px.colors.sequential.Plasma, 
        [(i/n_layers) for i in range(n_layers)]
        )
    
    fig = px.scatter(
        data, 
        x=feature1, y=feature2,
        color='Layer Index str',
        color_discrete_sequence=colorscale,
        hover_data=['Position Index', 'Layer Index', 'L1_norm', 'L_infinity_norm', 'skewness', 'kurtosis', 'pca1', 'pca2'],
        opacity=0.1,
        title=f'{feature1} vs {feature2}'
        )
    return fig


In [None]:
fig = plot_feature_interactions(enriched_df, 'L1_norm', 'L_infinity_norm')

fig.show()

In [40]:
# Perform clustering with KMeans
num_clusters = 2
kmeans = KMeans(n_clusters=num_clusters, random_state=42)
kmeans.fit(enriched_df[['L1_norm', 'L_infinity_norm', 'skewness', 'kurtosis', 'pca1', 'pca2']])
enriched_df['Cluster_all_enriched'] = kmeans.labels_
enriched_df['Cluster_all_enriched_str'] = [f'Cluster {i}' for i in kmeans.labels_]

In [None]:

# Plot using Plotly
fig = px.scatter(
    enriched_df, x='L Infinity Norm', y='L1 Norm', color='Layer Index',
    hover_data={'Position Index': True},
    title='Distribution of Different Features',
    labels={'L1 Norm': 'L1 Norm Feature'}
)

# Update layout to make Layer Index toggleable
fig.update_traces(marker=dict(size=5), selector=dict(mode='markers'))
fig.update_layout(
    legend=dict(
        title='Layer',
        itemsizing='constant',
        itemclick='toggleothers',  # Toggle visibility of other traces
        itemdoubleclick='toggle'   # Toggle visibility of the clicked trace
    )
)

fig.write_html("feature_distribution.html")
fig.show()


In [None]:
fig.write_html("feature_distribution1.html")

In [None]:

# Step 1: Collect A_log matrices
A_logs = [model.backbone.layers[i].mixer.A_log.detach().numpy() for i in range(len(model.backbone.layers))]

# Step 2: Stack into a single dataset
A_logs_all= np.vstack(A_logs)

# Optional: Standardize the data
scaler = StandardScaler()
A_logs_all_scaled = scaler.fit_transform(A_logs_all)


In [None]:

# Step 3: Perform clustering
num_clusters = 5
kmeans = KMeans(n_clusters=num_clusters, random_state=42)
kmeans.fit(A_logs_all_scaled)
labels = kmeans.labels_


In [None]:

# Step 4: Analyze clusters
# Example: Get indices of data points in cluster 0
cluster_0_indices = np.where(labels == 0)[0]


In [None]:
# Map back to layers and positions
layer_size = A_logs[0].shape[0]
layer_indices = np.repeat(np.arange(len(A_logs)), layer_size)
position_indices = np.tile(np.arange(layer_size), len(A_logs))

In [None]:
pca = PCA(n_components=2)
A_logs_pca = pca.fit_transform(A_logs_all_scaled)

In [None]:
df = pd.DataFrame({
    'PCA Component 1': A_logs_pca[:, 0],
    'PCA Component 2': A_logs_pca[:, 1],
    'Cluster': labels,
    'Layer Index': layer_indices,
    'Position Index': position_indices
})


In [None]:
# Plot using Plotly
fig = px.scatter(
    df, x='PCA Component 1', y='PCA Component 2', color='Cluster',
    hover_data={'Layer Index': True, 'Position Index': True},
    opacity=0.7,  # Add alpha to points
    title='Clustering of A_log Matrices (PCA Reduced)',
    labels={'PCA Component 1': 'PCA Component 1', 'PCA Component 2': 'PCA Component 2'}
)
fig.update_layout(legend_title_text='Cluster')  # Add legend title for interactivity
fig.show()


In [None]:
for i, layer in tqdm(enumerate(model.backbone.layers)):
    for j in range(layer.mixer.A_log.shape[0]):
        for k in range(layer.mixer.A_log.shape[1]):
            vals['layer'].append(i)
            vals['val'].append(layer.mixer.A_log[j,k].item())
            vals['ssm_id'].append(j)
            vals['ssm_id+layer'].append(f'{i}:{j}')


In [None]:

df = pd.DataFrame(vals)
min_val_per_ssm = df.groupby(['ssm_id', 'layer'])['val'].min().reset_index()

# get min val
min_vals = df.groupby(['ssm_id', 'layer'])['val'].min().reset_index()
max_vals = df.groupby(['ssm_id', 'layer'])['val'].median().reset_index()
min_max_vals = min_vals.merge(max_vals, on=['ssm_id', 'layer'], suffixes=('_min', '_max'))

# plot scatter:
min_max_vals['layer'] = min_max_vals['layer'].astype(str)
n_layers = len(min_max_vals['layer'].unique())
colorscale = px.colors.sample_colorscale(px.colors.sequential.Plasma, [i / n_layers for i in range(n_layers)])
fig = px.scatter(min_max_vals, x='val_min', y='val_max', color='layer', color_discrete_sequence=colorscale)
fig.update_layout(height=1000, width=1000)
fig.write_html(f"a_vals_min_max_categ_{model_size}.html")


import torch
min_max_vals['val_min'] = -torch.exp(torch.from_numpy(min_max_vals['val_min'].to_numpy())).numpy()
min_max_vals['val_max'] = -torch.exp(torch.from_numpy(min_max_vals['val_max'].to_numpy())).numpy()
fig = px.scatter(min_max_vals, x='val_min', y='val_max', color='layer', color_discrete_sequence=colorscale)
fig.update_layout(height=1000, width=1000)
fig.write_html(f"a_vals_min_max_categ_{model_size}_exp.html")


min_max_vals['val_min'] = torch.exp(torch.from_numpy(min_max_vals['val_min'].to_numpy())).numpy()
min_max_vals['val_max'] = torch.exp(torch.from_numpy(min_max_vals['val_max'].to_numpy())).numpy()
fig = px.scatter(min_max_vals, x='val_min', y='val_max', color='layer', color_discrete_sequence=colorscale)
fig.update_layout(height=1000, width=1000)
fig.write_html(f"a_vals_min_max_categ_{model_size}_exp_exp.html")