In [4]:

import numpy as np
import pandas as pd
import plotly.express as px
import torch
from scipy.stats import kurtosis, skew
from sklearn.cluster import KMeans
from tqdm import tqdm

In [5]:
from scripts.evaluate_model import get_tokenizer_and_model
from scripts.plot_a_vals_distr import collect_and_stack_A_logs
from src.consts import PATHS

In [6]:
_, model = get_tokenizer_and_model("mamba", '2.8B')
model.eval();

Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00,  4.92it/s]


In [7]:
stacked_A_logs, layer_indices, position_indices = collect_and_stack_A_logs(model)
exped_log_A = torch.exp(-torch.exp(torch.from_numpy(stacked_A_logs))).numpy()


In [8]:
def compute_features(data, features_to_add, feature_dict=None):
    if feature_dict is None:
        feature_dict = {}
    features_to_add = [f for f in features_to_add if f not in feature_dict]
    if 'L1_norm' in features_to_add:
        feature_dict['L1_norm'] = np.linalg.norm(data, ord=1, axis=1, keepdims=True)
    if 'L_infinity_norm' in features_to_add:
        feature_dict['L_infinity_norm'] = np.linalg.norm(data, ord=np.inf, axis=1, keepdims=True)
    if 'skewness' in features_to_add:
        feature_dict['skewness'] = skew(data, axis=1).reshape(-1, 1)
    if 'kurtosis' in features_to_add:
        feature_dict['kurtosis'] = kurtosis(data, axis=1).reshape(-1, 1)
    if 'mean' in features_to_add:
        feature_dict['mean'] = np.mean(data, axis=1).reshape(-1, 1)
    if 'median' in features_to_add:
        feature_dict['median'] = np.median(data, axis=1).reshape(-1, 1)
    if 'min' in features_to_add:
        feature_dict['min'] = np.min(data, axis=1).reshape(-1, 1)
    if 'max' in features_to_add:
        feature_dict['max'] = np.max(data, axis=1).reshape(-1, 1)
    if 'std' in features_to_add:
        feature_dict['std'] = np.std(data, axis=1).reshape(-1, 1)
    
    return feature_dict

In [9]:
enriched_features_names = ['L1_norm', 'L_infinity_norm', 'skewness', 'kurtosis', 'mean', 'median', 'min', 'max', 'std']
enriched_features_names_exp = [f'{f}_exp' for f in enriched_features_names]

In [10]:
enriched_features_names = ['L1_norm', 'L_infinity_norm', 'skewness', 'kurtosis', 'mean', 'median', 'min', 'max', 'std']
enriched_features_names_exp = [f'{f}_exp' for f in enriched_features_names]
enriched_features = {}
enriched_features_exp = {}
enriched_features = compute_features(stacked_A_logs, features_to_add=enriched_features_names, feature_dict=enriched_features)
enriched_features_exp = compute_features(exped_log_A, features_to_add=enriched_features_names, feature_dict=enriched_features_exp)

In [11]:
# Create a DataFrame for enriched features
enriched_df = pd.DataFrame({
    **{
        f'{feature}': enriched_features[feature].flatten()
        for feature 
        in enriched_features_names
    },
    **{
        f'{feature}_exp': enriched_features_exp[feature].flatten()
        for feature 
        in enriched_features_names
    },
    'Layer Index': layer_indices,
    'Layer Index str': [f'Layer {i}' for i in layer_indices],
    'Position Index': position_indices
})

enriched_df.to_csv('A_features.csv', index=False)


In [12]:

def plot_feature_interactions(data, feature1, feature2, color_feature='Layer Index str'):
    n_layers = len(data[color_feature].unique())
    colorscale = px.colors.sample_colorscale(
        px.colors.sequential.Plasma, 
        [(i/n_layers) for i in range(n_layers)]
        )
    
    fig = px.scatter(
        data, 
        x=feature1, y=feature2,
        color=color_feature,
        color_discrete_sequence=colorscale,
        hover_data=['Position Index', 'Layer Index'] + enriched_features_names + enriched_features_names_exp,
        opacity=0.1,
        title=f'{feature1} vs {feature2}'
        )
    return fig


