In [2]:
import os
import json

train_data_path = 'chat_gpt_context/distilbert_squad_dataset.json'

# JSON 파일 불러오기
# with open(train_data_path) as f:
#   train_data = json.load(f)

In [1]:
import torch
from transformers import DistilBertForQuestionAnswering, DistilBertTokenizerFast, AdamW

In [4]:
# SQuAD 데이터셋을 불러오는 함수
def load_squad_data(filename):
    with open(filename, "r", encoding="utf-8") as f:
        squad_data = json.load(f)["data"]
    return squad_data

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [6]:
# Set the seed for reproducibility
torch.manual_seed(0)
if device.type == 'cuda':
    torch.cuda.manual_seed_all(0)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [7]:
# Set the maximum split size to avoid memory fragmentation
torch.backends.cuda.max_split_size_bytes = 128 * 1024 * 1024  # 128 MB

In [8]:
# SQuAD 데이터셋과 DistilBERT의 tokenizer, 모델을 불러옴
squad_data = load_squad_data(train_data_path)

In [None]:
cpu_model = DistilBertForQuestionAnswering.from_pretrained("distilbert-base-uncased")
gpu_model = DistilBertForQuestionAnswering.from_pretrained("distilbert-base-uncased", 
                                   max_length = 1024, max_position_embeddings = 1024, ignore_mismatched_sizes = True
                                   ).to(device)
tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased", max_length = 1024)

In [14]:
new_question = """
GET /admin/exec.php3?cmd=dir%20c:\\ HTTP/1.1¶Connection: Keep-Alive¶Content-Type: application/x-www-form-urlencoded¶Content-Length: 12¶Host: 10.10.123.123¶User-Agent: Mozilla/5.00 (Nikto/2.1.6) (Evasions:None) (Test:003219)¶¶dump_sql=foo In the case of the corresponding payload, please write in one sentence or less, which type it corresponds to.
"""

context = """
If any string appears between select and from or between case and when in payload, it is a type of SQLInjection.¶If any string appears between jndi and dap or between jndi and dns in payload, it is a type of JNDIInjection.¶If any string appears between script and alert in payload, it is a type of CrossSiteScripting.¶If any string appears between rm and rf or between wget and ttp or between chmod and 777 or between cd and tmp before HTTP/1. in payload, it is a type of CommandInjection.¶If any string appears user and agent before zgrab or nmap or dirbuster or ahrefsbot or baiduspider or mj12bot or petalbot or semrushbot or curl/ or masscan or sqlmap or urlgrabber in payload, it is a type of maliciousbot.¶If any string appears wp-login or wp-context or wp-include or wp-config in payload, it is a type of WordPress.¶If any string appears between etc and passwd or between cgi and bin or between cmd and open or between echo and shellshock or between php and echo or between admin and php or between setup and php in payload, it is a type of AttemptAccessAdminPage.

"""

In [None]:
print(len(new_question))
print(len(context))

In [15]:
from transformers import pipeline
answering = pipeline('question-answering', model = cpu_model, tokenizer = tokenizer)

In [None]:
result = answering(question = new_question, context = context)
result

In [17]:
# SQuAD 데이터셋을 DistilBERT의 입력 형식에 맞게 변환하는 함수
def convert_squad_data_to_features(squad_data, tokenizer, max_seq_length):
    features = []
    for article in squad_data:
        for paragraph in article["paragraphs"]:
            context = paragraph["context"]
            for qa in paragraph["qas"]:
                qas_id = qa["id"]
                print(qas_id)
                question = qa["question"]
                print(question)
                answer_text = qa["answers"][0]["text"]
                print(answer_text)
                start_position = qa["answers"][0]["answer_start"]
                print(start_position)
                end_position = start_position + len(answer_text)
                print(end_position)

                # context와 question을 DistilBERT의 입력 형식에 맞게 tokenize
                encoded_dict = tokenizer(question, context, max_length=max_seq_length, padding="max_length",
                                         # 지정된 token 수 (예, 1024개) 초과 시, 자름
                                         truncation=True, return_offsets_mapping=True, return_token_type_ids = True)
                print(encoded_dict)
                # answer의 시작 위치와 끝 위치를 토큰 단위로 변환
                token_start_position = 0
                token_end_position = 0
                for i, offset in enumerate(encoded_dict["offset_mapping"]):
                    if offset[0] <= start_position and offset[1] > start_position:
                        token_start_position = i
                    if offset[0] < end_position and offset[1] >= end_position:
                        token_end_position = i

                # feature 추가
                input_ids = encoded_dict["input_ids"]
                attention_mask = encoded_dict["attention_mask"]
                # token_type_ids = encoded_dict.token_type_ids()
                token_type_ids = encoded_dict['token_type_ids']

                features.append((input_ids, attention_mask, token_type_ids, token_start_position, token_end_position))

    return features

