# Feature Importance Scatter Plots

In [4]:
import pandas as pd
import os
import sys
from pathlib import Path
from scipy.stats import pearsonr, spearmanr
import plotly.express as px
from typing import List
from sklearn.manifold import MDS
import plotly.graph_objects as go
import torch
import numpy as np
import seaborn as sns

from transformers import AutoTokenizer, AutoModelForCausalLM, BertTokenizer, BertModel, AutoModel, AutoConfig


module_path = Path("..")
absolute_module_path = module_path.resolve()
if str(absolute_module_path) not in sys.path:
    sys.path.append(str(absolute_module_path))

  from .autonotebook import tqdm as notebook_tqdm


In [42]:
class Model():

    def __init__(self,
                 model_name: str,
    ):
        self.model_name = model_name
        

    @property
    def model(self):
        if "pythia" in self.model_name:
            return self.get_pythia_model(self.model_name.replace("pythia-", ""))
        if "gpt2" in self.model_name:
            return self.get_gpt2_model(self.model_name.replace("gpt2-", ""))
        if "bert" in self.model_name:
            return self.get_bert_model(self.model_name)
        if "Mistral" in self.model_name:
            return self.get_mistral_model(self.model_name.replace("Mistral-", ""))
        if "mamba" in self.model_name:
            return self.get_mamba_model()
        else:
            raise ValueError(f"Unsupported model: {self.model_name}.")
    
    @property
    def tokenizer(self):
        if "bert" in self.model_name:
            tokenizer = BertTokenizer.from_pretrained(self.model_name)
        else:
            tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        tokenizer.pad_token = tokenizer.eos_token
        return tokenizer

    @property
    def n_layers(self):
        return self.model.config.num_hidden_layers

    
    def get_layers_to_enumerate(self):
        if 'gpt' in self.model_name:
            return self.model.transformer.h
        elif 'pythia' in self.model_name:
            return self.model.gpt_neox.layers
        elif 'bert' in self.model_name:
            return self.model.encoder.layer
        elif 'Mistral' in self.model_name:
            return self.model.model.layers
        else:
            raise ValueError(f"Unsupported model: {self.model_name}.")


    def get_model(self, model_name):
        model_config = AutoConfig.from_pretrained(
            model_name,
            output_hidden_states=True,
            ignore_mismatched_sizes=True
        )
        model = AutoModel.from_pretrained(
            model_name,
            config=model_config,
            ignore_mismatched_sizes=True
            #trust_remote_code=True,
            #device_map="auto"
        )
        return model

    def get_pythia_model(self, size: str):
        assert size in ["14m", "31m", "70m", "160m", "410m", "1b", "2.8b", "6.9b", "12b"]
        model_name = f"EleutherAI/pythia-{size}"
        return self.get_model(model_name)

    def get_gpt2_model(self, size: str):
        assert size in ["gpt2", "small", "medium", "large", "xl"]
        if (size == "small" or size=="gpt2"):
            model_name = "gpt2"
        else:
            model_name = f"gpt2-{size}"
        return self.get_model(model_name)

    def get_bert_model(self, model_name: str):
        model = BertModel.from_pretrained(model_name)
        return model

    def get_mistral_model(self, size: str):
        assert size in ["7", "7x8"]
        model_name = f"mistralai/Mistral-{size}B-v0.1"
        return self.get_model(model_name)

    def get_mamba_model(self):
        model_name = f"state-spaces/mamba-130m-hf"
        return self.get_model(model_name)

    

