# Legal Contract Clause Classification using Stacked LSTM
## CCS 248 – Artificial Neural Networks Final Project
---

## Problem Statement

**Automated Classification of Legal Contract Clauses**

Lawyers spend hours manually reading and categorizing individual contract clauses (e.g., governing law, termination, confidentiality). This project automates that process using deep learning to classify each clause context into predefined legal categories.

## Solution: Stacked Bidirectional LSTM with Attention

Using a 2-layer bidirectional LSTM network plus an attention pooling head:
- **Bidirectional processing** — reads clauses forward and backward for full context
- **Stacked layers + attention** — captures low-level patterns and focuses on salient tokens
- **Dropout regularization** — prevents overfitting on legal jargon

## Dataset

**CUAD v1 master_clauses.csv** (flattened clause snippets)
- 1,965 snippets, 40 clause labels originally
- Filtered to 7 clause types with at least 5 examples each for stable stratification

## Target

**Test Accuracy: 50-60%** (course requirement)

**Evaluation**: Accuracy, macro F1, per-class precision/recall, confusion matrix

# 1. Setup

In [45]:
# Core data processing libraries
import numpy as np
import pandas as pd
import json
import os
import re
import ast
from datetime import datetime
from collections import Counter

# Text processing
import string
from typing import List, Dict, Tuple

# PyTorch for deep learning (avoid Keras)
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Scikit-learn for preprocessing and metrics
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    confusion_matrix, classification_report
)

# Set random seeds for reproducibility
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(RANDOM_SEED)

# Display versions
print(f"PyTorch Version: {torch.__version__}")
print(f"NumPy Version: {np.__version__}")
print(f"Pandas Version: {pd.__version__}")
print(f"Using device: {device}")

PyTorch Version: 2.9.1+cpu
NumPy Version: 2.1.3
Pandas Version: 2.2.3
Using device: cpu


# 2. Load Data

In [46]:
# Load clause snippets from CUAD XLSX sheets (label_group_xlsx)
import glob
XLSX_DIR = r"d:\\CodingRelated\\Codes.Ams\\ANNFINAL\\CUAD_v1\\label_group_xlsx"
print(f"Loading XLSX files from: {XLSX_DIR}")

def load_xlsx_snippets(xlsx_dir: str):
    rows = []
    files = glob.glob(os.path.join(xlsx_dir, "*.xlsx"))
    if not files:
        raise FileNotFoundError(f"No .xlsx files found in {xlsx_dir}")
    for path in files:
        df_x = pd.read_excel(path)
        if df_x.empty:
            continue
        clause_cols = [c for c in df_x.columns if c != df_x.columns[0]]
        for _, row in df_x.iterrows():
            for col in clause_cols:
                text = row[col]
                if pd.isna(text):
                    continue
                text = str(text).strip()
                if not text:
                    continue
                rows.append({"context": text, "clause_type": col})
    df_out = pd.DataFrame(rows).drop_duplicates().reset_index(drop=True)
    return df_out, len(files)

df, n_files = load_xlsx_snippets(XLSX_DIR)
print(f"✓ Loaded {len(df)} snippets from {n_files} XLSX files")
print(f"Unique clause types: {df['clause_type'].nunique()}")

Loading XLSX files from: d:\\CodingRelated\\Codes.Ams\\ANNFINAL\\CUAD_v1\\label_group_xlsx
✓ Loaded 8035 snippets from 28 XLSX files
Unique clause types: 47


In [47]:
# Basic dataset overview
print(df.head())
print("\nTop clause counts:")
print(df['clause_type'].value_counts().head(15))

                                             context        clause_type
0  MA may not assign, sell, lease or otherwise tr...    Anti-assignment
1  This Agreement may not be assigned, sold or tr...    Anti-assignment
2  For purposes of the preceding sentence, and wi...  Change of Control
3  Licensee shall not assign or otherwise transfe...    Anti-assignment
4  Licensee shall have the right to assign or sub...    Anti-assignment

Top clause counts:
clause_type
Parties                      505
Parties-Answer               499
Agreement Date               464
Governing Law                435
Agreement Date-Answer        424
Expiration Date              411
Effective Date               384
Anti-assignment              372
Effective Date-Answer        328
Document Name                311
Cap on Liability             275
License Grant                254
Expiration Date-Answer       249
Audit Rights                 214
Post-termination Services    182
Name: count, dtype: int64


