# 🧠 Clause Classification Training
Train LegalBERT on your CUAD-processed dataset (`clauses.csv`).

In [37]:
import numpy
import transformers
import datasets

print("✅ NumPy:", numpy.__version__)
print("✅ Transformers:", transformers.__version__)
print("✅ Datasets:", datasets.__version__)


✅ NumPy: 2.0.2
✅ Transformers: 4.53.0
✅ Datasets: 2.14.4


In [38]:
import json
import pandas as pd
from collections import defaultdict

# ✅ Step 1: Load CUAD and extract clause → cleaned labels mapping
def load_multi_label_clauses(json_path):
    with open(json_path, "r") as f:
        data = json.load(f)

    clause_label_map = defaultdict(set)

    for item in data["data"]:
        for para in item["paragraphs"]:
            for qa in para["qas"]:
                # 🔧 Extract and clean label
                raw_label = qa["question"].lower().strip()
                if "related to" in raw_label:
                    label = raw_label.split("related to")[1].split("that")[0].strip()
                    label = label.replace('"', '').strip()  # remove quotes
                    for answer in qa["answers"]:
                        text = answer.get("text", "").strip()
                        if text:
                            clause_label_map[text].add(label)

    # Build DataFrame
    rows = [{"clause_text": k, "labels": list(v)} for k, v in clause_label_map.items()]
    df = pd.DataFrame(rows)
    return df

# ✅ Load, filter, and save
df = load_multi_label_clauses("/content/CUADv1.json")
df = df[df["labels"].map(len) > 0].reset_index(drop=True)
df.to_csv("/content/clauses_multilabel_raw.csv", index=False)

# ✅ Preview
print(df.iloc[0]["clause_text"])
print(df.iloc[0]["labels"])


DISTRIBUTOR AGREEMENT
['document name']


In [39]:
print(df["labels"].head(10))


0                      [document name]
1                            [parties]
2                            [parties]
3                            [parties]
4                            [parties]
5                            [parties]
6                     [agreement date]
7    [effective date, expiration date]
8                     [effective date]
9                       [renewal term]
Name: labels, dtype: object


In [40]:
from sklearn.preprocessing import MultiLabelBinarizer

# ✅ 1. Define your clean label set (drop metadata & rare labels)
keep_labels = [
    "governing law",
    "audit rights",
    "cap on liability",
    "revenue/profit sharing",
    "license grant",
    "termination for convenience",
    "post-termination services",
    "insurance",
    "minimum commitment",
    "anti-assignment"
]

# ✅ 2. Clean and filter labels
df["labels"] = df["labels"].map(lambda lst: [l.lower().strip().replace('"', '') for l in lst])
df["labels"] = df["labels"].map(lambda lst: [l for l in lst if l in keep_labels])
df = df[df["labels"].map(len) > 0].reset_index(drop=True)

# ✅ 3. Multi-hot encode using MultiLabelBinarizer
mlb = MultiLabelBinarizer()
multi_hot = mlb.fit_transform(df["labels"])
df["labels"] = list(multi_hot)

# ✅ 4. Save label list for later
label_names = mlb.classes_.tolist()
print(f"✅ Final Labels: {len(label_names)}")
print(f"✅ Final Clauses: {len(df)}")


✅ Final Labels: 10
✅ Final Clauses: 5231


In [41]:
import numpy as np

label_counts = np.array(multi_hot).sum(axis=0)
for label, count in zip(mlb.classes_, label_counts):
    print(f"{label:40s} → {count}")


anti-assignment                          → 648
audit rights                             → 642
cap on liability                         → 670
governing law                            → 454
insurance                                → 560
license grant                            → 768
minimum commitment                       → 424
post-termination services                → 450
revenue/profit sharing                   → 416
termination for convenience              → 243


In [42]:
from transformers import AutoTokenizer
from datasets import Dataset, DatasetDict
from sklearn.model_selection import train_test_split

tokenizer = AutoTokenizer.from_pretrained("nlpaueb/legal-bert-base-uncased")

# Split train/test
train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)

from torch.utils.data import Dataset, DataLoader
import torch

class ClauseDataset(Dataset):
    def __init__(self, df, tokenizer, max_len=256):
        self.texts = df["clause_text"].tolist()
        self.labels = df["labels"].tolist()
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        enc = self.tokenizer(
            self.texts[idx],
            truncation=True,
            padding='max_length',
            max_length=self.max_len,
            return_tensors="pt"
        )
        item = {k: v.squeeze(0) for k, v in enc.items()}
        item["labels"] = torch.tensor(self.labels[idx], dtype=torch.float32)
        return item





In [43]:

train_dataset = ClauseDataset(train_df, tokenizer)
test_dataset = ClauseDataset(test_df, tokenizer)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=8)


In [44]:
from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained(
    "nlpaueb/legal-bert-base-uncased",
    num_labels=len(label_names),
    problem_type="multi_label_classification"
).to("cuda")


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at nlpaueb/legal-bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [45]:
from torch.optim import AdamW
from torch.nn import BCEWithLogitsLoss

optimizer = AdamW(model.parameters(), lr=2e-5)
loss_fn = BCEWithLogitsLoss()

for epoch in range(3):
    model.train()
    total_loss = 0
    for batch in train_loader:
        inputs = {k: v.to("cuda") for k, v in batch.items()}
        outputs = model(**inputs)
        loss = outputs.loss
        total_loss += loss.item()

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    print(f"Epoch {epoch+1} loss: {total_loss/len(train_loader):.4f}")


