In [1]:
import torch
import tqdm.notebook as tqdm
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer, BertForSequenceClassification, get_linear_schedule_with_warmup
from torch.optim import AdamW

In [2]:
%uv pip install kagglehub fastparquet
import kagglehub
import pandas as pd

[2mUsing Python 3.12.6 environment at: /usr/local[0m
[2mAudited [1m2 packages[0m [2min 10ms[0m[0m
Note: you may need to restart the kernel to use updated packages.


In [3]:
import pandas as pd

splits = {
    "train": "data/train-00000-of-00001.parquet",
    "validation": "data/validation-00000-of-00001.parquet",
    "test": "data/test-00000-of-00001.parquet"
}
train = pd.read_parquet("hf://datasets/stanfordnlp/sst2/" + splits["train"]).drop(columns="idx")
test = pd.read_parquet("hf://datasets/stanfordnlp/sst2/" + splits["validation"]).drop(columns="idx")

In [4]:
class TextDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len=128):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        encoding = self.tokenizer(
            self.texts[idx],
            max_length=self.max_len,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )
        return {
            "input_ids": encoding["input_ids"].squeeze(),
            "attention_mask": encoding["attention_mask"].squeeze(),
            "labels": torch.tensor(self.labels[idx], dtype=torch.long),
        }

In [5]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
model = model.to(device)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [6]:
test_inputs = tokenizer(
    test["sentence"].tolist(),
    return_tensors="pt",
    truncation=True,
    padding=True,
    max_length=1024
).to(device)

In [7]:
with torch.no_grad():
    test_outputs = model(**test_inputs)
    test_predictions = torch.argmax(test_outputs.logits, dim=1)
    comp = pd.DataFrame({"true": test["label"], "pred": test_predictions.tolist()})
    print("accuracy", (comp["true"] == comp["pred"]).mean())

accuracy 0.49655963302752293


In [10]:
dataset = TextDataset(train["sentence"].tolist(), train["label"].tolist(), tokenizer)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# 5. Set up optimizer and scheduler
optimizer = AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)
epochs = 1
scheduler = get_linear_schedule_with_warmup(
    optimizer, num_warmup_steps=0, num_training_steps=len(dataloader) * epochs
)

In [11]:
model.train()
for epoch in range(epochs):
    total_loss = 0
    for batch in tqdm.tqdm(dataloader):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        optimizer.zero_grad()
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        total_loss += loss.item()

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        scheduler.step()

    print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(dataloader):.4f}")

  0%|          | 0/2105 [00:00<?, ?it/s]

Epoch 1/1, Loss: 0.1878


In [12]:
model.eval()
with torch.no_grad():
    test_outputs = model(**test_inputs)
    test_predictions = torch.argmax(test_outputs.logits, dim=1)
    comp = pd.DataFrame({"true": test["label"], "pred": test_predictions.tolist()})
    print("accuracy", (comp["true"] == comp["pred"]).mean())

accuracy 0.9185779816513762