In [48]:
# Dataset stats
print(f"Total snippets: {len(df)}")
print(f"Unique clause types: {df['clause_type'].nunique()}")
print(f"Average length (words): {df['context'].apply(lambda x: len(str(x).split())).mean():.1f}")

Total snippets: 8035
Unique clause types: 47
Average length (words): 74.6


In [49]:
# Display first few rows
print("\n" + "="*80)
print("First 5 Rows of Dataset:")
print("="*80)
print(df.head())

# Display basic statistics
print("\n" + "="*80)
print("Dataset Info:")
print("="*80)
print(df.info())


First 5 Rows of Dataset:
                                             context        clause_type
0  MA may not assign, sell, lease or otherwise tr...    Anti-assignment
1  This Agreement may not be assigned, sold or tr...    Anti-assignment
2  For purposes of the preceding sentence, and wi...  Change of Control
3  Licensee shall not assign or otherwise transfe...    Anti-assignment
4  Licensee shall have the right to assign or sub...    Anti-assignment

Dataset Info:
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 8035 entries, 0 to 8034
Data columns (total 2 columns):
 #   Column       Non-Null Count  Dtype 
---  ------       --------------  ----- 
 0   context      8035 non-null   object
 1   clause_type  8035 non-null   object
dtypes: object(2)
memory usage: 125.7+ KB
None


# 3. Data Validation

In [50]:
# Check for missing values
print("Missing values:")
print(df.isnull().sum())

print(f"\nTotal samples: {len(df)}")

Missing values:
context        0
clause_type    0
dtype: int64

Total samples: 8035


In [51]:
# Check class distribution
print("Top 10 clause types:")
print(df['clause_type'].value_counts().head(10))

Top 10 clause types:
clause_type
Parties                  505
Parties-Answer           499
Agreement Date           464
Governing Law            435
Agreement Date-Answer    424
Expiration Date          411
Effective Date           384
Anti-assignment          372
Effective Date-Answer    328
Document Name            311
Name: count, dtype: int64


# 4. Preprocessing

In [52]:
def clean_text(text):
    """Basic text cleaning"""
    if not isinstance(text, str):
        return ""
    
    text = text.lower()
    text = re.sub(r'[^a-z\s\.,;:\-]', ' ', text)
    text = re.sub(r'\s+', ' ', text)
    return text.strip()

# Test
sample = "THIS AGREEMENT is made on January 1, 2020!!!"
print("Before:", sample)
print("After:", clean_text(sample))

Before: THIS AGREEMENT is made on January 1, 2020!!!
After: this agreement is made on january ,


In [53]:
# Apply cleaning
df['cleaned_text'] = df['context'].apply(clean_text)
print("✓ Cleaned all documents")

✓ Cleaned all documents


# 5. Text Length Analysis

In [54]:
# Use cleaned text directly (clause contexts are already short)
df['sampled_text'] = df['cleaned_text']
print(f"✓ Using {len(df)} clause contexts (no truncation needed)")

✓ Using 8035 clause contexts (no truncation needed)


# 6. Tokenization

In [55]:
class CustomTokenizer:
    """Simple tokenizer - built from scratch"""
    
    def __init__(self, vocab_size=10000):
        self.vocab_size = vocab_size
        self.word_to_index = {"<OOV>": 1}
        self.word_counts = Counter()
        
    def fit_on_texts(self, texts):
        for text in texts:
            self.word_counts.update(str(text).split())
        
        most_common = self.word_counts.most_common(self.vocab_size - 2)
        for idx, (word, _) in enumerate(most_common, start=2):
            self.word_to_index[word] = idx
        
        print(f"Vocabulary size: {len(self.word_to_index)}")
    
    def texts_to_sequences(self, texts):
        sequences = []
        for text in texts:
            seq = [self.word_to_index.get(word, 1) for word in str(text).split()]
            sequences.append(seq)
        return sequences
    
    def get_vocab_size(self):
        return len(self.word_to_index)

# Tokenizer will be built after filtering to top clauses

# 7. Prepare Data for Training

In [56]:
def pad_sequences(sequences, maxlen, padding='post', value=0):
    """Pad sequences to the same length"""
    padded = np.zeros((len(sequences), maxlen), dtype=np.int32)
    for i, seq in enumerate(sequences):
        if len(seq) > maxlen:
            if padding == 'post':
                padded[i] = seq[:maxlen]
            else:
                padded[i] = seq[-maxlen:]
        else:
            if padding == 'post':
                padded[i, :len(seq)] = seq
            else:
                padded[i, -len(seq):] = seq
    return padded

