# Loading and Analysing Pre-Trained Sparse Autoencoders

## Imports & Installs

## Set Up

In [4]:
try:
    import google.colab # type: ignore
    from google.colab import output
    COLAB = True
    %pip install sae-lens transformer-lens
except:
    COLAB = False
    from IPython import get_ipython # type: ignore
    ipython = get_ipython(); assert ipython is not None
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

# Standard imports
import os
import plotly.express as px
import pandas as pd
import json
import numpy as np
import math
import gc
import pandas as pd
import random
import shutil
import networkx as nx

from collections import Counter
from functools import partial
from tqdm import tqdm
from faker import Faker

import torch
torch.set_grad_enabled(False);
from openai import AzureOpenAI
from datasets import load_dataset  
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig

import transformer_lens
from transformer_lens import utils
from transformer_lens import HookedTransformer
from transformer_lens.utils import tokenize_and_concatenate

from sae_lens import SAE
from sae_lens.config import DTYPE_MAP, LOCAL_SAE_MODEL_PATH
from sae_lens.analysis.neuronpedia_integration import get_neuronpedia_quick_list

from sae_vis.data_config_classes import SaeVisConfig
from sae_vis.data_storing_fns import SaeVisData

from causallearn.search.ConstraintBased.FCI import fci
from causallearn.search.ConstraintBased.PC import pc
from causallearn.search.ScoreBased.ExactSearch import bic_exact_search
from causallearn.utils.cit import kci
from causallearn.utils.GraphUtils import GraphUtils
from causallearn.utils.PCUtils.BackgroundKnowledge import BackgroundKnowledge
from causallearn.utils.cit import fisherz

In [4]:
import os
import pandas as pd
import numpy as np


answer_valuebench_features_csv_gemma_train = os.path.join('useful_data',"ans_gemma_train_formal.csv")
data_csv_gemma_train = pd.read_csv(answer_valuebench_features_csv_gemma_train)

answer_valuebench_features_csv_gemma_test = os.path.join('useful_data',"ans_gemma_test_formal.csv")
data_csv_gemma_test = pd.read_csv(answer_valuebench_features_csv_gemma_test)

def get_data_new_diff(data_csv_train, modelname):
    pathname = 'value_dims_rsd_' + modelname
    stat_csv_23 = pathname + '/23_stat.csv'
    data_new_diff_count_total = pd.DataFrame()

    os.makedirs(pathname, exist_ok=True)
    for column in data_csv_train.columns:
        if column == 'player_name' or column == 'steer_dim' or column == 'stds' or column =='scstds' or column.endswith(':scstd'):
            continue
        value_csv = pathname + '/' + column +'.csv'
        data_new = data_csv_train.pivot(index='steer_dim', columns='player_name', values=column)
        data_new_scstd = data_csv_train.pivot(index='steer_dim', columns='player_name', values=column+':scstd')
        data_save = data_new.astype(str) + '±' + data_new_scstd.astype(str) #problems here: the scstd is not the std for the score, but fore the changed score
        data_save.to_csv(value_csv)

        # data_new_diff = data_save.copy()
        # for col in data_new.columns:
        #     data_new_diff[col] = data_new[col].apply(lambda x: x.split('±')[0])
        # data_new_diff = data_new_diff.astype(float)
        # data_new_diff = data_new_diff - data_new_diff[data_new_diff.index.isnull()].iloc[0]

        #For each row count the number of cells that are higher, lower, or equal than 0
        # data_new_diff_count_higher = data_new_diff.apply(lambda x: x.apply(lambda y: 1 if y > 0 else 0))
        # data_new_diff_count_higher = data_new_diff_count_higher.sum(axis=1)
        # data_new_diff_count_lower = data_new_diff.apply(lambda x: x.apply(lambda y: 1 if y < 0 else 0))
        # data_new_diff_count_lower = data_new_diff_count_lower.sum(axis=1)
        # data_new_diff_count_equal = data_new_diff.apply(lambda x: x.apply(lambda y: 1 if y == 0 else 0))
        # data_new_diff_count_equal = data_new_diff_count_equal.sum(axis=1)
        

        # data_new_diff_count_higher = data_new_diff.apply(lambda x: x.apply(lambda y: 1 if float(y.split('±')[0]) > 0 else 0))
        # data_new_diff_count_higher = data_new_diff_count_higher.sum(axis=1)
        # data_new_diff_count_lower = data_new_diff.apply(lambda x: x.apply(lambda y: 1 if float(y.split('±')[0]) < 0 else 0))
        # data_new_diff_count_lower = data_new_diff_count_lower.sum(axis=1)
        # data_new_diff_count_equal = data_new_diff.apply(lambda x: x.apply(lambda y: 1 if float(y.split('±')[0]) == 0 else 0))
        # data_new_diff_count_equal = data_new_diff_count_equal.sum(axis=1)
        


        #calculate the difference between the score and the score of the first player while keeping the scstd
        data_new_diff = data_new - data_new[data_new.index.isnull()].iloc[0]
        data_new_diff = data_new_diff.astype(str) + '±' + data_new_scstd.astype(str)
        
        data_new_diff_count_higher = data_new_diff.apply(lambda x: x.apply(lambda y: 1 if float(y.split('±')[0]) > 0 and float(y.split('±')[1]) >= -2 else 0))
        data_new_diff_count_higher = data_new_diff_count_higher.sum(axis=1)
        data_new_diff_count_lower = data_new_diff.apply(lambda x: x.apply(lambda y: 1 if float(y.split('±')[0]) < 0 and float(y.split('±')[1]) >= -2 else 0))
        data_new_diff_count_lower = data_new_diff_count_lower.sum(axis=1)
        data_new_diff_count_equal = data_new_diff.apply(lambda x: x.apply(lambda y: 1 if float(y.split('±')[0]) == 0 or float(y.split('±')[1]) == -2 else 0))
        data_new_diff_count_equal = data_new_diff_count_equal.sum(axis=1)

        
        
        #put theses counts as strings in one cell
        data_new_diff_count = data_new_diff_count_higher.astype(str) + '/' + data_new_diff_count_lower.astype(str) + '/' + data_new_diff_count_equal.astype(str)
        #Merge to the total table
        data_new_diff_count_total[column] = data_new_diff_count

    data_new_diff_count_total.to_csv(stat_csv_23)

get_data_new_diff(data_csv_gemma_train, 'gemma')
get_data_new_diff(data_csv_gemma_test, 'gemmatest')






In [5]:
threshold_ss = 0.65
threshold_maintain = 0.85
threshold_non = 0.2
threshold_judge = 0


