In [1]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification

tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-cased", clean_up_tokenization_spaces=True)
model = AutoModelForSequenceClassification.from_pretrained("google-bert/bert-base-cased", num_labels=5)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at google-bert/bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [2]:
from peft import LoraConfig, TaskType

lora_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.1,
    r=16,
    bias="none",
    task_type=TaskType.SEQ_CLS,
)

model.add_adapter(lora_config)
#model.enable_adapters()

In [3]:
from datasets import load_dataset

dataset = load_dataset("yelp_review_full")

def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)

tokenized_datasets = dataset.map(tokenize_function, batched=True)

small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(1000))
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(1000))

In [4]:
from transformers import TrainingArguments, Trainer
import numpy as np
import evaluate

metric = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    print("logits", logits[0:10])
    predictions = np.argmax(logits, axis=-1)
    print("predictions:", predictions[0:10])
    print("references:", labels[0:10])
    return metric.compute(predictions=predictions, references=labels)

training_args = TrainingArguments(output_dir="test_trainer/lora", eval_strategy="epoch")

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=small_train_dataset,
    eval_dataset=small_eval_dataset,
    compute_metrics=compute_metrics,
)




In [5]:
%%time
trainer.train()

  attn_output = torch.nn.functional.scaled_dot_product_attention(


Epoch,Training Loss,Validation Loss,Accuracy
1,No log,1.618,0.223
2,No log,1.609405,0.233
3,No log,1.607791,0.239


logits [[ 0.39363524  0.29100445 -0.04504333  0.43847725  0.01565461]
 [ 0.42400023  0.21148613 -0.0370684   0.50766057 -0.03084594]
 [ 0.38680276  0.46344084 -0.13962984  0.39054924 -0.13189423]
 [ 0.45406678  0.33147046 -0.12487298  0.44643724 -0.03454848]
 [ 0.2730037   0.531505   -0.2034405   0.20129198 -0.13594182]
 [ 0.44231802  0.2841633  -0.05341129  0.4303223   0.02512021]
 [ 0.38275272  0.2538352  -0.06050872  0.5045726  -0.00234294]
 [ 0.3631025   0.5363487  -0.17832117  0.29067066 -0.11127138]
 [ 0.41329414  0.39264446 -0.08661754  0.3705333  -0.10966914]
 [ 0.40771407  0.28945887 -0.08046947  0.42775768 -0.02572537]]
predictions: [3 3 1 0 1 0 3 1 0 3]
references: [2 4 1 4 3 4 2 3 2 3]
logits [[ 0.34906644  0.20688966 -0.01459301  0.47373378  0.02061399]
 [ 0.32351455  0.14338163 -0.01218935  0.5487954  -0.03273686]
 [ 0.38811567  0.4080621  -0.10435721  0.4109976  -0.10348926]
 [ 0.41623548  0.18446991 -0.05991444  0.53515345 -0.00660224]
 [ 0.2795762   0.4445521  -0.14152



logits [[ 3.31981033e-01  1.94623396e-01 -9.03977454e-03  4.74694252e-01
   2.50058025e-02]
 [ 3.05108160e-01  1.33043319e-01 -6.16460666e-03  5.51549494e-01
  -2.83561386e-02]
 [ 3.87138367e-01  3.94619644e-01 -9.75806713e-02  4.14026886e-01
  -9.34744179e-02]
 [ 3.98077846e-01  1.73085019e-01 -5.51330447e-02  5.38627028e-01
  -3.21989134e-03]
 [ 2.78395981e-01  4.23725784e-01 -1.30469427e-01  2.73007244e-01
  -7.26308450e-02]
 [ 3.53058130e-01  1.54832795e-01 -1.26879774e-02  4.89814460e-01
   2.09762380e-02]
 [ 3.24496835e-01  2.10919946e-01 -3.90627533e-02  5.09377539e-01
   9.03703645e-03]
 [ 3.61483723e-01  4.00231421e-01 -9.24250782e-02  3.35646391e-01
  -2.46070325e-02]
 [ 4.51947093e-01  2.74258733e-01 -1.57921948e-02  4.12747145e-01
  -1.28965825e-02]
 [ 3.64738941e-01  1.78727508e-01 -3.66871990e-02  4.75953192e-01
   2.20775604e-04]]
predictions: [3 3 3 3 1 3 3 1 0 3]
references: [2 4 1 4 3 4 2 3 2 3]
CPU times: total: 1min 9s
Wall time: 1min 32s


TrainOutput(global_step=375, training_loss=1.6336121419270833, metrics={'train_runtime': 91.7912, 'train_samples_per_second': 32.683, 'train_steps_per_second': 4.085, 'total_flos': 811097699328000.0, 'train_loss': 1.6336121419270833, 'epoch': 3.0})