In [2]:
import sys
import time
sys.path.insert(0, './..')

import transformers
from transformers import AutoTokenizer

from utils import load_dataset, calculate_classification_metric, plot_inference_time_histogram, write_classification_metric

In [None]:
dataset = load_dataset('./../dataset/test.csv')

In [None]:
model = 'meta-llama/Llama-2-7b-chat-hf'

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model, device_map='auto')
model = transformers.AutoModelForCausalLM.from_pretrained(model, device_map='auto')

In [None]:
classification_counts = {
    'true_positives': 0,
    'false_positives': 0,
    'true_negatives': 0,
    'false_negatives': 0
}
inference_time = []
labels = []
raw_queries = []
preds = []

for i, (query, label) in enumerate(dataset):
    prompt=[
        {
            'role': 'user',
            'content': f'string: \' or \'a\' = \'a \n. Distinguish if the string causes SQL injection as a web security specialist. If the string can cause SQL injection, return 1. If not, return 0. Return for answer. Do not include an explanation in your answer. Return:'
        },
        {
            'role': 'assistant',
            'content': '1'
        },
        {
            'role': 'user',
            'content': f'string: {query}. Distinguish if the string causes SQL injection as a web security specialist. If the string can cause SQL injection, return 1. If not, return 0. Return for answer. Do not include an explanation in your answer. Return:'
        }
    ]

    tokenized_prompt = tokenizer.apply_chat_template(prompt, add_generation_prompt=True, return_tensors="pt")

    start_time = time.time()
    outputs = model.generate(tokenized_prompt, do_sample=True, max_length=300, temperature=0.1, pad_token_id=tokenizer.eos_token_id)
    inference_time.append(time.time() - start_time)

    response = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]

    labels.append(int(label))
    raw_queries.append(query)
    try:
        if response[-1] == '.':
            preds.append(0)
        else:
            preds.append(int(response[-1]))
    except ValueError:
        preds.append('N/A')

calculate_classification_metric(classification_counts, labels, preds, raw_queries)
write_classification_metric(classification_counts, inference_time)
plot_inference_time_histogram(inference_time)