Skip to content

Commit

Permalink
Merge pull request #29 from mim-solutions/add-multiclass-support
Browse files Browse the repository at this point in the history
Add multiclass support
  • Loading branch information
MichalBrzozowski91 authored Jun 19, 2024
2 parents 866de5c + 85c3afd commit 77e9a75
Show file tree
Hide file tree
Showing 31 changed files with 2,283 additions and 2,774 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ __pycache__/
.idea

venv
.venv

belt_nlp.egg-info
dist
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Version History

## 1.1.0 Jun 19, 2024

* Added multilabel classification.
* Added regression.
* Change from aggregating probabilities to aggregating logits/scores.

## 1.0.1 Feb 26, 2024

* Add gradient accumulation by @jstremme in https://github.com/mim-solutions/bert_for_longer_texts/pull/23.
Expand Down
19 changes: 17 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# **BELT** (**BE**RT For **L**onger **T**exts)

🚀**New in version 1.1.0: support for multilabel and regression**. See [the examples](#examples)🚀

## Project description and motivation

### The BELT approach
Expand Down Expand Up @@ -59,8 +61,21 @@ It can be either:
To make sure everything works properly, run the command ```pytest tests -rA```. As a default, during tests, models are trained on small samples on the CPU.

## Examples
- [fit and predict method for base model](https://github.com/mim-solutions/bert_for_longer_texts/blob/main/notebooks/example_base_model_fit_predict.ipynb)
- [fit and predict method for model with pooling](https://github.com/mim-solutions/bert_for_longer_texts/blob/main/notebooks/example_model_with_pooling_fit_predict.ipynb)

All examples use public datasets from huggingface hub.

### Binary classification - prediction of sentiment of IMDB reviews
- [standard approach](https://github.com/mim-solutions/bert_for_longer_texts/blob/main/notebooks/binary_classification/base.ipynb)
- [belt](https://github.com/mim-solutions/bert_for_longer_texts/blob/main/notebooks/binary_classification/belt.ipynb)

### Multilabel classification - recognizing authors of Guardian articles
- [standard approach](https://github.com/mim-solutions/bert_for_longer_texts/blob/main/notebooks/multiclass/base.ipynb)
- [belt](https://github.com/mim-solutions/bert_for_longer_texts/blob/main/notebooks/multiclass/belt.ipynb)
- **Notice the effectiveness of the BELT approach here: the test accuracy increased by 10%.**

### Regression - prediction of 1 to 5 rating based on reviews from Polish online e-commerce platform Allegro
- [standard approach](https://github.com/mim-solutions/bert_for_longer_texts/blob/main/notebooks/regression/base.ipynb)
- [belt](https://github.com/mim-solutions/bert_for_longer_texts/blob/main/notebooks/regression/belt.ipynb)

## Contributors
The project was created at [MIM AI](https://www.mim.ai/) by:
Expand Down
67 changes: 31 additions & 36 deletions belt_nlp/bert.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
from __future__ import annotations

from abc import ABC, abstractmethod
import json
from pathlib import Path
from typing import Any, Optional, Union

import torch
from torch import Tensor
from torch.nn import BCELoss, DataParallel, Module, Linear, Sigmoid
from torch.nn import CrossEntropyLoss, DataParallel, Linear, Module, MSELoss
from torch.optim import AdamW, Optimizer
from torch.utils.data import Dataset, RandomSampler, SequentialSampler, DataLoader
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
from transformers import AutoModel, AutoTokenizer, BatchEncoding, BertModel, PreTrainedTokenizerBase, RobertaModel


class BertClassifier(ABC):
class BertBase(ABC):
"""
The "device" parameter can have the following values:
- "cpu" - The model will be loaded on CPU.
Expand All @@ -30,6 +31,7 @@ class BertClassifier(ABC):
@abstractmethod
def __init__(
self,
num_labels: int,
batch_size: int,
learning_rate: float,
epochs: int,
Expand All @@ -40,11 +42,14 @@ def __init__(
device: str = "cuda:0",
many_gpus: bool = False,
):
assert num_labels >= 1, "The num_labels parameter must be at least 1."
self.num_labels = num_labels

if not tokenizer:
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
if not neural_network:
bert = AutoModel.from_pretrained(pretrained_model_name_or_path)
neural_network = BertClassifierNN(bert)
neural_network = BertNN(model=bert, num_labels=num_labels)

self.batch_size = batch_size
self.learning_rate = learning_rate
Expand All @@ -54,7 +59,7 @@ def __init__(
"batch_size": self.batch_size,
"learning_rate": self.learning_rate,
"epochs": self.epochs,
"accumulation_steps": self.accumulation_steps
"accumulation_steps": self.accumulation_steps,
}
self.device = device
self.many_gpus = many_gpus
Expand All @@ -66,7 +71,7 @@ def __init__(
if device.startswith("cuda") and many_gpus:
self.neural_network = DataParallel(self.neural_network)

def fit(self, x_train: list[str], y_train: list[bool], epochs: Optional[int] = None) -> None:
def fit(self, x_train: list[str], y_train: Union[list[bool], list[float]], epochs: Optional[int] = None) -> None:
if not epochs:
epochs = self.epochs
optimizer = AdamW(self.neural_network.parameters(), lr=self.learning_rate)
Expand All @@ -79,52 +84,44 @@ def fit(self, x_train: list[str], y_train: list[bool], epochs: Optional[int] = N
for epoch in range(epochs):
self._train_single_epoch(dataloader, optimizer)

def predict(self, x: list[str], batch_size: Optional[int] = None) -> list[tuple[bool, float]]:
if not batch_size:
batch_size = self.batch_size
scores = self.predict_scores(x, batch_size)
classes = [i >= 0.5 for i in scores]
return list(zip(classes, scores))

def predict_classes(self, x: list[str], batch_size: Optional[int] = None) -> list[bool]:
if not batch_size:
batch_size = self.batch_size
scores = self.predict_scores(x, batch_size)
classes = [i >= 0.5 for i in scores]
return classes

def predict_scores(self, x: list[str], batch_size: Optional[int] = None) -> list[float]:
def _predict_logits(self, x: list[str], batch_size: Optional[int] = None) -> Tensor:
"""Returns classification (or regression if num_labels==1) scores (before SoftMax)."""
if not batch_size:
batch_size = self.batch_size
tokens = self._tokenize(x)
dataset = TokenizedDataset(tokens)
dataloader = DataLoader(
dataset, sampler=SequentialSampler(dataset), batch_size=batch_size, collate_fn=self.collate_fn
)
total_predictions = []
total_logits = []

# deactivate dropout layers
self.neural_network.eval()
for step, batch in enumerate(dataloader):
# deactivate autograd
with torch.no_grad():
predictions = self._evaluate_single_batch(batch)
total_predictions.extend(predictions.tolist())
return total_predictions
logits = self._evaluate_single_batch(batch)
total_logits.append(logits)
return torch.cat(total_logits)

@abstractmethod
def _tokenize(self, texts: list[str]) -> BatchEncoding:
pass

def _train_single_epoch(self, dataloader: DataLoader, optimizer: Optimizer) -> None:
self.neural_network.train()
cross_entropy = BCELoss()

for step, batch in enumerate(dataloader):

labels = batch[-1].float().cpu()
predictions = self._evaluate_single_batch(batch)
loss = cross_entropy(predictions, labels) / self.accumulation_steps
if self.num_labels > 1:
labels = batch[-1].long().to(self.device)
loss_function = CrossEntropyLoss()
logits = self._evaluate_single_batch(batch)
loss = loss_function(logits, labels) / self.accumulation_steps
elif self.num_labels == 1:
labels = batch[-1].float().to(self.device)
loss_function = MSELoss()
scores = torch.flatten(self._evaluate_single_batch(batch))
loss = loss_function(scores, labels) / self.accumulation_steps
loss.backward()

if ((step + 1) % self.accumulation_steps == 0) or (step + 1 == len(dataloader)):
Expand All @@ -147,7 +144,7 @@ def save(self, model_dir: str) -> None:
torch.save(self.neural_network, model_dir / "model.bin")

@classmethod
def load(cls, model_dir: str, device: str = "cuda:0", many_gpus: bool = False) -> BertClassifier:
def load(cls, model_dir: str, device: str = "cuda:0", many_gpus: bool = False) -> BertBase:
model_dir = Path(model_dir)
with open(file=model_dir / "params.json", mode="r", encoding="utf-8") as file:
params = json.load(file)
Expand All @@ -163,22 +160,20 @@ def load(cls, model_dir: str, device: str = "cuda:0", many_gpus: bool = False) -
)


class BertClassifierNN(Module):
def __init__(self, model: Union[BertModel, RobertaModel]):
class BertNN(Module):
def __init__(self, model: Union[BertModel, RobertaModel], num_labels: int):
super().__init__()
self.model = model

# classification head
self.linear = Linear(768, 1)
self.sigmoid = Sigmoid()
self.linear = Linear(768, num_labels)

def forward(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor:
x = self.model(input_ids, attention_mask)
x = x[0][:, 0, :] # take <s> token (equiv. to [CLS])

# classification head
x = self.linear(x)
x = self.sigmoid(x)
return x


Expand Down
55 changes: 55 additions & 0 deletions belt_nlp/bert_classifier_truncated.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from __future__ import annotations

from typing import Optional

from torch import argmax, Tensor
from torch.nn import Module, Softmax
from transformers import PreTrainedTokenizerBase

from belt_nlp.bert_truncated import BertBaseTruncated


class BertClassifierTruncated(BertBaseTruncated):
def __init__(
self,
num_labels: int,
batch_size: int,
learning_rate: float,
epochs: int,
accumulation_steps: int = 1,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
neural_network: Optional[Module] = None,
pretrained_model_name_or_path: Optional[str] = "bert-base-uncased",
device: str = "cuda:0",
many_gpus: bool = False,
):
super().__init__(
num_labels=num_labels,
batch_size=batch_size,
learning_rate=learning_rate,
epochs=epochs,
accumulation_steps=accumulation_steps,
tokenizer=tokenizer,
neural_network=neural_network,
pretrained_model_name_or_path=pretrained_model_name_or_path,
device=device,
many_gpus=many_gpus,
)
additional_classifier_params = {
"num_labels": self.num_labels,
}
self._params.update(additional_classifier_params)

def predict(self, x: list[str], batch_size: Optional[int] = None) -> Tensor:
"""Returns classes."""
logits = super()._predict_logits(x, batch_size)
classes = argmax(logits, dim=1)
return classes

def predict_scores(self, x: list[str], batch_size: Optional[int] = None) -> Tensor:
"""Returns classification probabilities."""
logits = super()._predict_logits(x, batch_size)
softmax = Softmax(dim=1)

probabilities = softmax(logits)
return probabilities
65 changes: 65 additions & 0 deletions belt_nlp/bert_classifier_with_pooling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from __future__ import annotations

from typing import Optional

from torch import argmax, Tensor
from torch.nn import Module, Softmax
from transformers import PreTrainedTokenizerBase

from belt_nlp.bert_with_pooling import BertBaseWithPooling


class BertClassifierWithPooling(BertBaseWithPooling):
def __init__(
self,
num_labels: int,
batch_size: int,
learning_rate: float,
epochs: int,
chunk_size: int,
stride: int,
minimal_chunk_length: int,
pooling_strategy: str = "mean",
accumulation_steps: int = 1,
maximal_text_length: Optional[int] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
neural_network: Optional[Module] = None,
pretrained_model_name_or_path: Optional[str] = "bert-base-uncased",
device: str = "cuda:0",
many_gpus: bool = False,
):
super().__init__(
num_labels=num_labels,
batch_size=batch_size,
learning_rate=learning_rate,
epochs=epochs,
chunk_size=chunk_size,
stride=stride,
minimal_chunk_length=minimal_chunk_length,
pooling_strategy=pooling_strategy,
accumulation_steps=accumulation_steps,
maximal_text_length=maximal_text_length,
tokenizer=tokenizer,
neural_network=neural_network,
pretrained_model_name_or_path=pretrained_model_name_or_path,
device=device,
many_gpus=many_gpus,
)
additional_classifier_params = {
"num_labels": self.num_labels,
}
self._params.update(additional_classifier_params)

def predict(self, x: list[str], batch_size: Optional[int] = None) -> Tensor:
"""Returns classes."""
logits = super()._predict_logits(x, batch_size)
classes = argmax(logits, dim=1)
return classes

def predict_scores(self, x: list[str], batch_size: Optional[int] = None) -> Tensor:
"""Returns classification probabilities."""
logits = super()._predict_logits(x, batch_size)
softmax = Softmax(dim=1)

probabilities = softmax(logits)
return probabilities
41 changes: 41 additions & 0 deletions belt_nlp/bert_regressor_truncated.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from __future__ import annotations

from typing import Optional

from torch import Tensor
from torch.nn import Module
from transformers import PreTrainedTokenizerBase

from belt_nlp.bert_truncated import BertBaseTruncated


class BertRegressorTruncated(BertBaseTruncated):
def __init__(
self,
batch_size: int,
learning_rate: float,
epochs: int,
accumulation_steps: int = 1,
tokenizer: Optional[PreTrainedTokenizerBase] = None,
neural_network: Optional[Module] = None,
pretrained_model_name_or_path: Optional[str] = "bert-base-uncased",
device: str = "cuda:0",
many_gpus: bool = False,
):
super().__init__(
num_labels=1,
batch_size=batch_size,
learning_rate=learning_rate,
epochs=epochs,
accumulation_steps=accumulation_steps,
tokenizer=tokenizer,
neural_network=neural_network,
pretrained_model_name_or_path=pretrained_model_name_or_path,
device=device,
many_gpus=many_gpus,
)

def predict(self, x: list[str], batch_size: Optional[int] = None) -> Tensor:
"""Returns regression scores."""
logits = super()._predict_logits(x, batch_size)
return logits
Loading

0 comments on commit 77e9a75

Please sign in to comment.