In [None]:
import json
import os
import glob

# --- Configuration ---

SOURCE_FOLDER = "runs_dealignedQwen_hhrlhf/layer_35" 
OUTPUT_FOLDER = "runs_dealignedQwen_hhrlhf/cleaned_layer_35"
TARGET_COLUMNS = ["output_base", "output_dpo", "output_steered"]
FILE_EXTENSION = "*.jsonl" 

def clean_content(text):
    """
    Returns tuple: (cleaned_text, boolean_was_modified)
    """
    delimiters = ["<|assistant|>", "\n\nassistant\n", "\n\nAssistant\n"]
    
    if text and isinstance(text, str):
        for delim in delimiters:
            if delim in text:
                return text.split(delim, 1)[1].strip(), True
                
    return text, False
    
def process_folder():
    if not os.path.exists(OUTPUT_FOLDER):
        os.makedirs(OUTPUT_FOLDER)

    files = glob.glob(os.path.join(SOURCE_FOLDER, FILE_EXTENSION))
    
    if not files:
        print("No files found.")
        return

    for file_path in files:
        filename = os.path.basename(file_path)
        output_path = os.path.join(OUTPUT_FOLDER, filename)
        
        print(f"\n{'='*80}")
        print(f"PROCESSING FILE: {filename}")
        print(f"{'='*80}")
        
        # Stats tracking
        total_lines = 0
        stats = {col: {"modified": 0, "unchanged": 0} for col in TARGET_COLUMNS}
        missing_logs = [] # To store details of missing tags

        with open(file_path, 'r', encoding='utf-8') as infile, \
             open(output_path, 'w', encoding='utf-8') as outfile:
            
            for line in infile:
                if not line.strip(): continue
                
                try:
                    data = json.loads(line)
                    total_lines += 1
                    current_id = data.get("id", "N/A")
                    # Grab first 50 chars of prompt for the log
                    prompt_preview = data.get("prompt", "")[:50].replace('\n', ' ') + "..."

                    # Process each column
                    for col in TARGET_COLUMNS:
                        if col in data:
                            cleaned_text, found = clean_content(data[col])
                            data[col] = cleaned_text
                            
                            if found:
                                stats[col]["modified"] += 1
                            else:
                                stats[col]["unchanged"] += 1
                                missing_logs.append({
                                    "id": current_id,
                                    "col": col,
                                    "prompt": prompt_preview
                                })
                        else:
                            # Handle case where column might be missing entirely from JSON
                            missing_logs.append({
                                "id": current_id,
                                "col": col,
                                "prompt": "COLUMN MISSING"
                            })

                    # Write using ensure_ascii=False to fix the \u2019 unicode issue
                    outfile.write(json.dumps(data, ensure_ascii=False) + '\n')

                except json.JSONDecodeError:
                    print(f"Error decoding JSON on line {total_lines}")

        # --- PRINT SUMMARY ---
        print(f"Total Samples Processed: {total_lines}")
        print("-" * 60)
        print(f"{'Column Name':<20} | {'Found Tag (Cleaned)':<20} | {'No Tag (Unchanged)':<20}")
        print("-" * 60)
        
        for col in TARGET_COLUMNS:
            print(f"{col:<20} | {stats[col]['modified']:<20} | {stats[col]['unchanged']:<20}")
            
        # --- PRINT DETAILED LOGS FOR UNCHANGED ITEMS ---
        if missing_logs:
            print("\n[!] DETAILED LOG: Samples where <|assistant|>, \n\nassistant\n", "\n\nAssistant\n was NOT found:")
            print(f"{'ID':<5} | {'Column':<15} | {'Prompt Snippet'}")
            print("-" * 80)
            for log in missing_logs:
                print(f"{str(log['id']):<5} | {log['col']:<15} | {log['prompt']}")
        else:
            print("\n[+] Success! All target columns in all rows contained the tag.")

        print(f"\nSaved cleaned file to: {output_path}")
        # break

# Run the function
process_folder()


PROCESSING FILE: hh_eval_lambda_0.05.jsonl
Total Samples Processed: 250
------------------------------------------------------------
Column Name          | Found Tag (Cleaned)  | No Tag (Unchanged)  
------------------------------------------------------------
output_base          | 244                  | 6                   
output_dpo           | 244                  | 6                   
output_steered       | 244                  | 6                   

[!] DETAILED LOG: Samples where <|assistant|>, 

assistant
 

Assistant
 was NOT found:
ID    | Column          | Prompt Snippet
--------------------------------------------------------------------------------
42    | output_base     | Human: What are methods of having sexual relations...
42    | output_dpo      | Human: What are methods of having sexual relations...
42    | output_steered  | Human: What are methods of having sexual relations...
60    | output_base     | Human: I need to find a gun to kill someone who ha...
60    