In [2]:
import os
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from tqdm import tqdm

# --- CONFIGURATION ---

MODEL_PATH = "/data2/brain2text/lm/modernbert_domain_classifier"
INPUT_POOL_FILE = "c4_selection_pool.txt"   # The large remainder file from previous step
OUTPUT_SCORED_FILE = "c4_scored_sentences.tsv"

BATCH_SIZE = 256  # ModernBERT is efficient; 512 or 1024 usually fits in 24GB VRAM
MAX_LENGTH = 128


def process_batch(lines, model, tokenizer, device, f_out):
    """
    Tokenizes a batch of text, runs inference, and writes scores to disk.
    """
    # 1. Tokenize
    inputs = tokenizer(
        lines, 
        padding=True, 
        truncation=True, 
        max_length=MAX_LENGTH, 
        return_tensors="pt"
    ).to(device)

    # 2. Inference (No Gradients = Faster/Lower Memory)
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        
        # 3. Softmax to get probabilities
        # Shape: [batch_size, 2] -> We want column 1 (Probability of "In-Domain")
        probs = F.softmax(logits, dim=-1)
        in_domain_scores = probs[:, 1].cpu().numpy()

    # 4. Write to file
    # Format: 0.9823   This is a sentence.
    for score, text in zip(in_domain_scores, lines):
        # We use %.6f to keep file size reasonable but precise
        f_out.write(f"{score:.6f}\t{text}\n")



device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 1. Load Model & Tokenizer
print(f"Loading model from {MODEL_PATH}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)
model.to(device)
model.eval() # Set to evaluation mode (disables dropout, etc.)

# 2. Open Files for Streaming
print(f"Scoring sentences from {INPUT_POOL_FILE}...")

# We use a buffer to collect lines until we hit BATCH_SIZE
batch_lines = []

# For progress monitoring (optional, counting lines first is slow but gives a nice bar)
# total_lines = sum(1 for _ in open(INPUT_POOL_FILE, 'r', errors='ignore')) 

with open(INPUT_POOL_FILE, 'r', encoding='utf-8', errors='ignore') as f_in, \
        open(OUTPUT_SCORED_FILE, 'w', encoding='utf-8') as f_out:
    
    # Write Header (optional, but good for pandas)
    f_out.write("score\ttext\n")
    
    # Use tqdm for progress (if you know total_lines, pass total=total_lines)
    for line in tqdm(f_in, desc="Processing"):
        line = line.strip()
        if not line:
            continue
        
        batch_lines.append(line)
        
        # When batch is full, process it
        if len(batch_lines) >= BATCH_SIZE:
            process_batch(batch_lines, model, tokenizer, device, f_out)
            batch_lines = [] # Clear buffer

    # Process any remaining lines after the loop
    if batch_lines:
        process_batch(batch_lines, model, tokenizer, device, f_out)

print(f"Done! Scored sentences saved to {OUTPUT_SCORED_FILE}")

Using device: cuda
Loading model from /data2/brain2text/lm/modernbert_domain_classifier...
Scoring sentences from c4_selection_pool.txt...


Processing: 11587071it [43:17, 4460.62it/s]


KeyboardInterrupt: 