In [None]:
ips_query = """

    SELECT

        IF(INT(RLIKE(payload, 'VCAvY2dpLWJpbi9waHA0') )>0
        OR INT(RLIKE(payload, 'L2NnaS1iaW4v') )>0
        OR INT(RLIKE(payload, 'IC9jZ2ktYmlu') )>0
        OR INT(RLIKE(payload, 'UE9TVCAvY2dpLWJpbi9waHA/') )>0
        OR INT(RLIKE(payload, 'VCAvY2dpLWJpbi9w') )>0
        OR INT(RLIKE(payload, 'ZGllKEBtZDU=') )>0
        OR INT(RLIKE(payload, 'L2FueWZvcm0yL3VwZGF0ZS9hbnlmb3JtMi5pbmk=') )>0
        OR INT(RLIKE(payload, 'Ly5iYXNoX2hpc3Rvcnk=') )>0
        OR INT(RLIKE(payload, 'L2V0Yy9wYXNzd2Q=') )>0
        OR INT(RLIKE(payload, 'QUFBQUFBQUFBQQ==') )>0
        OR INT(RLIKE(payload, 'IG1hc3NjYW4vMS4w') )>0
        OR INT(RLIKE(payload, 'd2dldA==') )>0
        OR INT(RLIKE(payload, 'MjB3YWl0Zm9yJTIwZGVsYXklMjAn') )>0
        OR INT(RLIKE(payload, 'V0FJVEZPUiBERUxBWQ==') )>0
        OR INT(RLIKE(payload, 'ZXhlYw==') )>0
        OR INT(RLIKE(payload, 'Tm9uZQ==') )>0
        OR INT(RLIKE(payload, 'OyB3Z2V0') )>0
        OR INT(RLIKE(payload, 'VXNlci1BZ2VudDogRGlyQnVzdGVy') )>0
        OR INT(RLIKE(payload, 'cGhwIGRpZShAbWQ1') )>0
        OR INT(RLIKE(payload, 'JTI4U0VMRUNUJTIw') )>0
                ,1, 0) AS ips_00001_payload_base64,

        IF(INT(RLIKE(SPLIT(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'http/1.', 2)[0],  'select(.*?)from') )>0
        OR INT(RLIKE(SPLIT(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'http/1.', 2)[0],  'select(.*?)count') )>0
        OR INT(RLIKE(SPLIT(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'http/1.', 2)[0],  'select(.*?)distinct') )>0
        OR INT(RLIKE(SPLIT(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'http/1.', 2)[0],  'union(.*?)select') )>0
        OR INT(RLIKE(SPLIT(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'http/1.', 2)[0],  'select(.*?)extractvalue(.*?)xmltype') )>0
        OR INT(RLIKE(SPLIT(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'http/1.', 2)[0],  'from(.*?)generate(.*?)series') )>0
        OR INT(RLIKE(SPLIT(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'http/1.', 2)[0],  'from(.*?)group(.*?)by') )>0
                ,1, 0) AS ips_00001_payload_sql_comb_01,

        IF(INT(RLIKE(SPLIT(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'http/1.', 2)[0],  'case(.*?)when') )>0
        OR INT(RLIKE(SPLIT(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'http/1.', 2)[0],  'then(.*?)else') )>0
        OR INT(RLIKE(SPLIT(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'http/1.', 2)[0],  'like') )>0
        OR INT(RLIKE(SPLIT(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'http/1.', 2)[0],  'sleep') )>0
        OR INT(RLIKE(SPLIT(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'http/1.', 2)[0],  'delete') )>0
                ,1, 0) AS ips_00001_payload_sql_comb_02,

        IF(INT(RLIKE(SPLIT(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'http/1.', 2)[0],  'waitfor(.*?)delay') )>0
        OR INT(RLIKE(SPLIT(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'http/1.', 2)[0],  'db(.*?)sql(.*?)server') )>0
        OR INT(RLIKE(SPLIT(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'http/1.', 2)[0],  'cast(.*?)chr') )>0
        OR INT(RLIKE(SPLIT(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'http/1.', 2)[0],  'upper(.*?)xmltype') )>0
                ,1, 0) AS ips_00001_payload_sql_comb_03,

        IF(INT(RLIKE(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'script(.*?)alert') )>0
        OR INT(RLIKE(LOWER(payload), 'eval') )>0
                ,1, 0) AS ips_00001_payload_xss_comb_01,

        IF(INT(RLIKE(SPLIT(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'http/1.', 2)[0],  'wget(.*?)ttp') )>0
        OR INT(RLIKE(SPLIT(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'http/1.', 2)[0],  'chmod(.*?)777') )>0
        OR INT(RLIKE(SPLIT(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'http/1.', 2)[0],  'rm(.*?)rf') )>0
        OR INT(RLIKE(SPLIT(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'http/1.', 2)[0],  'cd(.*?)tmp') )>0
                ,1, 0) AS ips_00001_payload_cmd_comb_01,

        IF(INT(RLIKE(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'jndi(.*?)dap') )>0
        OR INT(RLIKE(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '),'jndi(.*?)dns') )>0
                ,1, 0) AS ips_00001_payload_log4j_comb_01,

        IF(INT(RLIKE(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'etc(.*?)passwd') )>0
        OR INT(RLIKE(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'document(.*?)createelement') )>0
        OR INT(RLIKE(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'cgi(.*?)bin') )>0
        OR INT(RLIKE(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'document(.*?)forms') )>0
        OR INT(RLIKE(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'document(.*?)location') )>0
        OR INT(RLIKE(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'fckeditor(.*?)filemanager') )>0
        OR INT(RLIKE(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'manager(.*?)html') )>0
        OR INT(RLIKE(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'current_config(.*?)passwd') )>0
        OR INT(RLIKE(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'currentsetting(.*?)htm') )>0
        OR INT(RLIKE(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'well(.*?)known') )>0
                ,1, 0) AS ips_00001_payload_word_comb_01,

        IF(INT(RLIKE(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'bash(.*?)history') )>0
        OR INT(RLIKE(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'apache(.*?)struts') )>0
        OR INT(RLIKE(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'document(.*?)open') )>0
        OR INT(RLIKE(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'backup(.*?)sql') )>0
        OR INT(RLIKE(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'robots(.*?)txt') )>0
        OR INT(RLIKE(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'sqlexec(.*?)php') )>0
        OR INT(RLIKE(LOWER(payload), 'htaccess') )>0
        OR INT(RLIKE(LOWER(payload), 'htpasswd') )>0
        OR INT(RLIKE(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'cgi(.*?)cgi') )>0
        OR INT(RLIKE(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'api(.*?)ping') )>0
                ,1, 0) AS ips_00001_payload_word_comb_02,

        IF(INT(RLIKE(LOWER(payload), 'aaaaaaaaaa') )>0
        OR INT(RLIKE(LOWER(payload), 'cacacacaca') )>0
        OR INT(RLIKE(LOWER(payload), 'mozi[\\.]') )>0
        OR INT(RLIKE(LOWER(payload), 'bingbot') )>0
        OR INT(RLIKE(LOWER(payload), 'md5') )>0
        OR INT(RLIKE(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'count(.*?)cgi(.*?)http') )>0
        OR INT(RLIKE(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'this(.*?)program(.*?)can') )>0
        OR INT(RLIKE(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'get(.*?)ping') )>0
        OR INT(RLIKE(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'msadc(.*?)dll(.*?)http') )>0
        OR INT(RLIKE(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'filename(.*?)asp') )>0
        OR INT(RLIKE(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'filename(.*?)jsp') )>0
        OR INT(RLIKE(LOWER(payload), 'powershell'))>0
        OR INT(RLIKE(LOWER(payload), '[\\.]env'))>0
                ,1, 0) AS ips_00001_payload_word_comb_03,

        IF(INT(RLIKE(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'wp-login') )>0
        OR INT(RLIKE(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'wp-content') )>0
        OR INT(RLIKE(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'wp-include') )>0
        OR INT(RLIKE(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'wp-config') )>0
                ,1, 0) AS ips_00001_payload_wp_comb_01,

        IF(INT(RLIKE(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'cmd(.*?)open') )>0
        OR INT(RLIKE(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'echo(.*?)shellshock') )>0
        OR INT(RLIKE(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'php(.*?)echo') )>0
        OR INT(RLIKE(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'admin(.*?)php') )>0
        OR INT(RLIKE(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'script(.*?)setup(.*?)php') )>0
        OR INT(RLIKE(LOWER(payload), 'phpinfo') )>0
        OR INT(RLIKE(LOWER(payload), 'administrator') )>0
        OR INT(RLIKE(LOWER(payload), 'phpmyadmin') )>0
        OR INT(RLIKE(LOWER(payload), 'access') )>0
        OR INT(RLIKE(LOWER(payload), 'mdb') )>0
        OR INT(RLIKE(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'wise(.*?)survey(.*?)admin') )>0
        OR INT(RLIKE(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'admin(.*?)serv(.*?)admpw') )>0
        OR INT(RLIKE(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'php(.*?)create(.*?)function') )>0
                ,1, 0) AS ips_00001_payload_word_comb_04,

        IF(INT(RLIKE(SPLIT(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'http/1.', 2)[1],  'user(.*?)agent(.*?)zgrab') )>0
        OR INT(RLIKE(SPLIT(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'http/1.', 2)[1],  'user(.*?)agent(.*?)nmap') )>0
        OR INT(RLIKE(SPLIT(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'http/1.', 2)[1],  'user(.*?)agent(.*?)dirbuster') )>0
        OR INT(RLIKE(SPLIT(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'http/1.', 2)[1],  'user(.*?)agent(.*?)ahrefsbot') )>0
        OR INT(RLIKE(SPLIT(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'http/1.', 2)[1],  'user(.*?)agent(.*?)baiduspider') )>0
        OR INT(RLIKE(SPLIT(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'http/1.', 2)[1],  'user(.*?)agent(.*?)mj12bot') )>0
        OR INT(RLIKE(SPLIT(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'http/1.', 2)[1],  'user(.*?)agent(.*?)petalbot') )>0
        OR INT(RLIKE(SPLIT(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'http/1.', 2)[1],  'user(.*?)agent(.*?)curl/') )>0
        OR INT(RLIKE(SPLIT(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'http/1.', 2)[1],  'user(.*?)agent(.*?)semrushbot') )>0
        OR INT(RLIKE(SPLIT(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'http/1.', 2)[1],  'user(.*?)agent(.*?)masscan') )>0
        OR INT(RLIKE(SPLIT(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'http/1.', 2)[1],  'user(.*?)agent(.*?)sqlmap') )>0
        OR INT(RLIKE(SPLIT(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'http/1.', 2)[1],  'user(.*?)agent(.*?)urlgrabber(.*?)yum') )>0
                ,1, 0) AS ips_00001_payload_useragent_comb,

        (SIZE(SPLIT(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'get(.*?)http/1.')) -1)
            + (SIZE(SPLIT(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'post(.*?)http/1.')) -1)
        + (SIZE(SPLIT(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'head(.*?)http/1.')) -1)
        + (SIZE(SPLIT(REGEXP_REPLACE(LOWER(payload), '\\n|\\r|\\t', ' '), 'option(.*?)http/1.')) -1)
        AS ips_00001_payload_whitelist
    FROM table

"""