class FeatureImportanceScatterPlot():

    def __init__(self,
                 models_names: list[str],
                 dataset: str,
                 take_activation_from: str = None,
                 fi_data_path: str = '/scratch2/jsalle/MLEM/experiments',
                 save_folder_path: str = '/scratch2/jsalle/MLEM/feature_importance_scatter_plots',
    ):

        assert take_activation_from in ['last-token', 'mean']
        self.models_names = models_names
        self.dataset = dataset
        self.take_activation_from = take_activation_from
        self.fi_data_path = fi_data_path
        self.save_folder_path = save_folder_path

        self.folder_path_for_dataset = os.path.join(save_folder_path, dataset)
        if not os.path.exists(self.folder_path_for_dataset):
            os.makedirs(self.folder_path_for_dataset)
            print(f"Folder '{self.folder_path_for_dataset}' created.")
        else:
            print(f"Folder '{self.folder_path_for_dataset}' already exists.")

    @property
    def models(self):
        """ List of all models defined in self.model_names.
        """
        return [Model(model_name).model for model_name in self.model_names]


    def get_feature_importances(self):
        """ Get the feature importances for all models in self.model_names.
        """

        fi_df = pd.DataFrame()
        for model_name in self.models_names:
            n_layers = Model(model_name).n_layers

            for layer_id in range(1, n_layers+1):
                fi_current_layer = pd.read_csv(os.path.join(self.fi_data_path, 
                                            self.dataset, 
                                            model_name, 
                                            'spe_tok', 
                                            'analysis', 
                                            self.take_activation_from,
                                            'euclidean/zscore_False',
                                            f'layer_{layer_id}',
                                            'min_max_True',
                                            'feature_importance_conditional_False.csv'))
                fi_df = pd.concat([fi_df, fi_current_layer])

        fi_scatter_df = pd.DataFrame()
        for feature in fi_df['Feature'].unique():
            temp = {}
            temp['feature'] = feature
            for model in self.models_names:
                feature_df = fi_df[(fi_df['Feature'] == feature) & (fi_df['_model'] == model)]
                try:
                    temp.update({f"{model}_{i}": np.log(feature_df[feature_df['layer'] == i]['importance'].values[0]) for i in range(1, Model(model).n_layers+1)})
                except: # If the feature importance = 0
                    temp.update({f"{model}_{i}": feature_df[feature_df['layer'] == i]['importance'].values[0] for i in range(1, Model(model).n_layers+1)})
            
            temp_df = pd.DataFrame(temp, index=[0])
            fi_scatter_df = pd.concat([fi_scatter_df, temp_df], ignore_index=True)

        fi_scatter_df.to_csv(os.path.join(self.folder_path_for_dataset,
                                        f'{self.take_activation_from}_fi_scatter.csv'), index=False)

        self.fi_scatter_df = fi_scatter_df # Contains log(feature_importance) for each feature in each layer
        return fi_scatter_df

    
    def pairwise_correlation_matrix(self):

        data = []
        layer_names = self.fi_scatter_df.columns.tolist()
        layer_names.remove('feature')
        num_layers = len(layer_names)

        for i in range(num_layers):
            for j in range(i + 1, num_layers):
                layer1 = layer_names[i]
                layer2 = layer_names[j]
                vec1 = self.fi_scatter_df[layer1].to_numpy()
                vec2 = self.fi_scatter_df[layer2].to_numpy()
                correlation, _ = spearmanr(vec1, vec2)
                print(f"{layer1} - {layer2}: {correlation}")
                
                data.append([layer1, layer2, correlation])

        self.pairwise_correlation_df = pd.DataFrame(data, columns=['layer1', 'layer2', 'correlation'])

        return self.pairwise_correlation_df


    def plot_fi_scatter(self,
                        layer_id1: int,
                        layer_id2: int):

        fi_scatter_df = self.fi_scatter_df
        # Give one color per feature
        colors = sns.color_palette("colorblind")
        feature_names = sorted(fi_scatter_df["feature"].unique().tolist(), reverse=True)
        feature_color_mapping = {feature: colors[i] for i, feature in enumerate(feature_names)}
        feature_color_mapping_hex = {feature: f'rgb({int(color[0]*255)}, {int(color[1]*255)}, {int(color[2]*255)})' for feature, color in feature_color_mapping.items()}
        
        new_features_names = {'lexical_n_letters': 'word length',
                              'grammatical_number': 'grammatical number',
                              'grammatical_gender': 'grammatical gender',
                              'grammatical_pos': 'part-of-speech',
                              'grammatical_tense': 'tense',
                              'grammatical_person': 'person',
                              'sequential_posit_from_start': 'word position',
                              'sentential_question': 'question'}
        
        for index, row in fi_scatter_df.iterrows():
            feature = row['feature']
            if feature in new_features_names:
                fi_scatter_df.loc[index, 'feature'] = new_features_names[feature]

        for col in fi_scatter_df.columns:
            renamed_bert_columns = {}
            if 'bert' in col:
                layer_number = int(col.split('_')[-1])
                renamed_bert_columns[col] = f'bert_{layer_number}'
            fi_scatter_df.rename(columns=renamed_bert_columns, inplace=True)

        fi_scatter_df['max_value_pre'] = fi_scatter_df[[layer_id1, layer_id2]].max(axis=1) 
        fi_scatter_df['max_value'] = fi_scatter_df['max_value_pre'] + fi_scatter_df['max_value_pre'].min() + 1

        fig = px.scatter(fi_scatter_df, 
                         x=layer_id1, 
                         y=layer_id2, 
                         color='feature',
                         color_discrete_map=feature_color_mapping_hex,
                         size='max_value')

        
        min_ = min(fi_scatter_df[layer_id1].min(), fi_scatter_df[layer_id2].min())
        max_ = max(fi_scatter_df[layer_id1].max(), fi_scatter_df[layer_id2].max())

        fig.update_layout(
            xaxis=dict(
                range=[min_-0.01, 
                       max_+0.01],
                zeroline=False,
                showline=True,
                mirror=True,
                ticks='outside',
                showgrid=False,
            ),
            yaxis=dict(
                range=[min_-0.01,
                       max_+0.01],
                zeroline=False,
                showline=True,
                mirror=True,
                ticks='outside',
                showgrid=False,
            ),
            plot_bgcolor='white',
            width=600, 
            height=500
        )

        # Add diagonal line
        fig.add_shape(
            type='line',
            xref="paper", yref="paper",
            x0=0,
            y0=0,
            x1=1,
            y1=1,
            line=dict(
                color='grey',
                dash='dash'
            )
        )

        vec1 = self.fi_scatter_df[layer_id1].to_numpy()
        vec2 = self.fi_scatter_df[layer_id2].to_numpy()
        correlation, _ = spearmanr(vec1, vec2)
        print('Spearman coeff:', correlation)

        return fig
    

    def plot_pairwise_correlation_matrix(self):
        def sort_key(layer_name):
            model_name, layer_id = layer_name.split('_')
            return (model_name, int(layer_id))
        
        layer_names = self.fi_scatter_df.columns.tolist()
        layer_names.remove('feature')
        df = self.pairwise_correlation_df

        layer_names_sorted = sorted(layer_names, key=sort_key)
        layer_order = {layer: i for i, layer in enumerate(layer_names_sorted)}

        df['layer1_ordered'] = df['layer1'].map(layer_order)
        df['layer2_ordered'] = df['layer2'].map(layer_order)
        df = df.sort_values(by=['layer1_ordered', 
                                'layer2_ordered']).drop(columns=['layer1_ordered', 'layer2_ordered'])

        pivot_df = df.pivot(index='layer1', columns='layer2', values='correlation')
        pivot_df = pivot_df.T.combine_first(pivot_df).fillna(1)
        pivot_df = pivot_df.reindex(index=layer_names_sorted, columns=layer_names_sorted)
        fig = px.imshow(pivot_df, text_auto='.2f', color_continuous_scale='Viridis', 
                        labels=dict(x="Layer", y="Layer", color="Correlation"))

        fig.update_layout(title='Correlation Between ', xaxis_title='Layer', yaxis_title='Layer')

        pivot_df.to_csv(os.path.join(self.folder_path_for_dataset,
                                    f'{self.take_activation_from}_pairwise_scatter_corr.csv'),
                            index=False)

        return fig

