# InLegalBERT Embedding Generation for Legal Text Classification

This notebook generates embeddings for legal text classification using **InLegalBERT** - a BERT model specifically pre-trained on Indian legal documents.

## Dataset Structure
- **Train**: Files with text and labels (Facts, Reasoning, Arguments of Respondent, Arguments of Petitioner, Decision, Issue)
- **Test**: Files with only text (no labels)
- **Val**: Files with only text (no labels)

## Output
- JSON files with text, embeddings, class names, and class numbers
- Saved in the embeddings folder

In [2]:
# Import required libraries
import pandas as pd
import os
import numpy as np
import json
import torch
from transformers import AutoTokenizer, AutoModel
from sklearn.preprocessing import LabelEncoder
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

print("Torch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("Device:", torch.device("cuda" if torch.cuda.is_available() else "cpu"))

Torch version: 2.6.0+cu124
CUDA available: True
Device: cuda


In [3]:
# Set device for GPU acceleration if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Dataset paths (adjusted for local structure)
base_path = "/kaggle/input/dataset/Hier_BiLSTM_CRF"
train_path = os.path.join(base_path, "train")
test_path = os.path.join(base_path, "test") 
val_path = os.path.join(base_path, "val", "val")

# Output path for embeddings
embeddings_output_path = "/kaggle/working/"

print(f"Train path: {train_path}")
print(f"Test path: {test_path}")
print(f"Val path: {val_path}")
print(f"Output path: {embeddings_output_path}")

# Verify paths exist
for path_name, path in [("Train", train_path), ("Test", test_path), ("Val", val_path)]:
    if os.path.exists(path):
        file_count = len([f for f in os.listdir(path) if f.endswith('.txt')])
        print(f"✓ {path_name} directory exists with {file_count} files")
    else:
        print(f"✗ {path_name} directory not found: {path}")

Using device: cuda
Train path: /kaggle/input/dataset/Hier_BiLSTM_CRF/train
Test path: /kaggle/input/dataset/Hier_BiLSTM_CRF/test
Val path: /kaggle/input/dataset/Hier_BiLSTM_CRF/val/val
Output path: /kaggle/working/
✓ Train directory exists with 4994 files
✓ Test directory exists with 712 files
✓ Val directory exists with 1424 files


In [4]:
def load_train_files(directory_path):
    """Load training files with labels"""
    all_dfs = []
    
    print(f"Loading training files from: {directory_path}")
    files = [f for f in os.listdir(directory_path) if f.endswith('.txt')]
    print(f"Found {len(files)} files")
    
    for file_name in tqdm(files, desc="Loading train files"):
        file_path = os.path.join(directory_path, file_name)
        try:
            df = pd.read_csv(file_path, sep="\t", header=None, names=["text", "label"])
            if not df.empty:
                # 🔧 Replace NaN with "None"
                df["label"] = df["label"].fillna("None")
                # 🔧 Normalize label values
                df["label"] = df["label"].astype(str).str.strip()
                df["label"] = df["label"].replace(
                    {"none": "None", "NONE": "None"}  # unify casing
                )
                all_dfs.append(df)
        except Exception as e:
            print(f"Error loading {file_name}: {e}")
    
    if all_dfs:
        result_df = pd.concat(all_dfs, ignore_index=True)
        print(f"✅ Successfully loaded {len(result_df)} training samples")
        return result_df
    else:
        print("⚠️ No valid training data found")
        return pd.DataFrame(columns=["text", "label"])


def load_test_val_files(directory_path):
    """Load test/val files with only text (no labels)"""
    all_dfs = []
    
    print(f"Loading test/val files from: {directory_path}")
    files = [f for f in os.listdir(directory_path) if f.endswith('.txt')]
    print(f"Found {len(files)} files")
    
    for file_name in tqdm(files, desc="Loading test/val files"):
        file_path = os.path.join(directory_path, file_name)
        try:
            df = pd.read_csv(file_path, sep="\t", header=None, names=["text"])
            if not df.empty:
                # 🔧 Normalize text (strip whitespace)
                df["text"] = df["text"].astype(str).str.strip()
                all_dfs.append(df)
        except Exception as e:
            print(f"Error loading {file_name}: {e}")
    
    if all_dfs:
        result_df = pd.concat(all_dfs, ignore_index=True)
        print(f"✅ Successfully loaded {len(result_df)} samples")
        return result_df
    else:
        print("⚠️ No valid test/val data found")
        return pd.DataFrame(columns=["text"])


In [5]:
# Load datasets
print("Loading datasets...")
df_train = load_train_files(train_path)  # has text + label
df_test = load_test_val_files(test_path)  # only text  
df_val = load_test_val_files(val_path)    # only text

# Show results
print(f"\nDataset Summary:")
print(f"Train: {len(df_train)} rows")
print(f"Test: {len(df_test)} rows")
print(f"Val: {len(df_val)} rows")

if not df_train.empty:
    print(f"\nTrain labels distribution:")
    print(df_train["label"].value_counts())
    
    print(f"\nSample train data:")
    print(df_train.head())
else:
    print("No training data loaded!")

Loading datasets...
Loading training files from: /kaggle/input/dataset/Hier_BiLSTM_CRF/train
Found 4994 files


Loading train files: 100%|██████████| 4994/4994 [00:28<00:00, 173.35it/s]


✅ Successfully loaded 1123832 training samples
Loading test/val files from: /kaggle/input/dataset/Hier_BiLSTM_CRF/test
Found 712 files


Loading test/val files: 100%|██████████| 712/712 [00:03<00:00, 199.46it/s]


✅ Successfully loaded 149868 samples
Loading test/val files from: /kaggle/input/dataset/Hier_BiLSTM_CRF/val/val
Found 1424 files


Loading test/val files: 100%|██████████| 1424/1424 [00:07<00:00, 199.47it/s]


✅ Successfully loaded 293408 samples

Dataset Summary:
Train: 1123832 rows
Test: 149868 rows
Val: 293408 rows

Train labels distribution:
label
None                       603585
Reasoning                  202593
Facts                      170068
Arguments of Petitioner     65032
Arguments of Respondent     50137
Decision                    19599
Issue                       12818
Name: count, dtype: int64

Sample train data:
                                                text  label
0   M. JOSEPH, J. The appeal is directed against ...  Issue
1  The chargesheet came to be filed on the basis ...  Facts
2  The appellant was Director of Mines and Geolog...  Facts
3  There was a partnership firm by the name M s A...  Facts
4  The offences are alleged to revolve around the...  Facts


In [6]:
# Label encoding for training data
if not df_train.empty:
    # Manual mapping (similar to your original code)
    label_to_num = {
        'Facts': 0,
        'Reasoning': 1, 
        'Arguments of Respondent': 2,
        'Arguments of Petitioner': 3,
        'Decision': 4,
        'Issue': 5,
        'None': 6
    }
    
    print("Creating label mappings...")
    
    # Check if all labels in data are in our mapping
    unique_labels = df_train['label'].unique()
    print(f"Unique labels in dataset: {unique_labels}")
    
    missing_labels = [label for label in unique_labels if label not in label_to_num]
    if missing_labels:
        print(f"Warning: Missing labels in mapping: {missing_labels}")
        
        # Use LabelEncoder as fallback for missing labels
        label_encoder = LabelEncoder()
        df_train['label_encoded'] = label_encoder.fit_transform(df_train['label'])
        
        # Create updated mapping
        label_mapping = {}
        for i, label in enumerate(label_encoder.classes_):
            label_mapping[i] = label
            
        train_labels = df_train['label'].tolist()
        train_label_numbers = df_train['label_encoded'].tolist()
        
        print("Using LabelEncoder mapping:")
        for i, label in label_mapping.items():
            print(f"{i}: {label}")
    else:
        # Use manual mapping
        df_train['label_numeric'] = df_train['label'].map(label_to_num)
        
        label_mapping = {v: k for k, v in label_to_num.items()}  # Reverse mapping
        train_labels = df_train['label'].tolist()
        train_label_numbers = df_train['label_numeric'].tolist()
        
        print("Using manual mapping:")
        for num, label in label_mapping.items():
            print(f"{num}: {label}")
            
    print(f"\nLabel distribution by number:")
    label_counts = {}
    for label_num in train_label_numbers:
        label_counts[label_num] = label_counts.get(label_num, 0) + 1
    for num, count in sorted(label_counts.items()):
        print(f"{num} ({label_mapping[num]}): {count}")
else:
    print("No training data available for label encoding")

Creating label mappings...
Unique labels in dataset: ['Issue' 'Facts' 'None' 'Arguments of Petitioner'
 'Arguments of Respondent' 'Reasoning' 'Decision']
Using manual mapping:
0: Facts
1: Reasoning
2: Arguments of Respondent
3: Arguments of Petitioner
4: Decision
5: Issue
6: None

Label distribution by number:
0 (Facts): 170068
1 (Reasoning): 202593
2 (Arguments of Respondent): 50137
3 (Arguments of Petitioner): 65032
4 (Decision): 19599
5 (Issue): 12818
6 (None): 603585


In [7]:
# Load InLegalBERT model and tokenizer
print("Loading InLegalBERT model and tokenizer...")
print("This may take a few minutes on first run...")

try:
    # Load tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained("law-ai/InLegalBERT")
    model = AutoModel.from_pretrained("law-ai/InLegalBERT")
    
    # Move model to device (GPU if available)
    model = model.to(device)
    model.eval()  # Set to evaluation mode
    
    print(f"✓ InLegalBERT loaded successfully!")
    print(f"✓ Model moved to: {device}")
    print(f"✓ Tokenizer vocabulary size: {tokenizer.vocab_size}")
    print(f"✓ Model max position embeddings: {model.config.max_position_embeddings}")
    print(f"✓ Hidden size: {model.config.hidden_size}")
    
except Exception as e:
    print(f"✗ Error loading InLegalBERT: {e}")
    print("Please ensure you have internet connection and transformers library installed")
    raise

Loading InLegalBERT model and tokenizer...
This may take a few minutes on first run...


tokenizer_config.json:   0%|          | 0.00/516 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/671 [00:00<?, ?B/s]

2025-09-02 16:05:20.035920: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1756829120.226958      36 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1756829120.284460      36 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


pytorch_model.bin:   0%|          | 0.00/534M [00:00<?, ?B/s]

✓ InLegalBERT loaded successfully!
✓ Model moved to: cuda
✓ Tokenizer vocabulary size: 30522
✓ Model max position embeddings: 512
✓ Hidden size: 768


model.safetensors:   0%|          | 0.00/534M [00:00<?, ?B/s]

In [8]:
def get_bert_embeddings(texts, tokenizer, model, device, max_length=512, batch_size=8):
    """
    Generate embeddings using InLegalBERT
    
    Args:
        texts: List of text strings
        tokenizer: InLegalBERT tokenizer
        model: InLegalBERT model
        device: torch device (cuda/cpu)
        max_length: Maximum sequence length for BERT
        batch_size: Batch size for processing
    
    Returns:
        numpy array of embeddings (texts x hidden_size)
    """
    embeddings = []
    
    # Process in batches to manage memory
    for i in tqdm(range(0, len(texts), batch_size), desc="Generating embeddings"):
        batch_texts = texts[i:i + batch_size]
        
        # Tokenize batch
        encoded = tokenizer(
            batch_texts,
            padding=True,
            truncation=True,
            max_length=max_length,
            return_tensors='pt'
        )
        
        # Move to device
        encoded = {key: val.to(device) for key, val in encoded.items()}
        
        # Generate embeddings
        with torch.no_grad():
            outputs = model(**encoded)
            # Use [CLS] token embedding (first token) as sentence representation
            cls_embeddings = outputs.last_hidden_state[:, 0, :]  # Shape: (batch_size, hidden_size)
            
            # Move back to CPU and convert to numpy
            batch_embeddings = cls_embeddings.cpu().numpy()
            embeddings.append(batch_embeddings)
    
    # Concatenate all batches
    all_embeddings = np.vstack(embeddings)
    return all_embeddings

print("✓ Embedding function defined")
print("This function will use the [CLS] token representation as sentence embeddings")

✓ Embedding function defined
This function will use the [CLS] token representation as sentence embeddings


In [9]:
def create_json_data_bert(texts, embeddings, labels=None, label_numbers=None, dataset_name=""):
    """Create JSON data with text, BERT embeddings, classname, classnumber"""
    print(f"Creating JSON data for {dataset_name}...")
    json_data = []
    
    for i in range(len(texts)):
        data_point = {
            "text": texts[i],
            "vector": embeddings[i].tolist(),  # Convert numpy array to list for JSON
        }
        
        if labels is not None and label_numbers is not None:
            data_point["classname"] = labels[i]
            data_point["classnumber"] = int(label_numbers[i])
        else:
            data_point["classname"] = None
            data_point["classnumber"] = None
            
        json_data.append(data_point)
        
        # Progress indicator
        if (i + 1) % 1000 == 0:
            print(f"  Processed {i + 1}/{len(texts)} samples")
    
    print(f"✓ Created JSON data for {len(json_data)} samples")
    return json_data

print("✓ JSON creation function defined")

✓ JSON creation function defined


In [10]:
# Generate embeddings for all datasets
print("="*60)
print("GENERATING INLEGALBERT EMBEDDINGS")
print("="*60)

# Prepare text data
train_texts = df_train["text"].tolist() if not df_train.empty else []
test_texts = df_test["text"].tolist() if not df_test.empty else []
val_texts = df_val["text"].tolist() if not df_val.empty else []

print(f"Texts to process:")
print(f"  Train: {len(train_texts)} texts")
print(f"  Test: {len(test_texts)} texts") 
print(f"  Val: {len(val_texts)} texts")
print(f"  Total: {len(train_texts) + len(test_texts) + len(val_texts)} texts")

# Configuration for embedding generation
MAX_LENGTH = 512  # BERT's typical max length
BATCH_SIZE = 4 if device.type == 'cuda' else 2  # Smaller batch size to avoid memory issues

print(f"\nEmbedding configuration:")
print(f"  Max length: {MAX_LENGTH}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Device: {device}")

# Start timing
import time
start_time = time.time()

GENERATING INLEGALBERT EMBEDDINGS
Texts to process:
  Train: 1123832 texts
  Test: 149868 texts
  Val: 293408 texts
  Total: 1567108 texts

Embedding configuration:
  Max length: 512
  Batch size: 4
  Device: cuda


In [10]:
# DEMO: Generate embeddings for a small sample first (for testing)
print("🧪 DEMO MODE: Processing small sample first")
print("="*60)

DEMO_SIZE = 10  # Process only 10 samples for demo
demo_train_texts = train_texts[:DEMO_SIZE]
demo_train_labels = train_labels[:DEMO_SIZE] 
demo_train_label_numbers = train_label_numbers[:DEMO_SIZE]

print(f"Demo sample size: {DEMO_SIZE}")
print("Demo texts:")
for i, (text, label) in enumerate(zip(demo_train_texts, demo_train_labels)):
    preview = text[:100] + "..." if len(text) > 100 else text
    print(f"{i+1}. [{label}] {preview}")

print(f"\n🔄 Generating embeddings for demo sample...")
demo_embeddings = get_bert_embeddings(
    demo_train_texts, 
    tokenizer, 
    model, 
    device, 
    max_length=MAX_LENGTH, 
    batch_size=2
)

print(f"✅ Demo embeddings generated!")
print(f"📊 Shape: {demo_embeddings.shape}")
print(f"📊 Embedding dimension: {demo_embeddings.shape[1]}")
print(f"📊 Sample embedding (first 5 values): {demo_embeddings[0][:5]}")

# Create demo JSON
demo_json_data = create_json_data_bert(
    demo_train_texts, 
    demo_embeddings, 
    demo_train_labels, 
    demo_train_label_numbers, 
    "demo set"
)

print(f"✅ Demo JSON data created with {len(demo_json_data)} samples")

🧪 DEMO MODE: Processing small sample first
Demo sample size: 10
Demo texts:
1. [Issue]  M. JOSEPH, J. The appeal is directed against the Order of the High Court setting aside the Order pa...
2. [Facts] The chargesheet came to be filed on the basis of a FIR dated 01.10.2011.
3. [Facts] The appellant was Director of Mines and Geology in the State of Karnataka at the relevant time.
4. [Facts] There was a partnership firm by the name M s Associated Mineral Company (AMC, for short).
5. [Facts] The offences are alleged to revolve around the affairs of the said firm.
6. [Facts] First accused is the husband of the second accused.
7. [Facts] They became partners of the firm (AMC) in 2009.
8. [Facts] Appellant was arrayed as the third accused.
9. [Facts] There was reference in the chargesheet to a conspiracy between the first accused and the second accu...
10. [Facts] It is alleged, inter alia, that they obtained an undated letter from one Shri K.M. Vishwanath, the E...

🔄 Generating embeddings 


Generating embeddings:   0%|          | 0/5 [00:00<?, ?it/s][A
Generating embeddings: 100%|██████████| 5/5 [00:00<00:00, 10.96it/s][A

✅ Demo embeddings generated!
📊 Shape: (10, 768)
📊 Embedding dimension: 768
📊 Sample embedding (first 5 values): [ 0.4694947  -0.09240394  0.19579029  0.00546941 -0.14047483]
Creating JSON data for demo set...
✓ Created JSON data for 10 samples
✅ Demo JSON data created with 10 samples





In [None]:
# Generate embeddings for training data
if train_texts:
    print(f"\n Processing training data ({len(train_texts)} samples)...")
    train_embeddings = get_bert_embeddings(
        train_texts, 
        tokenizer, 
        model, 
        device, 
        max_length=MAX_LENGTH, 
        batch_size=BATCH_SIZE
    )
    print(f"✓ Train embeddings shape: {train_embeddings.shape}")
    
    # Create JSON data for training
    train_json_data = create_json_data_bert(
        train_texts, 
        train_embeddings, 
        train_labels, 
        train_label_numbers, 
        "train set"
    )
else:
    print("No training data available")
    train_json_data = []

print(f"\n Training data processing time: {time.time() - start_time:.2f} seconds")


 Processing training data (1123832 samples)...


Generating embeddings: 100%|██████████| 280958/280958 [1:03:40<00:00, 73.54it/s]


✓ Train embeddings shape: (1123832, 768)
Creating JSON data for train set...
  Processed 1000/1123832 samples
  Processed 2000/1123832 samples
  Processed 3000/1123832 samples
  Processed 4000/1123832 samples
  Processed 5000/1123832 samples
  Processed 6000/1123832 samples
  Processed 7000/1123832 samples
  Processed 8000/1123832 samples
  Processed 9000/1123832 samples
  Processed 10000/1123832 samples
  Processed 11000/1123832 samples
  Processed 12000/1123832 samples
  Processed 13000/1123832 samples
  Processed 14000/1123832 samples
  Processed 15000/1123832 samples
  Processed 16000/1123832 samples
  Processed 17000/1123832 samples
  Processed 18000/1123832 samples
  Processed 19000/1123832 samples
  Processed 20000/1123832 samples
  Processed 21000/1123832 samples
  Processed 22000/1123832 samples
  Processed 23000/1123832 samples
  Processed 24000/1123832 samples
  Processed 25000/1123832 samples
  Processed 26000/1123832 samples
  Processed 27000/1123832 samples
  Processed 28

In [None]:
# Generate embeddings for test data
if test_texts:
    print(f"\n Processing test data ({len(test_texts)} samples)...")
    test_start_time = time.time()
    
    test_embeddings = get_bert_embeddings(
        test_texts, 
        tokenizer, 
        model, 
        device, 
        max_length=MAX_LENGTH, 
        batch_size=BATCH_SIZE
    )
    print(f"✓ Test embeddings shape: {test_embeddings.shape}")
    
    # Create JSON data for test (no labels)
    test_json_data = create_json_data_bert(
        test_texts, 
        test_embeddings, 
        dataset_name="test set"
    )
    
    print(f" Test data processing time: {time.time() - test_start_time:.2f} seconds")
else:
    print(" No test data available")
    test_json_data = []

In [None]:
# Generate embeddings for validation data
if val_texts:
    print(f"\nProcessing validation data ({len(val_texts)} samples)...")
    val_start_time = time.time()
    
    val_embeddings = get_bert_embeddings(
        val_texts, 
        tokenizer, 
        model, 
        device, 
        max_length=MAX_LENGTH, 
        batch_size=BATCH_SIZE
    )
    print(f"✓ Val embeddings shape: {val_embeddings.shape}")
    
    # Create JSON data for validation (no labels)
    val_json_data = create_json_data_bert(
        val_texts, 
        val_embeddings, 
        dataset_name="validation set"
    )
    
    print(f"Validation data processing time: {time.time() - val_start_time:.2f} seconds")
else:
    print(" No validation data available")
    val_json_data = []

total_time = time.time() - start_time
print(f"\n Total embedding generation time: {total_time:.2f} seconds")
print(f" Average time per sample: {total_time / (len(train_texts) + len(test_texts) + len(val_texts)):.3f} seconds")

In [None]:
# Save embeddings to JSON files
print("="*60)
print("SAVING EMBEDDINGS TO JSON FILES")
print("="*60)

# Ensure output directory exists
os.makedirs(embeddings_output_path, exist_ok=True)
print(f"📁 Output directory: {embeddings_output_path}")

# Save training embeddings
if train_json_data:
    train_file_path = os.path.join(embeddings_output_path, 'train_embeddings_inlegalbert.json')
    with open(train_file_path, 'w') as f:
        json.dump(train_json_data, f, indent=2)
    print(f"Saved train_embeddings_inlegalbert.json with {len(train_json_data)} samples")
    print(f"   File size: {os.path.getsize(train_file_path) / (1024*1024):.1f} MB")

# # Save test embeddings  
# if test_json_data:
#     test_file_path = os.path.join(embeddings_output_path, 'test_embeddings_inlegalbert.json')
#     with open(test_file_path, 'w') as f:
#         json.dump(test_json_data, f, indent=2)
#     print(f"✅ Saved test_embeddings_inlegalbert.json with {len(test_json_data)} samples")
#     print(f"   File size: {os.path.getsize(test_file_path) / (1024*1024):.1f} MB")

# # Save validation embeddings
# if val_json_data:
#     val_file_path = os.path.join(embeddings_output_path, 'val_embeddings_inlegalbert.json')
#     with open(val_file_path, 'w') as f:
#         json.dump(val_json_data, f, indent=2)
#     print(f"✅ Saved val_embeddings_inlegalbert.json with {len(val_json_data)} samples")
#     print(f"   File size: {os.path.getsize(val_file_path) / (1024*1024):.1f} MB")

# Save label mapping for reference (if available)
if 'label_mapping' in locals():
    label_file_path = os.path.join(embeddings_output_path, 'label_mapping_inlegalbert.json')
    with open(label_file_path, 'w') as f:
        json.dump(label_mapping, f, indent=2)
    print(f"✅ Saved label_mapping_inlegalbert.json")

print(f"\n📂 All files saved in: {os.path.abspath(embeddings_output_path)}")