In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
import numpy as np
import pandas as pd

import torch
from transformers import AutoTokenizer, ModernBertForSequenceClassification
from torch.utils.data import DataLoader, Dataset

from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import classification_report, f1_score

!pip install tqdm -q
from tqdm import tqdm

In [3]:
file_path = '/content/drive/My Drive/metadata/4b.test.json'
df = pd.read_json(file_path)
len(df)

1963

In [4]:
df.head()

Unnamed: 0,case_id,filtered_issues,docket_entries,case_name,case_status,case_state,court_name,case_ongoing
2877,8615,"[Disparate Treatment, Direct Suit on Merits]",For docket number 34472 of case 'EEOC v. YALE ...,"EEOC v. YALE E. KEY, INC. fdba COBRA INDUSTRIE...",Coding Complete,New Mexico,District of New Mexico,No
233,11518,"[Disparate Treatment, Pay / Benefits]",For docket number 31495 of case 'Walker v. Mon...,Walker v. Monsanto Co. Pension Plan,Coding Complete,Illinois,Southern District of Illinois,No
2131,10664,"[National origin discrimination, Disparate Tre...",For docket number 32358 of case 'Latin America...,Latin American Law Enforcement Association v. ...,Approved,California,Central District of California,No
4328,8338,"[Sex discrimination, Female, Disparate Treatme...",For docket number 34771 of case 'EEOC v. UNITE...,EEOC v. UNITED RENTAL HOMES,Coding Complete,North Carolina,Eastern District of North Carolina,No
1323,10650,"[Disparate Treatment, Pay / Benefits]",For docket number 32373 of case 'Breedlove v. ...,Breedlove v. Tele-Trip Co. Inc.,Approved,Illinois,Northern District of Illinois,No


In [5]:
# Load tokenizer and model
model_name = "answerdotai/ModernBERT-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Create label binarizer
mlb = MultiLabelBinarizer()
binary_labels = mlb.fit_transform(df["filtered_issues"])

# Load model with correct number of labels
model = ModernBertForSequenceClassification.from_pretrained(
    model_name,
    num_labels=len(mlb.classes_),
    problem_type="multi_label_classification"
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
model.to(device)

model.eval()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
Some weights of ModernBertForSequenceClassification were not initialized from the model checkpoint at answerdotai/ModernBERT-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


cuda


ModernBertForSequenceClassification(
  (model): ModernBertModel(
    (embeddings): ModernBertEmbeddings(
      (tok_embeddings): Embedding(50368, 768, padding_idx=50283)
      (norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (drop): Dropout(p=0.0, inplace=False)
    )
    (layers): ModuleList(
      (0): ModernBertEncoderLayer(
        (attn_norm): Identity()
        (attn): ModernBertAttention(
          (Wqkv): Linear(in_features=768, out_features=2304, bias=False)
          (rotary_emb): ModernBertRotaryEmbedding()
          (Wo): Linear(in_features=768, out_features=768, bias=False)
          (out_drop): Identity()
        )
        (mlp_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): ModernBertMLP(
          (Wi): Linear(in_features=768, out_features=2304, bias=False)
          (act): GELUActivation()
          (drop): Dropout(p=0.0, inplace=False)
          (Wo): Linear(in_features=1152, out_features=768, bias=False)
        )
      

In [6]:
class TextDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=512):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        encoding = self.tokenizer(
            self.texts[idx],
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        item = {key: val.squeeze(0) for key, val in encoding.items()}
        item["labels"] = torch.tensor(self.labels[idx], dtype=torch.float)
        return item

# Create dataset and dataloader
dataset = TextDataset(df["docket_entries"].tolist(), binary_labels, tokenizer)
dataloader = DataLoader(dataset, batch_size=64)

In [7]:
all_preds = []
all_labels = []

with torch.no_grad():
    for batch in tqdm(dataloader, desc="Running Inference"):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        preds = (torch.sigmoid(logits) > 0.5).int()

        all_preds.append(preds.cpu())     # Move back to CPU for metrics
        all_labels.append(labels.int().cpu())


Running Inference: 100%|██████████| 31/31 [02:39<00:00,  5.13s/it]


In [8]:
# Combine all batches
y_pred = torch.cat(all_preds).cpu().numpy()
y_true = torch.cat(all_labels).cpu().numpy()

In [9]:
print(classification_report(y_true, y_pred, target_names=mlb.classes_))

                                                                                            precision    recall  f1-score   support

                                                      Access to lawyers or judicial system       0.05      0.63      0.09        87
                                                       Assault/abuse by staff (facilities)       0.00      0.00      0.00        89
                                                                                     Black       0.14      0.89      0.24       262
                                                                Classification / placement       0.05      0.99      0.09        98
Conditions of Employment (including assignment, transfer, hours, working conditions, etc.)       0.03      0.01      0.02       160
                                                                 Conditions of confinement       0.40      0.02      0.03       132
                                                                     Consti

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
