In [2]:
import re
import sys
import time
sys.path.insert(0, './..')

import transformers
from transformers import AutoTokenizer

from utils import load_dataset, calculate_classification_metric, plot_inference_time_histogram, write_classification_metric

In [None]:
dataset = load_dataset('./../dataset/test.csv')

In [None]:
model = 'meta-llama/Llama-2-7b-chat-hf'

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model, device_map='auto')
model = transformers.AutoModelForCausalLM.from_pretrained(model, device_map='auto')

In [None]:
classification_counts = {
    'true_positives': 0,
    'false_positives': 0,
    'true_negatives': 0,
    'false_negatives': 0
}
inference_time = []
labels = []
raw_queries = []
preds = []

config = transformers.GenerationConfig(do_sample=True, max_length=500, temperature=0.1, pad_token_id=tokenizer.eos_token_id, max_time=5)
prompt = [
    {
        'role': 'user',
        'content': '\
            <string>\' or \'a\' = \'a</string>\
            Distinguish if the string causes SQL injection as a web security specialist.\
            If the string can cause SQL injection, return 1. If not, return 0.\
            Return only result for answer. Do not contain explanation of result.\
            Answer in the form of <response>1</response> or <response>0</response>.'
    },
    {
        'role': 'assistant',
        'content': '<response>1</response>'
    }
]

for i, (query, label) in enumerate(dataset):
    prompt.append(
        {
            'role': 'user',
            'content': f'<string>{query}</string>'
        })
    tokenized_prompt = tokenizer.apply_chat_template(prompt, add_generation_prompt=True, return_tensors="pt")

    start_time = time.time()
    outputs = model.generate(tokenized_prompt, config)
    inference_time.append(time.time() - start_time)

    response = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
    try:
        result = int(re.search(r'<response>(\d+)</response>', response).group(1))
    except ValueError:
        preds.append('N/A')

    labels.append(int(label))
    raw_queries.append(query)
    preds.append(result)

    prompt.pop(2)

calculate_classification_metric(classification_counts, labels, preds, raw_queries)
write_classification_metric(classification_counts, inference_time)
plot_inference_time_histogram(inference_time)