Skip to content

Commit

Permalink
switch to deepset model and pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
deadbits committed Sep 14, 2023
1 parent 5bed20e commit 5489c55
Showing 1 changed file with 30 additions and 37 deletions.
67 changes: 30 additions & 37 deletions vigil/scanners/transformer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
import uuid
import logging

from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
TextClassificationPipeline,
)
from transformers import pipeline

from vigil.schema import ModelMatch
from vigil.schema import BaseScanner
Expand All @@ -18,57 +14,54 @@ class TransformerScanner(BaseScanner):
def __init__(self, config_dict: dict):
self.name = 'scanner:transformer'
self.model_name = config_dict['model']
self.threshold = config_dict['threshold']
self.threshold = float(config_dict['threshold'])

try:
self.model = AutoModelForSequenceClassification.from_pretrained(
self.model_name
)
self.pipeline = pipeline('text-classification', model=self.model_name)
logger.info(f'[{self.name}] Model loaded: {self.model_name}')
except Exception as err:
logger.error(f'[{self.name}] Failed to load model: {err}')

try:
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
logger.info(f'[{self.name}] Tokenizer loaded: {self.model_name}')
except Exception as err:
logger.error(f'[{self.name}] Failed to load tokenizer: {err}')

self.pipeline = TextClassificationPipeline(
model=self.model, tokenizer=self.tokenizer
)
logger.info(f'[{self.name}] Pipeline loaded: {self.model_name}')
logger.info(f'[{self.name}] Scanner loaded: {self.model_name}')

def analyze(self, input_data: str, scan_uuid: uuid.uuid4) -> list:
logger.info(f'[{self.name}] Performing scan; id="{scan_uuid}"')

score = 0.0
results = []
results, hits = [], []

if input_data.strip() == '':
logger.info(f'[{self.name}] No input data; id={scan_uuid}')
return results

try:
result = self.pipeline(
input_data,
truncation=True,
max_length=self.tokenizer.model_max_length
)
score = round(
result[0]['score'] if result[0]['label'] == 'INJECTION' else 1 - result[0]['score'], 2
hits = self.pipeline(
input_data
)
except Exception as err:
logger.error(f'[{self.name}] Pipeline error: {err}')
return results

if score > float(self.threshold):
logger.info(f'[{self.name}] Detected prompt injection; score={score} threshold={self.threshold} id={scan_uuid}')
results.append(
ModelMatch(
model_name=self.model_name,
score=score,
threshold=self.threshold,
)
)
if len(hits) > 0:
for rec in hits:
if rec['label'] == 'INJECTION':
if rec['score'] > self.threshold:
logger.info(f'[{self.name}] Detected prompt injection; score={rec["score"]} threshold={self.threshold} id={scan_uuid}')
else:
logger.info(
f'[{self.name}] Detected prompt injection below threshold (may warrant manual review); \
score={rec["score"]} threshold={self.threshold} id={scan_uuid}'
)

results.append(
ModelMatch(
model_name=self.model_name,
score=rec['score'],
label=rec['label'],
threshold=self.threshold,
)
)

else:
logger.info(f'[{self.name}] No hits returned by model; id={scan_uuid}')

return results

0 comments on commit 5489c55

Please sign in to comment.