In [1]:
import os
import json

train_data_path = 'YOUR JSON FILE PATH'

# JSON 파일 불러오기
with open(train_data_path) as f:
   train_data = json.load(f)

In [3]:
import torch
from transformers import DistilBertForQuestionAnswering, DistilBertTokenizerFast, AdamW

In [4]:
# SQuAD 데이터셋을 불러오는 함수
def load_squad_data(filename):
    with open(filename, "r", encoding="utf-8") as f:
        squad_data = json.load(f)["data"]
    return squad_data

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [6]:
# Set the seed for reproducibility
torch.manual_seed(0)
if device.type == 'cuda':
    torch.cuda.manual_seed_all(0)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [7]:
# Set the maximum split size to avoid memory fragmentation
torch.backends.cuda.max_split_size_bytes = 128 * 1024 * 1024  # 128 MB

In [8]:
# SQuAD 데이터셋과 DistilBERT의 tokenizer, 모델을 불러옴
squad_data = load_squad_data(train_data_path)

In [None]:
cpu_model = DistilBertForQuestionAnswering.from_pretrained("distilbert-base-uncased")
gpu_model = DistilBertForQuestionAnswering.from_pretrained("distilbert-base-uncased", 
                                   max_length = 1024, max_position_embeddings = 1024, ignore_mismatched_sizes = True
                                   ).to(device)
tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased", max_length = 1024)

In [14]:
new_question = """
get /news/news_video/?pageid=3&mod="><script%20>alert in the case of the corresponding waf device payload, please write in one sentence or less, which type of attack it corresponds to.
"""

context = """
if any string appears between script and alert in waf payload, it is a type of xss (cross site scripting) attack.
"""

In [15]:
from transformers import pipeline
answering = pipeline('question-answering', model = cpu_model, tokenizer = tokenizer)

In [None]:
result = answering(question = new_question, context = context)
result

In [17]:
# SQuAD 데이터셋을 DistilBERT의 입력 형식에 맞게 변환하는 함수
def convert_squad_data_to_features(squad_data, tokenizer, max_seq_length):
    features = []
    for article in squad_data:
        for paragraph in article["paragraphs"]:
            context = paragraph["context"]
            for qa in paragraph["qas"]:
                qas_id = qa["id"]
                print(qas_id)
                question = qa["question"]
                print(question)
                answer_text = qa["answers"][0]["text"]
                print(answer_text)
                start_position = qa["answers"][0]["answer_start"]
                print(start_position)
                end_position = start_position + len(answer_text)
                print(end_position)

                # context와 question을 DistilBERT의 입력 형식에 맞게 tokenize
                encoded_dict = tokenizer(question, context, max_length=max_seq_length, padding="max_length",
                                         # 지정된 token 수 (예, 1024개) 초과 시, 자름
                                         truncation=True, return_offsets_mapping=True, return_token_type_ids = True)
                print(encoded_dict)
                # answer의 시작 위치와 끝 위치를 토큰 단위로 변환
                token_start_position = 0
                token_end_position = 0
                for i, offset in enumerate(encoded_dict["offset_mapping"]):
                    if offset[0] <= start_position and offset[1] > start_position:
                        token_start_position = i
                    if offset[0] < end_position and offset[1] >= end_position:
                        token_end_position = i

                # feature 추가
                input_ids = encoded_dict["input_ids"]
                attention_mask = encoded_dict["attention_mask"]
                # token_type_ids = encoded_dict.token_type_ids()
                token_type_ids = encoded_dict['token_type_ids']

                features.append((input_ids, attention_mask, token_type_ids, token_start_position, token_end_position))

    return features

In [None]:
# 입력 sequence의 최대 길이
max_seq_length = 1024

# 데이터셋을 feature로 변환
features = convert_squad_data_to_features(squad_data, tokenizer, max_seq_length)

In [None]:
gpu_model.config

In [None]:
# gpu_model.config.max_length = max_seq_length
# gpu_model.config.max_position_embeddings = 1024
gpu_model.config

