In [None]:
import torch
from datasets import load_dataset

data_set = load_dataset('imdb')

In [None]:
from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
def tokenizer_func(examples):
    return tokenizer(examples['text'],padding='max_length',truncation=True)
tokenized_datasets = data_set.map(tokenizer_func,batched=True)

In [None]:
from torch.utils.data import DataLoader

torch.cuda.empty_cache()
tokenized_datasets = tokenized_datasets.remove_columns(['text'])
tokenized_datasets = tokenized_datasets.rename_column("label","labels")
tokenized_datasets.set_format('torch')
small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(1000))
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(1000))
train_dataloader = DataLoader(small_train_dataset, shuffle=True, batch_size=8)
eval_dataloader = DataLoader(small_eval_dataset, batch_size=1)

In [None]:
import torch
from transformers import BertForSequenceClassification
from transformers import get_scheduler
from torch.optim import AdamW

model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
optimizer = AdamW(model.parameters(), lr=5e-5)
num_epochs = 3
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)

In [None]:
from tqdm.auto import tqdm

progress_bar = tqdm(range(num_training_steps))
model.train()
for epoch in range(num_epochs):
    for batch in train_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)

In [None]:
output = []
model.eval()
for batch in eval_dataloader :
  batch = {k: v.to(device) for k, v in batch.items()}
  output.append(torch.argmax(model(**batch)[1],axis=-1).item())

In [None]:
from sklearn import metrics
score = metrics.accuracy_score(small_eval_dataset['labels'],output)
score