In [1]:
import asyncio
from openai import AsyncOpenAI
from sklearn.metrics import f1_score
from scipy.stats import pearsonr
from scipy import sparse
from scipy.sparse import csr_array
import numpy as np
import pandas as pd
import json
import re
from tqdm.asyncio import tqdm as async_tqdm
from datasets import load_dataset

In [2]:
MODEL_DIR = "models/4096_8_-1/"
INTERP_GPT = "gpt-4o-2024-08-06"
INTERP_PATH = f"{MODEL_DIR}/interpreter_responses_{INTERP_GPT}_v2.json"
DATA_PATH = "jam963/indigeneity_fr"
NUM_POS = 3
NUM_NEG = 3
BATCH_SIZE = 64
RNG = np.random.default_rng(seed=910)

In [3]:
df = pd.DataFrame(load_dataset(DATA_PATH, split="train"))

Downloading readme: 100%|██████████████████████████████████████████████████████████████████████████████████████████████| 1.31k/1.31k [00:00<00:00, 6.32kB/s]
Downloading data: 100%|████████████████████████████████████████████████████████████████████████████████████████████████| 4.69G/4.69G [03:23<00:00, 23.0MB/s]
Generating train split: 100%|██████████████████████████████████████████████████████████████████████████████| 210305/210305 [00:28<00:00, 7297.70 examples/s]


In [3]:
all_sparse_embeddings = sparse.load_npz(f"{MODEL_DIR}/sparse_embeddings.npz")

with open(INTERP_PATH, "r") as f:
    raw_interpretations = json.load(f)

with open("predictor_prompt_fr.txt", "r") as f:
    base_prompt = f.read()

In [4]:
def extract_prompt_response(response):
    pattern = r'\b\s*(?:FINAL)\s*:?\s*\*?\s*(.*?)(?:\n|$)'
    
    matches = list(re.finditer(pattern, response))
    
    if matches:
        last_match = matches[-1]
        extracted_text = last_match.group(1).replace("*", "").strip()
        
        return extracted_text
    else:
        return None

In [5]:
interpretations = {int(k): extract_prompt_response(v) 
                   for k, v in raw_interpretations.items()}

In [6]:
interpretations