Epoch 1 loss: 0.1900
Epoch 2 loss: 0.0529
Epoch 3 loss: 0.0277


In [46]:
from sklearn.metrics import classification_report
import numpy as np

model.eval()
all_preds, all_labels = [], []

with torch.no_grad():
    for batch in test_loader:
        inputs = {k: v.to("cuda") for k, v in batch.items()}
        logits = model(**inputs).logits
        preds = (torch.sigmoid(logits) > 0.5).int().cpu().numpy()
        labels = inputs["labels"].cpu().numpy()

        all_preds.extend(preds)
        all_labels.extend(labels)

print(classification_report(all_labels, all_preds, target_names=label_names))


                             precision    recall  f1-score   support

            anti-assignment       0.99      0.96      0.97       122
               audit rights       0.94      0.99      0.97       128
           cap on liability       0.99      0.98      0.98       135
              governing law       0.98      0.98      0.98        92
                  insurance       0.98      0.98      0.98       109
              license grant       0.97      0.97      0.97       150
         minimum commitment       0.98      0.91      0.94        92
  post-termination services       0.85      0.86      0.86        87
     revenue/profit sharing       0.97      0.96      0.96        93
termination for convenience       0.98      0.91      0.95        47

                  micro avg       0.96      0.96      0.96      1055
                  macro avg       0.96      0.95      0.96      1055
               weighted avg       0.96      0.96      0.96      1055
                samples avg     

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [47]:
from transformers import AutoTokenizer

# Set your output directory
output_dir = "saved_model/multilabel_legalbert"

# Save model


# Save tokenizer
tokenizer.save_pretrained(output_dir)


('saved_model/multilabel_legalbert/tokenizer_config.json',
 'saved_model/multilabel_legalbert/special_tokens_map.json',
 'saved_model/multilabel_legalbert/vocab.txt',
 'saved_model/multilabel_legalbert/added_tokens.json',
 'saved_model/multilabel_legalbert/tokenizer.json')

In [49]:
import json

label_names = mlb.classes_.tolist()  # From MultiLabelBinarizer
with open(f"{output_dir}/label_names.json", "w") as f:
    json.dump(label_names, f)


In [50]:
import shap
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import numpy as np

# ✅ Load model + tokenizer
model_path = "/content/saved_model/multilabel_legalbert"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path)
model.eval()

# ✅ Get label names
label_names = [
      "anti-assignment",
    "audit rights",
    "cap on liability",
    "governing law",
    "insurance",
    "license grant",
    "minimum commitment",
    "post-termination services",
    "revenue/profit sharing",
    "termination for convenience"]

# ✅ Prediction wrapper
def predict_proba(texts):
    # SHAP sometimes sends as np.array([text]), so flatten
    if isinstance(texts, np.ndarray):
        texts = texts.tolist()

    # Just to be safe: convert all items to str
    texts = [str(t) for t in texts]

    # Tokenize
    inputs = tokenizer(texts, padding=True, truncation=True, max_length=256, return_tensors="pt")

    # Predict
    with torch.no_grad():
        outputs = model(**inputs)
        probs = torch.sigmoid(outputs.logits).cpu().numpy()

    return probs



# ✅ SHAP explainer
explainer = shap.Explainer(predict_proba, tokenizer)

# ✅ Sample clause
sample_clause = "Channel Partner accepts iPass as the exclusive provider to Channel Partner for all services of the nature of the Services. In no event may Channel Partner resell or otherwise provide the Service to any third party for purposes of further down channel resale of the Services, absent iPass notice and consent."

# ✅ Get SHAP values
shap_values = explainer([sample_clause])

# ✅ Plot SHAP values per class
for i, label in enumerate(label_names):
    print(f"\n🔍 Explaining label: {label}")
    shap.plots.text(shap_values[0, :, i])  # Single input, ith label



  0%|          | 0/498 [00:00<?, ?it/s]

PartitionExplainer explainer: 2it [00:56, 56.06s/it]               


🔍 Explaining label: anti-assignment






🔍 Explaining label: audit rights



🔍 Explaining label: cap on liability



🔍 Explaining label: exclusivity



🔍 Explaining label: governing law



🔍 Explaining label: insurance



🔍 Explaining label: license grant



🔍 Explaining label: minimum commitment



🔍 Explaining label: post-termination services



🔍 Explaining label: revenue/profit sharing


In [52]:
!zip -r /content/multilabel_legalbert.zip /content/saved_model/multilabel_legalbert



  adding: content/saved_model/multilabel_legalbert/ (stored 0%)
  adding: content/saved_model/multilabel_legalbert/label_names.json (deflated 35%)
  adding: content/saved_model/multilabel_legalbert/special_tokens_map.json (deflated 42%)
  adding: content/saved_model/multilabel_legalbert/vocab.txt (deflated 51%)
  adding: content/saved_model/multilabel_legalbert/tokenizer.json (deflated 71%)
  adding: content/saved_model/multilabel_legalbert/model.safetensors (deflated 7%)
  adding: content/saved_model/multilabel_legalbert/tokenizer_config.json (deflated 75%)
  adding: content/saved_model/multilabel_legalbert/config.json (deflated 58%)
