<a href="https://colab.research.google.com/github/kvamsi7/vads-prevalent-safety-llm/blob/main/notebooks/Phase_2_Latent_Filter.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [7]:
import json
from collections import Counter
import pandas as pd

In [8]:
def load_dataset(file_path):
    with open(file_path, "r") as file:
        data = json.load(file)
    return data

In [104]:
def extract_latent_features(data):
  prompts_data = {}
  for prompt_id, prompt_data in data.items():
    features_list_unsafe = {}
    features_list_safe = {}
    features_list = []
    if "unsafe_latent_info" in prompt_data and "latents" in prompt_data["unsafe_latent_info"]:
      for feature,value in prompt_data["unsafe_latent_info"]["latents"].items():
        features_list_unsafe[feature] = value.get('act_score',0)
      features_list.append(features_list_unsafe)
    if "safe_latent_data" in prompt_data and "latents" in prompt_data["safe_latent_data"]:
      for feature, value in prompt_data["safe_latent_data"]["latents"].items():
        features_list_safe[feature] = value.get('act_score',0)
      features_list.append(features_list_safe)
    prompts_data[prompt_id] = features_list
  return prompts_data


In [105]:
file_path = "/content/dataset_latent_autointep_dataset_v2_info.json"
data = load_dataset(file_path)

In [120]:
prompt_features = extract_latent_features(data)

In [127]:
unsafe_prompts_features = [{'ID':prompt_id,'Features':features_list[0]}  for prompt_id, features_list in prompt_features.items() if len(features_list) > 0]
safe_prompts_features = [{'ID':prompt_id,'Features':features_list[1]} for prompt_id, features_list in prompt_features.items() if len(features_list) > 0]

df_harmful_features = pd.DataFrame(unsafe_prompts_features, columns=["ID", "Features"])
df_harmless_features = pd.DataFrame(safe_prompts_features, columns=["ID", "Features"])

         ID                                           Features
0  prompt_1  {'6631': 2528.1, '11527': 189.59, '994': 65.3,...
1  prompt_2  {'6631': 2565.0, '11527': 225.16, '8684': 260....
2  prompt_3  {'6631': 2486.13, '7933': 102.38, '4817': 188....
3  prompt_4  {'6631': 2545.57, '7576': 103.2, '11092': 79.3...
4  prompt_5  {'6631': 2690.91, '11795': 178.62, '1392': 107...


In [146]:
def get_common_features(df1, df2):

  # Merge DataFrames
  merged_df = df1.merge(df2, on='ID', suffixes=('_df1', '_df2'))

  # Function to get common keys for each prompt
  def get_common_keys(row):
      keys_df1 = set(row['Features_df1'].keys()) if isinstance(row['Features_df1'], dict) else set()
      keys_df2 = set(row['Features_df2'].keys()) if isinstance(row['Features_df2'], dict) else set()
      return list(keys_df1 & keys_df2)  # Intersection of keys


  # print(merged_df)
  merged_df['common_keys'] = merged_df.apply(get_common_keys, axis=1)
  # Display result
  result = merged_df[['ID', 'common_keys']]
  return result


In [148]:
common_features_df = get_common_features(df_harmful_features,df_harmless_features)

In [151]:
common_features_df.to_csv('common_features.csv', index=False)

In [150]:
len(common_features_df)

85