In [1]:
import torch
import transformers
from datasets import load_dataset, concatenate_datasets
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model, PeftModel

tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
model = AutoModelForCausalLM.from_pretrained(
    "mistralai/Mistral-7B-Instruct-v0.2",
    load_in_8bit=True,
    device_map="auto",
    torch_dtype=torch.float16
)

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [2]:
model = prepare_model_for_kbit_training(model)
tokenizer.pad_token = "!"

In [3]:
CUTOFF_LEN = 768
LORA_R = 8
LORA_ALPHA = 2 * LORA_R
LORA_DROPOUT = 0.1

config = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj"
                    , "down_proj", "lm_head"], #these are the  names for the layers
    lora_dropout=LORA_DROPOUT,
    bias="none",
    task_type="CAUSAL_LM"
)
model = get_peft_model(model, config)

In [5]:
dataset = load_dataset("lmsys/toxic-chat","toxicchat0124") #Interesting dataset with toxic chat.
print("dataset", dataset)
train_data = dataset["train"]

dataset DatasetDict({
    train: Dataset({
        features: ['conv_id', 'user_input', 'model_output', 'human_annotation', 'toxicity', 'jailbreaking', 'openai_moderation'],
        num_rows: 5082
    })
    test: Dataset({
        features: ['conv_id', 'user_input', 'model_output', 'human_annotation', 'toxicity', 'jailbreaking', 'openai_moderation'],
        num_rows: 5083
    })
})


In [6]:
def generate_prompt(user_query):  #The prompt format is taken from the official Mistral huggingface page
  if user_query["model_output"] is not None and user_query["user_input"] is not None:
      p =  "<s> [INST]" + user_query["user_input"] + "[/INST]" +  user_query["model_output"] + "</s>"
      return p
  else:
    p = "<s> [INST]" + "Hello" + "[/INST]" +  "Hello" + "</s>"
    return p

In [7]:
def tokenize(prompt):
    return tokenizer(
        prompt + tokenizer.eos_token,
        truncation=True,
        max_length=CUTOFF_LEN ,
        padding="max_length"
    )


In [8]:
train_data = train_data.shuffle().map(lambda x: tokenize(generate_prompt(x)), remove_columns=['conv_id', 'user_input', 'model_output', 'human_annotation', 'toxicity', 'jailbreaking', 'openai_moderation'])

Map:   0%|          | 0/5082 [00:00<?, ? examples/s]

In [14]:
trainer = Trainer(
    model=model,
    train_dataset=train_data,
    args=TrainingArguments(
        per_device_train_batch_size=1,
        gradient_accumulation_steps=4,
        num_train_epochs=1,    # 3 or 6 is good
        learning_rate=1e-4,
        logging_steps=50,
        optim="adamw_torch",
        save_strategy="epoch",
        output_dir="mistral-lora-instruct-toxic"
    ),
    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False)
)
model.config.use_cache = False

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [15]:
trainer.train()



Step,Training Loss
50,0.7707


KeyboardInterrupt: 

In [None]:
save_location  = "/home/thsch026/masterarbeit/models/generated/lora/unspec/Mistral_instruct_toxic_chat"


model.save_pretrained(save_location, safetensors=True)
tokenizer.save(save_location)