## Fine-tuning BERT to predict valid english sentence

In [39]:
import torch
from torch.utils.data import Dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    Trainer,
    TrainingArguments
)
from transformers import TrainerCallback, PrinterCallback

import numpy as np

In [40]:

class SentenceDataset(Dataset):
    """Custom dataset for sentence validity classification."""
    def __init__(self, texts, labels, processor, max_length=128):
        self.texts = texts
        self.labels = labels
        self.processor = processor
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        enc = self.processor(
            self.texts[idx],
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        item = {k: v.squeeze() for k, v in enc.items()}
        item['labels'] = torch.tensor(self.labels[idx], dtype=torch.long)
        return item


def train_classifier(
    train_texts,
    train_labels,
    val_texts=None,
    val_labels=None,
    model_name: str = 'distilbert-base-uncased',
    output_dir: str = './sentence_classifier',
    epochs: int = 3,
    batch_size: int = 16,
    learning_rate: float = 2e-5
):
    """
    Fine-tunes a pretrained transformer for binary sentence validity classification.

    Returns the trained Trainer instance.
    """
    processor = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSequenceClassification.from_pretrained(
        model_name,
        num_labels=2
    )

    train_dataset = SentenceDataset(train_texts, train_labels, processor)
    eval_dataset = None
    if val_texts is not None and val_labels is not None:
        eval_dataset = SentenceDataset(val_texts, val_labels, processor)

    args = TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=epochs,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        learning_rate=learning_rate,
        logging_steps=100,
        save_steps=500,
        eval_steps=500,
        logging_dir=f'{output_dir}/logs',
        report_to='wandb'  # or 'wandb'
    )

    trainer = Trainer(
        model=model,
        args=args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=processor,  # Correct parameter name
        callbacks=[PrinterCallback]
    )


    trainer.train()
    # Manual evaluation if validation data provided
    if eval_dataset is not None:
        trainer.evaluate(eval_dataset)
    return trainer


def predict_validity(texts, trainer):
    """
    Predicts validity (0=invalid, 1=valid) for a list of texts using a trained Trainer.
    """
    class _Wrapper(Dataset):
        def __init__(self, texts, processor, max_length=128):
            self.texts = texts
            self.processor = processor
            self.max_length = max_length

        def __len__(self):
            return len(self.texts)

        def __getitem__(self, idx):
            enc = self.processor(
                self.texts[idx],
                truncation=True,
                padding='max_length',
                max_length=self.max_length,
                return_tensors='pt'
            )
            return {k: v.squeeze() for k, v in enc.items()}

    pred_dataset = _Wrapper(texts, trainer.processing_class)
    outputs = trainer.predict(pred_dataset)
    preds = np.argmax(outputs.predictions, axis=1)
    return preds.tolist()



In [41]:
valid_texts = [
    "I love reading books.",
    "Hello!",
    "The quick brown fox jumps over the lazy dog.",
    "How are you?",
    "Good morning.",
    "Absolutely!",
    "Sure.",
    "Thanks a lot.",
    "Yes",
    "Nice work",
    "Okay",
    "This is interesting.",
    "What time is it?",
    "Running fast is fun.",
    "He plays the piano.",
    "good",
    "Wonderful day",
    "Lunch at noon",
    "Be careful!",
    "Do it now"
]

invalid_texts = [
    "asdfgh",
    "qwertyuiop",
    "loremipsum",
    "12345",
    "!!!???",
    "___--",
    "afg78gdf",
    "hjkl hjkl",
    "blahblahblah",
    "zxcvb asdfg",
    ".....",
    ",,,,,",
    "@#$%^&*()",
    "rrrrrrrrrrrrrrr",
    "wtrhysd",
    "👾👾👾",
    "dsklfjsd",
    "yt!op",
    "xxxxxxxxxx",
    "999999"
]

train_texts = valid_texts + invalid_texts
train_labels = [1] * len(valid_texts) + [0] * len(invalid_texts)


In [42]:
val_texts = ["Hello world!", "ghjk lkjh hjkl"]
val_labels = [1, 0]

trainer = train_classifier(
    train_texts,
    train_labels,
    val_texts,
    val_labels,
    model_name='distilbert-base-uncased',
    epochs=20
)

test_sentences = [
    "This is a test.",
    "qwerty 12345"
]
preds = predict_validity(test_sentences, trainer)
for sent, p in zip(test_sentences, preds):
    print(f"\"{sent}\": {'Valid' if p == 1 else 'Invalid'}")

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = Trainer(


Step,Training Loss


{'train_runtime': 11.2985, 'train_samples_per_second': 70.806, 'train_steps_per_second': 5.31, 'train_loss': 0.26260852813720703, 'epoch': 20.0}


{'eval_loss': 0.04613985866308212, 'eval_runtime': 0.0148, 'eval_samples_per_second': 134.781, 'eval_steps_per_second': 67.39, 'epoch': 20.0}
"This is a test.": Valid
"qwerty 12345": Invalid


In [52]:
predict_validity(["fasfsaffasfasf"], trainer)

[0]

In [53]:
predict_validity(["poison"], trainer)

[1]