In [2]:
import torch
import truecase
from transformers import AutoTokenizer

def get_candidate_tokens(tokenizer, context: str, candidate: str):
    """
    Truecases the full text and returns ONLY the token IDs belonging to the candidate word.
    
    Args:
        tokenizer: HuggingFace tokenizer.
        context: The context string (e.g., "he is a member of the royal").
        candidate: The candidate word (e.g., "heirs").
    
    Returns:
        candidate_ids: List[int] of token IDs for the candidate.
        true_full_text: The string after truecasing (for debugging).
    """
    
    # 1. Construct the raw full text
    # We enforce a space between context and candidate
    raw_full_text = f"{context} {candidate}"
    
    # 2. Apply Truecase to the full text
    # This fixes "royal irish" -> "Royal Irish"
    true_full_text = truecase.get_true_case(raw_full_text)
    
    # 3. Determine the Split Point
    # We need to find where the 'context' ends in the truecased version.
    # We truecase the context separately to measure its length.
    true_context = truecase.get_true_case(context)
    
    # The candidate generally starts immediately after the context.
    # Note: truecase usually strips trailing spaces, so the split point is the length of true_context.
    # Example: "Royal" (len 5). Full: "Royal Irish". " Irish" starts at index 5.
    split_char_idx = len(true_context)

    # 4. Tokenize with Offsets
    # return_offsets_mapping gives (start_char, end_char) for each token
    inputs = tokenizer(
        true_full_text, 
        return_tensors="pt", 
        return_offsets_mapping=True,
        add_special_tokens=True 
    )
    
    input_ids = inputs.input_ids[0]
    offsets = inputs.offset_mapping[0]
    
    candidate_ids = []
    
    for i, (start, end) in enumerate(offsets):
        # Skip special tokens (like BOS/EOS which often have 0,0 offsets)
        if start == end == 0:
            continue
            
        # LOGIC: If the token starts AT or AFTER our context ended, it's part of the candidate.
        if start >= split_char_idx:
            candidate_ids.append(input_ids[i].item())
            
    return candidate_ids, true_full_text

# ==========================================
# TEST BLOCK
# ==========================================
if __name__ == "__main__":
    # Setup
    print("Loading Tokenizer...")
    # Using gpt2 as a proxy for OPT/Gemma/Falcon (space-sensitive)
    tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba2-2.7b") 
    
    test_cases = [
        # Case 1: The "Royal" Fix
        # Context is lowercase "royal", but truecase should make it "Royal", 
        # which changes the probability of "Irish" vs "heirs".
        ("he is a member of the royal", "irish"),
        
        # Case 2: Standard lowercase continuation
        # Context implies common noun, truecase should keep it lowercase (mostly).
        ("the prince and his", "heirs"),
        
        # Case 3: Proper Noun in candidate
        ("i live in", "new york"),
    ]

    print(f"\n{'='*10} RUNNING TESTS {'='*10}")

    for ctx, cand in test_cases:
        ids, text = get_candidate_tokens(tokenizer, ctx, cand)
        
        print(f"\nInput:     '{ctx}' + '{cand}'")
        print(f"Truecased: '{text}'")
        print(f"Token IDs: {ids}")
        print(f"Decoded:   {tokenizer.decode(ids)}")
        
        # Verification Logic
        # We check if the decoded tokens match the candidate (ignoring casing differences)
        decoded = tokenizer.decode(ids).strip().lower()
        if decoded == cand.lower():
            print("✅ VERIFIED: Candidate extracted correctly.")
        else:
            print(f"❌ MISMATCH: Expected '{cand}', got '{decoded}'")

Loading Tokenizer...


AttributeError: 'NoneType' object has no attribute 'endswith'

In [4]:
tokenizer("He is also a member of the royal", return_tensors="pt", add_special_tokens=True)

{'input_ids': tensor([[    2,  2209,   563,   992,   496,  4374,   529,   506, 19833]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1]])}

In [None]:
2, 2209,  563,  992,  496, 4374,  529,  506, 1759
2, 2209,  563,  992,  496, 4374,  529,  506, 19833

In [23]:
tokenizer.decode(torch.tensor([2, 18047,  1288,   496,  1494, 23957,  3004, tokenizer.eos_token_id]))

'<bos>Has such a high clay content<eos>'

In [10]:
attention_mask.shape

NameError: name 'attention_mask' is not defined