In [43]:
for dataset in ['svo_long_word_level_offset_simplif']:
    for activation in ['last-token']:

        fi_scatter_plot = FeatureImportanceScatterPlot(models_names=["gpt2", "bert-base-uncased", "mamba"],
                                                    dataset=dataset,
                                                    take_activation_from=activation)

        fi_scatter_df = fi_scatter_plot.get_feature_importances()
        pairwise_correlation = fi_scatter_plot.pairwise_correlation_matrix()
        fi_scatter_plot.plot_pairwise_correlation_matrix()

Folder '/scratch2/jsalle/MLEM/feature_importance_scatter_plots/svo_long_word_level_offset_simplif' already exists.



divide by zero encountered in log


invalid value encountered in log



gpt2_1 - gpt2_2: 0.9047619047619048
gpt2_1 - gpt2_3: 0.880952380952381
gpt2_1 - gpt2_4: 0.9047619047619048
gpt2_1 - gpt2_5: 0.9285714285714287
gpt2_1 - gpt2_6: 0.9047619047619048
gpt2_1 - gpt2_7: 0.8571428571428572
gpt2_1 - gpt2_8: 0.8571428571428572
gpt2_1 - gpt2_9: 0.7380952380952381
gpt2_1 - gpt2_10: 0.7380952380952381
gpt2_1 - gpt2_11: 0.7142857142857144
gpt2_1 - gpt2_12: nan
gpt2_1 - bert-base-uncased_1: 0.880952380952381
gpt2_1 - bert-base-uncased_2: 0.6190476190476191
gpt2_1 - bert-base-uncased_3: 0.8095238095238096
gpt2_1 - bert-base-uncased_4: 0.6190476190476191
gpt2_1 - bert-base-uncased_5: 0.5952380952380953
gpt2_1 - bert-base-uncased_6: 0.523809523809524
gpt2_1 - bert-base-uncased_7: 0.4523809523809524
gpt2_1 - bert-base-uncased_8: 0.28571428571428575
gpt2_1 - bert-base-uncased_9: 0.0
gpt2_1 - bert-base-uncased_10: 0.0
gpt2_1 - bert-base-uncased_11: 0.14285714285714288
gpt2_1 - bert-base-uncased_12: 0.04761904761904763
gpt2_1 - mamba_1: 0.8095238095238096
gpt2_1 - mamba_2: 

