In [4]:
import pandas as pd
import torch
from transformers import BertTokenizer, BertForSequenceClassification
from datasets import DatasetDict, Dataset, load_from_disk
from sklearn.model_selection import train_test_split
from transformers import BertForSequenceClassification, Trainer, TrainingArguments
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForSeq2SeqLM, GenerationConfig

from sklearn.metrics import accuracy_score

from sklearn.preprocessing import LabelEncoder
from transformers import BitsAndBytesConfig
from accelerate import Accelerator
from peft import prepare_model_for_kbit_training, LoraConfig, TaskType, get_peft_model
from transformers import TrainingArguments, AutoConfig, \
    AutoModelForSequenceClassification, AutoTokenizer, BitsAndBytesConfig, DataCollatorWithPadding
from peft import (
    PeftConfig,
    PeftModel,
)

In [5]:
path_to_retrieve = "../tokenized_dataset"


In [6]:
dataset_dict = load_from_disk(path_to_retrieve)

In [7]:
def print_trainable_parameters(model):
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param:.2f}"
    )

In [8]:
model_id = "bert-large-uncased"
num_labels=5

In [16]:
config = LoraConfig(
    r=32,
    lora_alpha=64,
    lora_dropout=0.1,
    bias="none",
    task_type=TaskType.SEQ_CLS
)

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

In [17]:

model = AutoModelForSequenceClassification.from_pretrained(model_id, quantization_config=bnb_config,
                                                           config=AutoConfig.from_pretrained(model_id,
                                                                                             trust_remote_code=True,
                                                                                             num_labels=num_labels),
                                                           trust_remote_code=True)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-large-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [18]:
model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)

print_trainable_parameters(model)

trainable params: 0 || all params: 183627781 || trainable%: 0.00


In [11]:
def compute_metrics(p):
    logits, labels = p.predictions, p.label_ids
    preds = logits.argmax(axis=-1)
    acc = accuracy_score(labels, preds)
    return {"accuracy": acc}

In [12]:
training_args = TrainingArguments(
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    output_dir='/results',
    num_train_epochs=1,
    evaluation_strategy="steps",
    save_steps=10,
    save_total_limit=2,
    remove_unused_columns=False,
    run_name='run_name',
    logging_dir='/logs',
    logging_steps=10,
    load_best_model_at_end=True,
)

In [13]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset_dict["train"],
    eval_dataset=dataset_dict["test"],
    compute_metrics=compute_metrics, 
)

In [14]:
trainer.train()


[34m[1mwandb[0m: Currently logged in as: [33mlukemonington3[0m. Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112344188909952, max=1.0…

Step,Training Loss,Validation Loss,Accuracy
10,0.868,,0.2


NotImplementedError: You are calling `save_pretrained` on a 4-bit converted model. This is currently not supported

In [16]:
pip freeze

absl-py==2.0.0
accelerate==0.23.0
aiohttp==3.8.5
aiosignal==1.3.1
anyio==4.0.0
appdirs==1.4.4
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
arrow==1.2.3
asttokens==2.4.0
async-lru==2.0.4
async-timeout==4.0.3
attrs==23.1.0
Babel==2.12.1
backcall==0.2.0
beautifulsoup4==4.12.2
bitsandbytes==0.41.1
bleach==6.0.0
blinker==1.4
cachetools==5.3.1
certifi==2023.7.22
cffi==1.15.1
charset-normalizer==3.2.0
click==8.1.7
cmake==3.27.5
comm==0.1.4
contourpy==1.1.1
cryptography==3.4.8
cycler==0.11.0
datasets==2.14.5
dbus-python==1.2.18
debugpy==1.8.0
decorator==5.1.1
deepspeed==0.10.3
defusedxml==0.7.1
dill==0.3.7
distro==1.7.0
distro-info==1.1+ubuntu0.1
docker-pycreds==0.4.0
evaluate==0.4.0
exceptiongroup==1.1.3
executing==1.2.0
fastjsonschema==2.18.0
filelock==3.12.4
fonttools==4.42.1
fqdn==1.5.1
frozenlist==1.4.0
fsspec==2023.6.0
gitdb==4.0.10
GitPython==3.1.37
glob2==0.7
google-auth==2.23.1
google-auth-oauthlib==1.0.0
grpcio==1.58.0
hjson==3.1.0
httplib2==0.20.2
huggingface-hub==0.17.3
idna==3