### what all to do to the data
1. make sure everything is as balanced as it can be

 a. removing 50% attraction and restaurant   (done) use the code below
 
 b. making sure hotel taxi is not left behind (done)
 
2. fixing the dialogues with the rogue brackets "<"
4. similarity search and removal
5. num of lines >5 
6. text clustering and more?????????


#### removing 50% attraction and restaurant

In [5]:
import json
import random

# Define file paths
INPUT_FILE = 'generated_dialogues.jsonl'          # Replace with your actual file name
OUTPUT_FILE = './post_data/updated_dataset.jsonl'      # The file after removal
REMOVED_FILE = './post_data/removed_entries.jsonl'     # File to store removed entries

# Define the target services
TARGET_SERVICES = {'attraction', 'restaurant'}

# Optional: Set a random seed for reproducibility
RANDOM_SEED = 42
random.seed(RANDOM_SEED)


def main():
    # First pass: Identify all 'dialogue_id's with EXACTLY the target services
    matching_dialogue_ids = []
    total_entries = 0

    with open(INPUT_FILE, 'r', encoding='utf-8') as f:
        for line in f:
            total_entries += 1
            entry = json.loads(line)
            services = set(entry.get('services', []))
            # Check if services exactly matches TARGET_SERVICES (no more, no less)
            if services == TARGET_SERVICES:
                dialogue_id = entry.get('dialogue_id')
                if dialogue_id:
                    matching_dialogue_ids.append(dialogue_id)

    total_matching = len(matching_dialogue_ids)
    print(f"Total entries in dataset: {total_entries}")
    print(f"Entries with both {TARGET_SERVICES}: {total_matching}")

    if total_matching == 0:
        print("No matching entries found. No removal performed.")
        return

    # Determine number of entries to remove (50%)
    num_to_remove = total_matching // 2
    print(f"Number of entries to remove (50%): {num_to_remove}")

    # Randomly select 'dialogue_id's to remove
    entries_to_remove_ids = set(random.sample(matching_dialogue_ids, num_to_remove))
    print("Selected entries for removal.")

    # Second pass: Write to updated and removed files
    updated_count = 0
    removed_count = 0

    with open(INPUT_FILE, 'r', encoding='utf-8') as fin, \
         open(OUTPUT_FILE, 'w', encoding='utf-8') as fout, \
         open(REMOVED_FILE, 'w', encoding='utf-8') as fremoved:

        for line in fin:
            entry = json.loads(line)
            dialogue_id = entry.get('dialogue_id')
            if dialogue_id in entries_to_remove_ids:
                fremoved.write(line)
                removed_count += 1
            else:
                fout.write(line)
                updated_count += 1

    print(f"Updated dataset size: {updated_count}")
    print(f"Removed entries size: {removed_count}")
    print(f"Updated dataset saved to {OUTPUT_FILE}")
    print(f"Removed entries saved to {REMOVED_FILE}")

if __name__ == "__main__":
    main()


Total entries in dataset: 54395
Entries with both {'attraction', 'restaurant'}: 1893
Number of entries to remove (50%): 946
Selected entries for removal.
Updated dataset size: 53449
Removed entries size: 946
Updated dataset saved to ./post_data/updated_dataset.jsonl
Removed entries saved to ./post_data/removed_entries.jsonl


In [6]:
import json
import re
from tqdm import tqdm

# ----------------------------------------
# Step 1: Define the Rogue Bracket Detection Function
# ----------------------------------------

def has_rogue_brackets(text):
    """
    Determines if a given text contains rogue '<' or '>' brackets.
    A rogue '<' is defined as any '<' not followed by a letter (a-z or A-Z),
    which typically denotes the start of an HTML tag.
    Similarly, '>' is considered rogue if not part of a valid structure.
    
    Args:
        text (str): The text to be checked for rogue brackets.
        
    Returns:
        bool: True if a rogue '<' or '>' is found, False otherwise.
    """
    # Regex to find '<' not followed by a letter
    rogue_less_than = re.compile(r'<(?![A-Za-z])')
    # Regex to find '>' not preceded by a letter
    rogue_greater_than = re.compile(r'(?<![A-Za-z])>')
    return bool(rogue_less_than.search(text) or rogue_greater_than.search(text))

# ----------------------------------------
# Step 2: Process the Dataset and Create a Cleaned File
# ----------------------------------------