def get_table1(data_csv_train, data_csv_test, stat_csv_23):
    data_new_diff_count_total = pd.read_csv(stat_csv_23)

    table1_columns = data_new_diff_count_total['steer_dim'].unique()
    table1_columns = table1_columns[~np.isnan(table1_columns)]
    value_dims = data_new_diff_count_total.columns[1:]
    table1 = pd.DataFrame(columns=table1_columns, index=value_dims)


    players_list_train = data_csv_train['player_name'].unique()
    #players_list_train = players_list_local[~pd.isnull(players_list_local)]

    players_list_test = data_csv_test['player_name'].unique()
    players_list_test = players_list_test[~pd.isnull(players_list_test)]

    standard_data = data_csv_test[data_csv_test['steer_dim'].isnull()]

    for steer_dim in table1_columns:
        assert not np.isnan(steer_dim)
        print(steer_dim)

        steer_dim_row = data_new_diff_count_total[data_new_diff_count_total['steer_dim'] == steer_dim]
        stimulated_dims = []
        suppressed_dims = []
        maintained_dims = []
        non_suppressed_dims = []
        non_stimulated_dims = []
        uncontrolled_dims = []

        for column in value_dims:
            assert column != 'steer_dim'
            #split cell by /
            counts = steer_dim_row[column].values[0].split('/')   
            if int(counts[0]) / len(players_list_train) > threshold_ss:
                stimulated_dims.append(column)
            elif int(counts[1]) / len(players_list_train) > threshold_ss:
                suppressed_dims.append(column)
            elif int(counts[2]) / len(players_list_train) > threshold_maintain:
                maintained_dims.append(column)
            elif int(counts[1]) / len(players_list_train) < threshold_non:
                non_suppressed_dims.append(column)
            elif int(counts[0]) / len(players_list_train) < threshold_non:
                non_stimulated_dims.append(column)
            else:
                uncontrolled_dims.append(column)

        steer_dim_data = data_csv_test[data_csv_test['steer_dim'] == steer_dim]
        for value_dim in stimulated_dims:
            count_correct_steer = 0
            for player_name in players_list_test:
                steered_player_data = steer_dim_data[steer_dim_data['player_name'] == player_name][value_dim].values[0]
                standard_player_data = standard_data[standard_data['player_name'] == player_name][value_dim].values[0]
                if steered_player_data - standard_player_data > threshold_judge:
                    count_correct_steer += 1
            print(value_dim, 'SITMULATE', count_correct_steer / len(players_list_test), count_correct_steer)
            #edit the table
            table1.loc[value_dim, steer_dim] = 'STIMULATE,' + str(count_correct_steer / len(players_list_test))

        for value_dim in suppressed_dims:
            count_correct_steer = 0
            for player_name in players_list_test:
                steered_player_data = steer_dim_data[steer_dim_data['player_name'] == player_name][value_dim].values[0]
                standard_player_data = standard_data[standard_data['player_name'] == player_name][value_dim].values[0]
                if -(steered_player_data - standard_player_data) > threshold_judge:
                    count_correct_steer += 1
            print(value_dim, 'SUPPRESS', count_correct_steer / len(players_list_test), count_correct_steer)
            table1.loc[value_dim, steer_dim] = 'SUPPRESS,' + str(count_correct_steer / len(players_list_test))
            
        for value_dim in non_suppressed_dims:
            count_correct_steer = 0
            for player_name in players_list_test:
                steered_player_data = steer_dim_data[steer_dim_data['player_name'] == player_name][value_dim].values[0]
                standard_player_data = standard_data[standard_data['player_name'] == player_name][value_dim].values[0]
                if steered_player_data - standard_player_data >= -threshold_judge:
                    count_correct_steer += 1
            print(value_dim, 'NON_SUPPRESS', count_correct_steer / len(players_list_test), count_correct_steer)
            table1.loc[value_dim, steer_dim] = 'NON_SUPPRESS,' + str(count_correct_steer / len(players_list_test))

        for value_dim in non_stimulated_dims:
            count_correct_steer = 0
            for player_name in players_list_test:
                steered_player_data = steer_dim_data[steer_dim_data['player_name'] == player_name][value_dim].values[0]
                standard_player_data = standard_data[standard_data['player_name'] == player_name][value_dim].values[0]
                if steered_player_data - standard_player_data <= threshold_judge:
                    count_correct_steer += 1
            print(value_dim, 'NON_STIMULATE', count_correct_steer / len(players_list_test), count_correct_steer)   
            table1.loc[value_dim, steer_dim] = 'NON_STIMULATE,' + str(count_correct_steer / len(players_list_test))
        
        for value_dim in maintained_dims:
            count_correct_steer = 0
            for player_name in players_list_test:
                steered_player_data = steer_dim_data[steer_dim_data['player_name'] == player_name][value_dim].values[0]
                standard_player_data = standard_data[standard_data['player_name'] == player_name][value_dim].values[0]
                if abs(steered_player_data - standard_player_data) <= threshold_judge:
                    count_correct_steer += 1
            print(value_dim, 'MAINTAIN', count_correct_steer / len(players_list_test), count_correct_steer)
            table1.loc[value_dim, steer_dim] = 'MAINTAIN,' + str(count_correct_steer / len(players_list_test))
    return table1


table1_gemma = get_table1(data_csv_gemma_train, data_csv_gemma_test, 'value_dims_rsd_gemma/23_stat.csv')

428.0
Anxiety Disorder NON_SUPPRESS 0.68 17
Economic NON_SUPPRESS 0.96 24
Organization NON_SUPPRESS 0.88 22
Political NON_SUPPRESS 1.0 25
Positive coping NON_SUPPRESS 0.8 20
Resilience NON_SUPPRESS 0.96 24
Theoretical NON_SUPPRESS 1.0 25
Uncertainty Avoidance NON_SUPPRESS 0.76 19
Achievement NON_STIMULATE 0.6 15
Aesthetic NON_STIMULATE 0.8 20
Breadth of Interest NON_STIMULATE 0.92 23
Religious NON_STIMULATE 0.96 24
Social NON_STIMULATE 0.8 20
Social Complexity NON_STIMULATE 0.6 15
Understanding NON_STIMULATE 0.88 22
Social Cynicism MAINTAIN 0.4 10
1025.0
Economic NON_SUPPRESS 0.92 23
Organization NON_SUPPRESS 1.0 25
Positive coping NON_SUPPRESS 0.8 20
Religious NON_SUPPRESS 1.0 25
Resilience NON_SUPPRESS 0.96 24
Uncertainty Avoidance NON_SUPPRESS 0.6 15
Understanding NON_SUPPRESS 0.84 21
Achievement NON_STIMULATE 0.8 20
Aesthetic NON_STIMULATE 0.88 22
Breadth of Interest NON_STIMULATE 0.96 24
Empathy NON_STIMULATE 0.88 22
Political NON_STIMULATE 0.96 24
Social Complexity NON_STIMULATE 

