In [None]:
import os
import random
import shutil
from pathlib import Path
import csv
from concurrent.futures import ThreadPoolExecutor, as_completed
import multiprocessing as mp
import time
from tqdm import tqdm

random.seed(42)


# 1. Gather file paths
fake_dirs = [
    r"for-2seconds\testing\fake",
    r"for-2seconds\training\fake",
    r"for-2seconds\validation\fake",
    r"release_in_the_wild\fake",
    r"generated_audio\fake\common_voices_prompts_from_conformer_fastspeech2_pwg_ljspeech",
    r"generated_audio\fake\jsut_multi_band_melgan",
    r"generated_audio\fake\jsut_parallel_wavegan",
    r"generated_audio\fake\ljspeech_full_band_melgan",
    r"generated_audio\fake\ljspeech_hifiGAN",
    r"generated_audio\fake\ljspeech_melgan",
    r"generated_audio\fake\ljspeech_melgan_large",
    r"generated_audio\fake\ljspeech_multi_band_melgan",
    r"generated_audio\fake\ljspeech_parallel_wavegan",
    r"generated_audio\fake\ljspeech_waveglow",
]

real_dirs = [
    r"for-2seconds\testing\real",
    r"for-2seconds\training\real",
    r"for-2seconds\validation\real",
    r"common-voices-mozilla\cv-valid-train\wav-files"
]

def get_files_from_directory(folder, ext=".wav"):
    """Get files from a single directory - used for parallel processing."""
    folder_path = Path(folder)
    if folder_path.exists():
        return list(folder_path.rglob(f"*{ext}"))
    return []

def get_all_audio_files_parallel(folder_list, ext=".wav", max_workers=None):
    """Efficiently gather all audio files using parallel processing."""
    if max_workers is None:
        max_workers = min(len(folder_list), mp.cpu_count())
    
    all_files = []
    
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        future_to_folder = {
            executor.submit(get_files_from_directory, folder, ext): folder 
            for folder in folder_list
        }
        
        for future in as_completed(future_to_folder):
            folder = future_to_folder[future]
            try:
                files = future.result()
                all_files.extend([str(p) for p in files])
                print(f"Scanned {folder}: {len(files)} files")
            except Exception as e:
                print(f"Error scanning {folder}: {e}")
    
    return all_files

def generate_unique_filename(filepath, destination_dir):
    """Generate a unique filename to avoid collisions."""
    path = Path(filepath)
    name = path.stem
    ext = path.suffix
    counter = 1
    
    new_path = destination_dir / f"{name}{ext}"
    while new_path.exists():
        new_path = destination_dir / f"{name}_{counter}{ext}"
        counter += 1
    
    return new_path

def copy_file_safe(src_path, dst_dir):
    """Safely copy a file with unique naming and return manifest entry."""
    try:
        src = Path(src_path)
        dst = generate_unique_filename(src, dst_dir)
        
        # Copy file
        shutil.copy2(src, dst)
        
        return {
            'filepath': str(dst.relative_to(dst_dir.parent.parent)),
            'label': dst_dir.name,
            'success': True
        }
    except Exception as e:
        return {
            'filepath': src_path, # Return the original source path on failure
            'label': dst_dir.name,
            'success': False,
            'error': str(e)
        }

def copy_files_parallel(file_paths, destination_dir, max_workers=None):
    """Copy multiple files in parallel and return successes and failures."""
    if max_workers is None:
        max_workers = min(32, mp.cpu_count() * 2)
    
    destination_dir.mkdir(parents=True, exist_ok=True)
    
    results = []
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = [
            executor.submit(copy_file_safe, src_path, destination_dir)
            for src_path in file_paths
        ]
        
        desc = f"Copying to {destination_dir.relative_to(Path('dataset'))}"
        for future in tqdm(as_completed(futures), total=len(futures), desc=desc):
            results.append(future.result())
            
    # MODIFIED: Separate successful from failed results to be returned
    successful_results = []
    failed_results = []
    
    for r in results:
        if r['success']:
            successful_results.append({
                'filepath': r['filepath'],
                'label': r['label']
            })
        else:
            failed_results.append(r)
    
    # MODIFIED: Return both lists
    return successful_results, failed_results

def check_existing_splits(base_dir):
    """Check which splits already exist and return completed ones."""
    completed_splits = []
    base_path = Path(base_dir)
    
    for split_name in ['train', 'val', 'test']:
        split_dir = base_path / split_name
        manifest_file = base_path / f'{split_name}.csv'
        
        if split_dir.exists() and manifest_file.exists():
            fake_count = len(list((split_dir / 'fake').glob('*.wav')))
            real_count = len(list((split_dir / 'real').glob('*.wav')))
            
            if fake_count > 0 and real_count > 0:
                completed_splits.append(split_name)
                print(f"Found existing {split_name} split: {fake_count} fake, {real_count} real files")
    
    return completed_splits

