Trying to use an unmaintained deberta package, since HuggingFace does not support replaced token detection (RTD).

In [2]:
# conda create -n deberta ipykernel scipy python=3.9
# SKLEARN_ALLOW_DEPRECATED_SKLEARN_PACKAGE_INSTALL=True pip install deberta

In [16]:
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel
from typing import List, Tuple
import torch.nn.functional as F

class RTDHead(nn.Module):
    def __init__(self, config):
        super(RTDHead, self).__init__()
        hidden_size = config.hidden_size
        eps = config.layer_norm_eps

        self.transform = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.LayerNorm(hidden_size, eps),
            nn.GELU()
        )
        self.classifier = nn.Linear(hidden_size, 1)
    
    def forward(self, hidden_state):
        # Get CLS token representation
        ctx = hidden_state[:, 0]
        # Add context to each token representation and transform
        seq = self.transform(ctx[:, None, :] + hidden_state)
        # Get binary classification logits
        return self.classifier(seq).squeeze(-1)

class RTDModel(nn.Module):
    def __init__(self, model_name: str):
        super(RTDModel, self).__init__()
        self.transformer = AutoModel.from_pretrained(model_name)
        config = self.transformer.config
        self.head = RTDHead(config)
        
    def forward(self, input_ids, attention_mask=None, token_type_ids=None):
        outputs = self.transformer(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )
        
        hidden_states = outputs.last_hidden_state
        rtd_logits = self.head(hidden_states)
        
        # Convert logits to probabilities
        rtd_probs = torch.sigmoid(rtd_logits)
        
        return {
            'logits': rtd_logits,
            'probs': rtd_probs,
            'hidden_states': hidden_states
        }

class RTDScorer:
    def __init__(self, model_name: str = "microsoft/deberta-v3-base", device: str = "cuda"):
        self.device = device
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = RTDModel(model_name).to(device)
        self.model.eval()

    def get_rtd_scores(self, text: str):
        """
        Get RTD scores for words at specified positions.
        Higher scores indicate the model believes the token is original (not replaced).
        
        Args:
            text: Input text
            word_positions: List of (start, end) character positions for words to score
            
        Returns:
            List of RTD scores for each word position
        """
        # Tokenize the text
        inputs = self.tokenizer(
            text,
            return_tensors="pt",
        )
        
        # Move inputs to device
        inputs.to('cuda')

        # Get RTD scores
        with torch.no_grad():
            outputs = self.model(**inputs)
            probs = outputs['probs'][0]  # Remove batch dimension

        # Get tokens
        tokens = self.tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
                
        return list(zip(tokens, probs.cpu().numpy()))

scorer = RTDScorer()

text = "The quick brown fox jumps over the lazy dog."

scores = scorer.get_rtd_scores(text)
for token, prob in scores:
    print(f"{token} - {prob}")

[CLS] - 0.4169638156890869
▁The - 0.28968536853790283
▁quick - 0.3809061646461487
▁brown - 0.37623000144958496
▁fox - 0.44079703092575073
▁jumps - 0.3352523148059845
▁over - 0.30622223019599915
▁the - 0.3058890998363495
▁lazy - 0.3761269450187683
▁dog - 0.3800089657306671
. - 0.4784148335456848
[SEP] - 0.4149031341075897


In [7]:
from DeBERTa.deberta.apps.models.replac import ReplacedTokenDetectionModel
from transformers import AutoTokenizer
import torch

model = ReplacedTokenDetectionModel.load_model(None, '/experiments/language_model/deberta_large.json')
model.load_state_dict(torch.load('pytorch_model.bin'))
tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-base')

input = tokenizer('The Chinese sports delegation is the last to enter the 2002 Winter Olympics',return_tensors='pt')
out = model(
    input_ids=input['input_ids'],
    input_mask=input['attention_mask'],
    labels=torch.tensor([0,0,0,0,0,0,0,0,0,0,0,1,0,0,0])
)
print(out['logits'])

ModuleNotFoundError: No module named 'DeBERTa.deberta.apps'