In [None]:
import re
# new_sql_query의 ips_00001_payload_base64 부터 ips_00001_payload_useragent_comb 까지 추출
# re.S의 경우, 줄바꿈 문자열 까지 매치 !!!!!!!
attack_query = re.findall(r'ips_00001_payload_base64.*?ips_00001_payload_useragent_comb', ips_query, re.S)[0]
# attack_new_sql_query '\\n|\\r|\\t', 'http/1.', 2 는 제거, 단 regex = False
attack_query = attack_query.replace('\\n|\\r|\\t', '').replace("'http/1.', 2", '')
# new_sql_query의 '' 안에 있는 문자열들을 추출하여 리스트 생성,
ai_field = re.findall(r'\'(.*?)\'', attack_query)
# ai_field에서 'remove_string' 는 제거
ai_field = [x for x in ai_field if x != '' and x != ' ']

In [None]:
ai_field

In [None]:
# ai_list에 element 안에 '(.*?)'가 포함되어 있는 경우, '(.*?)' 기준으로 split 후, 리스트에 추가
first_ai_list = [x.split('(.*?)')[0] for x in ai_field if '(.*?)' in x]
end_ai_list = [x.split('(.*?)')[1] for x in ai_field if '(.*?)' in x]
except_ai_list = [x.replace('[\\.]', '.') for x in ai_field]