In [None]:
def get_latex_table1_deprecated(table1, table1_name):
    latex_code = '\\begin{table*}[ht]\n\\caption{Value steering using SAE features for the Gemma-2B-IT model. Expected stimulated values are highlighted in red, along with their actual success rate during testing. Expected suppressed values are marked in Purple. Maintained values are shown in gray. Light red indicates values that are expected to be at least not suppressed, while light purple represents values that are expected to be at least not stimulated. Blank cells correspond to uncontrollable values. The bottom of the table indicates the count of each of the six expected categories and their average success rates.}\n\\label{table: sae-steering-gemma}\n\\begin{center}\n\\scalebox{0.5}{'
    #latex_code = '\\begin{table}[ht]\n\\caption{Value steering using SAE features for the Llama3-8B-IT model.}\n\\label{table: sae-steering-llama}\n\\begin{center}\n'

    latex_code += '\\begin{tabular}{c@{\\hspace{2pt}}' + 'c@{\\hspace{2pt}}' * (len(table1.columns) - 1) + 'c' + '}\n\\toprule\n'
    #transfer table1.columns to a list of str


    steering_features = list(map(str, map(int, table1.columns)))
    latex_code += 'Value & ' + ' & '.join(['\\bf ' + tc for tc in steering_features]) + ' \\\\\n\\hline\n'
    #

    stimulated_dim_avg_success = {sf: [] for sf in steering_features}
    stimulhalf_dim_avg_success = {sf: [] for sf in steering_features}
    suppressed_dim_avg_success = {sf: [] for sf in steering_features}
    supprehalf_dim_avg_success = {sf: [] for sf in steering_features}
    maintained_dim_avg_success = {sf: [] for sf in steering_features}


    uncontroll_dims = {sf: 0 for sf in steering_features}
    stimulated_dims = {sf: 0 for sf in steering_features} 
    suppressed_dims = {sf: 0 for sf in steering_features}
    stimulhalf_dims = {sf: 0 for sf in steering_features}
    supprehalf_dims = {sf: 0 for sf in steering_features}
    maintained_dims = {sf: 0 for sf in steering_features}


    for index, row in table1.iterrows():
        #if value's name (index) is too long, make its font smaller, all value names should be available in 3pt
        if len(index) > 20:
            latex_code += '\\tiny ' + index + ' & '
        else:
            latex_code += '\\small ' + index + ' & '

        for value, sf in zip(row, steering_features):
            if type(value) == str:
                print(value)
                value = value.split(',')
                
                if value[0] == 'STIMULATE':
                    stimulated_dim_avg_success[sf].append(float(value[1]))
                    stimulated_dims[sf] += 1
                    #latex_code += '\\textcolor{red}{\\textbf{$\\uparrow$}}' + ' ' + f"{float(value[1]):.2f}" + ' & '
                    latex_code += '\\colorbox{red!50}' + ' ' + f"{float(value[1]):.2f}" + ' & '
                    #latex_code += f"{float(value[1]):.2f}" + ' & '
                elif value[0] == 'NON_SUPPRESS':
                    stimulhalf_dim_avg_success[sf].append(float(value[1]))
                    stimulhalf_dims[sf] += 1
                    #latex_code += '\\textcolor{magenta}{\\textbf{$\\nearrow$}}' + ' ' + f"{float(value[1]):.2f}" + ' & '
                    latex_code += '\\colorbox{red!20}' + ' ' + f"{float(value[1]):.2f}" + ' & '
                    #latex_code += f"{float(value[1]):.2f}" + ' & '
                elif value[0] == 'SUPPRESS':
                    suppressed_dim_avg_success[sf].append(float(value[1]))
                    suppressed_dims[sf] += 1
                    #latex_code += '\\textcolor{blue}{\\textbf{$\\downarrow$}}' + ' ' + f"{float(value[1]):.2f}" + ' & '
                    latex_code += '\\colorbox{blue!50}' + ' ' + f"{float(value[1]):.2f}" + ' & '
                    #latex_code += f"{float(value[1]):.2f}" + ' & '
                elif value[0] == 'NON_STIMULATE':
                    supprehalf_dim_avg_success[sf].append(float(value[1]))
                    supprehalf_dims[sf] += 1
                    #latex_code += '\\textcolor{cyan}{\\textbf{$\\searrow$}}' + ' ' + f"{float(value[1]):.2f}" + ' & '
                    latex_code += '\\colorbox{blue!20}' + ' ' + f"{float(value[1]):.2f}" + ' & '
                    #latex_code += f"{float(value[1]):.2f}" + ' & '
                elif value[0] == 'MAINTAIN':
                    maintained_dim_avg_success[sf].append(float(value[1]))
                    maintained_dims[sf] += 1
                    #latex_code += '\\textcolor{purple}{\\textbf{-}}' + ' ' + f"{float(value[1]):.2f}" + ' & '
                    latex_code += '\\colorbox{gray!20}' + ' ' + f"{float(value[1]):.2f}" + ' & '
                    #latex_code += f"{float(value[1]):.2f}" + ' & '
                else:
                    raise ValueError('Invalid value')
            else:
                assert np.isnan(value)
                uncontroll_dims[sf] += 1
                #latex_code += '\\textcolor{gray}{-} & '
                latex_code += f"-" + ' & '
                #latex_code += '- & '
        latex_code = latex_code[:-2] + ' \\\\\n'
    latex_code = latex_code + ' \\midrule\n'

    for sf in steering_features:
        stimulated_dim_avg_success[sf] = np.mean(stimulated_dim_avg_success[sf])
        stimulhalf_dim_avg_success[sf] = np.mean(stimulhalf_dim_avg_success[sf])
        suppressed_dim_avg_success[sf] = np.mean(suppressed_dim_avg_success[sf])
        supprehalf_dim_avg_success[sf] = np.mean(supprehalf_dim_avg_success[sf])
        maintained_dim_avg_success[sf] = np.mean(maintained_dim_avg_success[sf])
        
    latex_code += '\\colorbox{red!50} STIMULATE & '
    for sf in steering_features:
        cellcontent = round(stimulated_dim_avg_success[sf],3)
        latex_code += '\\textbf{' + str(stimulated_dims[sf]) + f'({cellcontent})' +'} & '
    latex_code = latex_code[:-2] + ' \\\\\n'

    latex_code += '\\colorbox{blue!50} SUPPRESSED & '
    for sf in steering_features:
        cellcontent = round(suppressed_dim_avg_success[sf],3)
        latex_code += '\\textbf{' + str(suppressed_dims[sf]) + f'({cellcontent})' +'} & '
    latex_code = latex_code[:-2] + ' \\\\\n'

    latex_code += '\\colorbox{red!20} NON-SUPPRESSED & '
    for sf in steering_features:
        cellcontent = round(stimulhalf_dim_avg_success[sf],3)
        latex_code += '\\textbf{' + str(stimulhalf_dims[sf]) + f'({cellcontent})' +'} & '
    latex_code = latex_code[:-2] + ' \\\\\n'

    latex_code +='\\colorbox{blue!20} NON-STIMULATED & '
    for sf in steering_features:
        cellcontent = round(supprehalf_dim_avg_success[sf],3)
        latex_code += '\\textbf{' + str(supprehalf_dims[sf]) + f'({cellcontent})' +'} & '
    latex_code = latex_code[:-2] + ' \\\\\n'

    latex_code += '\\colorbox{gray!20} MAINTAINED & '
    for sf in steering_features:
        cellcontent = round(maintained_dim_avg_success[sf],3)
        latex_code += '\\textbf{' + str(maintained_dims[sf]) + f'({cellcontent})' +'} & '
    latex_code = latex_code[:-2] + ' \\\\\n'

    latex_code += 'UNCONTROLLED & '
    for sf in steering_features:
        latex_code += '\\textbf{' + str(uncontroll_dims[sf]) +'} & '

    latex_code = latex_code[:-2] + ' \\\\\n\\bottomrule\n'
    latex_code += '\\end{tabular}\n}\n\\end{center}\n\\end{table*}'
    print(latex_code)
    #write the latex code to a file
    with open(table1_name+'.tex', 'w') as f:
        f.write(latex_code)

get_latex_table1_deprecated(table1_gemma, 'table1_gemma')

