In [None]:
import sys
import time
sys.path.insert(0, './..')

import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, BertForSequenceClassification
import pytorch_lightning as pl

from utils import count_model_parameters, load_dataset, plot_loss, test_model

In [None]:
train_dataset = load_dataset("../dataset/train.csv")
val_dataset = load_dataset("../dataset/validation.csv")
test_dataset = load_dataset("../dataset/test.csv")

In [None]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
device = torch.device("cuda" if torch.cuda.is_available() else device)

In [None]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
max_length = 128

def create_loader_dataset(dataset, test=False):
    queries = [data[0] for data in dataset]
    labels = [int(data[1]) for data in dataset]

    loader_dataset = []
    for i in range(len(dataset)):
      encoded_tokens = tokenizer(
          queries[i],
          max_length=max_length,
          padding='max_length',
          truncation=True
      )
      encoded_tokens['labels'] = labels[i]
      encoded_tokens = { k: torch.tensor(v).to(device) for k, v in encoded_tokens.items() }
      if test:
          encoded_tokens['raw_queries'] = queries[i]
      loader_dataset.append(encoded_tokens)

    return loader_dataset

In [None]:
loader_tarain_dataset = create_loader_dataset(train_dataset)
loader_val_dataset = create_loader_dataset(val_dataset)
loader_test_dataset = create_loader_dataset(test_dataset, test=True)

train_loader = DataLoader(loader_tarain_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(loader_val_dataset, batch_size=256)
test_loader = DataLoader(loader_test_dataset, batch_size=1)

In [None]:
class BertClassifier(pl.LightningModule):
    def __init__(self, model_name, num_labels, lr):
        super().__init__()
        self.save_hyperparameters()
        self.bert_sc = BertForSequenceClassification.from_pretrained(
            model_name,
            num_labels=num_labels
        )
        self.count = 0
        self.report_freq = 90
        self.loss_values = []
        self.report_count = []

    def training_step(self, batch):
        output = self.bert_sc(**batch)
        loss = output.loss

        self.count += len(batch['input_ids'])
        if self.count % self.report_freq == 0:
            self.loss_values.append(loss.item())
            self.report_count.append(self.count)

        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch):
        output = self.bert_sc(**batch)
        val_loss = output.loss
        self.log('val_loss', val_loss)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)

In [None]:
checkpoint = pl.callbacks.ModelCheckpoint(
    monitor='val_loss',
    mode='min',
    save_top_k=1,
    save_weights_only=True,
    dirpath='model/',
)
trainer = pl.Trainer(max_epochs=3, callbacks=[checkpoint])

In [None]:
model = BertClassifier('bert-base-uncased', num_labels=2, lr=1e-5)

strat_time = time.time()
trainer.fit(model, train_loader, val_loader)
end_time = time.time()
with open('reslut.txt', 'a') as f:
    f.write(f'training time: {end_time - strat_time}\n')
    f.write('\n')

plot_loss(model.report_count, model.loss_values)

In [None]:
best_model_path = checkpoint.best_model_path
model = BertClassifier.load_from_checkpoint(best_model_path)
model.bert_sc.save_pretrained('./model_transformers')

In [None]:
model = BertForSequenceClassification.from_pretrained('./model_transformers').to(device)

In [None]:
count_model_parameters(model)

In [None]:
test_model(model, test_loader)