In [1]:
from transformers import BertForSequenceClassification, BertTokenizer

model = BertForSequenceClassification.from_pretrained(
    'SZTAKI-HLT/hubert-base-cc',
    num_labels=2
)

tokenizer = BertTokenizer.from_pretrained('SZTAKI-HLT/hubert-base-cc')

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at SZTAKI-HLT/hubert-base-cc and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [2]:
from peft import LoraConfig, TaskType

lora_config = LoraConfig(
    task_type=TaskType.SEQ_CLS,
    r=16,
    lora_alpha=16,
    lora_dropout=0.1,
)


Welcome to bitsandbytes. For bug reports, please run

python -m bitsandbytes

 and submit this information together with your error trace to: https://github.com/TimDettmers/bitsandbytes/issues
bin /home/boa/.conda/envs/ai/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda117.so
CUDA SETUP: CUDA runtime path found: /home/boa/.conda/envs/ai/lib/libcudart.so.11.0
CUDA SETUP: Highest compute capability among GPUs detected: 8.6
CUDA SETUP: Detected CUDA version 117
CUDA SETUP: Loading binary /home/boa/.conda/envs/ai/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda117.so...


Either way, this might cause trouble in the future:
If you get `CUDA error: invalid device function` errors, the above might be the cause and the solution is to make sure only one ['libcudart.so', 'libcudart.so.11.0', 'libcudart.so.12.0'] in the paths that we search based on your env.
  warn(msg)


In [3]:
from peft import get_peft_model

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

trainable params: 592,900 || all params: 111,211,012 || trainable%: 0.5331306579603825


In [4]:
from datasets import load_dataset

dataset = load_dataset("boapps/kmdb_classification")

Found cached dataset parquet (/home/boa/.cache/huggingface/datasets/boapps___parquet/boapps--kmdb_classification-4003d65da9c3e34a/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


  0%|          | 0/3 [00:00<?, ?it/s]

In [5]:
dataset = dataset.map(lambda row: {'td': row['title']+'\n'+row['description']})

Loading cached processed dataset at /home/boa/.cache/huggingface/datasets/boapps___parquet/boapps--kmdb_classification-4003d65da9c3e34a/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-c6514a33fc6a449f.arrow
Loading cached processed dataset at /home/boa/.cache/huggingface/datasets/boapps___parquet/boapps--kmdb_classification-4003d65da9c3e34a/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-5cb1afa7c832e235.arrow
Loading cached processed dataset at /home/boa/.cache/huggingface/datasets/boapps___parquet/boapps--kmdb_classification-4003d65da9c3e34a/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-f803453116857360.arrow


In [7]:
def tokenize_function(examples):
    return tokenizer(examples["td"], padding="max_length", truncation=True, max_length=512)

tokenized_datasets = dataset.map(tokenize_function, batched=True)

Loading cached processed dataset at /home/boa/.cache/huggingface/datasets/boapps___parquet/boapps--kmdb_classification-4003d65da9c3e34a/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-d941a3221c8f97e0.arrow
Loading cached processed dataset at /home/boa/.cache/huggingface/datasets/boapps___parquet/boapps--kmdb_classification-4003d65da9c3e34a/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-944862f9e6c970c0.arrow
Loading cached processed dataset at /home/boa/.cache/huggingface/datasets/boapps___parquet/boapps--kmdb_classification-4003d65da9c3e34a/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-c161745cc338a64d.arrow


In [8]:
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, _, _ = precision_recall_fscore_support(labels, preds, average='binary')
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'precision': precision,
        'recall': recall
    }

In [9]:
from transformers import Trainer, TrainingArguments

batch_size=16

training_args = TrainingArguments(
    output_dir="hubert-classification-v11",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    gradient_accumulation_steps=4,
    weight_decay=0.01,
    load_best_model_at_end=True,
    logging_steps=10,
    eval_steps=100,    
    save_steps=100,
    save_total_limit=40,
    save_strategy='steps',
    evaluation_strategy='steps',
    learning_rate=0.0005,
    warmup_steps=400,
    num_train_epochs=2,
)