In [44]:
fig = fi_scatter_plot.plot_fi_scatter(layer_id1='gpt2_7', layer_id2='bert_1')
fig.show()

ValueError: 
    Invalid element(s) received for the 'size' property of scatter.marker
        Invalid elements include: [-15.894783790141567]

    The 'size' property is a number and may be specified as:
      - An int or float in the interval [0, inf]
      - A tuple, list, or one-dimensional numpy array of the above

In [45]:
fi_scatter_plot = FeatureImportanceScatterPlot(models_names=["gpt2"],
                                               dataset='svo_long_word_level_offset_simplif',
                                               take_activation_from='last-token')

fi_scatter_df = fi_scatter_plot.get_feature_importances()
pairwise_correlation = fi_scatter_plot.pairwise_correlation_matrix()
fi_scatter_plot.plot_pairwise_correlation_matrix()

Folder '/scratch2/jsalle/MLEM/feature_importance_scatter_plots/svo_long_word_level_offset_simplif' already exists.


In [52]:
fi_scatter_df['max_value_pre']

0   -7.229649
1   -8.296893
2   -5.874147
3   -4.979906
4   -5.981000
5   -9.665135
6   -5.685040
7   -9.149987
Name: max_value_pre, dtype: float64

In [57]:
fi_scatter_df['now'] = fi_scatter_df['max_value_pre'] - fi_scatter_df['max_value_pre'].min() + 1

In [58]:
fi_scatter_df['now']

0    2.435487
1    1.368243
2    3.790989
3    4.685230
4    3.684135
5    0.000000
6    3.980095
7    0.515148
Name: now, dtype: float64