In [57]:
# Select clause types with enough support to stratify
TOP_N = 20
MIN_COUNT = 5
clause_counts = df['clause_type'].value_counts()
filtered_counts = clause_counts[clause_counts >= MIN_COUNT]
top_clauses = filtered_counts.head(TOP_N).index.tolist()
df_filtered = df[df['clause_type'].isin(top_clauses)].copy()

print(f"Using {len(df_filtered)} samples before augmentation")
print(f"Top clause types (min {MIN_COUNT} per class):")
for i, (clause, count) in enumerate(filtered_counts.head(TOP_N).items(), 1):
    print(f"  {i}. {clause[:80]}... ({count} samples)")

ENABLE_AUGMENTATION = True
TARGET_MIN_PER_CLASS = 30  # desired minimum rows per class after augmentation
REPLACE_PROB = 0.25        # probability of replacing a token with a synonym
MAX_AUG_PER_CLASS = 80     # cap to avoid explosion per class

if ENABLE_AUGMENTATION:
    import random
    try:
        import nltk
        from nltk.corpus import wordnet as wn
        nltk.download('wordnet', quiet=True)
        nltk.download('omw-1.4', quiet=True)
    except Exception as e:
        wn = None
        print(f"NLTK/wordnet not available, skipping synonym augmentation: {e}")

    def get_synonyms(word):
        if wn is None:
            return []
        syns = set()
        for syn in wn.synsets(word):
            for lemma in syn.lemmas():
                candidate = lemma.name().replace('_', ' ').lower()
                if candidate.isalpha() and candidate != word.lower():
                    syns.add(candidate)
        return list(syns)

    def synonym_replace(text, replace_prob=0.2):
        tokens = str(text).split()
        new_tokens = []
        for tok in tokens:
            if random.random() < replace_prob:
                syns = get_synonyms(tok)
                if syns:
                    new_tokens.append(random.choice(syns))
                    continue
            new_tokens.append(tok)
        return " ".join(new_tokens)

    aug_rows = []
    for label, group in df_filtered.groupby('clause_type'):
        current_count = len(group)
        if current_count >= TARGET_MIN_PER_CLASS:
            continue
        needed = min(TARGET_MIN_PER_CLASS - current_count, MAX_AUG_PER_CLASS)
        pool = group['sampled_text'].tolist()
        for i in range(needed):
            base_text = pool[i % len(pool)]
            aug_text = synonym_replace(base_text, replace_prob=REPLACE_PROB)
            aug_rows.append({
                'context': aug_text,
                'clause_type': label,
                'cleaned_text': aug_text,
                'sampled_text': aug_text,
            })

    if aug_rows:
        df_aug = pd.DataFrame(aug_rows)
        df_filtered = pd.concat([df_filtered, df_aug], ignore_index=True)
        print(f"Applied augmentation: +{len(aug_rows)} synthetic rows")
    else:
        print("No augmentation applied (all classes already above target or wordnet unavailable)")

print(f"Total samples after augmentation: {len(df_filtered)}")

# Build tokenizer on filtered (and possibly augmented) data with smaller vocab to limit noise
tokenizer = CustomTokenizer(vocab_size=10000)
tokenizer.fit_on_texts(df_filtered['sampled_text'])

# Tokenize filtered data
sequences_filtered = tokenizer.texts_to_sequences(df_filtered['sampled_text'])

# Length stats and padding length
sequence_lengths = [len(seq) for seq in sequences_filtered]
percentile_len = int(np.percentile(sequence_lengths, 85))
MAX_LENGTH = min(percentile_len, 160)
print(f"Sequence length percentile(85th): {percentile_len}")
print(f"Max sequence length used: {MAX_LENGTH} (capped at 160)")

# Pad filtered sequences
X_filtered = pad_sequences(sequences_filtered, maxlen=MAX_LENGTH, padding='post')
print(f"Padded shape (filtered): {X_filtered.shape}")

