In [15]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from preprocessing import preprocess_text
import torch

#### Load my model and tokenizer from Hugging Face

In [16]:
model_path = 'oranne55/qualifier-model4-finetune-pretrained-transformer-for-long-inputs'

model = AutoModelForSequenceClassification.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)

#### The function predict_long_text_with_preprocess processes a list of texts by first preprocessing and tokenizing each text. It splits long texts into chunks (512 tokens with 100 tokens overlap), classifies each chunk using a model, and checks if any chunk is classified as "jailbreak." The final classification for each text is determined by whether any chunk contains "jailbreak". The overlapping happens to maintain the connection between the different parts of the text. It then returns a list of predictions ("jailbreak" or "benign") for each text. The model is run on the available device (MPS or CPU).

![Sliding Window Overlapping](sliding_window_overlaping.png)


In [17]:
device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
model.to(device)

def predict_long_text_with_preprocess(texts, model, tokenizer):
    predictions = []
    
    for text in texts:
        # Preprocess and tokenize the text into chunks
        text = preprocess_text(text)
        inputs = tokenizer(text, return_tensors="pt", truncation=False)
        input_ids = inputs["input_ids"].squeeze()

        # Define chunk size and overlap
        chunk_size = 512
        overlap_size = 100

        # Create chunks with overlap
        chunks = []
        for i in range(0, len(input_ids), chunk_size - overlap_size):
            chunk = input_ids[i:i + chunk_size]
            chunks.append(chunk)

        # Flag to track if any chunk is classified as "jailbreak"
        contains_jailbreak = False
        for chunk in chunks:
            chunk = chunk.unsqueeze(0).to(device)  # Move chunk to the correct device (MPS or CPU)

            # Predict on the chunk
            with torch.no_grad():
                outputs = model(chunk)
                logits = outputs.logits
                prediction = torch.argmax(logits, dim=1).item()

                # Check if this chunk is classified as "jailbreak"
                if model.config.id2label[prediction] == "jailbreak":
                    contains_jailbreak = True
                    break  # Stop further checks if "jailbreak" is detected

        # Final decision based on whether any chunk was classified as "jailbreak"
        final_prediction = "jailbreak" if contains_jailbreak else "benign"
        predictions.append(final_prediction)
    
    return predictions



#### Show example:

In [18]:
example_texts = [
    "This is a example text.",
    "This is a example2 text.",]

predictions = predict_long_text_with_preprocess(example_texts, model, tokenizer)

print(predictions)

['benign', 'benign']
