In [1]:
# MODEL = "meta-llama/Llama-2-7b-hf"
MODEL = "meta-llama/Meta-Llama-3.1-8B"

In [4]:
import tqdm
import pandas as pd 
import wandb
import functools


@functools.cache
def get_df_from_wandb(path):
    api = wandb.Api()

    # Project is specified by <entity/project-name>
    runs = api.runs(path)
    
    data_df_lines = []
    for run in tqdm.tqdm(runs): 
        data_df_lines.append({
            'Name': run.name,
            'Commit': run.commit,
            **run.summary._json_dict,
            **{k: v for k,v in run.config.items() if not k.startswith('_')},
        })
    data_df = pd.DataFrame(data_df_lines)
    return data_df

In [5]:
data_df = get_df_from_wandb('rock-and-roll/NOISY_MSE_ONE_LINEAR_LAYER_GALQIWI_PPL')
data_df['layer'] = data_df['layer_name']
data_df['mse'] = data_df['relative_mse']
data_df = data_df[['layer', 'mse', 'wikitext2']]
data_df = data_df.dropna().copy()

34504it [04:58, 115.58it/s]                                                     


In [6]:
layers = sorted(set(data_df['layer']))
len(layers)

224

In [None]:
from scipy.stats import linregress


scale_by_layer = {}

for layer_idx, layer in enumerate(layers):
    to_fit = data_df[data_df['layer'] == layer]
    to_fit = to_fit[to_fit['mse'] < 4 ** -1.7]
    # to_fit = to_fit[to_fit['wikitext2'] < 5.640]
    
    scale_by_layer[layer] = linregress(to_fit['mse'], to_fit['wikitext2']).slope

In [None]:
scale_by_layer_old = scale_by_layer

In [None]:
plt.hist(sorted((value for value in scale_by_layer_old.values() if value < 0.6)), bins=10, alpha=0.5)
plt.hist(sorted((value for value in scale_by_layer.values() if value < 0.6)), bins=10, alpha=0.5)

In [None]:
import matplotlib
import matplotlib.pyplot as plt
import numpy as np


cmap = matplotlib.colormaps["plasma"]

for layer_idx, layer in enumerate(layers):
    to_plot = data_df[data_df['layer'] == layer]
    to_plot = to_plot[to_plot['mse'] < 4 ** -2]
    # to_plot = to_plot[to_plot['wikitext2'] < 5.640]
    
    plt.scatter(to_plot['mse'], to_plot['wikitext2'], color=cmap(layer_idx / len(layers)))

    grid = np.linspace(to_plot['mse'].min(), to_plot['mse'].max(), 10)
    
    plt.plot(grid, data_df['wikitext2'].min() + grid * scale_by_layer[layer], color=cmap(layer_idx / len(layers)))

In [None]:
pd.DataFrame(scale_by_layer.items(), columns=['layer', 'scale']).sort_values('scale')

In [None]:
plt.hist(sorted((value for value in scale_by_layer.values() if value < 0.6)), bins=10)

In [72]:
!git add . && git commit -m 'linear-layer-compression' && git push

[galqiwi 38782d9] linear-layer-compression
 2 files changed, 769 insertions(+)
 create mode 100644 Vladimir/2024-09-02/linear-layer-compression/.ipynb_checkpoints/values-checkpoint.ipynb
 create mode 100644 Vladimir/2024-09-02/linear-layer-compression/values.ipynb
Counting objects: 8, done.
Delta compression using up to 255 threads.
Compressing objects: 100% (8/8), done.
Writing objects: 100% (8/8), 44.77 KiB | 14.92 MiB/s, done.
Total 8 (delta 0), reused 0 (delta 0)
To github.com:galqiwi/linear-layer-compression.git
   528f9c4..38782d9  galqiwi -> galqiwi
