In [1]:
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
import torch.nn as nn
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import IA3Config, get_peft_model, TaskType
from tqdm import tqdm
from matplotlib import pyplot as plt

In [2]:
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments
from peft import LoraConfig, prepare_model_for_kbit_training
from trl import SFTTrainer

In [3]:
from huggingface_hub import login
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [4]:

quantization_config = transformers.BitsAndBytesConfig(load_in_4bit=True)

tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1", truncation=True, padding=True, padding_side="right")
model = AutoModelForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1", quantization_config=quantization_config)
tokenizer.add_special_tokens({'pad_token': '[PAD]'})

model = prepare_model_for_kbit_training(model)

config = LoraConfig(r = 4,
                    lora_alpha=4,
                    target_modules = ["gate", "gate_proj", "up_proj", "down_proj"],
                    lora_dropout=0.1
                    )



lora_model = get_peft_model(model, config)

lora_model.print_trainable_parameters()


Loading checkpoint shards:   0%|          | 0/19 [00:00<?, ?it/s]

trainable params: 525,312 || all params: 46,703,318,016 || trainable%: 0.0011


In [5]:
dataset = load_dataset("Na0s/sft-ready-Text-Generation-Augmented-Data", split="train")
print(tokenizer.chat_template)
print(dataset)

README.md:   0%|          | 0.00/344 [00:00<?, ?B/s]

Resolving data files:   0%|          | 0/22 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/22 [00:00<?, ?it/s]

Downloading data:   0%|          | 0/22 [00:00<?, ?files/s]

data/train-00000-of-00022.parquet:   0%|          | 0.00/332M [00:00<?, ?B/s]

data/train-00001-of-00022.parquet:   0%|          | 0.00/323M [00:00<?, ?B/s]

data/train-00002-of-00022.parquet:   0%|          | 0.00/183M [00:00<?, ?B/s]

data/train-00003-of-00022.parquet:   0%|          | 0.00/137M [00:00<?, ?B/s]

data/train-00004-of-00022.parquet:   0%|          | 0.00/329M [00:00<?, ?B/s]

data/train-00005-of-00022.parquet:   0%|          | 0.00/331M [00:00<?, ?B/s]

data/train-00006-of-00022.parquet:   0%|          | 0.00/255M [00:00<?, ?B/s]

data/train-00007-of-00022.parquet:   0%|          | 0.00/256M [00:00<?, ?B/s]

data/train-00008-of-00022.parquet:   0%|          | 0.00/251M [00:00<?, ?B/s]

data/train-00009-of-00022.parquet:   0%|          | 0.00/316M [00:00<?, ?B/s]

data/train-00010-of-00022.parquet:   0%|          | 0.00/347M [00:00<?, ?B/s]

data/train-00011-of-00022.parquet:   0%|          | 0.00/383M [00:00<?, ?B/s]

data/train-00012-of-00022.parquet:   0%|          | 0.00/476M [00:00<?, ?B/s]

data/train-00013-of-00022.parquet:   0%|          | 0.00/594M [00:00<?, ?B/s]

data/train-00014-of-00022.parquet:   0%|          | 0.00/252M [00:00<?, ?B/s]

data/train-00015-of-00022.parquet:   0%|          | 0.00/77.0M [00:00<?, ?B/s]

data/train-00016-of-00022.parquet:   0%|          | 0.00/92.4M [00:00<?, ?B/s]

data/train-00017-of-00022.parquet:   0%|          | 0.00/95.4M [00:00<?, ?B/s]

data/train-00018-of-00022.parquet:   0%|          | 0.00/99.7M [00:00<?, ?B/s]

data/train-00019-of-00022.parquet:   0%|          | 0.00/119M [00:00<?, ?B/s]

data/train-00020-of-00022.parquet:   0%|          | 0.00/98.5M [00:00<?, ?B/s]

data/train-00021-of-00022.parquet:   0%|          | 0.00/109M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/7667416 [00:00<?, ? examples/s]

Loading dataset shards:   0%|          | 0/22 [00:00<?, ?it/s]

{%- if messages[0]['role'] == 'system' %}
    {%- set system_message = messages[0]['content'] %}
    {%- set loop_messages = messages[1:] %}
{%- else %}
    {%- set loop_messages = messages %}
{%- endif %}

{{- bos_token }}
{%- for message in loop_messages %}
    {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}
        {{- raise_exception('After the optional system message, conversation roles must alternate user/assistant/user/assistant/...') }}
    {%- endif %}
    {%- if message['role'] == 'user' %}
        {%- if loop.first and system_message is defined %}
            {{- ' [INST] ' + system_message + '\n\n' + message['content'] + ' [/INST]' }}
        {%- else %}
            {{- ' [INST] ' + message['content'] + ' [/INST]' }}
        {%- endif %}
    {%- elif message['role'] == 'assistant' %}
        {{- ' ' + message['content'] + eos_token}}
    {%- else %}
        {{- raise_exception('Only user and assistant roles are supported, with the exception of an initial opt

In [6]:

# def format_chat(example):
#     messages = [
#         {"role": "user", "content": example["prompt"]},
#         {"role": "assistant", "content": example["completion"]},
#     ]
#     example["text"] = tokenizer.apply_chat_template(
#         messages,
#         tokenize=False,
#         add_generation_prompt=False
#     )
#     return example

# tokenized_dataset = dataset.map(format_chat, num_proc=40, remove_columns=["prompt", "completion"])

In [7]:
# print(tokenized_dataset['text'][0])

In [8]:
# dataset.push_to_hub("kaaiiii/Mixtral_tokenized_data")

In [9]:
import os
os.environ["WANDB_PROJECT"] = "Mixtral_fine_tune"
os.environ["WANDB_RUN_NAME"] = "experiment-1"

In [None]:
from trl import SFTConfig
import wandb
import logging
logging.getLogger("trl.trainer.sft_trainer").setLevel(logging.ERROR)

trainer = SFTTrainer(
    model = lora_model,
    train_dataset = dataset,
    processing_class = tokenizer,
    args = SFTConfig(
        per_device_train_batch_size = 1,
        gradient_accumulation_steps = 16,
        packing = True,
        group_by_length = True,
        warmup_ratio = 0.05,
        bf16 = True,
        max_steps=500,
        learning_rate = 1e-4,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "cosine",
        seed = 42,
        eval_strategy="no",
        do_eval=False,
        output_dir = "./outputs",
        remove_unused_columns=False,

        save_strategy="steps",
        save_steps=100,
        save_total_limit=10,
        save_safetensors=True,
        push_to_hub=True,
        hub_model_id="kaaiiii/Mixtral_LoRA_v1",

        report_to = "wandb",              
    )
)

torch.cuda.empty_cache()

trainer.train()

Adding EOS to train dataset:   0%|          | 0/7667416 [00:00<?, ? examples/s]