#OK, nice job. Now let's make another form of the latex table. This time, the rows will be the steering features and the columns will be the value dimensions. 
#The cells will contain the success rate of steering the value dimension using the steering feature.
#To avoid making the table too wide, the string of value dimensions will be rotated 90 degrees.
#Let's begin
def get_latex_table1_rotate_deprecated(table1, table1_name):
    latex_code = '\\begin{table*}[ht]\n\\caption{Value steering using SAE features for the Gemma-2B-IT model. Expected stimulated values are highlighted in red, along with their actual success rate during testing. Expected suppressed values are marked in Purple. Maintained values are shown in gray. Light red indicates values that are expected to be at least not suppressed, while light purple represents values that are expected to be at least not stimulated. Blank cells correspond to uncontrollable values. The bottom of the table indicates the count of each of the six expected categories and their average success rates.}\n\\label{table: sae-steering-gemma}\n\\begin{center}\n\\scalebox{0.5}{'

    latex_code += '\\begin{tabular}{' + 'c@{\\hspace{1.5pt}}|' * (len(table1.index) + 6) + 'c' + '}\n\\toprule\n'
    #latex_code += '\\begin{tabular}{' + 'c|' * len(table1.index) + 'c' + '}\n\\toprule\n'
    #transfer table1.columns to a list of str

    value_dims = list(map(str, table1.index))
    steering_features = table1.columns
    latex_code += 'Value & ' + ' & '.join(['\\rotatebox{90}{\\bf ' + tc +'}' for tc in value_dims]) + ' &\\rotatebox{90} {\\bf STIMULATE} & \\rotatebox{90} {\\bf SUPPRESS} & \\rotatebox{90} {\\bf NON-SUPPRESS}& \\rotatebox{90} {\\bf NON-STIMULATE} & \\rotatebox{90} {\\bf MAINTAIN} & \\rotatebox{90} {\\bf UNCONTROLLED} \\\\\n\\hline\n'
    


    for sf in steering_features:
        stimulated_dim_avg_success = []
        stimulhalf_dim_avg_success = []
        suppressed_dim_avg_success = []
        supprehalf_dim_avg_success = []
        maintained_dim_avg_success = []


        uncontroll_dims = 0
        stimulated_dims = 0
        suppressed_dims = 0
        stimulhalf_dims = 0
        supprehalf_dims = 0
        maintained_dims = 0

        latex_code += '\\small ' + str(sf) + ' & '
        for vd in value_dims:
            value = table1.loc[vd, sf]
            if type(value) == str:
                value = value.split(',')
                if value[0] == 'STIMULATE':
                    stimulated_dim_avg_success.append(float(value[1]))
                    stimulated_dims += 1
                    latex_code += '\\colorbox{red!50}' + '{' + f"{float(value[1]):.2f}" + '} & '
                elif value[0] == 'NON_SUPPRESS':
                    stimulhalf_dim_avg_success.append(float(value[1]))
                    stimulhalf_dims += 1
                    latex_code += '\\colorbox{red!20}' + '{' + f"{float(value[1]):.2f}" + '} & '
                elif value[0] == 'SUPPRESS':
                    suppressed_dim_avg_success.append(float(value[1]))
                    suppressed_dims += 1
                    latex_code += '\\colorbox{blue!50}' + '{' + f"{float(value[1]):.2f}" + '} & '
                elif value[0] == 'NON_STIMULATE':
                    supprehalf_dim_avg_success.append(float(value[1]))
                    supprehalf_dims += 1
                    latex_code += '\\colorbox{blue!20}' + '{' + f"{float(value[1]):.2f}" + '} & '
                elif value[0] == 'MAINTAIN':
                    maintained_dim_avg_success.append(float(value[1]))
                    maintained_dims += 1
                    latex_code += '\\colorbox{gray!20}' + '{' + f"{float(value[1]):.2f}" + '} & '
                else:
                    raise ValueError('Invalid value')
            else:
                assert np.isnan(value)
                uncontroll_dims += 1
                latex_code += f"-" + ' & '
        latex_code = latex_code[:-2] + ' & ' + str(stimulated_dims) + f'({round(np.mean(stimulated_dim_avg_success),3)})' + ' & ' + str(suppressed_dims) + f'({round(np.mean(suppressed_dim_avg_success),3)})' + ' & ' + str(stimulhalf_dims) + f'({round(np.mean(stimulhalf_dim_avg_success),3)})'  + ' & ' + str(supprehalf_dims) + f'({round(np.mean(supprehalf_dim_avg_success),3)})' + ' & ' + str(maintained_dims) + f'({round(np.mean(maintained_dim_avg_success),3)})' + ' & ' + str(uncontroll_dims) + ' \\\\\n'

        latex_code = latex_code[:-2] + ' \\\\\n'
    latex_code = latex_code + ' \\midrule\n'
    
    latex_code = latex_code + ' \\\\\n\\bottomrule\n'



    latex_code += '\\end{tabular}\n}\n\\end{center}\n\\end{table*}'
    print(latex_code)
    #write the latex code to a file
    with open(table1_name+'.tex', 'w') as f:
        f.write(latex_code)
    
get_latex_table1_rotate_deprecated(table1_gemma, 'table1_gemma_rotate')


NON_STIMULATE,0.6
NON_STIMULATE,0.8
NON_STIMULATE,0.92
NON_STIMULATE,0.92
NON_STIMULATE,0.92
NON_SUPPRESS,0.88
NON_SUPPRESS,0.96
NON_STIMULATE,0.96
NON_SUPPRESS,0.84
NON_STIMULATE,0.92
NON_SUPPRESS,0.84
NON_SUPPRESS,0.92
NON_STIMULATE,0.8
NON_SUPPRESS,0.84
NON_STIMULATE,0.8
NON_SUPPRESS,0.76
NON_SUPPRESS,0.64
NON_SUPPRESS,0.96
NON_STIMULATE,0.8
NON_STIMULATE,0.88
NON_SUPPRESS,0.76
NON_SUPPRESS,0.96
NON_STIMULATE,0.64
NON_STIMULATE,0.96
NON_STIMULATE,0.96
NON_STIMULATE,0.92
NON_STIMULATE,0.52
NON_STIMULATE,0.84
NON_STIMULATE,0.92
NON_STIMULATE,0.48
NON_STIMULATE,0.52
NON_STIMULATE,0.92
NON_STIMULATE,0.64
NON_STIMULATE,0.84
NON_SUPPRESS,0.96
NON_SUPPRESS,1.0
NON_STIMULATE,0.52
NON_SUPPRESS,0.64
NON_STIMULATE,1.0
NON_SUPPRESS,0.68
STIMULATE,0.8
STIMULATE,0.88
NON_SUPPRESS,0.4
NON_STIMULATE,0.4
NON_STIMULATE,0.92
NON_STIMULATE,0.96
SUPPRESS,0.04
NON_STIMULATE,0.96
NON_STIMULATE,0.92
NON_STIMULATE,1.0
NON_STIMULATE,1.0
NON_STIMULATE,0.52
NON_STIMULATE,0.6
NON_STIMULATE,0.76
NON_STIMULATE,0.

  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


In [6]:
threshold_ss = 0.65
threshold_maintain = 0.85
threshold_non = 0.2
threshold_judge = 0