# ai_list의 element 안에 ('*?)' 가 2번 포함되어 있는 경우, 2번째 '(.*?)' 기준으로 split 후, 리스트에 추가
two_ai_list = [x.split('(.*?)')[2] for x in ai_field if x.count('(.*?)') == 2]

# ai_list의 element 안에 ('*?)' 가 3번 포함되어 있는 경우, 3번째 '(.*?)' 기준으로 split 후, 리스트에 추가
three_ai_list = [x.split('(.*?)')[3] for x in ai_field if x.count('(.*?)') == 3]

ai_list_split = first_ai_list + end_ai_list + ai_field + except_ai_list + two_ai_list + three_ai_list

# ai_list_split 안에 중복되는 element 가 있는 경우, 단일 처리
ai_list_split = list(set(ai_list_split))

# ai_list_split 안에 '(.*?' 나, '[\\.]' 가 포함되어 있는 경우, 제거
ai_list_split = [x for x in ai_list_split if '(.*?)' not in x]
ai_list_split = [x for x in ai_list_split if '[\\.]' not in x]

In [None]:
len(ai_list_split)

In [None]:
ai_list_split

In [None]:
import pandas as pd
df = pd.read_csv('chat_gpt_context/distilbert_squad_dataset.csv')
answer_list = df['answer'].str.lower().unique()
answer_list = list(answer_list)
answer_list

