In [1]:
from transformers import BertTokenizerFast, BertForSequenceClassification, Trainer, TrainingArguments
from datasets import load_dataset, load_metric
import numpy as np
import torch




In [None]:
dataset = load_dataset("imdb", split="train[:80%]").train_test_split(test_size=0.2)
# Load pre-trained tokenizer and tokenize the dataset
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")

def tokenize(batch):
    return tokenizer(batch["text"], padding=True, truncation=True)

encoded_dataset = dataset.map(tokenize, batched=True)
encoded_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'label'])

In [None]:
model=BertForSequenceClassification.from_pretrained("bert-base-uncased",num_labels=2)
from evaluate import load
metric = load("accuracy")

def compute_metrics(pred):
    logits, labels = pred
    preds=np.argmax(logits, axis=-1)
    return metric.compute(predictions=preds, references=labels)

In [None]:
training_args = TrainingArguments(
    output_dir="./result",
    num_train_epochs=2,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    eval_strategy="epoch",
    logging_dir="./logs",
    logging_steps=10,
)

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=encoded_dataset["train"],
    eval_dataset=encoded_dataset["test"],
    compute_metrics=compute_metrics,
)
trainer.train()
trainer.evaluate()
trainer.save_model()

In [None]:
print(training_args.output_dir)


In [None]:
# Run inference on a sample sentence directly after training
text = "I hate it."

# Tokenize
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)

# Move to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
inputs = {k: v.to(device) for k, v in inputs.items()}

# Inference
model.eval()
with torch.no_grad():
    outputs = model(**inputs)
    logits = outputs.logits
    predicted_class = torch.argmax(logits, dim=1).item()

print("Predicted class:", predicted_class)