<a href="https://colab.research.google.com/github/georgilos/Bert-for-text-classification/blob/main/Active%2BBERT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#cloning the repository
#https://github.com/rmunro/pytorch_active_learning.git
#https://github.com/georgilos/Bert-for-text-classification.git
!git clone https://github.com/georgilos/Bert-for-text-classification.git

Cloning into 'Bert-for-text-classification'...
remote: Enumerating objects: 151, done.[K
remote: Counting objects: 100% (14/14), done.[K
remote: Compressing objects: 100% (10/10), done.[K
remote: Total 151 (delta 5), reused 11 (delta 4), pack-reused 137 (from 1)[K
Receiving objects: 100% (151/151), 36.81 MiB | 20.20 MiB/s, done.
Resolving deltas: 100% (54/54), done.


In [5]:
%ls


active_learning_basics.py    [0m[01;34mevaluation_data[0m/     READMEa.md        uncertainty_sampling.py
active_learning.py           LICENSE              README.md         [01;34munlabeled_data[0m/
advanced_active_learning.py  [01;34mmodels[0m/              requirements.txt  [01;34mvalidation_data[0m/
diversity_sampling.py        pytorch_clusters.py  [01;34mtraining_data[0m/


In [4]:
#changing directory to repository
#%cd pytorch_active_learning/
%cd Bert-for-text-classification/

[Errno 2] No such file or directory: 'Bert-for-text-classification/'
/content/Bert-for-text-classification


In [None]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


Using device: cuda


### Excecuting active_learning_basic.py with the SimpleTextClassifier swapped with bert-base-uncased

In [12]:
#!/usr/bin/env python

"""INTRODUCTION TO ACTIVE LEARNING

A simple text classification algorithm in PyTorch

This is an open source example to accompany Chapter 2 from the book:
"Human-in-the-Loop Machine Learning"

This example tries to classify news headlines into one of two categories:
  disaster-related
  not disaster-related

It looks for low confidence items and outliers humans should review

"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import math
import datetime
import csv
import re
import os
from random import shuffle
from collections import defaultdict
import transformers
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW  # Use PyTorch's AdamW
from transformers import get_linear_schedule_with_warmup
import logging
from sklearn.metrics import f1_score, roc_auc_score
import numpy as np

__author__ = "Robert Munro"
__license__ = "MIT"
__version__ = "1.0.1"

# Define device
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(f'Using device: {device}')

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# settings

minimum_evaluation_items = 1200  # annotate this many randomly sampled items first for evaluation data before creating training data
minimum_training_items = 400  # minimum number of training items before we first train a model

epochs = 3  # number of epochs per training session
select_per_epoch = 200  # number to select per epoch per label
batch_size = 16  # Batch size for training

data = []
test_data = []

# directories with data
unlabeled_data = "unlabeled_data/unlabeled_data.csv"
evaluation_related_data = "evaluation_data/related.csv"
evaluation_not_related_data = "evaluation_data/not_related.csv"
training_related_data = "training_data/related.csv"
training_not_related_data = "training_data/not_related.csv"
# validation_related_data # not used in this example
# validation_not_related_data # not used in this example

already_labeled = {}  # tracking what is already labeled

def load_data(filepath, skip_already_labeled=False):
    # csv format: [ID, TEXT, LABEL, SAMPLING_STRATEGY, CONFIDENCE]
    data = []
    try:
        with open(filepath, 'r', encoding='utf-8') as csvfile:
            reader = csv.reader(csvfile)
            for row in reader:
                if skip_already_labeled and row[0] in already_labeled:
                    continue

                # Ensure all necessary columns are present
                while len(row) < 5:
                    if len(row) == 2:
                        row.append("")  # LABEL
                    elif len(row) == 3:
                        row.append("")  # SAMPLING_STRATEGY
                    elif len(row) == 4:
                        row.append(0)  # CONFIDENCE

                data.append(row)

                label = str(row[2])
                if row[2] != "":
                    textid = row[0]
                    already_labeled[textid] = label
    except FileNotFoundError:
        logger.error(f"File not found: {filepath}")
    return data

def append_data(filepath, data):
    """Append data to a CSV file."""
    try:
        with open(filepath, 'a', encoding='utf-8', errors='replace', newline='') as csvfile:
            writer = csv.writer(csvfile)
            writer.writerows(data)
    except FileNotFoundError:
        logger.error(f"File not found: {filepath}")

def write_data(filepath, data):
    """Write data to a CSV file."""
    try:
        with open(filepath, 'w', encoding='utf-8', errors='replace', newline='') as csvfile:
            writer = csv.writer(csvfile)
            writer.writerows(data)
    except FileNotFoundError:
        logger.error(f"File not found: {filepath}")

# LOAD ALL UNLABELED, TRAINING, VALIDATION, AND EVALUATION DATA
training_data = load_data(training_related_data) + load_data(training_not_related_data)
training_count = len(training_data)

evaluation_data = load_data(evaluation_related_data) + load_data(evaluation_not_related_data)
evaluation_count = len(evaluation_data)

data = load_data(unlabeled_data, skip_already_labeled=True)

annotation_instructions = (
    "Please type 1 if this message is disaster-related, "
    "or hit Enter if not.\n"
    "Type 2 to go back to the last message, "
    "type d to see detailed definitions, "
    "or type s to save your annotations.\n"
)

last_instruction = (
    "All done!\n"
    "Type 2 to go back to change any labels,\n"
    "or Enter to save your annotations."
)

detailed_instructions = (
    "A 'disaster-related' headline is any story about a disaster.\n"
    "It includes:\n"
    "  - human, animal and plant disasters.\n"
    "  - the response to disasters (aid).\n"
    "  - natural disasters and man-made ones like wars.\n"
    "It does not include:\n"
    "  - criminal acts and non-disaster-related police work\n"
    "  - post-response activity like disaster-related memorials.\n\n"
)

def get_annotations(data, default_sampling_strategy="random"):
    """Prompts annotator for label from command line and adds annotations to data

    Keyword arguments:
        data -- a list of unlabeled items where each item is
                [ID, TEXT, LABEL, SAMPLING_STRATEGY, CONFIDENCE]
        default_sampling_strategy -- strategy to use for each item if not already specified
    """

    ind = 0
    while ind < len(data):
        if ind < 0:
            ind = 0  # in case you've gone back before the first
        if ind < len(data):
            textid = data[ind][0]
            text = data[ind][1]
            label = data[ind][2]
            strategy = data[ind][3]

            if textid in already_labeled:
                print(f"Skipping seen label: {label}")
                ind += 1
            else:
                print(annotation_instructions)
                label_input = input(text + "\n\n> ").strip()

                if label_input == "2":
                    ind -= 1  # go back
                elif label_input == "d":
                    print(detailed_instructions)  # print detailed instructions
                elif label_input == "s":
                    break  # save and exit
                else:
                    if label_input != "1":
                        label = "0"  # treat everything other than 1 as 0
                    else:
                        label = "1"

                    data[ind][2] = label  # add label to our data

                    if not data[ind][3]:
                        data[ind][3] = default_sampling_strategy  # add default if none given
                    ind += 1

        else:
            # Last one - give annotator a chance to go back
            print(last_instruction)
            label_input = input("\n\n> ").strip()
            if label_input == "2":
                ind -= 1
            else:
                ind += 1

    return data

class BERTTextClassifier(nn.Module):
    def __init__(self, num_labels, max_seq_length):
        super(BERTTextClassifier, self).__init__()
        self.bert = transformers.BertModel.from_pretrained('bert-base-uncased')
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.pooler_output
        pooled_output = self.dropout(pooled_output)  # Use the pooled output for classification
        logits = self.classifier(pooled_output)
        return logits

class TextDataset(Dataset):
    def __init__(self, data, tokenizer, max_seq_length):
        self.input_ids = []
        self.attention_masks = []
        self.labels = []

        for item in data:
            try:
                label = int(item[2])
            except ValueError:
                # If not, skip this item or assign a default label
                logger.warning(f"Skipping item with invalid label: {item[2]}")  # Log the warning
                print(f"Skipping item with invalid label: {item[2]}")  # Print to console
                continue  # Skip this item
            text = item[1]
            encoded_input = tokenizer.encode_plus(
                text,
                add_special_tokens=True,
                max_length=max_seq_length,
                truncation=True,
                padding='max_length',
                return_attention_mask=True,
                return_tensors='pt'
            )
            # Avoid using torch.tensor on existing tensors
            self.input_ids.append(encoded_input['input_ids'].clone().detach().squeeze(0))
            self.attention_masks.append(encoded_input['attention_mask'].clone().detach().squeeze(0))
            self.labels.append(label)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return {
            'input_ids': self.input_ids[idx],
            'attention_mask': self.attention_masks[idx],
            'labels': self.labels[idx]
        }

def train_model(training_data, validation_data, evaluation_data, num_labels=2):
    """Train model on the given training_data
      Tune with the validation_data
      Evaluate accuracy with the evaluation_data
    """
    model = BERTTextClassifier(num_labels=num_labels, max_seq_length=128).to(device)
    loss_function = nn.CrossEntropyLoss()

    tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-uncased')

    # Prepare DataLoader
    train_dataset = TextDataset(training_data, tokenizer, max_seq_length=128)
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    # Optimizer & Scheduler
    optimizer = AdamW(model.parameters(), lr=2e-5, eps=1e-8)  # Use PyTorch's AdamW

    total_steps = len(train_dataloader) * epochs

    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=0,  # Typically, warmup steps are a small fraction of total steps
        num_training_steps=total_steps
    )

    # Training Loop
    for epoch in range(epochs):
        epoch_num = epoch + 1
        logger.info(f"Epoch: {epoch_num}/{epochs}")
        print(f"Epoch: {epoch_num}/{epochs}")  # Print epoch progress
        model.train()

        total_loss = 0

        for batch in train_dataloader:
            optimizer.zero_grad()

            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            logits = model(input_ids, attention_mask)

            loss = loss_function(logits, labels)
            total_loss += loss.item()

            loss.backward()

            # Gradient Clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            optimizer.step()
            scheduler.step()

        avg_train_loss = total_loss / len(train_dataloader)
        logger.info(f"Average training loss: {avg_train_loss:.3f}")
        print(f"Average training loss: {avg_train_loss:.3f}")  # Print training loss

    fscore, auc = evaluate_model(model, evaluation_data, tokenizer)
    fscore = round(fscore, 3)
    auc = round(auc, 3)

    # Save model to path that is alphanumeric and includes number of items and accuracies in filename
    timestamp = re.sub('\.[0-9]*', '_', str(datetime.datetime.now())).replace(" ", "_").replace("-", "").replace(":", "")
    training_size = "_" + str(len(training_data))
    accuracies = f"{fscore}_{auc}"

    os.makedirs("models", exist_ok=True)  # Ensure the models directory exists
    model_path = f"models/{timestamp}{accuracies}{training_size}.params"

    torch.save(model.state_dict(), model_path)
    logger.info(f"Model saved to: {model_path}")
    print(f"Model saved to: {model_path}")  # Print model save path
    return model_path

def evaluate_model(model, evaluation_data, tokenizer, batch_size=32):
    """Evaluate the model on the held-out evaluation data

    Return the f-value for disaster-related and the AUC
    """
    model.eval()
    all_labels = []
    all_probs = []

    dataset = TextDataset(evaluation_data, tokenizer, max_seq_length=128)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            logits = model(input_ids, attention_mask)
            probabilities = F.softmax(logits, dim=1)
            probs_related = probabilities[:, 1].cpu().numpy()

            all_probs.extend(probs_related.tolist())
            all_labels.extend(labels.cpu().numpy().tolist())

    # Calculate metrics using sklearn
    fscore = f1_score(all_labels, [1 if p > 0.5 else 0 for p in all_probs])
    auc = roc_auc_score(all_labels, all_probs)

    logger.info(f"[fscore, auc] = {fscore}, {auc}")
    print(f"[fscore, auc] = {fscore}, {auc}")  # Print evaluation metrics
    return [fscore, auc]


def get_low_conf_unlabeled(model, unlabeled_data, tokenizer, max_seq_length, number=80, limit=10000):
    confidences = []
    if limit == -1:  # Predicting on all data
        logger.info("Get confidences for all unlabeled data (this might take a while)")
        print("Get confidences for all unlabeled data (this might take a while)")
    else:
        # Only apply the model to a limited number of items
        shuffle(unlabeled_data)
        unlabeled_data = unlabeled_data[:limit]

    model.eval()
    with torch.no_grad():
        for item in unlabeled_data:
            textid = item[0]
            if textid in already_labeled:
                continue
            item[3] = "random_remaining"
            text = item[1]

            encoded_input = tokenizer.encode_plus(
                text,
                add_special_tokens=True,
                max_length=max_seq_length,
                truncation=True,
                padding='max_length',
                return_attention_mask=True,
                return_tensors='pt'
            )
            input_ids = encoded_input['input_ids'].clone().detach().to(device)
            attention_mask = encoded_input['attention_mask'].clone().detach().to(device)

            logits = model(input_ids, attention_mask)

            # Apply softmax to get probabilities
            probabilities = F.softmax(logits, dim=1)
            prob_related = probabilities[0][1].item()

            if prob_related < 0.5:
                confidence = 1 - prob_related
            else:
                confidence = prob_related

            item[3] = "low confidence"
            item[4] = confidence
            confidences.append(item)

    # Sort by confidence ascending
    confidences.sort(key=lambda x: x[4])
    return confidences[:number:]

def get_random_items(unlabeled_data, number=10):
    shuffle(unlabeled_data)
    random_items = []
    for item in unlabeled_data:
        textid = item[0]
        if textid in already_labeled:
            continue
        item[3] = "random_remaining"
        random_items.append(item)
        if len(random_items) >= number:
            break

    return random_items

def get_outliers(training_data, unlabeled_data, number=10, max_iterations=1000):
    """Get outliers from unlabeled data in training data

    Returns number outliers

    An outlier is defined as the percent of words in an item in
    unlabeled_data that do not exist in training_data
    """
    outliers = []
    total_feature_counts = defaultdict(lambda: 0)

    for item in training_data:
        text = item[1]
        features = text.split()

        for feature in features:
            total_feature_counts[feature] += 1

    iterations = 0
    while (len(outliers) < number and iterations < max_iterations):
        iterations += 1
        top_outlier = []
        top_match = float("inf")

        for item in unlabeled_data:
            textid = item[0]
            if textid in already_labeled:
                continue

            text = item[1]
            features = text.split()

            total_matches = 1  # start at 1 for slight smoothing
            for feature in features:
                if feature in total_feature_counts:
                    total_matches += total_feature_counts[feature]

            ave_matches = total_matches / len(features)
            if ave_matches < top_match:
                top_match = ave_matches
                top_outlier = item

        if not top_outlier:
            break  # No outlier found

        # Add this outlier to list and update what is 'labeled',
        # assuming this new outlier will get a label
        top_outlier[3] = "outlier"
        outliers.append(top_outlier)
        text = top_outlier[1]
        features = text.split()
        for feature in features:
            total_feature_counts[feature] += 1

    if iterations == max_iterations:
        logger.warning(f"Reached maximum iterations ({max_iterations}) without finding enough outliers.")
        print(f"Reached maximum iterations ({max_iterations}) without finding enough outliers.")

    return outliers

def main():

    tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-uncased')
    global training_data, training_count, evaluation_data, evaluation_count, data

    if evaluation_count < minimum_evaluation_items:
        # Keep adding to evaluation data first
        logger.info("Creating evaluation data:\n")
        print("Creating evaluation data:\n")  # Print action

        shuffle(data)
        needed = minimum_evaluation_items - evaluation_count
        selected_data = data[:needed]
        logger.info(f"{needed} more annotations needed")
        print(f"{needed} more annotations needed")  # Print needed annotations

        annotated_data = get_annotations(selected_data)

        related = [item for item in annotated_data if item[2] == "1"]
        not_related = [item for item in annotated_data if item[2] == "0"]

        # Append evaluation data
        append_data(evaluation_related_data, related)
        append_data(evaluation_not_related_data, not_related)

    elif training_count < minimum_training_items:
        # Let's create our first training data!
        logger.info("Creating initial training data:\n")
        print("Creating initial training data:\n")  # Print action

        shuffle(data)
        needed = minimum_training_items - training_count
        selected_data = data[:needed]
        logger.info(f"{needed} more annotations needed")
        print(f"{needed} more annotations needed")  # Print needed annotations

        annotated_data = get_annotations(selected_data)

        related = [item for item in annotated_data if item[2] == "1"]
        not_related = [item for item in annotated_data if item[2] == "0"]

        # Append training data
        append_data(training_related_data, related)
        append_data(training_not_related_data, not_related)
    else:
        # Let's start Active Learning!!

        # Train new model with current training data
        logger.info("Training model with current training data.")
        print("Training model with current training data.")  # Print action

        model_path = train_model(training_data, None, evaluation_data)

        logger.info("Sampling via Active Learning:\n")
        print("Sampling via Active Learning:\n")  # Print action

        model = BERTTextClassifier(num_labels=2, max_seq_length=128).to(device)
        try:
            model.load_state_dict(torch.load(model_path))
            model.to(device)
        except Exception as e:
            logger.error(f"Error loading model from {model_path}: {e}")
            print(f"Error loading model from {model_path}: {e}")  # Print error
            return

        # Get items per iteration with the following breakdown of strategies:
        random_items = get_random_items(data, number=10)
        low_confidences = get_low_conf_unlabeled(model, data, tokenizer, 128, number=80)
        outliers = get_outliers(training_data + random_items + low_confidences, data, number=10)

        sampled_data = random_items + low_confidences + outliers
        shuffle(sampled_data)

        annotated_sampled_data = get_annotations(sampled_data)
        related = [item for item in annotated_sampled_data if item[2] == "1"]
        not_related = [item for item in annotated_sampled_data if item[2] == "0"]

        # Append training data
        append_data(training_related_data, related)
        append_data(training_not_related_data, not_related)

    if training_count > minimum_training_items:
        logger.info("\nRetraining model with new data")
        print("\nRetraining model with new data")  # Print action

        # UPDATE OUR DATA AND (RE)TRAIN MODEL WITH NEWLY ANNOTATED DATA
        training_data = load_data(training_related_data) + load_data(training_not_related_data)
        training_count = len(training_data)

        evaluation_data = load_data(evaluation_related_data) + load_data(evaluation_not_related_data)
        evaluation_count = len(evaluation_data)

        logger.info("Training model with updated training data.")
        print("Training model with updated training data.")  # Print action

        model_path = train_model(training_data, None, evaluation_data)
        model = BERTTextClassifier(num_labels=2, max_seq_length=128).to(device)
        try:
            model.load_state_dict(torch.load(model_path))
            model.to(device)
        except Exception as e:
            logger.error(f"Error loading model from {model_path}: {e}")
            print(f"Error loading model from {model_path}: {e}")  # Print error
            return

        fscore, auc = evaluate_model(model, evaluation_data, tokenizer)
        logger.info(f"[fscore, auc] = {fscore}, {auc}")
        print(f"[fscore, auc] = {fscore}, {auc}")  # Print evaluation metrics
        logger.info(f"Model saved to: {model_path}")
        print(f"Model saved to: {model_path}")  # Print model save path

if __name__ == "__main__":
    main()


Using device: cuda




Training model with current training data.
Epoch: 1/3
Average training loss: 0.561
Epoch: 2/3
Average training loss: 0.368
Epoch: 3/3
Average training loss: 0.266
[fscore, auc] = 0.6735751295336787, 0.8982463659416409
Model saved to: models/20241105_203804_0.674_0.898_500.params
Sampling via Active Learning:



  model.load_state_dict(torch.load(model_path))


Please type 1 if this message is disaster-related, or hit Enter if not.
Type 2 to go back to the last message, type d to see detailed definitions, or type s to save your annotations.

billionaire pissing contest yawning faceyawning faceyawning face

> 1
Please type 1 if this message is disaster-related, or hit Enter if not.
Type 2 to go back to the last message, type d to see detailed definitions, or type s to save your annotations.

peoplecnbprolifedadsmomsprolifebynaturedadmomchildrenhomelandsafganistanitalygermanychinaetc samesexmarriages messespeople causespeopleturnin transgenderpeople wernthomosexualslesbiansortransgenderpeople wersonsdaughters

> 1
Please type 1 if this message is disaster-related, or hit Enter if not.
Type 2 to go back to the last message, type d to see detailed definitions, or type s to save your annotations.

hello and i had to turn off attacking biden for not cleaning up the gop caused mess adequately with no mention of gop caused the mess

> 1
Please type 1



Epoch: 1/3
Average training loss: 0.558
Epoch: 2/3
Average training loss: 0.333
Epoch: 3/3
Average training loss: 0.233
[fscore, auc] = 0.6415094339622641, 0.8966659523554413
Model saved to: models/20241105_204118_0.642_0.897_506.params


  model.load_state_dict(torch.load(model_path))


[fscore, auc] = 0.6415094339622641, 0.8966659523554413
[fscore, auc] = 0.6415094339622641, 0.8966659523554413
Model saved to: models/20241105_204118_0.642_0.897_506.params


##Inference

In [13]:
import torch
from transformers import BertTokenizer, BertModel
import torch.nn as nn # Add this import statement
from torch.nn import functional as F # Add this import statement

class BERTTextClassifier(nn.Module):
    def __init__(self, num_labels, max_seq_length):
        super(BERTTextClassifier, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased') # Use BertModel instead of transformers.BertModel
        self.dropout = nn.Dropout(0.1)  # Adjust dropout rate as needed
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids, attention_mask=attention_mask)

        pooled_output = outputs[1] #equivalent to outputs .pooler_output
  # Use the pooled output for classification
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        return logits

# Load the tokenizer and model (assuming you have downloaded and saved them)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BERTTextClassifier(num_labels=2, max_seq_length=128).to(device)
model.load_state_dict(torch.load("/content/Bert-for-text-classification/models/20241105_204118_0.642_0.897_506.params"))  # Replace with your model path

def predict(sentence):
  """
  Classifies a new sentence using the loaded model.

  Args:
      sentence: The sentence to be classified (disaster-related or not).

  Returns:
      A tuple containing the predicted label (0 or 1) and the confidence score.
  """
  # Preprocess the sentence
  input_ids, attention_mask = make_feature_vector(sentence, tokenizer, max_seq_length=128)

  # Make prediction
  model.eval()
  with torch.no_grad():
    logits = model(input_ids.to(device), attention_mask.to(device))
    probabilities = F.softmax(logits, dim=1)
    prob_related = probabilities[0][1].item()
    predicted_label = 1 if prob_related > 0.5 else 0
    return predicted_label, prob_related

# Define the function to create feature vectors (assuming the make_feature_vector function is defined elsewhere)
def make_feature_vector(text, tokenizer, max_seq_length):
  # Implement your feature vector creation logic here (same as the original code)

  tokens = tokenizer.encode_plus(text, add_special_tokens=True, max_length=max_seq_length, padding='max_length', truncation=True, return_tensors='pt')
  return tokens['input_ids'], tokens['attention_mask']

# Example usage
new_sentence = " Trump is an orange buffoon "
prediction, confidence = predict(new_sentence)

if prediction == 1:
  print("The sentence is predicted to be hatefull with confidence:", confidence)
else:
  print("The sentence is predicted to not be hatefull with confidence:", confidence)

  model.load_state_dict(torch.load("/content/Bert-for-text-classification/models/20241105_204118_0.642_0.897_506.params"))  # Replace with your model path


The sentence is predicted to be hatefull with confidence: 0.5089988708496094


##Saving in drive

In [14]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [15]:
#Save the .params file from colab to gdrive
!cp /content/Bert-for-text-classification/models/20241105_204118_0.642_0.897_506.params /content/gdrive/MyDrive/Saved_models

### Loading from drive

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [None]:
#Save the .params file from google drive to the /models directory

!cp /content/gdrive/MyDrive/Saved_models/20241010_210511_0.688_0.953_989.params /content/pytorch_active_learning/models/