In [None]:
len(ai_list_split)

In [None]:
len(answer_list)

In [None]:
ai_list = ai_list_split + answer_list
ai_list = list(set(ai_list))
len(ai_list)

In [None]:
new_tokens = set(ai_list) - set(tokenizer.vocab.keys())
len(new_tokens)

In [None]:
new_tokens

In [None]:
# add the tokens to the tokenizer vocabulary
tokenizer.add_tokens(list(new_tokens))
# tokenizer.add_special_tokens({"additional_special_tokens": ai_list})

In [None]:
# add new, random embeddings for the new tokens
gpu_model.resize_token_embeddings(len(tokenizer))

In [None]:
# 입력 sequence의 최대 길이
max_seq_length = 1024

# 데이터셋을 feature로 변환
features = convert_squad_data_to_features(squad_data, tokenizer, max_seq_length)

In [None]:
gpu_model.config

In [None]:
# gpu_model.config.max_length = max_seq_length
# gpu_model.config.max_position_embeddings = 1024
gpu_model.config

In [None]:
# feature를 torch tensor로 변환
input_ids = torch.tensor([f[0] for f in features], dtype=torch.long)
attention_mask = torch.tensor([f[1] for f in features], dtype=torch.long)
token_type_ids = torch.tensor([f[2] for f in features], dtype=torch.long)
start_positions = torch.tensor([f[3] for f in features], dtype=torch.long)
end_positions = torch.tensor([f[4] for f in features], dtype=torch.long)

In [None]:
def collate_fn(batch):
    # batch: [(input_ids, attention_masks, token_type_ids, start_positions, end_positions), ...]
    input_ids = torch.tensor([item[0] for item in batch])
    attention_masks = torch.tensor([item[1] for item in batch])
    token_type_ids = torch.tensor([item[2] for item in batch])
    start_positions = torch.tensor([item[3] for item in batch])
    end_positions = torch.tensor([item[4] for item in batch])
    return input_ids, attention_masks, token_type_ids, start_positions, end_positions

In [None]:
# optimizer와 learning rate 설정
optimizer = AdamW(gpu_model.parameters(), lr=2e-5)

In [None]:
from torch.utils.data import DataLoader

batch_size = 2
train_dataloader = DataLoader(features, batch_size=batch_size, collate_fn=collate_fn)