trainer = Trainer(
    model=model,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['validation'],
    args=training_args,
    compute_metrics=compute_metrics,
)

trainer.train()

[34m[1mwandb[0m: Currently logged in as: [33mboapps[0m. Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss,Validation Loss,Accuracy,Precision,Recall
100,0.3916,0.319337,0.882505,0.871882,0.905505
200,0.218,0.186451,0.927345,0.900412,0.966147
300,0.1846,0.20357,0.923241,0.883962,0.979982
400,0.1945,0.178305,0.932361,0.897844,0.980571
500,0.1343,0.155824,0.943,0.931961,0.95967
600,0.1595,0.157132,0.944368,0.931398,0.963203
700,0.1685,0.152402,0.942392,0.926753,0.964675
800,0.16,0.141432,0.947408,0.944623,0.954077
900,0.1313,0.164281,0.941784,0.912425,0.981454
1000,0.1605,0.139825,0.94756,0.946199,0.952605


TrainOutput(global_step=1428, training_loss=0.19073975966925047, metrics={'train_runtime': 5914.8267, 'train_samples_per_second': 15.447, 'train_steps_per_second': 0.241, 'total_flos': 2.4205386012893184e+16, 'train_loss': 0.19073975966925047, 'epoch': 2.0})

In [11]:
trainer.evaluate(eval_dataset=tokenized_datasets['test'])

{'eval_loss': 0.12962540984153748,
 'eval_accuracy': 0.9522884882108184,
 'eval_precision': 0.9497076023391813,
 'eval_recall': 0.9497076023391813,
 'eval_runtime': 70.5257,
 'eval_samples_per_second': 51.116,
 'eval_steps_per_second': 3.205,
 'epoch': 2.0}

In [12]:
test_pos = tokenized_datasets['test'].filter(lambda row: row['label'] == 1)
test_neg = tokenized_datasets['test'].filter(lambda row: row['label'] == 0)

Loading cached processed dataset at /home/boa/.cache/huggingface/datasets/boapps___parquet/boapps--kmdb_classification-4003d65da9c3e34a/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-ae0fb8807f5534a2.arrow
Loading cached processed dataset at /home/boa/.cache/huggingface/datasets/boapps___parquet/boapps--kmdb_classification-4003d65da9c3e34a/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-3404e518452c8c4a.arrow


In [13]:
len(test_pos), len(test_neg)

(1710, 1895)

In [14]:
test_pos = test_pos.select(range(len(test_neg)/10))
test_neg = test_neg.select(range(len(test_neg)))

In [15]:
from datasets import concatenate_datasets

test_set = concatenate_datasets([test_pos, test_neg]).shuffle(seed=42)

Loading cached shuffled indices for dataset at /home/boa/.cache/huggingface/datasets/boapps___parquet/boapps--kmdb_classification-4003d65da9c3e34a/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-a072ea5a47f9fe89.arrow


In [16]:
trainer.evaluate(eval_dataset=test_set)

{'eval_loss': 0.12052392959594727,
 'eval_accuracy': 0.9551312649164678,
 'eval_precision': 0.6906474820143885,
 'eval_recall': 0.96,
 'eval_runtime': 41.0559,
 'eval_samples_per_second': 51.028,
 'eval_steps_per_second': 3.191,
 'epoch': 2.0}

In [25]:
merged_model = model.merge_and_unload()

In [47]:
merged_model.push_to_hub("boapps/kmdb_classification_model")

pytorch_model.bin:   0%|          | 0.00/443M [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/boapps/kmdb_classification_model/commit/fecad43f99be1f8baac7177741f791d0eccfc757', commit_message='Upload BertForSequenceClassification', commit_description='', oid='fecad43f99be1f8baac7177741f791d0eccfc757', pr_url=None, pr_revision=None, pr_num=None)