Using 6175 samples before augmentation
Top clause types (min 5 per class):
  1. Parties... (505 samples)
  2. Parties-Answer... (499 samples)
  3. Agreement Date... (464 samples)
  4. Governing Law... (435 samples)
  5. Agreement Date-Answer... (424 samples)
  6. Expiration Date... (411 samples)
  7. Effective Date... (384 samples)
  8. Anti-assignment... (372 samples)
  9. Effective Date-Answer... (328 samples)
  10. Document Name... (311 samples)
  11. Cap on Liability... (275 samples)
  12. License Grant... (254 samples)
  13. Expiration Date-Answer... (249 samples)
  14. Audit Rights... (214 samples)
  15. Post-termination Services... (182 samples)
  16. Termination for Convenience... (181 samples)
  17. Exclusivity... (180 samples)
  18. Renewal Term... (175 samples)
  19. Revenue-Profit Sharing... (166 samples)
  20. Insurance... (166 samples)
No augmentation applied (all classes already above target or wordnet unavailable)
Total samples after augmentation: 6175
Vocabulary size: 

In [58]:
# Diagnostic: OOV rate on filtered sequences
# OOV token id is 1 in the tokenizer
all_tokens = sum(len(seq) for seq in sequences_filtered)
oov_tokens = sum(sum(1 for t in seq if t == 1) for seq in sequences_filtered)
oov_pct = 100 * oov_tokens / max(1, all_tokens)
print(f"OOV tokens: {oov_tokens} / {all_tokens} ({oov_pct:.2f}%)")

OOV tokens: 1746 / 349594 (0.50%)


In [59]:
# Encode labels after filtering
df_filtered = df_filtered.reset_index(drop=True)
label_encoder = LabelEncoder()
y_encoded = label_encoder.fit_transform(df_filtered['clause_type'])
num_classes = len(label_encoder.classes_)
print(f"Labels shape: {y_encoded.shape}")
print(f"Classes: {label_encoder.classes_}")

Labels shape: (6175,)
Classes: ['Agreement Date' 'Agreement Date-Answer' 'Anti-assignment' 'Audit Rights'
 'Cap on Liability' 'Document Name' 'Effective Date'
 'Effective Date-Answer' 'Exclusivity' 'Expiration Date'
 'Expiration Date-Answer' 'Governing Law' 'Insurance' 'License Grant'
 'Parties' 'Parties-Answer' 'Post-termination Services' 'Renewal Term'
 'Revenue-Profit Sharing' 'Termination for Convenience']


In [60]:
# Clause counts summary (safe to run after data/filter cells)
if 'df' not in globals() or 'top_clauses' not in globals():
    print("Please run the data load and filtering cells first.")
else:
    print(f"Total clause types loaded: {df['clause_type'].nunique()}")
    print(f"Clause types kept after filtering: {len(top_clauses)}")
    print("Kept clause types:", top_clauses)

Total clause types loaded: 47
Clause types kept after filtering: 20
Kept clause types: ['Parties', 'Parties-Answer', 'Agreement Date', 'Governing Law', 'Agreement Date-Answer', 'Expiration Date', 'Effective Date', 'Anti-assignment', 'Effective Date-Answer', 'Document Name', 'Cap on Liability', 'License Grant', 'Expiration Date-Answer', 'Audit Rights', 'Post-termination Services', 'Termination for Convenience', 'Exclusivity', 'Renewal Term', 'Revenue-Profit Sharing', 'Insurance']


In [61]:
# TF-IDF + Logistic Regression baseline (quick sanity check)
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report

texts = df_filtered['sampled_text'].astype(str).tolist()
labels = y_encoded

print('Building TF-IDF matrix...')
vect = TfidfVectorizer(max_features=20000, ngram_range=(1,2))
X_tfidf = vect.fit_transform(texts)

# Split and train a simple linear classifier
X_tr, X_te, y_tr, y_te = train_test_split(X_tfidf, labels, test_size=0.30, random_state=42, stratify=labels)
clf = LogisticRegression(max_iter=2000, solver='lbfgs', multi_class='multinomial')
clf.fit(X_tr, y_tr)
acc = clf.score(X_te, y_te)
print(f"TF-IDF Logistic accuracy (test): {acc:.4f}")

# Print detailed per-class report
y_pred = clf.predict(X_te)
print('\nClassification report:')
print(classification_report(y_te, y_pred, digits=4))

Building TF-IDF matrix...




TF-IDF Logistic accuracy (test): 0.7890