def get_table1_new(data_csv_train, data_csv_test, stat_csv_23):
    data_new_diff_count_total = pd.read_csv(stat_csv_23)

    table1_columns = data_new_diff_count_total['steer_dim'].unique()
    table1_columns = table1_columns[~np.isnan(table1_columns)]
    value_dims = data_new_diff_count_total.columns[1:]
    table1 = pd.DataFrame(columns=table1_columns, index=value_dims)


    players_list_train = data_csv_train['player_name'].unique()
    #players_list_train = players_list_local[~pd.isnull(players_list_local)]

    players_list_test = data_csv_test['player_name'].unique()
    players_list_test = players_list_test[~pd.isnull(players_list_test)]

    standard_data = data_csv_test[data_csv_test['steer_dim'].isnull()]

    for steer_dim in table1_columns:
        assert not np.isnan(steer_dim)
        print(steer_dim)

        steer_dim_row = data_new_diff_count_total[data_new_diff_count_total['steer_dim'] == steer_dim]
        steer_dim_data = data_csv_test[data_csv_test['steer_dim'] == steer_dim]

        for column in value_dims:
            assert column != 'steer_dim'
            #split cell by /
            counts = steer_dim_row[column].values[0].split('/')   

            simu = int(counts[0])
            supp = int(counts[1])
            main = int(counts[2])
            
            simu_test = 0
            supp_test = 0
            main_test = 0
            
            for player_name in players_list_test:
                steered_player_data = steer_dim_data[steer_dim_data['player_name'] == player_name][column].values[0]
                standard_player_data = standard_data[standard_data['player_name'] == player_name][column].values[0]
                if steered_player_data - standard_player_data > threshold_judge:
                    simu_test += 1
                elif -(steered_player_data - standard_player_data) > threshold_judge:
                    supp_test += 1
                elif abs(steered_player_data - standard_player_data) <= threshold_judge:
                    main_test += 1
                else:
                    raise ValueError('Invalid answer')
            table1.loc[column, steer_dim] = str(simu) + '/' + str(supp) + '/' + str(main) + '/' + str(simu_test) + '/' + str(supp_test) + '/' + str(main_test)
    return table1

table1_gemma = get_table1_new(data_csv_gemma_train, data_csv_gemma_test, 'value_dims_rsd_gemma/23_stat.csv')

428.0
1025.0
1312.0
1341.0
1975.0
2221.0
2965.0
3183.0
3402.0
4752.0
6188.0
6216.0
6619.0
6884.0
7502.0
8387.0
10096.0
10454.0
10605.0
11712.0
12703.0
14049.0
14185.0
14351.0


In [7]:
def get_latex_table1_rotate_new(table1, table1_name):
    latex_code = '\\begin{table*}[ht]\n\\caption{Value steering using SAE features for the Gemma-2B-IT model. Expected stimulated values are highlighted in red, along with their actual success rate during testing. Expected suppressed values are marked in Purple. Maintained values are shown in gray. Light red indicates values that are expected to be at least not suppressed, while light purple represents values that are expected to be at least not stimulated. Blank cells correspond to uncontrollable values. The bottom of the table indicates the count of each of the six expected categories and their average success rates.}\n\\label{table: sae-steering-gemma}\n\\begin{center}\n\\scalebox{0.5}{'

    latex_code += '\\begin{tabular}{' + 'c@{\\hspace{1.5pt}}|' * (len(table1.index) + 1) + 'c' + '}\n\\toprule\n'
    #latex_code += '\\begin{tabular}{' + 'c|' * len(table1.index) + 'c' + '}\n\\toprule\n'
    #transfer table1.columns to a list of str

    value_dims = list(map(str, table1.index))
    steering_features = table1.columns
    latex_code += '\\rotatebox{90}{Value} & ' + ' & '.join(['\\rotatebox{90}{\\bf ' + tc +'}' for tc in value_dims]) + '&\\rotatebox{90}{\\bf AVG}' + ' \\\\\n\\hline\n'
    # &\\rotatebox{90} {\\bf STIMULATE} & \\rotatebox{90} {\\bf SUPPRESS} & \\rotatebox{90} {\\bf NON-SUPPRESS}& \\rotatebox{90} {\\bf NON-STIMULATE} & \\rotatebox{90} {\\bf MAINTAIN} & \\rotatebox{90} {\\bf UNCONTROLLED}


    for sf in steering_features:
        cosines = []

        latex_code += '\\small ' + str(int(sf)) + ' & '
        for vd in value_dims:
            value = table1.loc[vd, sf]
            if type(value) == str:
                #split value by / and trans each part to int
                value = value.split('/')
                value = list(map(int, value))
                #value = [red, blue, transparency, red_test, blue_test, transparency_test]
                traindata = value[:3]
                testdata = value[3:]
                #compute the cosine similarity between the two vectors (first normalize them)
                traindata_p = np.array(traindata) / np.sum(traindata)
                testdata_p = np.array(testdata) / np.sum(testdata)

                if vd.startswith('Anxiety') and sf == 1312:
                    pass
                #cosine similarity
                similarity = np.dot(traindata_p, testdata_p) / (np.linalg.norm(traindata_p) * np.linalg.norm(testdata_p))
                cosines.append(similarity)


                coeff = 1
                transratio = value[2] / (value[0] + value[1] + value[2])
                redratio = (value[0] / (value[0] + value[1])) 
                blueratio = (value[1] / (value[0] + value[1]))  
                redratio_t = redratio * (1 - transratio) 
                blueratio_t = blueratio * (1 - transratio)

                rr = int(255 * (1 - blueratio_t) * redratio_t + 255 * (1 - redratio_t) * (1 - blueratio_t)) / 255.0
                gg = int(255 * (1 - redratio_t) * blueratio_t + 255 * (1 - blueratio_t) * (1 - redratio_t)) / 255.0
                bb = int(255 * (1 - redratio_t) * (1 - blueratio_t))  / 255.0

                #latex_code += '\\colorbox{red!' + str(redratio_t) + '!green!' + str(blueratio_t) + '}' + '{' + f"{float(similarity):.2f}" + '} & '
                #latex_code += '\\colorbox[rgb]{' + str(redratio_t) + ',' + str(blueratio_t) + ',0' + '}' + '{' + f"{float(similarity):.2f}" + '} & '
                latex_code += '\\colorbox[rgb]{' + str(round(rr, 2)) + ',' + str(round(gg, 2)) + ',' + str(round(bb, 2)) + '}' + '{' + f"{float(similarity):.2f}" + '} & '
                #latex_code += '\\colorbox[rgb]{' + str(round(bb, 2)) + ',' + str(round(rr, 2)) + ',' + str(round(gg, 2)) + '}' + '{' + f"{float(similarity):.2f}" + '} & '
            

                # if random.random() < 0.5:
                #     latex_code += '\\colorbox[rgb]{1,0,0}' + '{' + f"{float(similarity):.2f}" + '} & '
                # else:
                #     latex_code += '\\colorbox[rgb]{0.14,0.78,0}' + '{' + f"{float(similarity):.2f}" + '} & '

        # latex_code = latex_code[:-2] + ' & ' + str(stimulated_dims) + f'({round(np.mean(stimulated_dim_avg_success),3)})' + ' & ' + str(suppressed_dims) + f'({round(np.mean(suppressed_dim_avg_success),3)})' + ' & ' + str(stimulhalf_dims) + f'({round(np.mean(stimulhalf_dim_avg_success),3)})'  + ' & ' + str(supprehalf_dims) + f'({round(np.mean(supprehalf_dim_avg_success),3)})' + ' & ' + str(maintained_dims) + f'({round(np.mean(maintained_dim_avg_success),3)})' + ' & ' + str(uncontroll_dims) + ' \\\\\n'

        latex_code = latex_code[:-2] + ' & ' + f"{np.mean(cosines):.2f}" + ' \\\\\n'
    #latex_code = latex_code + ' \\midrule\n'
    latex_code = latex_code + ' \\\\\n\\bottomrule\n'



    latex_code += '\\end{tabular}\n}\n\\end{center}\n\\end{table*}'
    print(latex_code)
    #write the latex code to a file
    with open(table1_name+'.tex', 'w') as f:
        f.write(latex_code)
    
