# InLegalBERT Embedding Generation for Legal Text Classification

This notebook generates embeddings for legal text classification using **InLegalBERT** - a BERT model specifically pre-trained on Indian legal documents.

## Dataset Structure
- **Train**: Files with text and labels (Facts, Reasoning, Arguments of Respondent, Arguments of Petitioner, Decision, Issue)
- **Test**: Files with only text (no labels)
- **Val**: Files with only text (no labels)

## Output
- JSON files with text, embeddings, class names, and class numbers
- Saved in the embeddings folder

In [1]:
# Import required libraries
import pandas as pd
import os
import numpy as np
import json
import torch
from transformers import AutoTokenizer, AutoModel
from sklearn.preprocessing import LabelEncoder
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

print("Torch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("Device:", torch.device("cuda" if torch.cuda.is_available() else "cpu"))

  from .autonotebook import tqdm as notebook_tqdm


Torch version: 2.8.0+cu128
CUDA available: False
Device: cpu


In [2]:
# Set device for GPU acceleration if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Dataset paths (adjusted for local structure)
base_path = "/home/uttam/B.Tech Major Project/nyaya/server/dataset/Hier_BiLSTM_CRF"
train_path = os.path.join(base_path, "train")
test_path = os.path.join(base_path, "test") 
val_path = os.path.join(base_path, "val", "val")

# Output path for embeddings
embeddings_output_path = "/home/uttam/B.Tech Major Project/nyaya/server/embeddings"

print(f"Train path: {train_path}")
print(f"Test path: {test_path}")
print(f"Val path: {val_path}")
print(f"Output path: {embeddings_output_path}")

# Verify paths exist
for path_name, path in [("Train", train_path), ("Test", test_path), ("Val", val_path)]:
    if os.path.exists(path):
        file_count = len([f for f in os.listdir(path) if f.endswith('.txt')])
        print(f"✓ {path_name} directory exists with {file_count} files")
    else:
        print(f"✗ {path_name} directory not found: {path}")

Using device: cpu
Train path: /home/uttam/B.Tech Major Project/nyaya/server/dataset/Hier_BiLSTM_CRF/train
Test path: /home/uttam/B.Tech Major Project/nyaya/server/dataset/Hier_BiLSTM_CRF/test
Val path: /home/uttam/B.Tech Major Project/nyaya/server/dataset/Hier_BiLSTM_CRF/val/val
Output path: /home/uttam/B.Tech Major Project/nyaya/server/embeddings
✓ Train directory exists with 4994 files
✓ Test directory exists with 712 files
✓ Val directory exists with 1424 files


In [None]:
# 📊 Dataset Size Configuration - Adjust these to control dataset size
print("="*60)
print("DATASET SIZE CONFIGURATION")
print("="*60)

# Set limits for each dataset split (None = process all files)
TRAIN_FILE_LIMIT = 100   # Process only first 100 training files (was 4994)
TEST_FILE_LIMIT = 50     # Process only first 50 test files (was 712)  
VAL_FILE_LIMIT = 50      # Process only first 50 validation files (was 1424)

# Alternative: Use percentage of total dataset
# TRAIN_PERCENTAGE = 0.02  # Use 2% of training data
# TEST_PERCENTAGE = 0.07   # Use 7% of test data
# VAL_PERCENTAGE = 0.04    # Use 4% of validation data

print(f"📋 Dataset Limits Configuration:")
print(f"   Train files limit: {TRAIN_FILE_LIMIT if TRAIN_FILE_LIMIT else 'No limit (all files)'}")
print(f"   Test files limit: {TEST_FILE_LIMIT if TEST_FILE_LIMIT else 'No limit (all files)'}")
print(f"   Val files limit: {VAL_FILE_LIMIT if VAL_FILE_LIMIT else 'No limit (all files)'}")
print(f"   Expected total files to process: ~{(TRAIN_FILE_LIMIT or 0) + (TEST_FILE_LIMIT or 0) + (VAL_FILE_LIMIT or 0)}")

print(f"\n💡 To process full dataset, set all limits to None")
print(f"💡 Current settings will process ~{((TRAIN_FILE_LIMIT or 0) + (TEST_FILE_LIMIT or 0) + (VAL_FILE_LIMIT or 0)) / 71.3:.1f}% of total dataset")

In [None]:
def load_train_files(directory_path, file_limit=None):
    """Load training files with labels"""
    all_dfs = []
    
    print(f"Loading training files from: {directory_path}")
    files = [f for f in os.listdir(directory_path) if f.endswith('.txt')]
    print(f"Found {len(files)} total files")
    
    # Apply file limit if specified
    if file_limit is not None:
        files = files[:file_limit]
        print(f"🔄 Processing only first {len(files)} files (limit: {file_limit})")
    else:
        print(f"🔄 Processing all {len(files)} files")
    
    for file_name in tqdm(files, desc="Loading train files"):
        file_path = os.path.join(directory_path, file_name)
        try:
            df = pd.read_csv(file_path, sep="\t", header=None, names=["text", "label"])
            if not df.empty:
                # 🔧 Replace NaN with "None"
                df["label"] = df["label"].fillna("None")
                # 🔧 Normalize label values
                df["label"] = df["label"].astype(str).str.strip()
                df["label"] = df["label"].replace(
                    {"none": "None", "NONE": "None"}  # unify casing
                )
                all_dfs.append(df)
        except Exception as e:
            print(f"Error loading {file_name}: {e}")
    
    if all_dfs:
        result_df = pd.concat(all_dfs, ignore_index=True)
        print(f"✅ Successfully loaded {len(result_df)} training samples from {len(files)} files")
        return result_df
    else:
        print("⚠️ No valid training data found")
        return pd.DataFrame(columns=["text", "label"])


def load_test_val_files(directory_path, file_limit=None):
    """Load test/val files with only text (no labels)"""
    all_dfs = []
    
    print(f"Loading test/val files from: {directory_path}")
    files = [f for f in os.listdir(directory_path) if f.endswith('.txt')]
    print(f"Found {len(files)} total files")
    
    # Apply file limit if specified
    if file_limit is not None:
        files = files[:file_limit]
        print(f"🔄 Processing only first {len(files)} files (limit: {file_limit})")
    else:
        print(f"🔄 Processing all {len(files)} files")
    
    for file_name in tqdm(files, desc="Loading test/val files"):
        file_path = os.path.join(directory_path, file_name)
        try:
            df = pd.read_csv(file_path, sep="\t", header=None, names=["text"])
            if not df.empty:
                # 🔧 Normalize text (strip whitespace)
                df["text"] = df["text"].astype(str).str.strip()
                all_dfs.append(df)
        except Exception as e:
            print(f"Error loading {file_name}: {e}")
    
    if all_dfs:
        result_df = pd.concat(all_dfs, ignore_index=True)
        print(f"✅ Successfully loaded {len(result_df)} samples from {len(files)} files")
        return result_df
    else:
        print("⚠️ No valid test/val data found")
        return pd.DataFrame(columns=["text"])

In [None]:
# Load datasets with file limits
print("Loading datasets with configured limits...")
print("="*60)

df_train = load_train_files(train_path, file_limit=TRAIN_FILE_LIMIT)  # has text + label
df_test = load_test_val_files(test_path, file_limit=TEST_FILE_LIMIT)  # only text  
df_val = load_test_val_files(val_path, file_limit=VAL_FILE_LIMIT)    # only text

# Show results
print(f"\n📊 Dataset Summary (After Applying Limits):")
print(f"Train: {len(df_train)} rows")
print(f"Test: {len(df_test)} rows")
print(f"Val: {len(df_val)} rows")
print(f"Total samples: {len(df_train) + len(df_test) + len(df_val)}")

if not df_train.empty:
    print(f"\n📋 Train labels distribution:")
    print(df_train["label"].value_counts())
    
    print(f"\n📄 Sample train data:")
    print(df_train.head(3))
    
    # Show average text length
    avg_length = df_train["text"].str.len().mean()
    print(f"\n📏 Average text length: {avg_length:.1f} characters")
else:
    print("⚠️ No training data loaded!")

Loading datasets...
Loading training files from: /home/uttam/B.Tech Major Project/nyaya/server/dataset/Hier_BiLSTM_CRF/train
Found 4994 files


Loading train files: 100%|██████████| 4994/4994 [00:19<00:00, 260.67it/s]



Successfully loaded 520247 training samples
Loading test/val files from: /home/uttam/B.Tech Major Project/nyaya/server/dataset/Hier_BiLSTM_CRF/test
Found 712 files


Loading test/val files: 100%|██████████| 712/712 [00:03<00:00, 227.54it/s]



Successfully loaded 149868 samples
Loading test/val files from: /home/uttam/B.Tech Major Project/nyaya/server/dataset/Hier_BiLSTM_CRF/val/val
Found 1424 files


Loading test/val files: 100%|██████████| 1424/1424 [00:05<00:00, 271.06it/s]



Successfully loaded 293408 samples

Dataset Summary:
Train: 520247 rows
Test: 149868 rows
Val: 293408 rows

Train labels distribution:
label
Reasoning                  202593
Facts                      170068
Arguments of Petitioner     65032
Arguments of Respondent     50137
Decision                    19599
Issue                       12818
Name: count, dtype: int64

Sample train data:
                                                text  label
0   K. Mathur, J. This appeal is directed against...  Issue
1  Brief facts giving rise to this appeal areThe ...  Facts
2  The case of the complainant respondent was tha...  Facts
3  The respondent complainant held a valid Fire P...  Facts
4  This policy also endorsed to cover risk of flood.  Facts


In [6]:
# Label encoding for training data
if not df_train.empty:
    # Manual mapping (similar to your original code)
    label_to_num = {
        'Facts': 0,
        'Reasoning': 1, 
        'Arguments of Respondent': 2,
        'Arguments of Petitioner': 3,
        'Decision': 4,
        'Issue': 5,
        'None': 6
    }
    
    print("Creating label mappings...")
    
    # Check if all labels in data are in our mapping
    unique_labels = df_train['label'].unique()
    print(f"Unique labels in dataset: {unique_labels}")
    
    missing_labels = [label for label in unique_labels if label not in label_to_num]
    if missing_labels:
        print(f"Warning: Missing labels in mapping: {missing_labels}")
        
        # Use LabelEncoder as fallback for missing labels
        label_encoder = LabelEncoder()
        df_train['label_encoded'] = label_encoder.fit_transform(df_train['label'])
        
        # Create updated mapping
        label_mapping = {}
        for i, label in enumerate(label_encoder.classes_):
            label_mapping[i] = label
            
        train_labels = df_train['label'].tolist()
        train_label_numbers = df_train['label_encoded'].tolist()
        
        print("Using LabelEncoder mapping:")
        for i, label in label_mapping.items():
            print(f"{i}: {label}")
    else:
        # Use manual mapping
        df_train['label_numeric'] = df_train['label'].map(label_to_num)
        
        label_mapping = {v: k for k, v in label_to_num.items()}  # Reverse mapping
        train_labels = df_train['label'].tolist()
        train_label_numbers = df_train['label_numeric'].tolist()
        
        print("Using manual mapping:")
        for num, label in label_mapping.items():
            print(f"{num}: {label}")
            
    print(f"\nLabel distribution by number:")
    label_counts = {}
    for label_num in train_label_numbers:
        label_counts[label_num] = label_counts.get(label_num, 0) + 1
    for num, count in sorted(label_counts.items()):
        print(f"{num} ({label_mapping[num]}): {count}")
else:
    print("No training data available for label encoding")

Creating label mappings...
Unique labels in dataset: ['Issue' 'Facts' 'Arguments of Petitioner' 'Arguments of Respondent'
 'Reasoning' 'Decision']
Using manual mapping:
0: Facts
1: Reasoning
2: Arguments of Respondent
3: Arguments of Petitioner
4: Decision
5: Issue
6: None

Label distribution by number:
0 (Facts): 170068
1 (Reasoning): 202593
2 (Arguments of Respondent): 50137
3 (Arguments of Petitioner): 65032
4 (Decision): 19599
5 (Issue): 12818


In [7]:
# Load InLegalBERT model and tokenizer
print("Loading InLegalBERT model and tokenizer...")
print("This may take a few minutes on first run...")

try:
    # Load tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained("law-ai/InLegalBERT")
    model = AutoModel.from_pretrained("law-ai/InLegalBERT")
    
    # Move model to device (GPU if available)
    model = model.to(device)
    model.eval()  # Set to evaluation mode
    
    print(f"✓ InLegalBERT loaded successfully!")
    print(f"✓ Model moved to: {device}")
    print(f"✓ Tokenizer vocabulary size: {tokenizer.vocab_size}")
    print(f"✓ Model max position embeddings: {model.config.max_position_embeddings}")
    print(f"✓ Hidden size: {model.config.hidden_size}")
    
except Exception as e:
    print(f"✗ Error loading InLegalBERT: {e}")
    print("Please ensure you have internet connection and transformers library installed")
    raise

Loading InLegalBERT model and tokenizer...
This may take a few minutes on first run...
✓ InLegalBERT loaded successfully!
✓ Model moved to: cpu
✓ Tokenizer vocabulary size: 30522
✓ Model max position embeddings: 512
✓ Hidden size: 768
✓ InLegalBERT loaded successfully!
✓ Model moved to: cpu
✓ Tokenizer vocabulary size: 30522
✓ Model max position embeddings: 512
✓ Hidden size: 768


In [8]:
def get_bert_embeddings(texts, tokenizer, model, device, max_length=512, batch_size=8):
    """
    Generate embeddings using InLegalBERT
    
    Args:
        texts: List of text strings
        tokenizer: InLegalBERT tokenizer
        model: InLegalBERT model
        device: torch device (cuda/cpu)
        max_length: Maximum sequence length for BERT
        batch_size: Batch size for processing
    
    Returns:
        numpy array of embeddings (texts x hidden_size)
    """
    embeddings = []
    
    # Process in batches to manage memory
    for i in tqdm(range(0, len(texts), batch_size), desc="Generating embeddings"):
        batch_texts = texts[i:i + batch_size]
        
        # Tokenize batch
        encoded = tokenizer(
            batch_texts,
            padding=True,
            truncation=True,
            max_length=max_length,
            return_tensors='pt'
        )
        
        # Move to device
        encoded = {key: val.to(device) for key, val in encoded.items()}
        
        # Generate embeddings
        with torch.no_grad():
            outputs = model(**encoded)
            # Use [CLS] token embedding (first token) as sentence representation
            cls_embeddings = outputs.last_hidden_state[:, 0, :]  # Shape: (batch_size, hidden_size)
            
            # Move back to CPU and convert to numpy
            batch_embeddings = cls_embeddings.cpu().numpy()
            embeddings.append(batch_embeddings)
    
    # Concatenate all batches
    all_embeddings = np.vstack(embeddings)
    return all_embeddings

print("✓ Embedding function defined")
print("This function will use the [CLS] token representation as sentence embeddings")

✓ Embedding function defined
This function will use the [CLS] token representation as sentence embeddings


In [9]:
def create_json_data_bert(texts, embeddings, labels=None, label_numbers=None, dataset_name=""):
    """Create JSON data with text, BERT embeddings, classname, classnumber"""
    print(f"Creating JSON data for {dataset_name}...")
    json_data = []
    
    for i in range(len(texts)):
        data_point = {
            "text": texts[i],
            "vector": embeddings[i].tolist(),  # Convert numpy array to list for JSON
        }
        
        if labels is not None and label_numbers is not None:
            data_point["classname"] = labels[i]
            data_point["classnumber"] = int(label_numbers[i])
        else:
            data_point["classname"] = None
            data_point["classnumber"] = None
            
        json_data.append(data_point)
        
        # Progress indicator
        if (i + 1) % 1000 == 0:
            print(f"  Processed {i + 1}/{len(texts)} samples")
    
    print(f"✓ Created JSON data for {len(json_data)} samples")
    return json_data

print("✓ JSON creation function defined")

✓ JSON creation function defined


In [None]:
# Generate embeddings for all datasets
print("="*60)
print("GENERATING INLEGALBERT EMBEDDINGS")
print("="*60)

# Prepare text data
train_texts = df_train["text"].tolist() if not df_train.empty else []
test_texts = df_test["text"].tolist() if not df_test.empty else []
val_texts = df_val["text"].tolist() if not df_val.empty else []

total_samples = len(train_texts) + len(test_texts) + len(val_texts)

print(f"📋 Texts to process:")
print(f"  Train: {len(train_texts)} texts")
print(f"  Test: {len(test_texts)} texts") 
print(f"  Val: {len(val_texts)} texts")
print(f"  Total: {total_samples} texts")

# Configuration for embedding generation (optimized for smaller datasets)
MAX_LENGTH = 512  # BERT's typical max length

# Adjust batch size based on dataset size and device
if total_samples <= 500:
    BATCH_SIZE = 8 if device.type == 'cuda' else 4  # Larger batches for small datasets
elif total_samples <= 2000:
    BATCH_SIZE = 6 if device.type == 'cuda' else 3  # Medium batches
else:
    BATCH_SIZE = 4 if device.type == 'cuda' else 2  # Conservative for large datasets

print(f"\n⚙️ Embedding configuration:")
print(f"  Max length: {MAX_LENGTH}")
print(f"  Batch size: {BATCH_SIZE} (auto-adjusted based on dataset size)")
print(f"  Device: {device}")

# Estimate processing time
estimated_time = total_samples * 0.1  # Rough estimate: 0.1 seconds per sample
print(f"  Estimated processing time: {estimated_time/60:.1f} minutes")

# Start timing
import time
start_time = time.time()

GENERATING INLEGALBERT EMBEDDINGS
Texts to process:
  Train: 520247 texts
  Test: 149868 texts
  Val: 293408 texts
  Total: 963523 texts

Embedding configuration:
  Max length: 512
  Batch size: 2
  Device: cpu


In [None]:
# Skip demo mode since we're already using a reduced dataset
print("🚀 PROCESSING REDUCED DATASET")
print("="*60)

print(f"📊 Dataset is already reduced to manageable size:")
print(f"   Train samples: {len(train_texts) if 'train_texts' in locals() else 0}")
print(f"   Test samples: {len(test_texts) if 'test_texts' in locals() else 0}")
print(f"   Val samples: {len(val_texts) if 'val_texts' in locals() else 0}")

if 'train_texts' in locals() and train_texts:
    print(f"\n📄 Sample training texts:")
    for i, (text, label) in enumerate(zip(train_texts[:3], train_labels[:3])):
        preview = text[:100] + "..." if len(text) > 100 else text
        print(f"{i+1}. [{label}] {preview}")

print(f"\n✅ Ready to process full reduced dataset")
print(f"   No demo mode needed - dataset size is already optimized")

🧪 DEMO MODE: Processing small sample first
Demo sample size: 10
Demo texts:
1. [Issue]  K. Mathur, J. This appeal is directed against the order passed by the National Consumer Disputes Re...
2. [Facts] Brief facts giving rise to this appeal areThe respondent complainant M s Kiran Combers Spinners file...
3. [Facts] The case of the complainant respondent was that they got their building and stock insured from the U...
4. [Facts] The respondent complainant held a valid Fire Policy for its stock (Building Rs. 25 lakhs, Machinery ...
5. [Facts] This policy also endorsed to cover risk of flood.
6. [Facts] On account of heavy rains and floods in the city, insured property was affected by floods on 24th Ju...
7. [Facts] This incident was reported to the Company on 25th July, 1993 and an FIR was lodged on 27th July, 199...
8. [Facts] The respondentclaimant claimed Rs.20,03,842/ in July, 1993 from the Company.
9. [Facts] Surveyor, namely, M s Vij Engineers Enterprise appointed by the Company ca

Generating embeddings: 100%|██████████| 5/5 [00:00<00:00,  5.84it/s]

✅ Demo embeddings generated!
📊 Shape: (10, 768)
📊 Embedding dimension: 768
📊 Sample embedding (first 5 values): [-0.05631871 -0.25117683  0.42925707 -0.5550534  -0.05736984]
Creating JSON data for demo set...
✓ Created JSON data for 10 samples
✅ Demo JSON data created with 10 samples





In [12]:
# Generate embeddings for training data
if train_texts:
    print(f"\n Processing training data ({len(train_texts)} samples)...")
    train_embeddings = get_bert_embeddings(
        train_texts, 
        tokenizer, 
        model, 
        device, 
        max_length=MAX_LENGTH, 
        batch_size=BATCH_SIZE
    )
    print(f"✓ Train embeddings shape: {train_embeddings.shape}")
    
    # Create JSON data for training
    train_json_data = create_json_data_bert(
        train_texts, 
        train_embeddings, 
        train_labels, 
        train_label_numbers, 
        "train set"
    )
else:
    print("No training data available")
    train_json_data = []

print(f"\n Training data processing time: {time.time() - start_time:.2f} seconds")


 Processing training data (520247 samples)...


Generating embeddings:   0%|          | 786/260124 [02:12<12:07:15,  5.94it/s]



KeyboardInterrupt: 

In [None]:
# Generate embeddings for test data
if test_texts:
    print(f"\n Processing test data ({len(test_texts)} samples)...")
    test_start_time = time.time()
    
    test_embeddings = get_bert_embeddings(
        test_texts, 
        tokenizer, 
        model, 
        device, 
        max_length=MAX_LENGTH, 
        batch_size=BATCH_SIZE
    )
    print(f"✓ Test embeddings shape: {test_embeddings.shape}")
    
    # Create JSON data for test (no labels)
    test_json_data = create_json_data_bert(
        test_texts, 
        test_embeddings, 
        dataset_name="test set"
    )
    
    print(f" Test data processing time: {time.time() - test_start_time:.2f} seconds")
else:
    print(" No test data available")
    test_json_data = []

In [None]:
# Generate embeddings for validation data
if val_texts:
    print(f"\nProcessing validation data ({len(val_texts)} samples)...")
    val_start_time = time.time()
    
    val_embeddings = get_bert_embeddings(
        val_texts, 
        tokenizer, 
        model, 
        device, 
        max_length=MAX_LENGTH, 
        batch_size=BATCH_SIZE
    )
    print(f"✓ Val embeddings shape: {val_embeddings.shape}")
    
    # Create JSON data for validation (no labels)
    val_json_data = create_json_data_bert(
        val_texts, 
        val_embeddings, 
        dataset_name="validation set"
    )
    
    print(f"Validation data processing time: {time.time() - val_start_time:.2f} seconds")
else:
    print(" No validation data available")
    val_json_data = []

total_time = time.time() - start_time
print(f"\n Total embedding generation time: {total_time:.2f} seconds")
print(f" Average time per sample: {total_time / (len(train_texts) + len(test_texts) + len(val_texts)):.3f} seconds")

In [None]:
# Save embeddings to JSON files
print("="*60)
print("SAVING EMBEDDINGS TO JSON FILES")
print("="*60)

# Ensure output directory exists
os.makedirs(embeddings_output_path, exist_ok=True)
print(f"📁 Output directory: {embeddings_output_path}")

# Save training embeddings
if train_json_data:
    train_file_path = os.path.join(embeddings_output_path, 'train_embeddings_inlegalbert.json')
    with open(train_file_path, 'w') as f:
        json.dump(train_json_data, f, indent=2)
    print(f"Saved train_embeddings_inlegalbert.json with {len(train_json_data)} samples")
    print(f"   File size: {os.path.getsize(train_file_path) / (1024*1024):.1f} MB")

# Save test embeddings  
if test_json_data:
    test_file_path = os.path.join(embeddings_output_path, 'test_embeddings_inlegalbert.json')
    with open(test_file_path, 'w') as f:
        json.dump(test_json_data, f, indent=2)
    print(f"✅ Saved test_embeddings_inlegalbert.json with {len(test_json_data)} samples")
    print(f"   File size: {os.path.getsize(test_file_path) / (1024*1024):.1f} MB")

# Save validation embeddings
if val_json_data:
    val_file_path = os.path.join(embeddings_output_path, 'val_embeddings_inlegalbert.json')
    with open(val_file_path, 'w') as f:
        json.dump(val_json_data, f, indent=2)
    print(f"✅ Saved val_embeddings_inlegalbert.json with {len(val_json_data)} samples")
    print(f"   File size: {os.path.getsize(val_file_path) / (1024*1024):.1f} MB")

# Save label mapping for reference (if available)
if 'label_mapping' in locals():
    label_file_path = os.path.join(embeddings_output_path, 'label_mapping_inlegalbert.json')
    with open(label_file_path, 'w') as f:
        json.dump(label_mapping, f, indent=2)
    print(f"✅ Saved label_mapping_inlegalbert.json")

print(f"\n📂 All files saved in: {os.path.abspath(embeddings_output_path)}")

In [None]:
# Display sample JSON structure and summary
print("="*60)
print("SUMMARY AND SAMPLE OUTPUT")
print("="*60)

# Show sample data structure
if train_json_data:
    print("\n📋 Sample JSON structure (train data):")
    sample = train_json_data[0].copy()
    
    # Show only first 5 vector elements for readability
    if 'vector' in sample and len(sample['vector']) > 5:
        original_length = len(sample['vector'])
        sample['vector'] = sample['vector'][:5] + [f'... ({original_length-5} more values)']
    
    print(json.dumps(sample, indent=2))

# Summary statistics
print(f"\n📊 FINAL SUMMARY:")
print(f"✅ InLegalBERT Model: law-ai/InLegalBERT")
print(f"✅ Embedding dimension: {model.config.hidden_size}")
print(f"✅ Device used: {device}")

if train_json_data:
    print(f"✅ train_embeddings_inlegalbert.json: {len(train_json_data)} samples with labels")
if test_json_data:
    print(f"✅ test_embeddings_inlegalbert.json: {len(test_json_data)} samples without labels")
if val_json_data:
    print(f"✅ val_embeddings_inlegalbert.json: {len(val_json_data)} samples without labels")
if 'label_mapping' in locals():
    print(f"✅ label_mapping_inlegalbert.json: Label number to name mapping")

print(f"\n🎉 InLegalBERT embedding generation completed successfully!")
print(f"📁 Files are saved in: {embeddings_output_path}")

# Cleanup to free memory
if 'model' in locals():
    del model
if 'tokenizer' in locals():
    del tokenizer
torch.cuda.empty_cache() if torch.cuda.is_available() else None
print(f"🧹 Memory cleanup completed")