Classification report:
              precision    recall  f1-score   support

           0     0.6150    0.8849    0.7257       139
           1     0.3884    1.0000    0.5595       127
           2     0.9813    0.9375    0.9589       112
           3     0.9538    0.9688    0.9612        64
           4     0.9630    0.9398    0.9512        83
           5     1.0000    0.9140    0.9551        93
           6     0.5405    0.1739    0.2632       115
           7     0.0000    0.0000    0.0000        98
           8     0.8810    0.6852    0.7708        54
           9     0.7315    0.8862    0.8015       123
          10     1.0000    0.0133    0.0263        75
          11     1.0000    0.9847    0.9923       131
          12     1.0000    0.9600    0.9796        50
          13     0.7978    0.9342    0.8606        76
          14     1.0000    0.9934    0.9967       152
          15     0.9660    0.9467    0.9562       150
          16     

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


# 8. Train/Val/Test Split

In [62]:
# Split data: 70% train, 15% val, 15% test
X_train, X_temp, y_train, y_temp = train_test_split(
    X_filtered, y_encoded, test_size=0.30, random_state=42, stratify=y_encoded
)

X_val, X_test, y_val, y_test = train_test_split(
    X_temp, y_temp, test_size=0.50, random_state=42, stratify=y_temp
)

print(f"Train: {X_train.shape}")
print(f"Val: {X_val.shape}")
print(f"Test: {X_test.shape}")

# Class weights to handle imbalance (toggle with USE_CLASS_WEIGHTS)
class_counts = np.bincount(y_train, minlength=num_classes)
class_weights = 1.0 / (class_counts + 1e-6)
class_weights = class_weights * (num_classes / class_weights.sum())
print("Class counts:", class_counts)
print("Class weights (normalized):", class_weights)
class_weights_tensor = torch.tensor(class_weights, dtype=torch.float).to(device)
USE_CLASS_WEIGHTS = True
USE_SAMPLER = True

class ClauseDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.long)
        self.y = torch.tensor(y, dtype=torch.long)
    def __len__(self):
        return len(self.X)
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

train_dataset = ClauseDataset(X_train, y_train)
val_dataset = ClauseDataset(X_val, y_val)
test_dataset = ClauseDataset(X_test, y_test)


Train: (4322, 99)
Val: (926, 99)
Test: (927, 99)
Class counts: [325 297 260 150 192 218 269 230 126 288 174 304 116 178 353 349 127 123
 116 127]
Class weights (normalized): [0.57139558 0.62526453 0.71424448 1.23802376 0.96720607 0.85185121
 0.69034783 0.8074068  1.47383781 0.64480405 1.06726187 0.61086699
 1.60089279 1.04327845 0.52607242 0.53210191 1.46223279 1.50978507
 1.60089279 1.46223279]


# 9. Build Model

In [63]:
class LSTMClassifier(nn.Module):
    """Bidirectional stacked LSTM with attention for clause classification"""
    def __init__(self, vocab_size, embed_dim=200, lstm_1=128, lstm_2=96, dropout=0.25, num_classes=10):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size + 1, embed_dim, padding_idx=0)
        self.lstm1 = nn.LSTM(embed_dim, lstm_1, batch_first=True, bidirectional=True)
        self.dropout1 = nn.Dropout(dropout)
        self.lstm2 = nn.LSTM(lstm_1 * 2, lstm_2, batch_first=True, bidirectional=True)
        self.attn = nn.Linear(lstm_2 * 2, 1)
        self.dropout2 = nn.Dropout(dropout)
        self.fc = nn.Linear(lstm_2 * 2, num_classes)
    
    def forward(self, x):
        x = self.embedding(x)
        x, _ = self.lstm1(x)
        x = self.dropout1(x)
        x, _ = self.lstm2(x)
        scores = torch.tanh(self.attn(x))
        weights = torch.softmax(scores, dim=1)
        context = (x * weights).sum(dim=1)
        context = self.dropout2(context)
        return self.fc(context)

VOCAB_SIZE = len(tokenizer.word_to_index)
NUM_CLASSES = num_classes
print(f"Vocab: {VOCAB_SIZE}, Classes: {NUM_CLASSES}, Max length: {MAX_LENGTH}")

Vocab: 9999, Classes: 20, Max length: 99


# 10. Hyperparameter Tuning Setup

Testing different optimizers as required by the course.

In [64]:
# Configurations to test - tuned for faster convergence with attention
configs = [
    {'opt': 'Adam',    'lr': 0.0008, 'wd': 1e-4, 'batch': 64, 'epochs': 5},
    {'opt': 'Adam',    'lr': 0.0010, 'wd': 1e-4, 'batch': 64, 'epochs': 10},
    {'opt': 'Adam',    'lr': 0.0005, 'wd': 1e-4, 'batch': 64, 'epochs': 5},
    {'opt': 'RMSprop', 'lr': 0.0008, 'wd': 0.0,  'batch': 64, 'epochs': 10},
    {'opt': 'RMSprop', 'lr': 0.0005, 'wd': 0.0,  'batch': 64, 'epochs': 5},
]

