In [None]:

import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertForSequenceClassification
from torch.optim import AdamW
from transformers import get_scheduler
import torch.nn.functional as F
from tqdm.auto import tqdm
from config import BERT_TRAIN_DATA_PATH

In [None]:
class NewsDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels
    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = self.labels[idx]
        return item
    def __len__(self):
        return len(self.labels)

In [None]:
import transformers.tokenization_utils_base

with torch.serialization.safe_globals([transformers.tokenization_utils_base.BatchEncoding]):
    data = torch.load(BERT_TRAIN_DATA_PATH)
train_encodings = data['train_encodings']
train_labels = data['train_labels']
val_encodings = data['val_encodings']
val_labels = data['val_labels']

In [None]:
train_dataset = NewsDataset(train_encodings, train_labels)
val_dataset = NewsDataset(val_encodings, val_labels)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16)

In [None]:
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)

In [None]:
optimizer = AdamW(model.parameters(), lr=5e-5)
num_epochs = 3
num_training_steps = num_epochs * len(train_loader)
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps
)

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)

for epoch in range(num_epochs):
    model.train()
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}")
    for batch in progress_bar:
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

        progress_bar.set_postfix(loss=loss.item())

    
    model.eval()
    correct = 0
    total = 0
    for batch in val_loader:
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model(**batch)
        logits = outputs.logits
        predictions = torch.argmax(logits, dim=-1)
        correct += (predictions == batch['labels']).sum().item()
        total += batch['labels'].size(0)
    accuracy = correct / total
    print(f"Validation accuracy after epoch {epoch+1}: {accuracy:.4f}")