{0: 'Résistance historique indigène contre envahisseurs étrangers',
 1: 'Relations coloniales, pouvoir, législation, démographie indigène',
 2: 'Élevage de races animales indigènes locales',
 3: 'Plantes et technologies traditionnelles locales.',
 4: 'Contraste culturel géographique, authenticité informations locales',
 5: 'Bois local utilisé en charpenterie',
 6: 'Contexte colonial, interactions économiques et politiques.',
 7: 'Contexte militaire, juridique, pouvoir colonial, indigènes.',
 8: 'Constructions indigènes historiques et archéologiques spécifiques',
 9: 'Interactions indigènes avec structures coloniales européennes.',
 10: 'Distinction ethnique historique en Afrique coloniale',
 11: 'Contexte de domination coloniale des indigènes.',
 12: 'Contexte militaire révolution haïtienne, armée indigène.',
 13: 'Comparaison scientifique indigène vs exotique.',
 14: 'Compétence locale, transformation, influence culturelle.',
 15: 'Contexte militaire colonial Afrique du Nord.',
 16: '

In [7]:
with open(f"{MODEL_DIR}/interpretations_{INTERP_GPT}.json", "w") as f: 
    json.dump(interpretations, f)

In [8]:
def get_samples_for_predictor(all_sparse_embeddings, df, 
                              feature_idx, 
                              num_pos=NUM_POS, 
                              num_neg=NUM_NEG):
    activations = all_sparse_embeddings[:, [feature_idx]].toarray().flatten()
    positive_indices = np.where(activations > 0)[0]
    zero_indices = np.where(activations == 0)[0]
    
    pos_samples = RNG.choice(positive_indices, size=min(num_pos, len(positive_indices)), replace=False)
    neg_samples = RNG.choice(zero_indices, size=min(num_neg, len(zero_indices)), replace=False)
    
    pos_sentences = df.iloc[pos_samples]["sentence"].tolist()
    neg_sentences = df.iloc[neg_samples]["sentence"].tolist()
    
    return pos_sentences, neg_sentences


def format_predictor_prompt(description, text, base_prompt):
    return base_prompt.format(description=description, text=text)


def prepare_prompts(feature_descriptions, all_sparse_embeddings, df, base_prompt):
    all_prompts = []
    all_labels = []
    feature_map = []

    for feature, description in feature_descriptions:
        pos_sentences, neg_sentences = get_samples_for_predictor(all_sparse_embeddings, df, int(feature))
        all_sentences = pos_sentences + neg_sentences
        labels = [1] * len(pos_sentences) + [-1] * len(neg_sentences)
        
        prompts = [format_predictor_prompt(description, sentence, base_prompt) for sentence in all_sentences]
        
        all_prompts.extend(prompts)
        all_labels.extend(labels)
        feature_map.extend([feature] * len(all_sentences))
    
    return all_prompts, all_labels, feature_map


async def run_inference(client, prompts):
    async def single_prediction(prompt):
        try:
            completion = await client.chat.completions.create(
                model="gpt-4o-mini",
                messages=[
                    {"role": "system", "content": "Vous êtes un assistant utile."},
                    {"role": "user", "content": prompt}
                ]
            )
            return completion.choices[0].message.content
        except Exception as e:
            print(f"Error in API call: {str(e)}")
            return None

    results = []
    for i in async_tqdm(range(0, len(prompts), BATCH_SIZE)):
        batch = prompts[i:i+BATCH_SIZE]
        batch_results = await asyncio.gather(*[single_prediction(prompt) for prompt in batch])
        results.extend(batch_results)
    
    return results


def extract_prediction(response):
    if response is None:
        return None
    try:
        import re
        pattern = r'PREDICTION\s*:\s*([-]?[0-9]*\.?[0-9]+)'
        match = re.search(pattern, response)
        if match:
            return float(match.group(1))
        else:
            print(f"No prediction found in response: {response}")
            return None
    except Exception as e:
        print(f"Error extracting prediction: {str(e)}")
        return None


def analyze_results(predictions, true_labels, feature_map):
    feature_validations = {}
    for feature in set(feature_map):
        feature_predictions = [p for p, f in zip(predictions, feature_map) if f == feature]
        feature_labels = [l for l, f in zip(true_labels, feature_map) if f == feature]
        
        processed_predictions = [0 if p is None else p for p in feature_predictions]
        
        try:
            correlation, p_value = pearsonr(feature_labels, processed_predictions)

            binary_predictions = [1 if p > 0 else -1 for p in processed_predictions]
            f1 = f1_score(feature_labels, binary_predictions, average='binary')
            
            accuracy = np.mean([1 if p == l else 0 for p, l in zip(binary_predictions, feature_labels)])
            
            feature_validations[feature] = {
                'correlation': correlation,
                'p_value': p_value,
                'f1_score': f1,
                'accuracy': accuracy,
                'total_samples': len(feature_labels),
                'none_responses': feature_predictions.count(None)
            }
        except Exception as e:
            print(f"Error calculating metrics for feature {feature}: {str(e)}")
    
    return feature_validations

In [9]:
client = AsyncOpenAI(
        organization="org-Ksqkwzk8Pm1pgpC4K1lftzQT", 
        project="proj_JNGzr42oLNfUh0XwQjeqpfLm"
    )

feature_descriptions = list(interpretations.items())

print("Preparing prompts...")
all_prompts, all_labels, feature_map = prepare_prompts(feature_descriptions, all_sparse_embeddings, df, base_prompt)

print("Running inference...")
raw_predictions = await run_inference(client, all_prompts)

print("Extracting predictions...")
predictions = [extract_prediction(resp) for resp in raw_predictions]

print("Analyzing results...")
feature_validations = analyze_results(predictions, all_labels, feature_map)

validation_df = pd.DataFrame.from_dict(feature_validations, orient='index')
validation_df['description'] = validation_df.index.map(dict(feature_descriptions))
validation_df.to_csv(f'{MODEL_DIR}predictions_{INTERP_GPT}.csv')
print("Saved final results")

Preparing prompts...
Running inference...


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 384/384 [1:31:36<00:00, 14.31s/it]
  correlation, p_value = pearsonr(feature_labels, processed_predictions)


Extracting predictions...
No prediction found in response: Pour évaluer si le neurone s'activera sur le texte donné, nous devons examiner à la fois la description du neurone et le contenu du texte.

1. **Analyse de la description du neurone** : Nous savons que ce neurone est activé par des textes faisant référence au « statut administratif au titre indigène ». Cela implique une discussion liée à un cadre légal, administratif ou politique concernant les populations ou personnes souvent considérées comme « indigènes » dans un contexte spécifique.

2. **Analyse du texte** : Le texte fourni mentionne que « la plupart des indigènes aiment à adopter un autre nom à cette occasion, le nom d'un disciple ou d'un prophète ». Cette phrase évoque une tradition ou une pratique culturelle des personnes désignées comme « indigènes », mais elle ne traite pas du statut administratif ou légal lié à ces personnes. Il s'agit plutôt d'une dimension socioculturelle, sans référence explicite à un cadre admini

  correlation, p_value = pearsonr(feature_labels, processed_predictions)
  correlation, p_value = pearsonr(feature_labels, processed_predictions)
  correlation, p_value = pearsonr(feature_labels, processed_predictions)
  correlation, p_value = pearsonr(feature_labels, processed_predictions)
  correlation, p_value = pearsonr(feature_labels, processed_predictions)
  correlation, p_value = pearsonr(feature_labels, processed_predictions)
  correlation, p_value = pearsonr(feature_labels, processed_predictions)
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  correlation, p_value = pearsonr(feature_labels, processed_predictions)
  correlation, p_value = pearsonr(feature_labels, processed_predictions)
  correlation, p_value = pearsonr(feature_labels, processed_predictions)
  correlation, p_value = pearsonr(feature_labels, processed_predictions)
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  correlation, p_value = pearsonr(feature_labels, p

Saved final results