print(f"Will test {len(configs)} configurations")

Will test 5 configurations


# 11. Training

In [65]:
results = []
models_dir = r'd:\CodingRelated\Codes.Ams\ANNFINAL\trained_models_run5'
os.makedirs(models_dir, exist_ok=True)

In [66]:
def run_epoch(model, loader, criterion, optimizer=None):
    model.train() if optimizer else model.eval()
    total_loss, total_correct, total_samples = 0.0, 0, 0
    for batch_idx, (X_batch, y_batch) in enumerate(loader):
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        if optimizer:
            optimizer.zero_grad()
        outputs = model(X_batch)
        loss = criterion(outputs, y_batch)
        if optimizer:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
        total_loss += loss.item() * X_batch.size(0)
        preds = torch.argmax(outputs, dim=1)
        total_correct += (preds == y_batch).sum().item()
        total_samples += X_batch.size(0)
        
        # Progress indicator every 50 batches
        if optimizer and batch_idx % 50 == 0:
            print(f"  Batch {batch_idx}/{len(loader)}", end='\r')
    
    avg_loss = total_loss / total_samples
    avg_acc = total_correct / total_samples
    return avg_loss, avg_acc

def save_model_as_h5(model, filepath):
    """Save PyTorch model weights to HDF5 format"""
    import h5py
    state_dict = model.state_dict()
    with h5py.File(filepath, 'w') as f:
        for key, value in state_dict.items():
            f.create_dataset(key, data=value.cpu().numpy())

results = []
models_dir = r'd:\CodingRelated\Codes.Ams\ANNFINAL\trained_models_run5'
os.makedirs(models_dir, exist_ok=True)

for i, cfg in enumerate(configs, 1):
    print(f"\n{'='*60}")
    print(f"Config {i}/{len(configs)}: {cfg['opt']}, LR={cfg['lr']}, WD={cfg['wd']}")
    print('='*60)
    
    model = LSTMClassifier(VOCAB_SIZE, embed_dim=200, num_classes=NUM_CLASSES).to(device)
    print(f"Model created, starting training...")
    
    if cfg['opt'] == 'Adam':
        optimizer = optim.Adam(model.parameters(), lr=cfg['lr'], weight_decay=cfg.get('wd', 0.0))
    elif cfg['opt'] == 'RMSprop':
        optimizer = optim.RMSprop(model.parameters(), lr=cfg['lr'], weight_decay=cfg.get('wd', 0.0))
    else:
        optimizer = optim.SGD(model.parameters(), lr=cfg['lr'], momentum=0.9, weight_decay=cfg.get('wd', 0.0))
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2)
    criterion = nn.CrossEntropyLoss(weight=class_weights_tensor if USE_CLASS_WEIGHTS else None)
    
    if USE_SAMPLER:
        sample_weights = class_weights_tensor.cpu().numpy()[y_train]
        train_sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)
        train_loader = DataLoader(train_dataset, batch_size=cfg['batch'], sampler=train_sampler)
    else:
        train_loader = DataLoader(train_dataset, batch_size=cfg['batch'], shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=cfg['batch'], shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=cfg['batch'], shuffle=False)
    
    print(f"Training batches: {len(train_loader)}, Val batches: {len(val_loader)}")
    
    # Early stopping
    best_val_loss = float('inf')
    patience_counter = 0
    patience = 6
    
    for epoch in range(cfg['epochs']):
        train_loss, train_acc = run_epoch(model, train_loader, criterion, optimizer)
        val_loss, val_acc = run_epoch(model, val_loader, criterion, optimizer=None)
        scheduler.step(val_loss)
        print(f"Epoch {epoch+1}/{cfg['epochs']} - Train loss {train_loss:.4f}, acc {train_acc:.4f} | Val loss {val_loss:.4f}, acc {val_acc:.4f}")
        
        # Early stopping check
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping at epoch {epoch+1}")
                break
    
    # Quick val prediction distribution
    model.eval()
    with torch.no_grad():
        all_val_preds = []
        for Xb, _ in val_loader:
            Xb = Xb.to(device)
            preds = model(Xb).argmax(dim=1).cpu().numpy()
            all_val_preds.extend(preds)
    from collections import Counter
    pred_dist = Counter(all_val_preds)
    print(f"Val pred distribution: {pred_dist}")
    
    # Evaluate
    test_loss, test_acc = run_epoch(model, test_loader, criterion, optimizer=None)
    results.append({
        'config': i,
        'optimizer': cfg['opt'],
        'lr': cfg['lr'],
        'wd': cfg.get('wd', 0.0),
        'batch_size': cfg['batch'],
        'train_acc': train_acc,
        'val_acc': val_acc,
        'test_acc': test_acc
    })
    print(f"Test accuracy: {test_acc:.4f}")
    
    # Save model in both PyTorch (.pt) and HDF5 (.h5) formats
    pt_path = os.path.join(models_dir, f'model_{i}.pt')
    h5_path = os.path.join(models_dir, f'model_{i}.h5')
    torch.save(model.state_dict(), pt_path)
    save_model_as_h5(model, h5_path)
    print(f"Saved: {pt_path} and {h5_path}")
    
    del model
    torch.cuda.empty_cache()

