In [None]:
import sys
import collections
sys.path.insert(0, '../')

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchtext

from utils import load_dataset, train

In [None]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
device = torch.device("cuda" if torch.cuda.is_available() else device)

In [None]:
dataset = load_dataset("../dataset/sqliv2.csv")
dataset_size = len(dataset)

In [None]:
tokenizer = torchtext.data.get_tokenizer("basic_english")

def create_vocab(dataset):
    counter = collections.Counter()
    for query, _ in dataset:
        for str in tokenizer(query):
            counter.update(list(str))
    return torchtext.vocab.vocab(counter, min_freq=1)

vocab = create_vocab(dataset)
vocab_size = len(vocab)

In [None]:
def hash(str):
    if "and" == str:
        str = "001"
    if "or" == str:
        str = "19"
    if "xp_" == str:
        str = "483"
    if "substr" == str:
        str = "1082"
    if "utl" == str:
        str = "292"
    if "benchmark" == str:
        str = "9282"
    if "shutdown" == str:
        str = "0902"
    if "hex" == str:
        str = "422"
    if "sqlmap" == str:
        str = "4990"
    if "md5" == str:
        str = "520"
    if "select" == str:
        str = "507"
    if "union" == str:
        str = "612"
    if "drop" == str:
        str = "629"
    if "delect" == str:
        str = "923"
    if "concat" == str:
        str = "309"
    if "orderby" == str:
        str = "981"
    if "exec" == str:
        str = "015"
    return str


query_length = 60
batch_size = 100

def process_str(query):
    query_strs = tokenizer(query)
    for i, str in enumerate(query_strs):
        query_strs[i] = hash(str)

    query_chars = []
    for str in query_strs:
        for char in list(str):
            query_chars += [vocab.get_stoi()[char]]

    # 最大長超えたら切る
    if len(query_chars) > query_length:
        query_chars = query_chars[:query_length]

    # 最大長に足りない分は埋める
    if len(query_chars) < query_length:
        query_chars.extend([vocab.get_stoi()["q"]] * (query_length - len(query_chars)))

    return query_chars

def process_batch(batch):
    queries = []
    labels = []

    for query, lable in batch:
        query_chars = process_str(query)
        queries.append(query_chars)
        labels.append(int(lable))

    return (torch.LongTensor(queries),
            torch.LongTensor(labels))

train_loader = DataLoader(dataset, batch_size=batch_size, collate_fn=process_batch, shuffle=True)

In [None]:
class CGRUClassifier(nn.Module):
    def __init__(self):
        super(CGRUClassifier, self).__init__()
        self.embedding = nn.Embedding(vocab_size, 30)
        self.conv1 = nn.Conv2d(1, 24, kernel_size=(2, 30))
        self.conv2 = nn.Conv2d(1, 24, kernel_size=(4, 30), padding=(1, 0))
        self.gru = nn.GRU(48, 24, batch_first=True, bias=True)
        self.dropuot = nn.Dropout(0.5)
        self.fc = nn.Linear(24, 2)

    def forward(self, x):
        x = self.embedding(x)
        x = x.unsqueeze(1)
        x1 = F.relu(self.conv1(x))
        x2 = F.relu(self.conv2(x))
        x1 = x1.squeeze(-1)
        x1 = x1.permute(0, 2, 1)
        x2 = x2.squeeze(-1)
        x2 = x2.permute(0, 2, 1)
        x = torch.cat((x1, x2), -1)
        # 最後の隠れ層のみ取り出す
        x = self.gru(x)[1]
        x = x.squeeze(0)
        x = F.relu(x)
        x = self.dropuot(x)
        x = self.fc(x)

        return x

network = CGRUClassifier().to(device)

In [None]:
hyperparameters = {
    "learning_rate": 0.01,
    "epoch": 30,
    "optimizer": optim.SGD(network.parameters(), lr=0.01),
    "lr_scheduler": {
        "step_size": 5,
        "gamma": 0.5,
    },
    "loss_fn": nn.CrossEntropyLoss(),
}

loss, accurancy = train(network, train_loader, device, dataset_size, 20, hyperparameters)
print(f"loss={loss}, accurancy={accurancy}")

# Save the model
torch.save(network.state_dict(), 'model.pth')