In [None]:
import random
import sys
sys.path.insert(0, './..')

import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, BertForSequenceClassification
import pytorch_lightning as pl

from utils import load_dataset, test_model

In [None]:
dataset = load_dataset("../dataset/sqli1.csv")
dataset_size = len(dataset)

In [None]:
queries = [data[0] for data in dataset]
labels = [int(data[1]) for data in dataset]

In [None]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
device = torch.device("cuda" if torch.cuda.is_available() else device)

In [None]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

max_length = 128
loader_dataset = []
for i in range(dataset_size):
    encoded_tokens = tokenizer(
        queries[i],
        max_length=max_length,
        padding='max_length',
        truncation=True
    )
    encoded_tokens['labels'] = labels[i]
    encoded_tokens = { k: torch.tensor(v) for k, v in encoded_tokens.items() }
    loader_dataset.append(encoded_tokens)

random.shuffle(loader_dataset)

n = len(loader_dataset)
n_train = int(0.6*n)
n_val = int(0.2*n)

train_dataset = loader_dataset[:n_train]
val_dataset = loader_dataset[n_train:n_train+n_val]
test_dataset = loader_dataset[n_train+n_val:]

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=256)
test_loader = DataLoader(test_dataset, batch_size=256)

In [None]:
class BertClassifier(pl.LightningModule):
    def __init__(self, model_name, num_labels, lr):
        super().__init__()
        self.save_hyperparameters()
        self.bert_sc = BertForSequenceClassification.from_pretrained(
            model_name,
            num_labels=num_labels
        )

    def training_step(self, batch, batch_idx):
        output = self.bert_sc(**batch)
        loss = output.loss
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        output = self.bert_sc(**batch)
        val_loss = output.loss
        self.log('val_loss', val_loss)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)

In [None]:
checkpoint = pl.callbacks.ModelCheckpoint(
    monitor='val_loss',
    mode='min',
    save_top_k=1,
    save_weights_only=True,
    dirpath='model/',
)

trainer = pl.Trainer(max_epochs=10, callbacks=[checkpoint])

In [None]:
model = BertClassifier("bert-base-uncased", num_labels=2, lr=1e-5)
trainer.fit(model, train_loader, val_loader)

In [None]:
%load_ext tensorboard
%tensorboard --logdir ./

In [None]:
best_model_path = checkpoint.best_model_path
model = BertClassifier.load_from_checkpoint(best_model_path)

# Transformers対応のモデルを./model_transformesに保存
model.bert_sc.save_pretrained('./model_transformers')

In [None]:
model = BertForSequenceClassification.from_pretrained('./model_transformers')

In [None]:
test_model(model, test_loader, device)

In [None]:
# Calculate the number of parameters
num_params = sum(p.numel() for p in model.parameters())
print(f'parameters: {num_params}')