# Optimized InLegalBERT Embedding Pipeline

This notebook implements an **efficient streaming approach** that:
- Processes files in batches instead of loading everything into memory
- Generates embeddings and saves them immediately
- Uses minimal memory footprint
- Supports resume functionality for interrupted processing

## Key Improvements:
- **Memory Efficient**: Only keeps current batch in memory
- **Resumable**: Can continue from where it left off
- **Incremental Saving**: Saves results as they're generated
- **Progress Tracking**: Real-time progress monitoring

In [None]:
# Import required libraries
import pandas as pd
import os
import numpy as np
import json
import torch
from transformers import AutoTokenizer, AutoModel
from sklearn.preprocessing import LabelEncoder
from tqdm import tqdm
import warnings
import time
import gc
from pathlib import Path

warnings.filterwarnings('ignore')

print("Torch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("Device:", torch.device("cuda" if torch.cuda.is_available() else "cpu"))

In [None]:
# Configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Dataset paths
base_path = "/home/uttam/B.Tech Major Project/nyaya/server/dataset/Hier_BiLSTM_CRF"
train_path = os.path.join(base_path, "train")
test_path = os.path.join(base_path, "test") 
val_path = os.path.join(base_path, "val", "val")

# Output path for embeddings
embeddings_output_path = "/home/uttam/B.Tech Major Project/nyaya/server/embeddings"
os.makedirs(embeddings_output_path, exist_ok=True)

# Processing configuration
FILES_PER_BATCH = 50  # Process 50 files at a time
EMBEDDING_BATCH_SIZE = 8 if device.type == 'cuda' else 4
MAX_LENGTH = 512

print(f"Configuration:")
print(f"  Files per batch: {FILES_PER_BATCH}")
print(f"  Embedding batch size: {EMBEDDING_BATCH_SIZE}")
print(f"  Max sequence length: {MAX_LENGTH}")
print(f"  Output directory: {embeddings_output_path}")

In [None]:
# Load model once and reuse
print("Loading InLegalBERT model and tokenizer...")

try:
    tokenizer = AutoTokenizer.from_pretrained("law-ai/InLegalBERT")
    model = AutoModel.from_pretrained("law-ai/InLegalBERT")
    model = model.to(device)
    model.eval()
    
    print(f"✓ InLegalBERT loaded successfully!")
    print(f"✓ Hidden size: {model.config.hidden_size}")
    
except Exception as e:
    print(f"✗ Error loading InLegalBERT: {e}")
    raise

In [None]:
def get_bert_embeddings_batch(texts, tokenizer, model, device, max_length=512, batch_size=8):
    """Generate embeddings for a batch of texts"""
    if not texts:
        return np.array([])
    
    embeddings = []
    
    for i in range(0, len(texts), batch_size):
        batch_texts = texts[i:i + batch_size]
        
        # Tokenize
        encoded = tokenizer(
            batch_texts,
            padding=True,
            truncation=True,
            max_length=max_length,
            return_tensors='pt'
        )
        
        # Move to device
        encoded = {key: val.to(device) for key, val in encoded.items()}
        
        # Generate embeddings
        with torch.no_grad():
            outputs = model(**encoded)
            cls_embeddings = outputs.last_hidden_state[:, 0, :]
            batch_embeddings = cls_embeddings.cpu().numpy()
            embeddings.append(batch_embeddings)
    
    if embeddings:
        return np.vstack(embeddings)
    else:
        return np.array([])


def append_to_json_file(data, file_path):
    """Append data to JSON file (create if doesn't exist)"""
    if os.path.exists(file_path):
        # Read existing data
        with open(file_path, 'r') as f:
            try:
                existing_data = json.load(f)
            except json.JSONDecodeError:
                existing_data = []
        
        # Append new data
        existing_data.extend(data)
    else:
        existing_data = data
    
    # Write back to file
    with open(file_path, 'w') as f:
        json.dump(existing_data, f, indent=2)
    
    return len(existing_data)


def get_progress_file(dataset_name):
    """Get progress tracking file for resuming"""
    return os.path.join(embeddings_output_path, f"{dataset_name}_progress.json")


def save_progress(dataset_name, processed_files):
    """Save processing progress"""
    progress_file = get_progress_file(dataset_name)
    with open(progress_file, 'w') as f:
        json.dump({"processed_files": processed_files}, f)


def load_progress(dataset_name):
    """Load processing progress"""
    progress_file = get_progress_file(dataset_name)
    if os.path.exists(progress_file):
        with open(progress_file, 'r') as f:
            return json.load(f).get("processed_files", [])
    return []

print("✓ Helper functions defined")

In [None]:
def process_dataset_streaming(dataset_path, output_filename, dataset_name, has_labels=True, 
                            label_mapping=None, files_per_batch=50):
    """
    Process dataset in streaming fashion: load batch -> generate embeddings -> save -> repeat
    
    Args:
        dataset_path: Path to dataset directory
        output_filename: Output JSON filename
        dataset_name: Name for progress tracking
        has_labels: Whether dataset has labels
        label_mapping: Dictionary mapping label names to numbers
        files_per_batch: Number of files to process in each batch
    """
    if not os.path.exists(dataset_path):
        print(f"⚠️ Dataset path not found: {dataset_path}")
        return 0
    
    # Get all files
    all_files = [f for f in os.listdir(dataset_path) if f.endswith('.txt')]
    print(f"📁 Found {len(all_files)} files in {dataset_name}")
    
    if not all_files:
        print(f"⚠️ No .txt files found in {dataset_path}")
        return 0
    
    # Load progress and filter already processed files
    processed_files = load_progress(dataset_name)
    remaining_files = [f for f in all_files if f not in processed_files]
    
    print(f"📋 Progress: {len(processed_files)} already processed, {len(remaining_files)} remaining")
    
    if not remaining_files:
        print(f"✅ All files already processed for {dataset_name}")
        return len(processed_files)
    
    output_file_path = os.path.join(embeddings_output_path, output_filename)
    total_processed = len(processed_files)
    
    # Process files in batches
    for batch_start in tqdm(range(0, len(remaining_files), files_per_batch), 
                           desc=f"Processing {dataset_name} batches"):
        
        batch_files = remaining_files[batch_start:batch_start + files_per_batch]
        print(f"\n🔄 Processing batch {batch_start//files_per_batch + 1}: {len(batch_files)} files")
        
        # Load current batch
        batch_texts = []
        batch_labels = []
        batch_label_numbers = []
        
        for filename in batch_files:
            file_path = os.path.join(dataset_path, filename)
            try:
                if has_labels:
                    df = pd.read_csv(file_path, sep="\t", header=None, names=["text", "label"])
                    if not df.empty:
                        # Process labels
                        df["label"] = df["label"].fillna("None").astype(str).str.strip()
                        df["label"] = df["label"].replace({"none": "None", "NONE": "None"})
                        
                        batch_texts.extend(df["text"].tolist())
                        labels = df["label"].tolist()
                        batch_labels.extend(labels)
                        
                        # Map labels to numbers
                        if label_mapping:
                            label_nums = [label_mapping.get(label, 6) for label in labels]  # 6 for unknown
                            batch_label_numbers.extend(label_nums)
                        else:
                            batch_label_numbers.extend([0] * len(labels))  # Default
                else:
                    df = pd.read_csv(file_path, sep="\t", header=None, names=["text"])
                    if not df.empty:
                        batch_texts.extend(df["text"].astype(str).str.strip().tolist())
                        
            except Exception as e:
                print(f"⚠️ Error loading {filename}: {e}")
                continue
        
        if not batch_texts:
            print(f"⚠️ No valid texts in current batch")
            continue
        
        print(f"   Loaded {len(batch_texts)} text samples")
        
        # Generate embeddings for current batch
        print(f"   🧠 Generating embeddings...")
        batch_embeddings = get_bert_embeddings_batch(
            batch_texts, tokenizer, model, device, MAX_LENGTH, EMBEDDING_BATCH_SIZE
        )
        
        if batch_embeddings.size == 0:
            print(f"⚠️ Failed to generate embeddings for batch")
            continue
        
        print(f"   ✓ Generated embeddings shape: {batch_embeddings.shape}")
        
        # Create JSON data for current batch
        batch_json_data = []
        for i in range(len(batch_texts)):
            data_point = {
                "text": batch_texts[i],
                "vector": batch_embeddings[i].tolist()
            }
            
            if has_labels and batch_labels and batch_label_numbers:
                data_point["classname"] = batch_labels[i] if i < len(batch_labels) else None
                data_point["classnumber"] = int(batch_label_numbers[i]) if i < len(batch_label_numbers) else None
            else:
                data_point["classname"] = None
                data_point["classnumber"] = None
            
            batch_json_data.append(data_point)
        
        # Save current batch results
        print(f"   💾 Saving {len(batch_json_data)} samples...")
        total_samples = append_to_json_file(batch_json_data, output_file_path)
        
        # Update progress
        processed_files.extend(batch_files)
        save_progress(dataset_name, processed_files)
        total_processed += len(batch_files)
        
        print(f"   ✅ Batch saved. Total samples in file: {total_samples}")
        
        # Free memory
        del batch_texts, batch_embeddings, batch_json_data
        if has_labels:
            del batch_labels, batch_label_numbers
        gc.collect()
        torch.cuda.empty_cache() if torch.cuda.is_available() else None
    
    print(f"\n✅ {dataset_name} processing completed!")
    print(f"   Total files processed: {total_processed}")
    
    # Clean up progress file
    progress_file = get_progress_file(dataset_name)
    if os.path.exists(progress_file):
        os.remove(progress_file)
    
    return total_processed

print("✓ Streaming processing function defined")

In [None]:
# Define label mapping
label_to_num = {
    'Facts': 0,
    'Reasoning': 1, 
    'Arguments of Respondent': 2,
    'Arguments of Petitioner': 3,
    'Decision': 4,
    'Issue': 5,
    'None': 6
}

print("Label mapping:")
for label, num in label_to_num.items():
    print(f"  {num}: {label}")

# Save label mapping
label_file_path = os.path.join(embeddings_output_path, 'label_mapping_inlegalbert.json')
with open(label_file_path, 'w') as f:
    json.dump({v: k for k, v in label_to_num.items()}, f, indent=2)
print(f"\n✓ Label mapping saved to: {label_file_path}")

In [None]:
# Process all datasets using streaming approach with NPZ format
print("="*80)
print("PROCESSING ALL DATASETS WITH STREAMING APPROACH (NPZ FORMAT)")
print("="*80)

start_time = time.time()

# Process training data (with labels)
print("\n🎯 Processing Training Data...")
train_count = process_dataset_streaming_npz(
    dataset_path=train_path,
    output_filename='train_embeddings_inlegalbert.npz',
    dataset_name='train',
    has_labels=True,
    label_mapping=label_to_num,
    files_per_batch=FILES_PER_BATCH
)

# Process test data (no labels)
print("\n🎯 Processing Test Data...")
test_count = process_dataset_streaming_npz(
    dataset_path=test_path,
    output_filename='test_embeddings_inlegalbert.npz',
    dataset_name='test',
    has_labels=False,
    files_per_batch=FILES_PER_BATCH
)

# Process validation data (no labels)
print("\n🎯 Processing Validation Data...")
val_count = process_dataset_streaming_npz(
    dataset_path=val_path,
    output_filename='val_embeddings_inlegalbert.npz',
    dataset_name='val',
    has_labels=False,
    files_per_batch=FILES_PER_BATCH
)

total_time = time.time() - start_time
total_files = train_count + test_count + val_count

print("\n" + "="*80)
print("PROCESSING COMPLETED!")
print("="*80)
print(f"✅ Training files processed: {train_count}")
print(f"✅ Test files processed: {test_count}")
print(f"✅ Validation files processed: {val_count}")
print(f"✅ Total files processed: {total_files}")
print(f"✅ Total processing time: {total_time:.2f} seconds ({total_time/60:.1f} minutes)")
print(f"✅ Average time per file: {total_time/total_files:.3f} seconds")
print(f"✅ Output directory: {embeddings_output_path}")

In [None]:
# Display final file information
print("📂 Generated Files:")
output_files = [
    'train_embeddings_inlegalbert.npz',
    'test_embeddings_inlegalbert.npz', 
    'val_embeddings_inlegalbert.npz',
    'label_mapping_inlegalbert.json'
]

total_size_mb = 0
for filename in output_files:
    filepath = os.path.join(embeddings_output_path, filename)
    if os.path.exists(filepath):
        size_mb = os.path.getsize(filepath) / (1024*1024)
        total_size_mb += size_mb
        print(f"  ✓ {filename}: {size_mb:.1f} MB")
        
        # Show sample info for NPZ files
        if filename.endswith('.npz'):
            with np.load(filepath, allow_pickle=True) as data:
                vectors = data['vectors']
                texts = data['texts']
                metadata = data['metadata'].item()
                print(f"    📊 Samples: {len(vectors)}, Embedding dim: {vectors.shape[1]}")
                print(f"    📅 Created: {metadata.get('created_date', 'Unknown')}")
    else:
        print(f"  ✗ {filename}: Not found")

print(f"\n📊 Total output size: {total_size_mb:.1f} MB")
print(f"💾 NPZ format saves ~90% space compared to JSON!")

# Memory cleanup
print("\n🧹 Final cleanup...")
del model, tokenizer
gc.collect()
torch.cuda.empty_cache() if torch.cuda.is_available() else None
print("✅ Memory cleanup completed")

In [None]:
# 📖 How to Load and Use NPZ Files
print("=" * 60)
print("HOW TO LOAD AND USE THE GENERATED NPZ FILES")
print("=" * 60)

def load_embeddings_from_npz(npz_file_path):
    """Load embeddings from NPZ file"""
    if not os.path.exists(npz_file_path):
        print(f"File not found: {npz_file_path}")
        return None
    
    with np.load(npz_file_path, allow_pickle=True) as data:
        result = {
            'vectors': data['vectors'],
            'texts': data['texts'],
            'metadata': data['metadata'].item()
        }
        
        # Add labels if they exist
        if 'labels' in data:
            result['labels'] = data['labels']
            result['label_numbers'] = data['label_numbers']
        
        return result

# Example usage
print("🔍 Example: Loading training embeddings...")
train_npz_path = os.path.join(embeddings_output_path, 'train_embeddings_inlegalbert.npz')

if os.path.exists(train_npz_path):
    # Load the data
    train_data = load_embeddings_from_npz(train_npz_path)
    
    print(f"✓ Loaded training data:")
    print(f"  📊 Vectors shape: {train_data['vectors'].shape}")
    print(f"  📊 Number of texts: {len(train_data['texts'])}")
    print(f"  📊 Has labels: {'labels' in train_data}")
    print(f"  📊 Embedding dimension: {train_data['vectors'].shape[1]}")
    
    if 'labels' in train_data:
        print(f"  📊 Unique labels: {np.unique(train_data['labels'])}")
        print(f"  📊 Label distribution:")
        unique_labels, counts = np.unique(train_data['labels'], return_counts=True)
        for label, count in zip(unique_labels, counts):
            print(f"    {label}: {count} samples")
    
    # Show sample
    print(f"\n📝 Sample data:")
    print(f"  Text: {train_data['texts'][0][:100]}...")
    print(f"  Vector (first 5 values): {train_data['vectors'][0][:5]}")
    if 'labels' in train_data:
        print(f"  Label: {train_data['labels'][0]}")
        print(f"  Label number: {train_data['label_numbers'][0]}")
    
    print(f"\n💡 Usage examples:")
    print(f"  # Get all embeddings: embeddings = train_data['vectors']")
    print(f"  # Get all texts: texts = train_data['texts']")
    print(f"  # Get all labels: labels = train_data['labels']")
    print(f"  # Convert to PyTorch tensor: torch.tensor(train_data['vectors'])")
    print(f"  # Use for sklearn: from sklearn.model_selection import train_test_split")
    
else:
    print("⚠️ Training NPZ file not found. Run the processing cells first.")

print(f"\n📁 File locations:")
for dataset in ['train', 'test', 'val']:
    npz_path = os.path.join(embeddings_output_path, f'{dataset}_embeddings_inlegalbert.npz')
    if os.path.exists(npz_path):
        size_mb = os.path.getsize(npz_path) / (1024*1024)
        print(f"  ✓ {dataset}: {npz_path} ({size_mb:.1f} MB)")
    else:
        print(f"  ✗ {dataset}: Not generated yet")

print(f"\n" + "=" * 60)

# 🔍 Understanding the Numbers in InLegalBERT Processing

This section clarifies the different numerical parameters used in the embedding pipeline.

## Key Numbers Explained:

### 1. **512 = Max Sequence Length (Token Limit)**
- **What it is**: Maximum number of tokens (words/subwords) InLegalBERT can process at once
- **Where used**: In tokenization step
- **Why 512**: BERT models have a fixed context window limit
- **Effect**: Longer texts get truncated to 512 tokens

### 2. **768 = Embedding Dimension (Vector Size)**  
- **What it is**: Size of the output embedding vector for each text
- **Where used**: Output of InLegalBERT model
- **Why 768**: Built into InLegalBERT architecture (BERT-base standard)
- **Effect**: Each text becomes a 768-dimensional vector

### 3. **50 = Files Per Batch**
- **What it is**: Number of files processed together in each batch
- **Where used**: Streaming processing to manage memory
- **Why 50**: Balance between memory usage and efficiency
- **Effect**: Processes 50 files → generates embeddings → saves → repeats

### 4. **8/4 = Embedding Batch Size**
- **What it is**: Number of texts processed simultaneously for embedding generation
- **Where used**: Inside embedding generation function
- **Why 8/4**: GPU memory management (8 for CUDA, 4 for CPU)
- **Effect**: Processes multiple texts in parallel for speed

In [None]:
# 🔬 Practical Demonstration of the 512 Token Limit
print("=" * 60)
print("DEMONSTRATION: 512 TOKEN LIMIT vs 768 EMBEDDING SIZE")
print("=" * 60)

# Sample legal text to demonstrate tokenization
sample_legal_text = """
The Hon'ble Supreme Court of India in the landmark case of Kesavananda Bharati vs State of Kerala 
held that the basic structure of the Constitution cannot be altered by Parliament through constitutional 
amendments. This doctrine ensures that fundamental principles like democracy, secularism, federalism, 
and judicial review remain intact. The Court emphasized that while Parliament has wide powers to amend 
the Constitution under Article 368, these powers are not unlimited and cannot be used to destroy the 
very essence of the constitutional framework that gives it legitimacy.
"""

# If model and tokenizer are available, demonstrate tokenization
if 'tokenizer' in globals():
    print("🔤 Tokenizing sample legal text...")
    
    # Tokenize the sample text
    tokens = tokenizer.tokenize(sample_legal_text)
    encoded = tokenizer(
        sample_legal_text,
        padding=False,
        truncation=True,
        max_length=512,  # This is our MAX_LENGTH
        return_tensors='pt'
    )
    
    print(f"📊 Original text length: {len(sample_legal_text)} characters")
    print(f"📊 Number of tokens: {len(tokens)}")
    print(f"📊 Encoded input_ids shape: {encoded['input_ids'].shape}")
    print(f"📊 Max allowed tokens: 512")
    
    # Show first few tokens
    print(f"\n🔤 First 10 tokens: {tokens[:10]}")
    print(f"🔤 Last 10 tokens: {tokens[-10:]}")
    
    # If we had a model, this would show embedding dimension
    if 'model' in globals():
        print(f"\n🧠 If we generate embeddings:")
        print(f"   Input shape: (1, {encoded['input_ids'].shape[1]}) tokens")
        print(f"   Output shape: (1, 768) embedding vector")
        print(f"   ↳ 512 is INPUT limit, 768 is OUTPUT size")
    
    print(f"\n💡 Key Insight:")
    print(f"   • 512 = Maximum INPUT tokens (text length limit)")
    print(f"   • 768 = OUTPUT embedding dimension (vector size)")
    print(f"   • These are completely different concepts!")
    
else:
    print("⚠️ Tokenizer not available (run previous cells first)")
    print("\n📝 Explanation without demonstration:")
    print("   • MAX_LENGTH = 512 means text longer than 512 tokens gets truncated")
    print("   • Embedding dimension = 768 means output vector has 768 numbers")
    print("   • Example: 'This is a very long legal document...' → [0.1, -0.2, 0.3, ... 768 numbers]")

print(f"\n" + "=" * 60)

# 💾 NPZ Format: Space-Efficient Storage

**Problem**: JSON format is very space-inefficient for large numerical arrays
**Solution**: Use NumPy's NPZ format for compressed binary storage

## NPZ vs JSON Comparison:

| Format | File Size | Load Speed | Memory Usage | Human Readable |
|--------|-----------|------------|--------------|----------------|
| **JSON** | ~10x larger | Slow | High | ✅ Yes |
| **NPZ** | Compact | Fast | Low | ❌ No |

## NPZ File Structure:
Each NPZ file will contain:
- `vectors`: 2D array of embeddings (samples × 768)
- `texts`: Array of text strings  
- `labels`: Array of class names (for training data)
- `label_numbers`: Array of class numbers (for training data)
- `metadata`: Dictionary with dataset info

In [None]:
# 🔧 Updated Functions for NPZ Format
def append_to_npz_file(batch_vectors, batch_texts, batch_labels, batch_label_numbers, file_path, dataset_name):
    """Append data to NPZ file (create if doesn't exist)"""
    
    if os.path.exists(file_path):
        # Load existing data
        with np.load(file_path, allow_pickle=True) as existing:
            existing_vectors = existing['vectors']
            existing_texts = existing['texts']
            existing_labels = existing.get('labels', np.array([]))
            existing_label_numbers = existing.get('label_numbers', np.array([]))
        
        # Concatenate new data
        all_vectors = np.vstack([existing_vectors, batch_vectors])
        all_texts = np.concatenate([existing_texts, batch_texts])
        
        if len(existing_labels) > 0 and len(batch_labels) > 0:
            all_labels = np.concatenate([existing_labels, batch_labels])
            all_label_numbers = np.concatenate([existing_label_numbers, batch_label_numbers])
        elif len(batch_labels) > 0:
            all_labels = np.array(batch_labels)
            all_label_numbers = np.array(batch_label_numbers)
        else:
            all_labels = existing_labels
            all_label_numbers = existing_label_numbers
    else:
        # First time - create new arrays
        all_vectors = batch_vectors
        all_texts = np.array(batch_texts)
        all_labels = np.array(batch_labels) if batch_labels else np.array([])
        all_label_numbers = np.array(batch_label_numbers) if batch_label_numbers else np.array([])
    
    # Create metadata
    metadata = {
        'dataset_name': dataset_name,
        'embedding_model': 'law-ai/InLegalBERT',
        'embedding_dim': all_vectors.shape[1],
        'total_samples': len(all_vectors),
        'has_labels': len(all_labels) > 0,
        'created_date': time.strftime('%Y-%m-%d %H:%M:%S')
    }
    
    # Save to NPZ format
    if len(all_labels) > 0:
        np.savez_compressed(
            file_path,
            vectors=all_vectors,
            texts=all_texts,
            labels=all_labels,
            label_numbers=all_label_numbers,
            metadata=metadata
        )
    else:
        np.savez_compressed(
            file_path,
            vectors=all_vectors,
            texts=all_texts,
            metadata=metadata
        )
    
    return len(all_vectors)


def save_labels_to_json(label_mapping, file_path):
    """Save label mapping to JSON for reference"""
    with open(file_path, 'w') as f:
        json.dump(label_mapping, f, indent=2)


print("✓ NPZ helper functions defined")
print("✓ NPZ format will save ~90% storage space compared to JSON")

In [None]:
# 🔄 Updated Streaming Processing Function for NPZ Format
def process_dataset_streaming_npz(dataset_path, output_filename, dataset_name, has_labels=True, 
                                label_mapping=None, files_per_batch=50):
    """
    Process dataset in streaming fashion and save to NPZ format
    
    Args:
        dataset_path: Path to dataset directory
        output_filename: Output NPZ filename  
        dataset_name: Name for progress tracking
        has_labels: Whether dataset has labels
        label_mapping: Dictionary mapping label names to numbers
        files_per_batch: Number of files to process in each batch
    """
    if not os.path.exists(dataset_path):
        print(f"⚠️ Dataset path not found: {dataset_path}")
        return 0
    
    # Get all files
    all_files = [f for f in os.listdir(dataset_path) if f.endswith('.txt')]
    print(f"📁 Found {len(all_files)} files in {dataset_name}")
    
    if not all_files:
        print(f"⚠️ No .txt files found in {dataset_path}")
        return 0
    
    # Load progress and filter already processed files
    processed_files = load_progress(dataset_name)
    remaining_files = [f for f in all_files if f not in processed_files]
    
    print(f"📋 Progress: {len(processed_files)} already processed, {len(remaining_files)} remaining")
    
    if not remaining_files:
        print(f"✅ All files already processed for {dataset_name}")
        return len(processed_files)
    
    output_file_path = os.path.join(embeddings_output_path, output_filename)
    total_processed = len(processed_files)
    
    # Process files in batches
    for batch_start in tqdm(range(0, len(remaining_files), files_per_batch), 
                           desc=f"Processing {dataset_name} batches"):
        
        batch_files = remaining_files[batch_start:batch_start + files_per_batch]
        print(f"\\n🔄 Processing batch {batch_start//files_per_batch + 1}: {len(batch_files)} files")
        
        # Load current batch
        batch_texts = []
        batch_labels = []
        batch_label_numbers = []
        
        for filename in batch_files:
            file_path = os.path.join(dataset_path, filename)
            try:
                if has_labels:
                    df = pd.read_csv(file_path, sep="\\t", header=None, names=["text", "label"])
                    if not df.empty:
                        # Process labels
                        df["label"] = df["label"].fillna("None").astype(str).str.strip()
                        df["label"] = df["label"].replace({"none": "None", "NONE": "None"})
                        
                        batch_texts.extend(df["text"].tolist())
                        labels = df["label"].tolist()
                        batch_labels.extend(labels)
                        
                        # Map labels to numbers
                        if label_mapping:
                            label_nums = [label_mapping.get(label, 6) for label in labels]  # 6 for unknown
                            batch_label_numbers.extend(label_nums)
                        else:
                            batch_label_numbers.extend([0] * len(labels))  # Default
                else:
                    df = pd.read_csv(file_path, sep="\\t", header=None, names=["text"])
                    if not df.empty:
                        batch_texts.extend(df["text"].astype(str).str.strip().tolist())
                        
            except Exception as e:
                print(f"⚠️ Error loading {filename}: {e}")
                continue
        
        if not batch_texts:
            print(f"⚠️ No valid texts in current batch")
            continue
        
        print(f"   Loaded {len(batch_texts)} text samples")
        
        # Generate embeddings for current batch
        print(f"   🧠 Generating embeddings...")
        batch_embeddings = get_bert_embeddings_batch(
            batch_texts, tokenizer, model, device, MAX_LENGTH, EMBEDDING_BATCH_SIZE
        )
        
        if batch_embeddings.size == 0:
            print(f"⚠️ Failed to generate embeddings for batch")
            continue
        
        print(f"   ✓ Generated embeddings shape: {batch_embeddings.shape}")
        
        # Save current batch results to NPZ
        print(f"   💾 Saving {len(batch_texts)} samples to NPZ...")
        total_samples = append_to_npz_file(
            batch_embeddings, 
            batch_texts, 
            batch_labels if has_labels else [], 
            batch_label_numbers if has_labels else [],
            output_file_path,
            dataset_name
        )
        
        # Update progress
        processed_files.extend(batch_files)
        save_progress(dataset_name, processed_files)
        total_processed += len(batch_files)
        
        print(f"   ✅ Batch saved to NPZ. Total samples in file: {total_samples}")
        
        # Show file size comparison
        if os.path.exists(output_file_path):
            npz_size_mb = os.path.getsize(output_file_path) / (1024*1024)
            print(f"   📊 Current NPZ file size: {npz_size_mb:.1f} MB")
        
        # Free memory
        del batch_texts, batch_embeddings
        if has_labels:
            del batch_labels, batch_label_numbers
        gc.collect()
        torch.cuda.empty_cache() if torch.cuda.is_available() else None
    
    print(f"\\n✅ {dataset_name} processing completed!")
    print(f"   Total files processed: {total_processed}")
    
    # Clean up progress file
    progress_file = get_progress_file(dataset_name)
    if os.path.exists(progress_file):
        os.remove(progress_file)
    
    return total_processed

print("✓ NPZ streaming processing function defined")
print("✓ This will save embeddings in compressed NumPy format")