In [2]:
!pip install transformers
!pip install datasets
!pip install accelerate
!pip install peft
!pip install numpy
!pip install transformers[torch]

Collecting datasets
  Downloading datasets-2.20.0-py3-none-any.whl (547 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m547.8/547.8 kB[0m [31m7.3 MB/s[0m eta [36m0:00:00[0m
Collecting pyarrow>=15.0.0 (from datasets)
  Downloading pyarrow-16.1.0-cp310-cp310-manylinux_2_28_x86_64.whl (40.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.8/40.8 MB[0m [31m16.5 MB/s[0m eta [36m0:00:00[0m
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m15.4 MB/s[0m eta [36m0:00:00[0m
Collecting requests>=2.32.2 (from datasets)
  Downloading requests-2.32.3-py3-none-any.whl (64 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m64.9/64.9 kB[0m [31m8.8 MB/s[0m eta [36m0:00:00[0m
Collecting xxhash (from datasets)
  Downloading xxhash-3.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1

In [19]:
!pip show accelerate

Name: accelerate
Version: 0.31.0
Summary: Accelerate
Home-page: https://github.com/huggingface/accelerate
Author: The HuggingFace team
Author-email: zach.mueller@huggingface.co
License: Apache
Location: /usr/local/lib/python3.10/dist-packages
Requires: huggingface-hub, numpy, packaging, psutil, pyyaml, safetensors, torch
Required-by: peft


In [3]:
!pip install accelerate -U



In [15]:
from transformers import DistilBertTokenizerFast, TrainingArguments, Trainer, DistilBertForSequenceClassification
from datasets import load_dataset, DatasetDict
import numpy as np
from peft import IA3Config, get_peft_model, TaskType

def truncate(example):
    return {
        'text': " ".join(example['text'].split()[:50]),
        'label': example['label']
    }

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return {"accuracy": np.mean(predictions == labels)}

imdb_dataset = load_dataset("stanfordnlp/imdb")
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-cased')
model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-cased', num_labels=2)

ia3_config = IA3Config(
    task_type=TaskType.SEQ_CLS,
    target_modules=["q_lin", "k_lin", "v_lin", "out_lin", "lin1", "lin2"],
    feedforward_modules=["lin1", "lin2"]
)

model = get_peft_model(model, ia3_config)

small_imdb_dataset = DatasetDict(
    train=imdb_dataset['train'].shuffle(seed=1111).select(range(128)).map(truncate),
    val=imdb_dataset['train'].shuffle(seed=1111).select(range(128, 160)).map(truncate),
)

small_tokenized_dataset = small_imdb_dataset.map(
    lambda example: tokenizer(example['text'], truncation=True),
    batched=True,
    batch_size=16
)

arguments = TrainingArguments(
    output_dir="sample_hf_trainer",
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=2,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    load_best_model_at_end=True,
    seed=224,
    optim="adamw_hf"
)

trainer = Trainer(
    model=model,
    args=arguments,
    train_dataset=small_tokenized_dataset['train'],
    eval_dataset=small_tokenized_dataset['val'],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

trainer.train()
results = trainer.predict(small_tokenized_dataset['val'])
results

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,Accuracy
1,No log,0.697723,0.4375
2,No log,0.696609,0.4375




PredictionOutput(predictions=array([[-0.09719563, -0.15349784],
       [-0.14710617, -0.18148091],
       [-0.11700655, -0.15873723],
       [-0.13189137, -0.19151086],
       [-0.15204352, -0.18720746],
       [-0.13153008, -0.17497538],
       [-0.07965463, -0.15512015],
       [-0.12697507, -0.19158605],
       [-0.12144058, -0.15663816],
       [-0.13493493, -0.18435718],
       [-0.14328067, -0.19742095],
       [-0.08444747, -0.14368232],
       [-0.10571112, -0.15034856],
       [-0.11617639, -0.16605538],
       [-0.09816663, -0.17631444],
       [-0.09142522, -0.18747406],
       [-0.08075761, -0.19491714],
       [-0.12301181, -0.20776519],
       [-0.12184913, -0.2025443 ],
       [-0.12927362, -0.20524333],
       [-0.10466611, -0.1838705 ],
       [-0.10823513, -0.17322236],
       [-0.0987185 , -0.15366799],
       [-0.10128648, -0.1669116 ],
       [-0.11656439, -0.16470113],
       [-0.13633755, -0.17981276],
       [-0.1075277 , -0.1975142 ],
       [-0.11031017, -0.16