# Patent Classification with Transformers
This notebook compares two transformer-based approaches for text classification:

1. **Truncation**: Truncate text to fit the model's context window.
2. **Segmentation**: Segment text into chunks, assign same label, and use majority voting.

We'll use the `ccdv/patent-classification` dataset from HuggingFace.

### Setup: Install and import packages

In [2]:
!pip install transformers datasets accelerate wandb -q
import torch
from datasets import load_dataset
from transformers import BertTokenizer, BertForSequenceClassification, TrainingArguments, Trainer, DataCollatorWithPadding
from collections import Counter
import pandas as pd
import wandb

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/491.4 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━[0m [32m266.2/491.4 kB[0m [31m10.3 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.4/491.4 kB[0m [31m9.0 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/116.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m11.3 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/193.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m193.6/193.6 kB[0m [31m18.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m143.5/143.5 kB[0m [31m14.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━

In [3]:
# Enable Weights & Biases logging
wandb.login()

# Automatically logs everything from Trainer
from transformers.integrations import WandbCallback
wandb.init(project="patent-classification", name="bert-trunc-vs-segment", job_type="training")

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33measonwangzk[0m ([33measonwangzk-the-university-of-chicago[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [4]:
# Use GPU (A100 support)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [5]:
# Initialize wandb
wandb.init(
    project="patent-classification",
    config={
        "model_name": "bert-base-uncased",
        "batch_size": 16,
        "learning_rate": 2e-5,
        "epochs": 3,
        "max_length": 512,
        "stride": 256,
        "weight_decay": 0.01,
        "fp16": True
    }
)

### Load dataset & keep top-4 frequent labels

In [6]:
dataset = load_dataset("ccdv/patent-classification")
label_counts = Counter(dataset['train']['label'])
top4_labels = [label for label, _ in label_counts.most_common(4)]

def filter_top_labels(example):
    return example['label'] in top4_labels

train_dataset = dataset['train'].filter(filter_top_labels)
test_dataset = dataset['test'].filter(filter_top_labels)

label_to_id = {label: i for i, label in enumerate(top4_labels)}
train_dataset = train_dataset.map(lambda x: {"label": label_to_id[x["label"]]})
test_dataset = test_dataset.map(lambda x: {"label": label_to_id[x["label"]]})

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/3.25k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/194M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/39.5M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/39.1M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/5000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/5000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/25000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/5000 [00:00<?, ? examples/s]

Map:   0%|          | 0/17700 [00:00<?, ? examples/s]

Map:   0%|          | 0/3545 [00:00<?, ? examples/s]

### Method 1: Truncation approach

In [None]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model_trunc = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=4).to(device)

def tokenize_truncate(example):
    return tokenizer(example["text"], padding="max_length", truncation=True, max_length=512)

train_encoded_trunc = train_dataset.map(tokenize_truncate, batched=True)
test_encoded_trunc = test_dataset.map(tokenize_truncate, batched=True)


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Map:   0%|          | 0/17700 [00:00<?, ? examples/s]

Map:   0%|          | 0/3545 [00:00<?, ? examples/s]

In [None]:
args_trunc = TrainingArguments(
    output_dir="./results_trunc",
    eval_strategy="epoch",
    save_strategy="epoch",
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    weight_decay=0.01,
    fp16=True,
    logging_dir="./logs_trunc",
    logging_steps=100,
    report_to="wandb",
    run_name="truncation_method"
)

trainer_trunc = Trainer(
    model=model_trunc,
    args=args_trunc,
    train_dataset=train_encoded_trunc,
    eval_dataset=test_encoded_trunc,
    tokenizer=tokenizer
)

  trainer_trunc = Trainer(


In [None]:
trainer_trunc.train()

Epoch,Training Loss,Validation Loss
1,0.5996,0.596515
2,0.4207,0.597401
3,0.2287,0.720172


TrainOutput(global_step=3321, training_loss=0.46151663028702106, metrics={'train_runtime': 286.8819, 'train_samples_per_second': 185.094, 'train_steps_per_second': 11.576, 'total_flos': 1.39714479230976e+16, 'train_loss': 0.46151663028702106, 'epoch': 3.0})

### Method 2: Segmentation approach

In [7]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model_segmented = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=4).to(device)

def segment_text(examples, max_words=512, stride=256):
    segmented_texts = []
    segmented_labels = []

    for text, label in zip(examples["text"], examples["label"]):
        words = text.split()
        for i in range(0, len(words), stride):
            chunk = " ".join(words[i:i + max_words])
            if chunk.strip():
                segmented_texts.append(chunk)
                segmented_labels.append(label)

    return {
        "text": segmented_texts,
        "label": segmented_labels
    }

# Segment the dataset
train_segmented = train_dataset.map(
    segment_text,
    batched=True,
    remove_columns=train_dataset.column_names,
    load_from_cache_file=False
)
test_segmented = test_dataset.map(
    segment_text,
    batched=True,
    remove_columns=test_dataset.column_names,
    load_from_cache_file=False
)


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Map:   0%|          | 0/17700 [00:00<?, ? examples/s]

Map:   0%|          | 0/3545 [00:00<?, ? examples/s]

In [8]:
# Tokenize
def tokenize_truncate(example):
    return tokenizer(
        example["text"],
        padding="max_length",
        truncation=True,
        max_length=512
    )

train_segmented = train_segmented.map(tokenize_truncate, batched=True)
test_segmented = test_segmented.map(tokenize_truncate, batched=True)

Map:   0%|          | 0/252289 [00:00<?, ? examples/s]

Map:   0%|          | 0/50926 [00:00<?, ? examples/s]

In [9]:
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
args_seg = TrainingArguments(
    output_dir="./results_seg",
    eval_strategy="epoch",
    save_strategy="epoch",
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    weight_decay=0.01,
    fp16=True,
    logging_dir="./logs_seg",
    logging_steps=100,
    report_to="wandb",
    run_name="segmentation_method"
)

trainer_seg = Trainer(
    model=model_segmented,
    args=args_seg,
    train_dataset=train_segmented,
    eval_dataset=test_segmented,
    tokenizer=tokenizer,
    data_collator=data_collator
)

  trainer_seg = Trainer(


In [10]:
trainer_seg.train()

Epoch,Training Loss,Validation Loss
1,0.331,0.854471
2,0.1874,1.292966


Epoch,Training Loss,Validation Loss
1,0.331,0.854471
2,0.1874,1.292966
3,0.095,1.747027


TrainOutput(global_step=47307, training_loss=0.2724853465841697, metrics={'train_runtime': 4236.8312, 'train_samples_per_second': 178.64, 'train_steps_per_second': 11.166, 'total_flos': 1.9914365113391923e+17, 'train_loss': 0.2724853465841697, 'epoch': 3.0})

## Evaluation

In [None]:
trunc_result = trainer_trunc.evaluate()
print("Truncation Result:", trunc_result)

Truncation Result: {'eval_loss': 0.7201722860336304, 'eval_runtime': 7.1299, 'eval_samples_per_second': 497.201, 'eval_steps_per_second': 31.136, 'epoch': 3.0}


### tested using test_encoded_trunc

In [13]:
segment_result = trainer_seg.evaluate(eval_dataset=test_encoded_trunc)
print("Segmentation Result:", segment_result)

Segmentation Result: {'eval_loss': 1.49589204788208, 'eval_runtime': 6.6184, 'eval_samples_per_second': 535.629, 'eval_steps_per_second': 33.543, 'epoch': 3.0}


### tested using test_segmented

In [11]:
segment_result = trainer_seg.evaluate()
print("Segmentation Result:", segment_result)

Segmentation Result: {'eval_loss': 1.7470272779464722, 'eval_runtime': 99.851, 'eval_samples_per_second': 510.02, 'eval_steps_per_second': 31.878, 'epoch': 3.0}


## Evaluation Summary: Truncation vs. Segmentation

### Evaluation Results

| Method                      | Evaluation Dataset     | `eval_loss` | `eval_runtime` | `samples/sec` | `steps/sec` | Notes                                                  |
|----------------------------|------------------------|-------------|----------------|----------------|-------------|--------------------------------------------------------|
| **Truncation**             | `test_encoded_trunc`   | **0.720**   | 7.13s          | 497.20         | 31.14       | Best performance, standard truncation at 512 tokens    |
| **Segmentation **| `test_encoded_trunc`   | 1.496       | 6.62s          | 535.63         | 33.54       | Segmented text re-tokenized, weaker performance        |


### Key Insights

- **Truncation outperforms segmentation**:
  - Lower `eval_loss` indicates better model generalization.
  - Simpler and more contextually coherent inputs.

- **Segmentation leads to worse performance**:
  - Labels are repeated across overlapping chunks, introducing **label noise**.
  - Loss of full-document context disrupts semantic coherence.
  - Potential for **input redundancy** and overfitting on duplicated fragments.

- **Evaluation data impacts results**:
  - Using `test_encoded_trunc`: evaluates fewer, cleaner examples.
  - Using `test_segmented`: evaluates more but noisier and repetitive chunks.


### Recommendations

- Prefer **truncation** when input fits within the token limit.
- Use **segmentation** only if:
  - Long documents dominate the dataset.
  - You implement **segment-level prediction aggregation** (e.g., max-pooling or majority voting).
  - Training and evaluation pipelines are **consistent**.

In [14]:
wandb.finish()

0,1
eval/loss,▁▄██▆
eval/runtime,████▁
eval/samples_per_second,▂▁▂▁█
eval/steps_per_second,▂▁▂▁█
train/epoch,▁▁▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇█
train/global_step,▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▃▃▃▄▄▄▄▄▄▄▅▅▅▅▆▆▆▆▆▇▇▇██
train/grad_norm,▂▃▃▃▃▄▂▅▃▂▃▅▅▃▅▅▄▃█▂▃▂▅▅▁▁▂▄▁▇▁▁▁▆█▄▁▁▂▁
train/learning_rate,███▇▇▇▇▇▇▇▇▆▆▆▆▅▅▅▅▅▄▄▄▄▃▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁
train/loss,██▇▆▆▆▆▆▆▅▅▅▅▅▄▃▃▃▃▃▃▃▃▂▂▃▃▃▃▃▁▁▁▂▁▁▁▁▁▁

0,1
eval/loss,1.49589
eval/runtime,6.6184
eval/samples_per_second,535.629
eval/steps_per_second,33.543
total_flos,1.9914365113391923e+17
train/epoch,3.0
train/global_step,47307.0
train/grad_norm,0.1868
train/learning_rate,0.0
train/loss,0.095


## Post-processing for Segmentation


When using segmentation-based classification, each document is split into multiple segments, and each segment receives its own prediction. To determine the final label for the entire document, post-processing is required.

### Common Strategies:

- **Majority Voting**: Count the predicted labels of all segments and assign the most frequent one to the whole document.

- **Probability Averaging**: Average the predicted probabilities (e.g., softmax outputs) of all segments and select the class with the highest mean probability.

- **Max Confidence**: Choose the prediction of the segment with the highest confidence as the document label.

- **Learned Aggregation** (optional): Train an additional model to combine segment-level outputs into a final document label.

### Recommendation:

In practice, probability averaging is often preferred due to its balance of robustness and sensitivity to confidence levels. Majority voting is a simpler alternative and may be used when probabilities are not available.

If each chunk gets a label, we can use **majority voting** over chunk predictions:

In [None]:
from collections import Counter
def predict_document_class(segment_preds):
    return Counter(segment_preds).most_common(1)[0][0]