def clean_dataset(input_file, output_file):
    """
    Processes the input JSONL file, filters out entries with rogue brackets,
    and writes the clean entries to the output JSONL file.
    
    Args:
        input_file (str): Path to the input JSONL file.
        output_file (str): Path to the output cleaned JSONL file.
    """
    total_entries = 0
    excluded_entries = 0
    included_entries = 0
    
    try:
        with open(input_file, 'r', encoding='utf-8') as fin, \
             open(output_file, 'w', encoding='utf-8') as fout:
             
            for line in fin:
                total_entries += 1
                try:
                    entry = json.loads(line)
                except json.JSONDecodeError:
                    print(f"❌ JSONDecodeError at line {total_entries}. Skipping this entry.")
                    excluded_entries += 1
                    continue  # Skip malformed JSON entries
                
                turns = entry.get('turns', [])
                has_problem = False
                for turn in turns:
                    utterance = turn.get('utterance', '')
                    intent = turn.get('intent', '')
                    assistant_response = turn.get('assistant_response', '')
                    
                    if (has_rogue_brackets(utterance) or
                        has_rogue_brackets(intent) or
                        has_rogue_brackets(assistant_response)):
                        has_problem = True
                        break  # No need to check further turns in this entry
                
                if not has_problem:
                    fout.write(json.dumps(entry, ensure_ascii=False) + '\n')
                    included_entries += 1
                else:
                    excluded_entries += 1
        
        print(f"📊 Total entries processed: {total_entries}")
        print(f"✅ Entries included in cleaned dataset: {included_entries}")
        print(f"⚠️ Entries excluded due to rogue brackets: {excluded_entries}")
        print(f"📄 Cleaned dataset saved to: {output_file}")
    
    except FileNotFoundError:
        print(f"❌ Error: The file '{input_file}' does not exist. Please check the filename and try again.")
    except Exception as e:
        print(f"❌ An unexpected error occurred: {e}")

# ----------------------------------------
# Step 3: Execute the Cleaning Process
# ----------------------------------------

if __name__ == "__main__":
    INPUT_FILE = './post_data/updated_dataset.jsonl'  # Replace with your actual dataset filename
    CLEANED_FILE = './post_data/cleaned_updated_dataset.jsonl'  # Output filename
    
    clean_dataset(INPUT_FILE, CLEANED_FILE)


📊 Total entries processed: 53449
✅ Entries included in cleaned dataset: 50723
⚠️ Entries excluded due to rogue brackets: 2726
📄 Cleaned dataset saved to: ./post_data/cleaned_updated_dataset.jsonl


In [1]:
import json
import numpy as np
from sentence_transformers import SentenceTransformer
import faiss
import logging
from tqdm import tqdm
from collections import defaultdict
import hashlib

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

def compute_dialogue_hash(dialogue):
    """
    Computes a hash of the dialogue's core content to catch exact duplicates.
    """
    # Extract only the essential dialogue content
    dialogue_content = []
    for turn in dialogue.get('turns', []):
        content = {
            'utterance': turn.get('utterance', '').lower().strip(),
            'intent': turn.get('intent', '').lower().strip(),
            'response': turn.get('assistant_response', '').lower().strip()
        }
        dialogue_content.append(content)
    
    # Create a stable string representation and hash it
    content_str = json.dumps(dialogue_content, sort_keys=True)
    return hashlib.md5(content_str.encode()).hexdigest()

def get_dialogue_text(dialogue):
    """
    Creates a comprehensive text representation of the dialogue.
    """
    text_parts = []
    
    # Add scenario context if available
    if 'generated_scenario' in dialogue:
        text_parts.append(dialogue['generated_scenario'])
    
    # Add each turn's content
    for turn in dialogue.get('turns', []):
        utterance = turn.get('utterance', '').strip()
        intent = turn.get('intent', '').strip()
        response = turn.get('assistant_response', '').strip()
        
        if utterance and intent and response:
            turn_text = f"Intent: {intent}. User: {utterance}. Assistant: {response}"
            text_parts.append(turn_text)
    
    return ' '.join(text_parts)

def find_duplicates(dialogues, similarity_threshold=0.9, batch_size=8, nlist=1000, nprobe=50):
    """
    Identifies duplicate dialogues using a two-stage approach:
    1. Exact matching via content hashing
    2. Semantic similarity via embeddings
    
    Args:
        dialogues: List of dialogue dictionaries
        similarity_threshold: Threshold for semantic similarity (default: 0.85)
        batch_size: Batch size for embedding generation (default: 64)
        nlist: Number of clusters for FAISS index (default: 1000)
        nprobe: Number of clusters to search in FAISS (default: 50)
    """
    logging.info("Starting duplicate detection...")
    
    # Stage 1: Find exact duplicates using content hashing
    content_hashes = {}
    exact_duplicates = set()
    
    for idx, dialogue in enumerate(tqdm(dialogues, desc="Finding exact duplicates")):
        dialogue_hash = compute_dialogue_hash(dialogue)
        if dialogue_hash in content_hashes:
            exact_duplicates.add(idx)
        else:
            content_hashes[dialogue_hash] = idx
    
    logging.info(f"Found {len(exact_duplicates)} exact duplicates")
    
    # Stage 2: Find semantic duplicates using embeddings
    # Only process dialogues that weren't identified as exact duplicates
    unique_indices = [i for i in range(len(dialogues)) if i not in exact_duplicates]
    unique_dialogues = [dialogues[i] for i in unique_indices]
    
    if not unique_dialogues:
        return exact_duplicates
    
    # Generate embeddings for remaining dialogues
    model = SentenceTransformer('dunzhang/stella_en_400M_v5',device="cuda",trust_remote_code=True)  # Using a more powerful model
    texts = [get_dialogue_text(d) for d in unique_dialogues]
    
    # Generate embeddings in batches
    embeddings = []
    for i in tqdm(range(0, len(texts), batch_size), desc="Generating embeddings"):
        batch_texts = texts[i:i + batch_size]
        batch_embeddings = model.encode(batch_texts, show_progress_bar=False)
        embeddings.extend(batch_embeddings)
    
    embeddings = np.array(embeddings).astype('float32')
    faiss.normalize_L2(embeddings)
    
    # Create FAISS index
    dimension = embeddings.shape[1]
    quantizer = faiss.IndexFlatIP(dimension)
    index = faiss.IndexIVFFlat(quantizer, dimension, nlist, faiss.METRIC_INNER_PRODUCT)
    
    # Train and add vectors to index
    index.train(embeddings)
    index.add(embeddings)
    index.nprobe = nprobe
    
    # Perform similarity search
    k = min(50, len(embeddings))  # Increased k for better duplicate detection
    distances, indices = index.search(embeddings, k)
    
    # Find semantic duplicates
    semantic_duplicates = set()
    for i in tqdm(range(len(embeddings)), desc="Finding semantic duplicates"):
        if i in semantic_duplicates:
            continue
        
        # Find all similar dialogues above threshold
        similar_indices = [j for j, dist in zip(indices[i], distances[i])
                           if dist >= similarity_threshold and j != i]
        
        # Add to duplicates set
        semantic_duplicates.update(similar_indices)
    
    # Map back to original indices
    semantic_duplicates = {unique_indices[i] for i in semantic_duplicates}
    
    # Combine both sets of duplicates
    all_duplicates = exact_duplicates | semantic_duplicates
    
    logging.info(f"Found {len(semantic_duplicates)} semantic duplicates")
    logging.info(f"Total duplicates found: {len(all_duplicates)}")
    
    return all_duplicates

