In [1]:
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer
from transformers import TrainingArguments, Trainer
from transformers import DataCollatorWithPadding
from datasets import load_from_disk
import numpy as np
import evaluate

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_id = "distilbert/distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_id)

In [3]:
def preprocess_function(examples):
    # Apply the tokenizer to the 'text' column, 
    # truncating sequences that are too long
    return tokenizer(examples["text"], truncation=True)

# Load the dataset saved from the data preparation step
dataset = load_from_disk("./data/mail_dataset_labeled")

# Apply the preprocessing function to the 
# entire dataset using map for efficiency
tokenized_dataset = dataset.map(preprocess_function, batched=True)

Map: 100%|██████████| 43/43 [00:00<00:00, 8560.62 examples/s]


In [4]:
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [5]:
label2id = {
 "IN_Bank": 0,
 "IN_School": 1,
 "US_Bank": 2,
 "US_School": 3
}

id2label = {
 0: "IN_Bank",
 1: "IN_School",
 2: "US_Bank",
 3: "US_School"
}

In [6]:
model = AutoModelForSequenceClassification.from_pretrained(model_id, num_labels=4, id2label=id2label, label2id=label2id)

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert/distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
accuracy = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    # The Trainer provides predictions and labels during evaluation
    predictions, labels = eval_pred
    # Get the predicted class ID (the one with the highest probability)
    predictions = np.argmax(predictions, axis=1)
    # Use the loaded accuracy metric to compare predictions to true labels
    return accuracy.compute(predictions=predictions, references=labels)

In [62]:
training_args = TrainingArguments(
    output_dir="trainer_output/my_mail_classifier",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    weight_decay=0.01,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    push_to_hub=False,
    report_to="none",
)

In [63]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

  trainer = Trainer(


In [64]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,No log,0.231592,0.97619
2,No log,0.243119,0.97619
3,No log,0.246381,0.97619


TrainOutput(global_step=126, training_loss=0.0001627967825957707, metrics={'train_runtime': 19.3209, 'train_samples_per_second': 52.016, 'train_steps_per_second': 6.521, 'total_flos': 13832470040592.0, 'train_loss': 0.0001627967825957707, 'epoch': 3.0})

In [65]:
model_path = "models/my_mail_classifier"
trainer.save_model(model_path)

# Test the model

In [12]:
lable_names = [
	"IN_Bank",
	"IN_School",
	"US_Bank",
	"US_School"
]

In [13]:
dataset["test"].features

{'text': Value(dtype='string', id=None),
 'label': ClassLabel(names=['IN_Bank', 'IN_School', 'US_Bank', 'US_School'], id=None)}

In [14]:
mail_summaries = [
	ds["text"] for ds in dataset["test"]
]

In [15]:
from transformers import pipeline

In [66]:
trained_model = pipeline("text-classification", model=model_path, device_map="mps")

Device set to use mps


In [None]:
Error_counter = 0
for summary in dataset["test"]:
	result = trained_model(summary["text"])
	largest_score = result[0]["score"]
	largest_label = result[0]["label"]
	actual_label = lable_names[summary["label"]]
	if largest_label != actual_label:
		Error_counter += 1
print(f"Total messages: {len(dataset['test'])}")
print(f"Total Error: {Error_counter}")
print("-"*50)
print(f"Accuracy: {((1-Error_counter / len(dataset['test']))*100)}%")

Total messages: 43
Total Error: 0
--------------------------------------------------
Accuracy: 100.0%
