In [None]:
import random
import sys
sys.path.insert(0, './..')

import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, BertForSequenceClassification
import pytorch_lightning as pl

from utils import load_dataset, train

In [None]:
dataset = load_dataset("../dataset/sqli1.csv")
dataset_size = len(dataset)

In [None]:
queries = [data[0] for data in dataset]
labels = [int(data[1]) for data in dataset]

In [None]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
device = torch.device("cuda" if torch.cuda.is_available() else device)

In [None]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2).to(device)

In [None]:
encoding = tokenizer(
    queries,
    padding = 'longest',
    return_tensors='pt'
)
encoding = { k: v.to(device) for k, v in encoding.items() }
labels = torch.tensor(labels).to(device)

with torch.no_grad():
    output = model.forward(**encoding)
scores = output.logits
labels_predicted = scores.argmax(-1)
num_correct = (labels_predicted==labels).sum().item()
accuracy = num_correct/labels.size(0)

print("# scores:")
print(scores.size())
print("# predicted labels:")
print(labels_predicted)
print("# accuracy:")
print(accuracy)

In [None]:
max_length = 128
dataset_for_loader = []
for i in range(dataset_size):
    encoding = tokenizer(
        queries[i],
        max_length=max_length,
        padding='max_length',
        truncation=True
    )
    encoding['labels'] = labels[i]
    encoding = { k: torch.tensor(v) for k, v in encoding.items() }
    dataset_for_loader.append(encoding)

random.shuffle(dataset_for_loader)

n = len(dataset_for_loader)
n_train = int(0.6*n)
n_val = int(0.2*n)

dataset_train = dataset_for_loader[:n_train]
dataset_val = dataset_for_loader[n_train:n_train+n_val]
dataset_test = dataset_for_loader[n_train+n_val:]

dataloader_train = DataLoader(dataset_train, batch_size=32, shuffle=True)
dataloader_val = DataLoader(dataset_val, batch_size=256)
dataloader_test = DataLoader(dataset_test, batch_size=256)

In [None]:
class BertForSequenceClassification_pl(pl.LightningModule):
    def __init__(self, model, lr):
        super().__init__()
        self.save_hyperparameters()

        self.bert_sc = model

    def training_step(self, batch, batch_idx):
        output = self.bert_sc(**batch)
        loss = output.loss
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        output = self.bert_sc(**batch)
        val_loss = output.loss
        self.log('val_loss', val_loss)

    def test_step(self, batch, batch_idx):
        labels = batch.pop('labels')
        output = self.bert_sc(**batch)
        labels_predicted = output.logits.argmax(-1)
        num_correct = ( labels_predicted == labels ).sum().item()
        accuracy = num_correct/labels.size(0)
        self.log('accuracy', accuracy)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)

In [None]:
checkpoint = pl.callbacks.ModelCheckpoint(
    monitor='val_loss',
    mode='min',
    save_top_k=1,
    save_weights_only=True,
    dirpath='model/',
)

trainer = pl.Trainer(max_epochs=10, callbacks = [checkpoint])

In [None]:
model = BertForSequenceClassification_pl(model, lr=1e-5)
trainer.fit(model, dataloader_train, dataloader_val)

In [None]:
%load_ext tensorboard
%tensorboard --logdir ./

In [None]:
test = trainer.test(dataloaders=dataloader_test)
print(f'Accuracy: {test[0]["accuracy"]:.2f}')

In [None]:
best_model_path = checkpoint.best_model_path
model = BertForSequenceClassification_pl.load_from_checkpoint(
    best_model_path
)

# Transformers対応のモデルを./model_transformesに保存
model.bert_sc.save_pretrained('./model_transformers')

In [None]:
bert_sc = BertForSequenceClassification.from_pretrained(
    './model_transformers'
)