In [1]:

import torch
from torch.utils.data import DataLoader, random_split
from transformers import BertForTokenClassification, AdamW, get_scheduler
from tqdm import tqdm


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

train_size = int(0.99 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

batch_size = 32
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size)

num_labels = len(set([label for row in df["labels"] for label in row]))
model = BertForTokenClassification.from_pretrained("bert-base-uncased", num_labels=num_labels)
model.to(device)

optimizer = AdamW(model.parameters(), lr=5e-5)
num_epochs = 0
num_training_steps = len(train_dataloader) * num_epochs
lr_scheduler = get_scheduler("linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps)


for epoch in range(num_epochs):

    model.train()
    train_loss = 0
    loop = tqdm(train_dataloader, leave=True)
    for batch in loop:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        train_loss += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        loop.set_description(f"Epoch {epoch}")
        loop.set_postfix(train_loss=loss.item())

    avg_train_loss = train_loss / len(train_dataloader)
    print(f"Epoch {epoch}: Average Training Loss = {avg_train_loss:.4f}")

    model.eval()
    val_loss = 0
    with torch.no_grad():
        for batch in val_dataloader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            val_loss += loss.item()

    avg_val_loss = val_loss / len(val_dataloader)
    print(f"Epoch {epoch}: Average Validation Loss = {avg_val_loss:.4f}")


KeyboardInterrupt: 