In [1]:
!pip install transformers
!pip install datasets
!pip install torch torchvision torchaudio

Collecting datasets
  Downloading datasets-3.2.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.2.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m13.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m9.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.9.0-py3-none-any.whl (

In [2]:
import os
import pandas as pd
from argparse import ArgumentParser
from tqdm import tqdm
from pathlib import Path

import torch as ch
import torch.nn as nn
from torch.utils.data import DataLoader

# Huggingface imports
from datasets import load_dataset
from transformers import (
    AutoConfig,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    default_data_collator,
)

# Configuration
GLUE_TASK_TO_KEYS = {
    "qnli": ("question", "sentence"),
    "mnli": ("premise", "hypothesis"),
    "mrpc": ("sentence1", "sentence2"),
    "qnli": ("question", "sentence"),
    "qqp": ("question1", "question2"),
    "rte": ("sentence1", "sentence2"),
    "sst2": ("sentence", None),
    "stsb": ("sentence1", "sentence2"),
    "wnli": ("sentence1", "sentence2"),
}


# Adjust dataset size as needed for Colab
TRAIN_SET_SIZE = 50_000  # Reduced for Colab memory constraints
VAL_SET_SIZE = 5_463


class SequenceClassificationModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.config = AutoConfig.from_pretrained(
            "gchhablani/bert-base-cased-finetuned-qnli",
            num_labels=2,
            finetuning_task="qnli",
            attn_implementation="eager",
        )

        self.model = AutoModelForSequenceClassification.from_pretrained(
            "gchhablani/bert-base-cased-finetuned-qnli", config=self.config, ignore_mismatched_sizes=False
        )

        # Check if GPU is available
        self.device = "cuda" if ch.cuda.is_available() else "cpu"
        self.model.eval().to(self.device)

    def forward(self, input_ids, token_type_ids, attention_mask):
        return self.model(
            input_ids=input_ids,
            token_type_ids=token_type_ids,
            attention_mask=attention_mask,
        ).logits


def get_dataset(split, inds=None):
    raw_datasets = load_dataset("glue", "qnli")
    sentence1_key, sentence2_key = GLUE_TASK_TO_KEYS["qnli"]

    tokenizer = AutoTokenizer.from_pretrained("gchhablani/bert-base-cased-finetuned-qnli", use_fast=True)

    def preprocess_function(examples):
        args = (examples[sentence1_key], examples[sentence2_key])
        return tokenizer(*args, padding="max_length", max_length=128, truncation=True)

    raw_datasets = raw_datasets.map(
        preprocess_function,
        batched=True,
        desc="Running tokenizer on dataset",
    )

    if split == "train":
        ds = raw_datasets["train"]
    else:
        ds = raw_datasets["validation"]
    return ds


def init_model(ckpt_path=None):
    model = SequenceClassificationModel()
    if ckpt_path and os.path.exists(ckpt_path):
        sd = ch.load(ckpt_path, map_location=model.device)
        model.model.load_state_dict(sd)
    return model


def init_loaders(batch_size=16):
    ds_train = get_dataset("train").select(range(TRAIN_SET_SIZE))
    ds_val = get_dataset("val").select(range(VAL_SET_SIZE))
    return (
        DataLoader(
            ds_train,
            batch_size=batch_size,
            shuffle=False,
            collate_fn=default_data_collator,
        ),
        DataLoader(
            ds_val,
            batch_size=batch_size,
            shuffle=False,
            collate_fn=default_data_collator,
        ),
    )


def process_batch(batch, device):
    return [
        x.to(device)
        for x in [
            batch["input_ids"],
            batch["token_type_ids"],
            batch["attention_mask"],
            batch["labels"],
        ]
    ]

In [3]:
def evaluate_model(model, data_loader):
    correct = 0
    total = 0

    data = {
        'input': [],
        'prediction': [],
        'label': []
    }

    with ch.no_grad():
        for batch in data_loader:
            input_ids, token_type_ids, attention_mask, labels = process_batch(batch, model.device)

            logits = model(input_ids, token_type_ids, attention_mask)
            predictions = ch.argmax(logits, dim=-1)

            correct += (predictions == labels).sum().item()
            total += labels.size(0)

            tokenizer = AutoTokenizer.from_pretrained("gchhablani/bert-base-cased-finetuned-qnli", use_fast=True)
            input_text = tokenizer.batch_decode(input_ids, skip_special_tokens=True)
            labels = labels.cpu().numpy()
            predictions = predictions.cpu().numpy()

            for i in range(len(input_text)):
                print(f"Input: {input_text[i]}")
                print(f"Prediction: {predictions[i].item()}")
                print(f"Label: {labels[i].item()}")

                data['input'].append(input_text[i])
                data['prediction'].append(predictions[i].item())
                data['label'].append(labels[i].item())

                print()

    df = pd.DataFrame(data)
    df.to_csv('predictions.csv', index=False)

    accuracy = correct / total
    return accuracy

if __name__ == "__main__":
    model = init_model()
    train_loader, val_loader = init_loaders(batch_size=16)

    print("Evaluating validation set...")
    val_accuracy = evaluate_model(model, val_loader)
    print(f"\nValidation Accuracy: {val_accuracy:.4f}")


config.json:   0%|          | 0.00/895 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/433M [00:00<?, ?B/s]

README.md:   0%|          | 0.00/35.3k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/872k [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/877k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/104743 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/5463 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/5463 [00:00<?, ? examples/s]

tokenizer_config.json:   0%|          | 0.00/320 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/436k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

Running tokenizer on dataset:   0%|          | 0/104743 [00:00<?, ? examples/s]

Running tokenizer on dataset:   0%|          | 0/5463 [00:00<?, ? examples/s]

Running tokenizer on dataset:   0%|          | 0/5463 [00:00<?, ? examples/s]

Running tokenizer on dataset:   0%|          | 0/5463 [00:00<?, ? examples/s]

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Label: 1

Input: What are the stages in a compound engine called? These stages were called expansions, with double and triple expansion engines being common, especially in shipping where efficiency was important to reduce the weight of coal carried.
Prediction: 0
Label: 0

Input: After the merger between ABC and Capital Cities, who became the vice president of ABC broadcasting? The merger between ABC and Capital Cities received federal approval on September 5, 1985.
Prediction: 1
Label: 1

Input: What is the weather type of Mallee and upper Wimmera? The Mallee and upper Wimmera are Victoria ' s warmest regions with hot winds blowing from nearby semi - deserts.
Prediction: 0
Label: 0

Input: Who registered the most sacks on the team this season? The Panthers defense gave up just 308 points, ranking sixth in the league, while also leading the NFL in interceptions with 24 and boasting four Pro Bowl selections.
Prediction: 1


In [4]:
def predict_custom_text(model, tokenizer, sentence1, sentence2=None):
    """
    Predicts the class of the given custom text input using the model.

    Args:
        model (SequenceClassificationModel): The sequence classification model.
        tokenizer (AutoTokenizer): The tokenizer used for preprocessing text.
        sentence1 (str): The first text input (e.g., question or sentence).
        sentence2 (str, optional): The second text input (e.g., sentence or hypothesis). Defaults to None.

    Returns:
        int: Predicted class label (e.g., 0 or 1).
    """
    # Tokenize the input
    inputs = tokenizer(
        sentence1,
        sentence2,
        padding="max_length",
        max_length=128,
        truncation=True,
        return_tensors="pt",
    )

    # Move inputs to the same device as the model
    inputs = {key: val.to(model.device) for key, val in inputs.items()}

    # Perform inference
    with ch.no_grad():
        logits = model(**inputs)
        predictions = ch.argmax(logits, dim=1)  # Get the predicted class

    return predictions.item()


if __name__ == "__main__":
    # Initialize model and tokenizer
    model = init_model()
    tokenizer = AutoTokenizer.from_pretrained("gchhablani/bert-base-cased-finetuned-qnli")

    # Input custom text
    custom_text1 = "Who discovered this and where did they come from?"
    custom_text2 = "The development of this fertile soil allowed agriculture and silviculture in the previously hostile environment ; meaning that large portions of the Amazon rainforest are probably the result of centuries of human management, rather than naturally occurring as has previously been supposed."

    # Get prediction
    prediction = predict_custom_text(model, tokenizer, custom_text1, custom_text2)
    print(f"Predicted label: {prediction}")


Predicted label: 1
