In [1]:
import os
import numpy as np
import torch
import torch.nn as nn

from transformers import BertTokenizer, BertForSequenceClassification
import pytorch as pytorch_utils

In [2]:
tokenizer = BertTokenizer.from_pretrained("dkleczek/bert-base-polish-cased-v1")

In [None]:
class Dataset(torch.utils.data.Dataset):
    
    def __init__(self):
        with open(os.path.join("data", "training_set_clean_only_text.txt"), encoding="utf-8") as f:
            texts = f.readlines()
            self.texts = [text.strip() for text in texts]
        with open(os.path.join("data", "training_set_clean_only_tags.txt"), encoding="utf-8") as f:
            self.labels = f.readlines()
            self.labels = [int(lab.strip()) for lab in self.labels]
        assert len(self.texts) == len(self.labels)
    
    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        return text, label
    
    def __len__(self):
        return len(self.texts)
    

In [None]:
class DataLoader(torch.utils.data.DataLoader):
    
    def __init__(self):
        dataset = Dataset()
        print(len(dataset))
        super(DataLoader, self).__init__(
            dataset=dataset,
            collate_fn=_collate_fn(tokenizer),
            batch_size=32,
            num_workers=1,
            shuffle=True,
            drop_last=False,
        )

In [None]:
def _collate_fn(tokenizer):
    def _make_batch(datapoints) -> dict:
        attention_masks = []
        input_ids = []
        labels = []
        for text, label in datapoints:
            encoding = tokenizer.encode_plus(
                  text,
                  add_special_tokens=True,
                  max_length=32,
                  return_token_type_ids=False,
                  padding="max_length",
                  truncation=True,
                  return_attention_mask=True,
                  return_tensors='pt',
                )
            attention_masks.append(encoding["attention_mask"])
            input_ids.append(encoding["input_ids"])
            labels.append(label)
        batch = {
            "input_ids": torch.cat(input_ids, axis=0),
            "attention_masks": torch.cat(attention_masks, axis=0),
            "targets": torch.from_numpy(np.array(labels))
        }
        return batch
    return _make_batch

In [None]:
dataloader = DataLoader()

In [None]:
model = BertForSequenceClassification.from_pretrained("dkleczek/bert-base-polish-cased-v1")
model.classifier = nn.Linear(768, 3)

In [None]:
optimizer = pytorch_utils.create_optimizer(
            params=model.parameters(),
            optimizer_name="adam",
            init_lr=1e-03,
            weight_decay=0,
        )


lr_scheduler = pytorch_utils.create_lr_scheduler(
    optimizer=optimizer,
    num_iterations=30000,
    gamma=1e-1,
    milestones=[0.4, 0.7, 0.9]
)

In [None]:
criterion = nn.CrossEntropyLoss()
for batch in dataloader:
    logits = model(batch["input_ids"], batch["attention_masks"]).logits
    optimizer.zero_grad()
    loss = criterion(logits, batch["targets"])
    loss.backward()
    optimizer.step()
    lr_scheduler.step() 