def split_balanced_dataset(fake_files, real_files, train_ratio=0.8, val_ratio=0.1, test_ratio=0.1):
    """Split dataset ensuring equal fake/real samples in each split."""
    
    random.shuffle(fake_files)
    random.shuffle(real_files)
    
    min_count = min(len(fake_files), len(real_files))
    
    fake_files = fake_files[:min_count]
    real_files = real_files[:min_count]
    
    print(f"Using {min_count} files per class for balanced dataset")
    
    train_size = int(train_ratio * min_count)
    val_size = int(val_ratio * min_count)
    
    print(f"Split sizes - Train: {train_size}, Val: {val_size}, Test: {min_count - train_size - val_size}")
    
    fake_train = fake_files[:train_size]
    fake_val = fake_files[train_size:train_size + val_size]
    fake_test = fake_files[train_size + val_size:]
    
    real_train = real_files[:train_size]
    real_val = real_files[train_size:train_size + val_size]
    real_test = real_files[train_size + val_size:]
    
    return {
        'train': {'fake': fake_train, 'real': real_train},
        'val': {'fake': fake_val, 'real': real_val},
        'test': {'fake': fake_test, 'real': real_test},
    }

# --- Main Execution ---
start_time = time.time()

print("Gathering files in parallel...")
fake_files = get_all_audio_files_parallel(fake_dirs)
real_files = get_all_audio_files_parallel(real_dirs)

print(f"\nTotal fake files found: {len(fake_files)}")
print(f"Total real files found: {len(real_files)}")

print("\nCreating balanced splits...")
splits = split_balanced_dataset(fake_files, real_files)

base_dir = Path('dataset')
log_file = Path('failed_copies.log') # MODIFIED: Define log file path

# Clear the log file at the start of a new run
if log_file.exists():
    log_file.unlink()

print(f"\nChecking for existing splits... (Failed filepaths will be logged to '{log_file}')")
completed_splits = check_existing_splits(base_dir)

if completed_splits:
    print(f"Found existing splits: {', '.join(completed_splits)}")

print("\nCreating directory structure and copying files...")
print(f"Using up to {min(32, mp.cpu_count() * 2)} threads for file copying...")

for split_name, classes in splits.items():
    if split_name in completed_splits:
        print(f"\nSkipping {split_name} split (already exists)...")
        continue
        
    print(f"\nProcessing {split_name} split...")
    all_manifest_rows = []
    
    for label, files in classes.items():
        out_dir = base_dir / split_name / label
        
        # MODIFIED: Capture both successful and failed copies
        successful_copies, failed_copies = copy_files_parallel(files, out_dir)
        all_manifest_rows.extend(successful_copies)

        # MODIFIED: If there were failures, log them to the file
        if failed_copies:
            print(f"  └─ ⚠️  Warning: {len(failed_copies)} files failed to copy for this batch.")
            with open(log_file, 'a', encoding='utf-8') as f:
                f.write(f"\n--- Failures for split='{split_name}', label='{label}' ---\n")
                for failure in failed_copies:
                    # Write the source filepath and the error message
                    f.write(f"{failure['filepath']} | Error: {failure['error']}\n")
    
    manifest_path = base_dir / f'{split_name}.csv'
    with open(manifest_path, 'w', newline='', encoding='utf-8') as f:
        writer = csv.DictWriter(f, fieldnames=['filepath', 'label'])
        writer.writeheader()
        writer.writerows(all_manifest_rows)
    
    print(f"  -> Created manifest: {manifest_path}")
    print(f"  -> Total files in {split_name}: {len(all_manifest_rows)}")

end_time = time.time()
total_time = end_time - start_time

print("\nDataset creation complete!")
print(f"Total time: {total_time:.2f} seconds")

# Final verification and summary
print("\nVerification:")
total_files_copied = 0
for split_name in ['train', 'val', 'test']:
    split_dir = base_dir / split_name
    if not split_dir.exists(): 
        continue
    
    fake_count = len(list((split_dir / 'fake').glob('*.wav')))
    real_count = len(list((split_dir / 'real').glob('*.wav')))
    total_files_copied += fake_count + real_count
    print(f"{split_name}: {fake_count} fake, {real_count} real (balanced: {fake_count == real_count})")

if log_file.exists():
    print(f"\n❗️ A log of failed file copies was created at: {log_file.resolve()}")

if total_time > 0:
    print(f"\nPerformance: {total_files_copied / total_time:.1f} files/second")