get_latex_table1_rotate_new(table1_gemma, 'table1_gemma_rotate')


\begin{table*}[ht]
\caption{Value steering using SAE features for the Gemma-2B-IT model. Expected stimulated values are highlighted in red, along with their actual success rate during testing. Expected suppressed values are marked in Purple. Maintained values are shown in gray. Light red indicates values that are expected to be at least not suppressed, while light purple represents values that are expected to be at least not stimulated. Blank cells correspond to uncontrollable values. The bottom of the table indicates the count of each of the six expected categories and their average success rates.}
\label{table: sae-steering-gemma}
\begin{center}
\scalebox{0.5}{\begin{tabular}{c@{\hspace{1.5pt}}|c@{\hspace{1.5pt}}|c@{\hspace{1.5pt}}|c@{\hspace{1.5pt}}|c@{\hspace{1.5pt}}|c@{\hspace{1.5pt}}|c@{\hspace{1.5pt}}|c@{\hspace{1.5pt}}|c@{\hspace{1.5pt}}|c@{\hspace{1.5pt}}|c@{\hspace{1.5pt}}|c@{\hspace{1.5pt}}|c@{\hspace{1.5pt}}|c@{\hspace{1.5pt}}|c@{\hspace{1.5pt}}|c@{\hspace{1.5pt}}|c@{\hspac

In [None]:
def get_valid_d_columns_deprecated(answer_valuebench_features_csv):
    data_csv = pd.read_csv(answer_valuebench_features_csv)
    digits = [str(d) for d in range(10)]
    d_columns = [d for d in data_csv.columns if d[0] in digits]
    d_data = data_csv[d_columns]
    stds = d_data.std()
    avgs = d_data.mean()
    std_avg = stds/avgs
    #d_columns_valid = [d for d in d_columns if avgs[d] > 1]
    d_columns_valid = d_columns
    return d_columns_valid

def deal_with_csv(data_csv, pdy_name, v_inference, v_showongraph, row_num, method='pc', dummy_steered_dim=False): 
    # data_csv = pd.read_csv(answer_valuebench_features_csv)
    # v_columns_all = [v for v in data_csv.columns if (v not in ['player_name', 'steer_dim', 'stds']) and (not v.endswith(':scstd'))]
    # if v_inference == 'ALL':
    #     v_columns_inference = v_columns_all
    # else:
    #     for v in v_inference:
    #         if v not in v_columns_all:
    #             raise ValueError('Invalid v_inference')
    #     v_columns_inference = v_inference

    v_columns_inference = v_inference

    if v_showongraph == 'ALL':
        v_columns_showgraph = v_columns_inference
    else:
        for v in v_showongraph:
            if v not in v_columns_inference:
                raise ValueError('Invalid v_showongraph')
        v_columns_showgraph = v_showongraph

    if dummy_steered_dim:
        steer_dim_dummies = pd.get_dummies(data_csv['steer_dim'], prefix='steer_dim') * 1
        data = pd.concat([data_csv, steer_dim_dummies], axis=1)
        v_columns_inference_total = v_columns_inference + list(steer_dim_dummies.columns) 
        v_columns_showgraph_total = v_columns_showgraph + list(steer_dim_dummies.columns)
    else:
        data = data_csv
        v_columns_inference_total = v_columns_inference
        v_columns_showgraph_total = v_columns_showgraph
    
    data = data[v_columns_inference_total].to_numpy()    
    
    if type(row_num) == int:
        rows = np.random.choice(data.shape[0], row_num, replace=False)
        data = data[rows]
    else:
        assert row_num == 'ALL'

    if dummy_steered_dim:
        edges_total = causal_inference(data, v_columns_inference_total, pdy_name, method, noise_augument=None, prior_source_set=list(steer_dim_dummies.columns))
    else:
        edges_total = causal_inference(data, v_columns_inference_total, pdy_name, method, noise_augument=10)
    
    edges_sfs = []
    steer_dims = data_csv['steer_dim'].unique()
    for steer_dim in steer_dims:
        print(steer_dim)
        if np.isnan(steer_dim):
            data = data_csv[data_csv['steer_dim'].isnull()][v_columns_inference].to_numpy()
        else:
            data = data_csv[data_csv['steer_dim'] == steer_dim][v_columns_inference].to_numpy()
        sfedge = causal_inference(data, v_columns_inference, pdy_name.replace('.png', f'_{steer_dim}.png'), method, noise_augument=10)
        edges_sfs.append(sfedge)

    return edges_total, edges_sfs

def causal_inference(data, ci_dimensions, pdy_name, method, noise_augument=None, prior_source_set=None):
    print(data.shape)
    
    #0 is the mean of the normal distribution you are choosing from, and 0.01 is the standard deviation of this distribution.
    #scale the data for several times by adding noise
    if noise_augument:
        data = np.tile(data, (noise_augument, 1))
        noise = np.random.normal(0, 0.00001, data.shape)
        data = data + noise

    if method == 'pc':
        #g = pc(data, 0.0005, uc_rule=0, rule_priority=2, node_names=ci_dimensions)
        g = pc(data, 0.0005, node_names=ci_dimensions)
        
        if prior_source_set:
            bk = BackgroundKnowledge()
            nodes = g.G.get_nodes()
            for node1 in nodes:
                for node2 in nodes:
                    if node1.name in prior_source_set and node2.name in prior_source_set and node1.name != node2.name:
                        bk = bk.add_forbidden_by_node(node1, node2)
            #g = pc(data, 0.0005, uc_rule=0, rule_priority=2, node_names=ci_dimensions, background_knowledge=bk)
            g = pc(data, 0.0005, node_names=ci_dimensions, background_knowledge=bk)
            
        graph = g.G

        edges = []
        for n1 in range(len(graph.nodes)):
            assert graph.nodes[n1].name == ci_dimensions[n1]
            for n2 in range(n1+1, len(graph.nodes)):
                # if n1 == n2:
                #     continue
                if graph.graph[n1][n2] == -1 and graph.graph[n2][n1] == 1:
                    edges.append([graph.nodes[n1].name, graph.nodes[n2].name, 1, 'single-arrow'])
                elif graph.graph[n1][n2] == 1 and graph.graph[n2][n1] == -1:
                    edges.append([graph.nodes[n2].name, graph.nodes[n1].name, 1, 'single-arrow']) 
                elif graph.graph[n1][n2] == -1 and graph.graph[n2][n1] == -1:
                    edges.append([graph.nodes[n1].name, graph.nodes[n2].name, 1, 'no-arrow'])
                elif graph.graph[n1][n2] == 1 and graph.graph[n2][n1] == 1:
                    edges.append([graph.nodes[n1].name, graph.nodes[n2].name, 1, 'double-arrow'])
                else:
                    if not (graph.graph[n1][n2] == 0 and graph.graph[n2][n1] == 0):
                        raise ValueError('Invalid edge')
    else:
        raise ValueError('Invalid method')
    
    columns_concerned_vis = [label.replace(':','-') for label in ci_dimensions]
    pdy = GraphUtils.to_pydot(graph, labels=columns_concerned_vis)
    pdy.write_png(pdy_name)

    return edges


#data_csv = data_csv[data_csv['player_name'].notnull()]

v_inference_gemma = [v for v in data_csv_gemma_train.columns if (v not in ['player_name', 'steer_dim', 'stds', 'scstds']) and (not v.endswith(':scstd'))]
# v_inference_llama = [v for v in data_csv_llama_train.columns if (v not in ['player_name', 'steer_dim', 'stds', 'scstds']) and (not v.endswith(':scstd'))]
# assert v_inference_gemma == v_inference_llama
v_inference = v_inference_gemma

#v_inference = ['Affiliation', 'Assertiveness', 'Behavioral Inhibition System', 'Breadth of Interest', 'Complexity', 'Dependence', 'Depth', 'Emotional Expression', 'Emotional Processing', 'Empathy', 'Extraversion', 'Imagination', 'Nurturance', 'Perspective Taking', 'Social Withdrawal', 'Positive Expressivity', 'Preference for Order and Structure', 'Privacy', 'Psychosocial flourishing', 'Reflection']
#v_inference = ['Affiliation', 'Assertiveness', 'Behavioral Inhibition System', 'Breadth of Interest', 'Complexity', 'Dependence', 'Depth', 'Emotional Expression', 'Emotional Processing', 'Empathy', 'Extraversion', ]

if os.path.exists('value_causal_graph_gemma'):
    shutil.rmtree('value_causal_graph_gemma')
os.makedirs('value_causal_graph_gemma', exist_ok=True)
edges_gemma_total, edges_gemma_sfs = deal_with_csv(data_csv_gemma_train, "value_causal_graph_gemma/total.png", v_inference, 'ALL', 'ALL', 'pc', False)

# if os.path.exists('value_causal_graph_llama'):
#     shutil.rmtree('value_causal_graph_llama')
# os.makedirs('value_causal_graph_llama', exist_ok=True)
# edges_llama_total, edges_llama_sfs = deal_with_csv(data_csv_llama_train, "value_causal_graph_llama/total.png", v_inference, 'ALL', 'ALL', 'pc', False)

edges_standard_json = json.load(open('value_graph_smallset_triplets.json'))
edges_standard = []
for edge in edges_standard_json:
    if edge[1] == '-->':
        edges_standard.append([edge[0], edge[2], 1, 'single-arrow'])
    elif edge[1] == 'o--o':
        edges_standard.append([edge[0], edge[2], 1, 'double-arrow'])
    else:
        raise ValueError('Invalid edge')


In [None]:
def check_zero_double_arrow(edges):
    double_arrow_edges = [edge for edge in edges if edge[3] == 'double-arrow']
    zero_arrow_edges = [edge for edge in edges if edge[3] == 'no-arrow']
    if double_arrow_edges:
        raise ValueError('Double arrow:', double_arrow_edges)
    if zero_arrow_edges:
        raise ValueError('Zero arrow:', zero_arrow_edges)

def dealwith_zero_double_duplicated_arrow(edges):
    double_arrow_edges = [edge for edge in edges if edge[3] == 'double-arrow']
    zero_arrow_edges = [edge for edge in edges if edge[3] == 'no-arrow']
    print('Double arrow:', double_arrow_edges)
    print('Zero arrow:', zero_arrow_edges)
    print('Dealwith zero and double arrow edges')
    print('----------------------')
    
    new_edges = []
    for edge in edges:
        if edge[3] == 'double-arrow' or edge[3] == 'no-arrow':
            if [edge[0], edge[1], edge[2], 'single-arrow'] not in new_edges:
                new_edges.append([edge[0], edge[1], edge[2], 'single-arrow'])
            if [edge[1], edge[0], edge[2], 'single-arrow'] not in new_edges:
                new_edges.append([edge[1], edge[0], edge[2], 'single-arrow'])
        else:
            if edge not in new_edges:
                new_edges.append(edge)
    return new_edges



def check_dag(edges):
    nxg = nx.DiGraph()
    for edge in edges:
        if edge[3] == 'single-arrow':
            nxg.add_edge(edge[0], edge[1])
    if not nx.is_directed_acyclic_graph(nxg):
        cycles = list(nx.simple_cycles(nxg))
        raise ValueError('Cycle:', cycles)

def get_all_subsequent_nodes(edges, node):
    #check_zero_double_arrow(edges)

    subsequent_nodes = set()
    subsequent_nodes.add(node)
    while True:
        subsequent_nodes_len = len(subsequent_nodes)
        for edge in edges:
            if edge[0] in subsequent_nodes:
                subsequent_nodes.add(edge[1])
        if len(subsequent_nodes) == subsequent_nodes_len:
            break
    subsequent_nodes.remove(node)
    return subsequent_nodes

def write_table2(edges, data_scorechange, mean_scorechange_related, num_related, mean_scorechange_unrelated, num_unrelated):
    for column in data_scorechange.columns:
        print(column)
        #related_columns_real1 = data_scorechange[data_scorechange[column] > 0].mean().abs().sort_values()
        #related_columns_real2 = data_scorechange[data_scorechange[column] < -0].mean().abs().sort_values()
        related_columns_real = data_scorechange[data_scorechange[column] != 0].abs().mean().sort_values()
        related_columns_ideal = get_all_subsequent_nodes(edges, column)

        related_scabs = []
        unrelated_scabs = []
        for related_column in related_columns_real.index:
            if related_column in related_columns_ideal:
                related_scabs.append(related_columns_real[related_column])
            elif related_column != column:
                unrelated_scabs.append(related_columns_real[related_column])
            else:
                assert related_column == column
        #     print(related_column, related_columns_real[related_column], related_column in related_columns_ideal)
        # print('~~~')
        
        print('Related:', np.mean([vdsc for vdsc in related_scabs if not np.isnan(vdsc)]), len(related_scabs))
        print('Unrelated:', np.mean([vdsc for vdsc in unrelated_scabs if not np.isnan(vdsc)]), len(unrelated_scabs))
        pd_result_table2.loc[mean_scorechange_related, column] = np.mean([vdsc for vdsc in related_scabs if not np.isnan(vdsc)])
        pd_result_table2.loc[num_related, column] = len(related_scabs)
        pd_result_table2.loc[mean_scorechange_unrelated, column] = np.mean([vdsc for vdsc in unrelated_scabs if not np.isnan(vdsc)])
        pd_result_table2.loc[num_unrelated, column] = len(unrelated_scabs)
        print('----------------------')




pd_result_table2 = pd.DataFrame(columns=v_inference)

# edges_standard = dealwith_zero_double_arrow(edges_standard)
edges_standard = [
    ['Emotional Processing', 'Emotional Expression', 1, 'single-arrow'],
    ['Emotional Processing', 'Psychosocial Flourishing', 1, 'single-arrow'],
    ['Perspective Taking', 'Sympathy', 1, 'single-arrow'],
    ['Perspective Taking', 'Empathy', 1, 'double-arrow'],
    ['Perspective Taking', 'Nurturance', 1, 'double-arrow'],
    ['Sociability', 'Extraversion', 1, 'double-arrow'],
    ['Sociability', 'Warmth', 1, 'double-arrow'],
    ['Sociability', 'Positive Expressivity', 1, 'double-arrow'],
    ['Dependence', 'Nurturance', 1, 'single-arrow'],
    ['Psychosocial Flourishing', 'Satisfaction with life', 1, 'single-arrow'],
    ['Psychosocial Flourishing', 'Nurturance', 1, 'single-arrow'],
    ['Extraversion', 'Positive Expressivity', 1, 'single-arrow'],
    ['Extraversion', 'Social Confidence', 1, 'single-arrow'],
    ['Extraversion', 'Social', 1, 'double-arrow'],
    ['Affiliation', 'Empathy', 1, 'double-arrow'],
    ['Affiliation', 'Social', 1, 'double-arrow'],
    ['Understanding', 'Empathy', 1, 'double-arrow'],
    ['Understanding', 'Reflection', 1, 'double-arrow'],
    ['Understanding', 'Depth', 1, 'single-arrow'],
    ['Understanding', 'Theoretical', 1, 'double-arrow'],
    ['Sympathy', 'Nurturance', 1, 'single-arrow'],
    ['Warmth', 'Empathy', 1, 'single-arrow'], 
    ['Warmth', 'Nurturance', 1, 'double-arrow'],
    ['Warmth', 'Positive Expressivity', 1, 'single-arrow'],
    ['Warmth', 'Social', 1, 'single-arrow'], 
    ['Empathy', 'Tenderness', 1, 'double-arrow'],
    ['Empathy', 'Nurturance', 1, 'double-arrow'], 
    ['Positive Expressivity', 'Social', 1, 'double-arrow'],
]

data_gemma_nosteer = data_csv_gemma_test[data_csv_gemma_test['steer_dim'].isnull()][data_csv_gemma_test['player_name'].notnull()]
data_gemma_nosteer = data_gemma_nosteer[v_inference + ['player_name']]
data_gemma_nosteer = data_gemma_nosteer.set_index('player_name')
data_gemma_nosteer = data_gemma_nosteer.astype(float)
data_gemma_scorechange = data_gemma_nosteer - data_gemma_nosteer.iloc[0]

edges_gemma_sfs0 = edges_gemma_sfs[0]
#edges_gemma_sfs0 = dealwith_zero_double_arrow(edges_gemma_sfs[0])
# for end_node in ['Affiliation', 'Breadth of Interest', 'Dependence']:#  'Behavioral Inhibition System'  'Nurturance'
#     for start_node in ['Poise', 'Social Confidence', 'Preference for Order and Structure']:#[,  , 'Assertiveness']:
#         edges_gemma_sfs0.append([end_node, start_node, 1, 'single-arrow'])

write_table2(edges_gemma_sfs0, data_gemma_scorechange,  'mean_scorechange_related_ours_gemma', 'num_related_ours_gemma', 'mean_scorechange_unrelated_ours_gemma', 'num_unrelated_ours_gemma')
write_table2(edges_standard, data_gemma_scorechange, 'mean_scorechange_related_standard_gemma', 'num_related_standard_gemma', 'mean_scorechange_unrelated_standard_gemma', 'num_unrelated_standard_gemma')


# data_llama_nosteer = data_csv_llama_test[data_csv_llama_test['steer_dim'].isnull()][data_csv_llama_test['player_name'].notnull()]
# data_llama_nosteer = data_llama_nosteer[v_inference + ['player_name']]
# data_llama_nosteer = data_llama_nosteer.set_index('player_name')
# data_llama_nosteer = data_llama_nosteer.astype(float)
# data_llama_scorechange = data_llama_nosteer - data_llama_nosteer.iloc[0]

# #edges_llama_sfs0 = dealwith_zero_double_arrow(edges_llama_sfs[0])
# edges_llama_sfs0 = edges_llama_sfs[0]
# write_table2(edges_llama_sfs0, data_llama_scorechange,  'mean_scorechange_related_ours_llama', 'num_related_ours_llama', 'mean_scorechange_unrelated_ours_llama', 'num_unrelated_ours_llama')
# write_table2(edges_standard, data_llama_scorechange, 'mean_scorechange_related_standard_llama', 'num_related_standard_llama', 'mean_scorechange_unrelated_standard_llama', 'num_unrelated_standard_llama')

pd_result_table2.to_csv('table2.csv')


#

In [None]:
#print the table2 in latex
#rows are for each values dimensions
#columns are in form num_related_ours(mean_scorechange_related_ours), num_unrelated_ours(mean_scorechange_unrelated_ours), num_related_standard(mean_scorechange_related_standard), num_unrelated_standard(mean_scorechange_unrelated_standard)
#the values are the number of related values, the mean of the score change of related values, the number of unrelated values, the mean of the score change of unrelated values
#the values are rounded to 3 decimal places
#the values are in the form number(mean)
#the values are in the form of number(mean)
pd_result_table2 = pd.read_csv('table2.csv', index_col=0)
latex_code = '\\begin{table}[ht]\n\\caption{The mean of the score change of related values, the number of related values, the mean of the score change of unrelated values, and the number of unrelated values.}\n\\label{table: scorechange}\n\\begin{center}\n'
#latex_code += '\\begin{tabular}{c@{\\hspace{2pt}}' + 'c@{\\hspace{2pt}}' * (len(pd_result_table2.columns) - 1) + 'c' + '}\n\\toprule\n'
latex_code += '\\begin{tabular}{c@{\\hspace{2pt}}|' + 'c@{\\hspace{2pt}}' * 4 +'|' + 'c@{\\hspace{2pt}}' * 4 + '}\n\\toprule\n'
latex_code += 'Value & \\multicolumn{4}{c|}{\\bf \\small Gemma-2B-IT} & \\multicolumn{4}{c}{\\bf \\small Llama3-8B-IT}\\\\\n\\hline\n'
latex_code += 'Dimensions & \\multicolumn{2}{c|}{\\bf \\tiny Our causal graph} & \\multicolumn{2}{c|}{\\bf \\tiny Our causal graph} & \\multicolumn{2}{c|}{\\bf \\tiny Our causal graph} & \\multicolumn{2}{c}{\\bf \\tiny Our causal graph}  \\\\\n\\hline\n'
latex_code += 'Score change & \\multicolumn{1}{c}{\\bf \\tiny Expected} & \\multicolumn{1}{c|}{\\bf \\tiny Unexpected} & \\multicolumn{1}{c}{\\bf \\tiny Expected} & \\multicolumn{1}{c|}{\\bf \\tiny Unexpected} & \\multicolumn{1}{c}{\\bf \\tiny Expected} & \\multicolumn{1}{c|}{\\bf \\tiny Unexpected} & \\multicolumn{1}{c}{\\bf \\tiny Expected} & \\multicolumn{1}{c}{\\bf \\tiny Unexpected}\\\\\n\\hline\n'
#each row in latex is a column in the dataframe
for column in pd_result_table2.columns:
    latex_code += '\\small ' + column + ' & '
    for index in pd_result_table2.index:
        if index.startswith('mean'):
            latex_code += str(round(pd_result_table2.loc[index, column], 2)) + ' & '

    latex_code = latex_code[:-2] + ' \\\\\n'
latex_code += '\\bottomrule\n\\end{tabular}\n\\end{center}\n\\end{table}'
print(latex_code)
#write the latex code to a file
with open('table2.tex', 'w') as f:
    f.write(latex_code)

In [30]:

#steer_dims = ['nan', 1312, 1341, 2221, 3183, 6619, 7502, 8387, 10096, 14049]

nodes = {}
for entity in v_inference:
    nodes[entity] = os.path.join('valuebench','value_questions_' + entity + '.html'),
# for feature in data_csv.['steer_dim'].unique()[1:]:
#     nodes[feature] = 'https://www.neuronpedia.org/' + sae.cfg.model_name +'/' + str(sae.cfg.hook_layer) + '-res-jb/' + str(feature)

edges = {
    'gemma': edges_gemma_sfs0,
    #'llama': edges_llama_sfs0,
    'standard': edges_standard
}

json_object = {
    'nodes': nodes,
    'edges': edges
    }

json.dump(json_object, open('data1.json', 'w'))

In [1]:
import ipywidgets as widgets