In [13]:

# Perform clustering with KMeans
for num_clusters in [2, 3, 4, 5]:
    kmeans = KMeans(n_clusters=num_clusters, random_state=42)
    kmeans.fit(enriched_df[enriched_features_names + enriched_features_names_exp])
    enriched_df[f'Cluster_{num_clusters}_enriched'] = kmeans.labels_
    enriched_df[f'Cluster_{num_clusters}_enriched_str'] = [f'Cluster {i}' for i in kmeans.labels_]
from itertools import combinations


In [41]:
existing = set([x.name[:-5] for x in (PATHS.RUNS_DIR / 'feature_interactions/').glob('*.html')])

In [50]:
all = set()
for feature1, feature2 in combinations(enriched_features_names + enriched_features_names_exp, 2):
    name = f'{feature1}_vs_{feature2}_layer_colored'
    all.add(name)
    for num_clusters in [2, 3, 4, 5]:
        name = f'{feature1}_vs_{feature2}_cluster_{num_clusters}_colored'
        all.add(name)
    

In [51]:
len(all), len(existing), len(all - existing)

(765, 331, 434)

In [None]:
remaining = all - existing
# Create and save a figure for each interaction of features
for feature1, feature2 in combinations(enriched_features_names + enriched_features_names_exp, 2):
    p = PATHS.RUNS_DIR / f'feature_interactions/{feature1}_vs_{feature2}_layer_colored.html'
    if not p.exists():
        fig = plot_feature_interactions(enriched_df, feature1, feature2)
        fig.write_html(p)
        remaining.remove(p.name[:-5])
        print(len(remaining))
        
    
    for num_clusters in [2, 3, 4, 5]:
        p = PATHS.RUNS_DIR / f'feature_interactions/{feature1}_vs_{feature2}_cluster_{num_clusters}_colored.html'
        if not p.exists():
            fig = plot_feature_interactions(enriched_df, feature1, feature2, color_feature=f'Cluster_{num_clusters}_enriched_str')
            fig.write_html(p)
            remaining.remove(p.name[:-5])
            print(len(remaining))

433
432
431
430
429
428
427
426
425
424
423
422
421
420
419
418
417
416
415
414
413
412
411
410
409
408
407
406
405
404
403
402
401
400
399
398
397
396
395
394
393
392
391
390
389
388
387
386
385
384
383
382
381
380
379
378
377
376
375
374
373
372
371
370
369
368
367
366
365
364
363
362
361
360
359
358
357
356
355
354
353
352
351
350
349
348
347
346
345
344
343
342
341
340
339
338
337
336
335
334
333
332
331
330
329
328
327
326
325
324
323
322
321
320
319
318
317
316
315
314
313
312
311
310
309
308
307
306
305
304
303
302
301
300
299
298
297
296
295
294
293
292
291
290
289
288
287
286
285
284
283
282
281
280
279
278
277
276
275
274
273
272
271
270
269
268
267
266
265
264
263
262
261
260
259
258
257
256
255
254
253
252
251
250
249
248
247
246
245
244
243
242
241
240
239
238
237
236
235
234
233
232
231
230
229
228
227
226
225
224
223
222
221
220
219
218
217
216
215
214
213
212
211
210
209
208
207
206
205
204
203
202
201
200
199
198
197
196
195
194
193
192
191
190
189
188
187
186
185
184


In [3]:
import os
import zipfile

# ZIP the files in /home/yandex/DL20232024a/nirendy/repos/ssm_analysis/runs/feature_interactions
zipf = zipfile.ZipFile('feature_interactions.zip', 'w', zipfile.ZIP_DEFLATED)
for root, _, files in os.walk('/home/yandex/DL20232024a/nirendy/repos/ssm_analysis/runs/feature_interactions'):
    for file in tqdm(files, desc="Zipping files"):
         zipf.write(os.path.join(root, file), file)
        
zipf.close()

Zipping files:  33%|███▎      | 108/329 [22:02<46:46, 12.70s/it] 