def save_deduplicated_dataset(dialogues, duplicates, output_unique_path, output_duplicates_path):
    """
    Saves the deduplicated dataset and the duplicates to separate JSONL files while preserving the original format.
    
    Args:
        dialogues: List of dialogue dictionaries
        duplicates: Set of indices representing duplicate dialogues
        output_unique_path: File path to save unique dialogues
        output_duplicates_path: File path to save duplicate dialogues
    """
    unique_dialogues = []
    duplicate_dialogues = []
    
    for i, dialogue in enumerate(dialogues):
        if i in duplicates:
            duplicate_dialogues.append(dialogue)
        else:
            unique_dialogues.append(dialogue)
    
    # Save unique dialogues
    with open(output_unique_path, 'w', encoding='utf-8') as f_unique:
        for dialogue in unique_dialogues:
            json_line = json.dumps(dialogue, ensure_ascii=False)
            f_unique.write(json_line + '\n')
    
    # Save duplicate dialogues
    with open(output_duplicates_path, 'w', encoding='utf-8') as f_duplicates:
        for dialogue in duplicate_dialogues:
            json_line = json.dumps(dialogue, ensure_ascii=False)
            f_duplicates.write(json_line + '\n')
    
    logging.info(f"Original dialogues: {len(dialogues)}")
    logging.info(f"Unique dialogues: {len(unique_dialogues)}")
    logging.info(f"Removed duplicates: {len(duplicates)}")
    logging.info(f"Deduplicated dataset saved to {output_unique_path}")
    logging.info(f"Duplicates saved to {output_duplicates_path}")

def main():
    input_path = './post_data/cleaned_updated_dataset.jsonl'
    output_unique_path = 'stella_deduplicated_dataset.jsonl'
    output_duplicates_path = 'stella_duplicates_dataset.jsonl'  # New file for duplicates
    
    # Load dialogues
    logging.info(f"Loading dialogues from {input_path}")
    dialogues = []
    with open(input_path, 'r', encoding='utf-8') as f:
        for line_number, line in enumerate(f, 1):
            try:
                dialogue = json.loads(line.strip())
                dialogues.append(dialogue)
            except json.JSONDecodeError as e:
                logging.error(f"Error parsing line {line_number}: {e}")
    
    # Find duplicates
    duplicates = find_duplicates(dialogues)
    
    # Save deduplicated dataset and duplicates
    save_deduplicated_dataset(dialogues, duplicates, output_unique_path, output_duplicates_path)

if __name__ == "__main__":
    main()


  from tqdm.autonotebook import tqdm, trange
2024-11-27 22:14:04,442 - INFO - Loading dialogues from ./post_data/cleaned_updated_dataset.jsonl
2024-11-27 22:14:06,293 - INFO - Starting duplicate detection...
Finding exact duplicates: 100%|██████████| 50723/50723 [00:01<00:00, 35046.98it/s]
2024-11-27 22:14:07,743 - INFO - Found 0 exact duplicates
2024-11-27 22:14:07,754 - INFO - Load pretrained SentenceTransformer: dunzhang/stella_en_400M_v5
Some weights of the model checkpoint at dunzhang/stella_en_400M_v5 were not used when initializing NewModel: ['new.pooler.dense.bias', 'new.pooler.dense.weight']
- This IS expected if you are initializing NewModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing NewModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassif