In [2]:
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
from transformers import AutoTokenizer
from datasets import load_dataset
import evaluate

In [None]:
dataset = load_dataset("csv", data_files="john_final_classified_vanity_plates_with_meaning.csv", split="train")
dataset = dataset.train_test_split(test_size = 0.2)

Generating train split: 0 examples [00:00, ? examples/s]

In [13]:
train_dataset = dataset["train"]
test_dataset = dataset["test"]
train_dataset[0]

{'index': 1134,
 'plate': 'PSYCHPS',
 'meaning': 'AM A PSYCHOLOGIST IN PALM SPRINGS',
 'classification': 4,
 'string_classification': 'Profanity, Drugs, Gang References'}

In [16]:
tokenizer = AutoTokenizer.from_pretrained("distilbert/distilbert-base-uncased")

def tokenize(plate):
    return tokenizer(plate["plate"], padding="max_length", max_length=10, truncation=True)

train_dataset = train_dataset.map(tokenize, batched=True)
test_dataset = test_dataset.map(tokenize, batched=True)

Map:   0%|          | 0/1176 [00:00<?, ? examples/s]

Map:   0%|          | 0/294 [00:00<?, ? examples/s]

In [None]:
train_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'classification'])
test_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'classification'])

{'index': 1134,
 'plate': 'PSYCHPS',
 'meaning': 'AM A PSYCHOLOGIST IN PALM SPRINGS',
 'classification': 4,
 'string_classification': 'Profanity, Drugs, Gang References',
 'input_ids': [101, 8827, 17994, 4523, 102, 0, 0, 0, 0, 0],
 'attention_mask': [1, 1, 1, 1, 1, 0, 0, 0, 0, 0]}

In [None]:

model = AutoModelForSequenceClassification.from_pretrained(
    "distilbert/distilbert-base-uncased", num_labels=9
)

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert/distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [24]:
metric = evaluate.load("accuracy")

def compute_metrics(p):
    predictions, labels = p
    predictions = predictions.argmax(axis=-1)
    return metric.compute(predictions=predictions, references=labels)

In [34]:
! pip install -U transformers


Collecting transformers
  Downloading transformers-4.46.3-py3-none-any.whl.metadata (44 kB)
Downloading transformers-4.46.3-py3-none-any.whl (10.0 MB)
   ---------------------------------------- 0.0/10.0 MB ? eta -:--:--
   ---------------------------------------- 10.0/10.0 MB 89.7 MB/s eta 0:00:00
Installing collected packages: transformers
  Attempting uninstall: transformers
    Found existing installation: transformers 4.46.0
    Uninstalling transformers-4.46.0:
      Successfully uninstalled transformers-4.46.0
Successfully installed transformers-4.46.3


In [35]:
training_args = TrainingArguments(
    output_dir='./results',          # output directory where the model checkpoints and logs are saved
    evaluation_strategy="epoch",     # evaluate at the end of each epoch
    save_strategy="epoch",
    learning_rate=2e-5,              # learning rate
    per_device_train_batch_size=16,  # batch size for training
    per_device_eval_batch_size=32,   # batch size for evaluation
    num_train_epochs=3,              # number of training epochs
    weight_decay=0.01,               # strength of weight decay
    logging_dir='./logs',            # directory for storing logs
    load_best_model_at_end=True,     # load the best model when finished training
    metric_for_best_model='accuracy', # metric used to evaluate the best model
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    compute_metrics=compute_metrics
)

trainer.train()

ImportError: Using the `Trainer` with `PyTorch` requires `accelerate>=0.26.0`: Please run `pip install transformers[torch]` or `pip install 'accelerate>={ACCELERATE_MIN_VERSION}'`