In [None]:
batch = next(iter(train_dataloader))
batch

In [None]:
# model = model.to(device)
gpu_model.device

In [None]:
from tqdm.auto import tqdm

# train loop 설정
train_loss = []
num_epochs = 5

for epoch in range(num_epochs):
    epoch_loss = 0
    for batch in tqdm(train_dataloader):
        # 데이터 준비
        inputs = {
            "input_ids": batch[0].to(device),
            "attention_mask": batch[1].to(device),
            # "token_type_ids": batch[2].to(device),
            "start_positions": batch[3].to(device),
            "end_positions": batch[4].to(device)
        }

        # forward 수행
        gpu_model.train()
        outputs = gpu_model(**inputs)
        loss = outputs[0]

        # backward 수행
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        # loss 계산
        epoch_loss += loss.item()

    train_loss.append(epoch_loss / len(train_dataloader))

In [None]:
# 평가 데이터셋 로딩
eval_features = convert_squad_data_to_features(squad_data, tokenizer, max_seq_length)
eval_dataloader = DataLoader(eval_features, batch_size=batch_size, collate_fn=collate_fn)

In [None]:
# 모델 평가
gpu_model.eval()

In [None]:
import torch.nn as nn
loss_fn = nn.CrossEntropyLoss()

In [None]:
eval_loss = 0
for batch in eval_dataloader:
    with torch.no_grad():
        inputs = {
            "input_ids": batch[0].to(device),
            "attention_mask": batch[1].to(device),
            "start_positions": batch[3].to(device),
            "end_positions": batch[4].to(device)
        }
        outputs = gpu_model(**inputs)
        # print('@@@@@@@@@@@@')
        # print(outputs)
        start_logits, end_logits = outputs.start_logits, outputs.end_logits
        start_positions, end_positions = inputs["start_positions"], inputs["end_positions"]
        # batch_loss = loss_fn(start_logits, end_logits, start_positions, end_positions)
        batch_loss = loss_fn(start_logits, start_positions) + loss_fn(end_logits, end_positions)

        eval_loss += batch_loss.item()

eval_loss /= len(eval_dataloader)
print("Eval Loss:", eval_loss)

In [None]:
new_question = """
GET /admin/exec.php3?cmd=dir%20c:\\ HTTP/1.1¶Connection: Keep-Alive¶Content-Type: application/x-www-form-urlencoded¶Content-Length: 12¶Host: 10.10.123.123¶User-Agent: Mozilla/5.00 (Nikto/2.1.6) (Evasions:None) (Test:003219)¶¶dump_sql=foo In the case of the corresponding payload, please write in one sentence or less, which type it corresponds to.
"""

context = """
If any string appears between select and from or between case and when in payload, it is a type of SQLInjection.¶If any string appears between jndi and dap or between jndi and dns in payload, it is a type of JNDIInjection.¶If any string appears between script and alert in payload, it is a type of CrossSiteScripting.¶If any string appears between rm and rf or between wget and ttp or between chmod and 777 or between cd and tmp before HTTP/1. in payload, it is a type of CommandInjection.¶If any string appears user and agent before zgrab or nmap or dirbuster or ahrefsbot or baiduspider or mj12bot or petalbot or semrushbot or curl/ or masscan or sqlmap or urlgrabber in payload, it is a type of maliciousbot.¶If any string appears wp-login or wp-context or wp-include or wp-config in payload, it is a type of WordPress.¶If any string appears between etc and passwd or between cgi and bin or between cmd and open or between echo and shellshock or between php and echo or between admin and php or between setup and php in payload, it is a type of AttemptAccessAdminPage.

"""

In [None]:
device_index = 0 # index of the GPU device you want to use
device = torch.device('cuda', device_index)

In [None]:
# hugging face의 transformers 이용
answering = pipeline('question-answering', model = gpu_model, tokenizer = tokenizer, device = device)
result = answering(question = new_question, context = context)
result

In [None]:
# pytorch 이용
inputs = tokenizer(new_question, context, return_tensors="pt")
with torch.no_grad():
    outputs = gpu_model(**inputs)

answer_start_index = torch.argmax(outputs.start_logits)
answer_end_index = torch.argmax(outputs.end_logits)

predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1]
tokenizer.decode(predict_answer_tokens)