In [None]:
import os
import numpy as np
from collections import defaultdict
from pathlib import Path

WINDOWS = os.name == 'nt'

# Base directories
if WINDOWS:
    INPUT_DIR = 'W:/Portrait/Embeddings/Portrait Transcripts'
else:
    # For linux
    INPUT_DIR = '/Volumes/mgialou/Portrait/Embeddings/Portrait Transcripts'

OUTPUT_BASE_DIR = Path(INPUT_DIR) / 'embeddings'
total_empty_count = 0
print(f"Scanning for embeddings in: {OUTPUT_BASE_DIR}")

# Data structures to collect all embeddings and track empty ones
# Structure: all_embeddings[model_id][question_base_name] = list_of_np_arrays
all_embeddings = defaultdict(lambda: defaultdict(list))

# Structure: empty_file_paths_for_imputation[model_id][question_base_name] = list_of_file_paths
empty_file_paths_for_imputation = defaultdict(lambda: defaultdict(list))
empty_embeddings = defaultdict(lambda: defaultdict(list))

# Walk through OUTPUT_BASE_DIR to first collect all embeddings and identify empty ones
print("\nCollecting all embedding data (this might take a while)...")
for root, _, files in os.walk(OUTPUT_BASE_DIR):
    for file in files:
        if file.endswith('.npy'):
            file_path = os.path.join(root, file)
            model_id = os.path.basename(os.path.dirname(file_path))
            # --- HIGHLIGHT START: Adjusted logic for question_filename_without_suffix ---
            base_filename = file.replace('.npy', '')
            
            # 1. Remove the _EMPTY suffix if it exists
            if base_filename.endswith('_EMPTY'):
                temp_q_name = base_filename[:-len('_EMPTY')]
            else:
                temp_q_name = base_filename
            
            # 2. Check for a leading user ID prefix (e.g., "USR12_")
            # This assumes user IDs are numeric and followed by an underscore
            if '_' in temp_q_name and temp_q_name.split('_')[0].isdigit():
                question_filename_without_suffix = temp_q_name.split('_', 1)[1]
            else:
                question_filename_without_suffix = temp_q_name
            # --- HIGHLIGHT END ---

            if file.endswith('EMPTY.npy'):
                user_id = os.path.basename(os.path.dirname(os.path.dirname(file_path)))
                empty_embeddings[user_id][model_id].append(file_path) # For reporting original empty files
                empty_file_paths_for_imputation[model_id][question_filename_without_suffix].append(file_path)
                total_empty_count += 1
            else:
                # Load non-empty embeddings
                try:
                    embedding = np.load(file_path)
                    if embedding.size > 0: # Ensure it's not truly empty or just a zero-sized array
                        all_embeddings[model_id][question_filename_without_suffix].append(embedding.flatten())
                except Exception as e:
                    print(f"Warning: Could not load embedding from {file_path}. Skipping. Error: {e}")

# Display identified empty files (unchanged from your original script)
print("\nEmpty Embedding Files by User ID and Model (Identified):")
if total_empty_count == 0:
    print("  No EMPTY.npy files found.")
else:
    for user_id, models in empty_embeddings.items():
        print(f"\nUser ID: {user_id}")
        for model_id, files in models.items():
            print(f"  Model: {model_id}")
            for file_path in files:
                print(f"    {file_path}")

print(f"\nTotal Empty Embeddings Identified: {total_empty_count}")

# Imputation and Overwriting Logic
print("\nStarting imputation of EMPTY.npy files...")
imputed_count = 0

for model_id, questions in empty_file_paths_for_imputation.items():
    for q_name, empty_files_list in questions.items():
        if q_name in all_embeddings[model_id] and all_embeddings[model_id][q_name]:
            # Calculate the mean embedding for this specific model and question from collected non-empty embeddings
            mean_embedding = np.mean(all_embeddings[model_id][q_name], axis=0)
            
            for empty_file_path in empty_files_list:
                try:
                    # Overwrite the EMPTY.npy file with the calculated mean embedding
                    np.save(empty_file_path, mean_embedding)
                    print(f"  Imputed and overwrote: {empty_file_path} with mean of {model_id}/{q_name}")
                    imputed_count += 1
                except Exception as e:
                    print(f"  Error imputing {empty_file_path}. Skipping. Error: {e}")
        else:
            print(f"  Warning: No non-empty embeddings found to calculate mean for model '{model_id}' and question '{q_name}'. "
                  f"Skipping imputation for {len(empty_files_list)} empty files of this type.")

print(f"\nFinished imputation. Successfully imputed {imputed_count} files.")