In [None]:
# feature를 torch tensor로 변환
input_ids = torch.tensor([f[0] for f in features], dtype=torch.long)
attention_mask = torch.tensor([f[1] for f in features], dtype=torch.long)
token_type_ids = torch.tensor([f[2] for f in features], dtype=torch.long)
start_positions = torch.tensor([f[3] for f in features], dtype=torch.long)
end_positions = torch.tensor([f[4] for f in features], dtype=torch.long)

In [None]:
def collate_fn(batch):
    # batch: [(input_ids, attention_masks, token_type_ids, start_positions, end_positions), ...]
    input_ids = torch.tensor([item[0] for item in batch])
    attention_masks = torch.tensor([item[1] for item in batch])
    token_type_ids = torch.tensor([item[2] for item in batch])
    start_positions = torch.tensor([item[3] for item in batch])
    end_positions = torch.tensor([item[4] for item in batch])
    return input_ids, attention_masks, token_type_ids, start_positions, end_positions

In [None]:
# optimizer와 learning rate 설정
optimizer = AdamW(gpu_model.parameters(), lr=2e-5)

In [None]:
from torch.utils.data import DataLoader

batch_size = 2
train_dataloader = DataLoader(features, batch_size=batch_size, collate_fn=collate_fn)

In [None]:
batch = next(iter(train_dataloader))
batch

In [None]:
# model = model.to(device)
gpu_model.device

In [None]:
from tqdm.auto import tqdm

# train loop 설정
train_loss = []
num_epochs = 5

for epoch in range(num_epochs):
    epoch_loss = 0
    for batch in tqdm(train_dataloader):
        # 데이터 준비
        inputs = {
            "input_ids": batch[0].to(device),
            "attention_mask": batch[1].to(device),
            # "token_type_ids": batch[2].to(device),
            "start_positions": batch[3].to(device),
            "end_positions": batch[4].to(device)
        }

        # forward 수행
        gpu_model.train()
        outputs = gpu_model(**inputs)
        loss = outputs[0]

        # backward 수행
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        # loss 계산
        epoch_loss += loss.item()

    train_loss.append(epoch_loss / len(train_dataloader))

In [None]:
# 평가 데이터셋 로딩
eval_features = convert_squad_data_to_features(squad_data, tokenizer, max_seq_length)
eval_dataloader = DataLoader(eval_features, batch_size=batch_size, collate_fn=collate_fn)

In [None]:
# 모델 평가
gpu_model.eval()

In [None]:
import torch.nn as nn
loss_fn = nn.CrossEntropyLoss()

In [None]:
eval_loss = 0
for batch in eval_dataloader:
    with torch.no_grad():
        inputs = {
            "input_ids": batch[0].to(device),
            "attention_mask": batch[1].to(device),
            "start_positions": batch[3].to(device),
            "end_positions": batch[4].to(device)
        }
        outputs = gpu_model(**inputs)
        # print('@@@@@@@@@@@@')
        # print(outputs)
        start_logits, end_logits = outputs.start_logits, outputs.end_logits
        start_positions, end_positions = inputs["start_positions"], inputs["end_positions"]
        # batch_loss = loss_fn(start_logits, end_logits, start_positions, end_positions)
        batch_loss = loss_fn(start_logits, start_positions) + loss_fn(end_logits, end_positions)

        eval_loss += batch_loss.item()

eval_loss /= len(eval_dataloader)
print("Eval Loss:", eval_loss)

In [None]:
new_question = """
get /user?select%20from in the case of the corresponding waf device payload, please write in one sentence or less, which type of attack it corresponds to.
"""

context = """
if any string appears between select and from in waf payload, it is a type of sql injection attack.
"""

In [None]:
device_index = 0 # index of the GPU device you want to use
device = torch.device('cuda', device_index)

In [None]:
answering = pipeline('question-answering', model = gpu_model, tokenizer = tokenizer, device = device)
result = answering(question = new_question, context = context)
result