print("\n✓ Training complete!")


Config 1/5: Adam, LR=0.0008, WD=0.0001
Model created, starting training...
Training batches: 68, Val batches: 15
Epoch 1/5 - Train loss 2.4523, acc 0.1564 | Val loss 1.9190, acc 0.2127
Epoch 2/5 - Train loss 1.3793, acc 0.4472 | Val loss 1.3115, acc 0.4708
Epoch 3/5 - Train loss 0.8800, acc 0.6333 | Val loss 0.8761, acc 0.6339
Epoch 4/5 - Train loss 0.5792, acc 0.7411 | Val loss 0.6638, acc 0.7322
Epoch 5/5 - Train loss 0.4481, acc 0.7760 | Val loss 0.6295, acc 0.7149
Val pred distribution: Counter({np.int64(7): 152, np.int64(6): 105, np.int64(15): 81, np.int64(14): 75, np.int64(11): 65, np.int64(9): 58, np.int64(2): 53, np.int64(5): 41, np.int64(4): 40, np.int64(3): 39, np.int64(8): 37, np.int64(16): 32, np.int64(19): 32, np.int64(18): 30, np.int64(13): 29, np.int64(12): 23, np.int64(17): 21, np.int64(0): 13})
Test accuracy: 0.7001
Saved: d:\CodingRelated\Codes.Ams\ANNFINAL\trained_models_run5\model_1.pt and d:\CodingRelated\Codes.Ams\ANNFINAL\trained_models_run5\model_1.h5

Config 2

# 12. Results

In [67]:
# Save results
results_df = pd.DataFrame(results)
results_df.to_csv(r'd:\CodingRelated\Codes.Ams\ANNFINAL\experiment_results_run2.csv', index=False)

print("All Results:")
print(results_df)

All Results:
   config optimizer      lr      wd  batch_size  train_acc   val_acc  test_acc
0       1      Adam  0.0008  0.0001          64   0.776030  0.714903  0.700108
1       2      Adam  0.0010  0.0001          64   0.835956  0.762419  0.723840
2       3      Adam  0.0005  0.0001          64   0.713559  0.671706  0.652643
3       4   RMSprop  0.0008  0.0000          64   0.839658  0.754860  0.743258
4       5   RMSprop  0.0005  0.0000          64   0.801249  0.734341  0.718447


In [68]:
# Best model
best_idx = results_df['test_acc'].idxmax()
best = results_df.iloc[best_idx]

print("="*60)
print("BEST MODEL")
print("="*60)
print(f"Optimizer: {best['optimizer']}")
print(f"Learning Rate: {best['lr']}")
print(f"Test Accuracy: {best['test_acc']:.2%}")

if best['test_acc'] >= 0.50:
    print("\n✓ Meets 50% requirement!")
else:
    print("\n✗ Below 50%")

best_model_path = os.path.join(models_dir, f"model_{best_idx + 1}.pt")

BEST MODEL
Optimizer: RMSprop
Learning Rate: 0.0008
Test Accuracy: 74.33%

✓ Meets 50% requirement!


In [69]:
# Artifact paths for this run
ARTIFACTS_DIR = r'd:\CodingRelated\Codes.Ams\ANNFINAL\artifacts_run5'
os.makedirs(ARTIFACTS_DIR, exist_ok=True)

