In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
%pip install nupunkt -q

import ast
from tqdm import tqdm

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from nupunkt import sent_tokenize

import torch
from transformers import AutoTokenizer, AutoModel

In [3]:
MODEL_NAME = "answerdotai/ModernBERT-base"

TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME)
MODEL = AutoModel.from_pretrained(MODEL_NAME)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [4]:
df = pd.read_json('/content/drive/MyDrive/metadata/GENDER.json')

df.head()

Unnamed: 0,case_id,clean_content,chunks,num_chunks,label
0,14,"The case ""Rooks & U.S. v. Herring"" from Alabam...","['The case ""Rooks & U.S. v. Herring"" from Alab...",1,0
1,70,"The case ""U.S. v. Dooly County"" from Georgia w...","['The case ""U.S. v. Dooly County"" from Georgia...",1,0
2,148,"The case ""Ahrens v. Thomas"" from Missouri was ...","['The case ""Ahrens v. Thomas"" from Missouri wa...",5,0
3,303,"The case ""Steven L. v. Kern County"" from Calif...","['The case ""Steven L. v. Kern County"" from Cal...",1,0
4,313,"The case ""Florida D.J.J. v. C.A"" from Florida ...","['The case ""Florida D.J.J. v. C.A"" from Flori...",1,0


In [5]:
len(df)

240

In [6]:
MODEL.to('cuda')

BATCH_SIZE = 8

for index, row in tqdm(df.iterrows(), total=len(df)):
    case_id = row["case_id"]
    #print(f"Processing: ", index)

    chunks = ast.literal_eval(row["chunks"])

    all_embeddings = []

    for i in range(0, len(chunks), BATCH_SIZE):
        batch = chunks[i:i + BATCH_SIZE]

        # Tokenize and move inputs to GPU
        inputs = TOKENIZER(batch, max_length=8192, padding=True, truncation=True, return_tensors='pt')
        inputs = {k: v.to('cuda') for k, v in inputs.items()}

        # Run model on batch
        with torch.no_grad():
            outputs = MODEL(**inputs)
            last_hidden_state = outputs.last_hidden_state  # shape: [batch, seq_len, hidden_dim]
            #print("lhs: ", last_hidden_state.shape)

            mean_embeddings = last_hidden_state.mean(dim=1)  # shape: [batch, hidden_dim]
            all_embeddings.append(mean_embeddings.cpu())

        torch.cuda.empty_cache()

    # Combine all embeddings: [n_chunks, hidden_dim]
    all_embeddings_tensor = torch.cat(all_embeddings, dim=0)
    final_embedding = all_embeddings_tensor.mean(dim=0).numpy()
    #print("embed: ", final_embedding.shape)

    df.at[index, 'mean_embedding'] = final_embedding.astype(object)

    torch.cuda.empty_cache()

100%|██████████| 240/240 [48:26<00:00, 12.11s/it]


In [7]:
#df["max_embedding"] = df["max_embedding"].apply(lambda x: x.tolist() if isinstance(x, np.ndarray) else x)
df["mean_embedding"] = df["mean_embedding"].apply(lambda x: x.tolist() if isinstance(x, np.ndarray) else x)

df.to_json('/content/drive/MyDrive/metadata/embeddings/GENDER_embeddings.json')

In [8]:
df.head()

Unnamed: 0,case_id,clean_content,chunks,num_chunks,label,mean_embedding
0,14,"The case ""Rooks & U.S. v. Herring"" from Alabam...","['The case ""Rooks & U.S. v. Herring"" from Alab...",1,0,"[-0.8460231423377991, -0.30086496472358704, 0...."
1,70,"The case ""U.S. v. Dooly County"" from Georgia w...","['The case ""U.S. v. Dooly County"" from Georgia...",1,0,"[-0.47160565853118896, -0.1695367842912674, -0..."
2,148,"The case ""Ahrens v. Thomas"" from Missouri was ...","['The case ""Ahrens v. Thomas"" from Missouri wa...",5,0,"[-0.7494961619377136, -0.4172031879425049, 0.1..."
3,303,"The case ""Steven L. v. Kern County"" from Calif...","['The case ""Steven L. v. Kern County"" from Cal...",1,0,"[-0.6845493912696838, -0.4140678346157074, 0.0..."
4,313,"The case ""Florida D.J.J. v. C.A"" from Florida ...","['The case ""Florida D.J.J. v. C.A"" from Flori...",1,0,"[-0.5295823216438293, -0.17789021134376526, 0...."