# Persist tokenizer and label encoder classes
with open(os.path.join(ARTIFACTS_DIR, 'tokenizer_word_index.json'), 'w', encoding='utf-8') as f:
    json.dump(tokenizer.word_to_index, f)
np.save(os.path.join(ARTIFACTS_DIR, 'label_classes.npy'), label_encoder.classes_)

print(f"Artifacts directory: {ARTIFACTS_DIR}")

Artifacts directory: d:\CodingRelated\Codes.Ams\ANNFINAL\artifacts_run5


# 13. Model Evaluation

In [70]:
# Load best model (match training embed_dim)
best_model = LSTMClassifier(VOCAB_SIZE, embed_dim=200, num_classes=NUM_CLASSES).to(device)
best_model.load_state_dict(torch.load(best_model_path, map_location=device))
best_model.eval()

# Get predictions
X_test_tensor = torch.tensor(X_test, dtype=torch.long).to(device)
with torch.no_grad():
    y_pred = best_model(X_test_tensor).cpu().numpy()

y_pred_classes = np.argmax(y_pred, axis=1)
y_true_classes = y_test

print(f"Loaded best model from: model_{best_idx + 1}.pt")

Loaded best model from: model_4.pt


In [71]:
# Confusion matrix - save and print
cm = confusion_matrix(y_true_classes, y_pred_classes)
cm_df = pd.DataFrame(cm, index=label_encoder.classes_, columns=label_encoder.classes_)

cm_path = os.path.join(ARTIFACTS_DIR, 'confusion_matrix.csv')
cm_df.to_csv(cm_path)

print("\nConfusion Matrix:")
print(cm)
print(f"\nSaved confusion matrix to: {cm_path}")
print(f"\nAccuracy per class:")
for i, class_name in enumerate(label_encoder.classes_):
    class_acc = cm[i, i] / cm[i].sum() if cm[i].sum() > 0 else 0
    print(f"{class_name}: {class_acc:.2%}")


Confusion Matrix:
[[55  0  0  0  0  1 14  0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0 64  0  0  0  0  0  0  0  0  0]
 [ 0  0 52  0  0  0  0  0  0  0  0  0  0  0  0  0  2  0  0  2]
 [ 0  0  0 31  0  0  0  0  0  0  0  0  0  0  0  0  0  1  0  0]
 [ 0  0  1  0 34  0  0  0  2  0  0  0  0  0  0  0  3  0  1  0]
 [ 1  0  0  0  0 43  1  0  0  0  0  0  0  0  0  1  0  0  0  0]
 [34  0  0  0  0  0 20  0  0  4  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0 49  0  0  0  0  0  0  0  0  0]
 [ 0  0  1  0  0  0  0  0 22  0  0  0  0  4  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  5  0  0 44  0  0  0  1  0  0  1  8  0  3]
 [ 0  0  0  0  0  0  0  0  0  0 37  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  0 66  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  0  0 23  0  0  0  1  0  1  0]
 [ 0  0  0  0  1  0  0  0  9  0  0  0  0 27  0  0  1  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  0  0  0  0 76  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  0  

In [72]:
# Classification report - save to artifacts
print("\nClassification Report:")
report = classification_report(
    y_true_classes,
    y_pred_classes,
    target_names=label_encoder.classes_,
    output_dict=True,
    zero_division=0,
)
report_df = pd.DataFrame(report).T
print(report_df)

report_path = os.path.join(ARTIFACTS_DIR, 'classification_report.csv')
report_df.to_csv(report_path)
print(f"\nSaved classification report to: {report_path}")


Classification Report:
                             precision    recall  f1-score     support
Agreement Date                0.611111  0.785714  0.687500   70.000000
Agreement Date-Answer         0.000000  0.000000  0.000000   64.000000
Anti-assignment               0.945455  0.928571  0.936937   56.000000
Audit Rights                  0.911765  0.968750  0.939394   32.000000
Cap on Liability              0.971429  0.829268  0.894737   41.000000
Document Name                 0.977273  0.934783  0.955556   46.000000
Effective Date                0.500000  0.344828  0.408163   58.000000
Effective Date-Answer         0.000000  0.000000  0.000000   49.000000
Exclusivity                   0.578947  0.814815  0.676923   27.000000
Expiration Date               0.846154  0.709677  0.771930   62.000000
Expiration Date-Answer        0.246667  1.000000  0.395722   37.000000
Governing Law                 1.000000  1.000000  1.000000   66.000000
Insurance                     1.000000  0.920000  0.9