In [1]:
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
import torch.nn as nn
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import IA3Config, get_peft_model, TaskType
from tqdm import tqdm
from matplotlib import pyplot as plt

In [2]:
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments
from peft import LoraConfig, prepare_model_for_kbit_training
from trl import SFTTrainer

In [3]:
from huggingface_hub import login
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [4]:

quantization_config = transformers.BitsAndBytesConfig(load_in_4bit=True)

tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1", truncation=True, padding=True, padding_side="right")
model = AutoModelForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-Instruct-v0.1", quantization_config=quantization_config)
tokenizer.add_special_tokens({'pad_token': '[PAD]'})

model = prepare_model_for_kbit_training(model)

config = LoraConfig(r = 4,
                    lora_alpha=4,
                    target_modules = ["gate", "w1", "w2", "w3"],
                    lora_dropout=0.1
                    )



lora_model = get_peft_model(model, config)

lora_model.print_trainable_parameters()


Loading checkpoint shards:   0%|          | 0/19 [00:00<?, ?it/s]

trainable params: 57,148,416 || all params: 46,759,941,120 || trainable%: 0.1222


In [5]:
print(lora_model)

PeftModel(
  (base_model): LoraModel(
    (model): MixtralForCausalLM(
      (model): MixtralModel(
        (embed_tokens): Embedding(32000, 4096)
        (layers): ModuleList(
          (0-31): 32 x MixtralDecoderLayer(
            (self_attn): MixtralAttention(
              (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
              (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
              (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
              (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
            )
            (block_sparse_moe): MixtralSparseMoeBlock(
              (gate): lora.Linear4bit(
                (base_layer): Linear4bit(in_features=4096, out_features=8, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.1, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_feature

In [6]:
dataset = load_dataset("kaaiiii/Mixtral_tokenized_data", split="train")
print(dataset)

README.md:   0%|          | 0.00/331 [00:00<?, ?B/s]

Resolving data files:   0%|          | 0/30 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/30 [00:00<?, ?it/s]

Downloading data:   0%|          | 0/30 [00:00<?, ?files/s]

data/train-00000-of-00030.parquet:   0%|          | 0.00/201M [00:00<?, ?B/s]

data/train-00001-of-00030.parquet:   0%|          | 0.00/208M [00:00<?, ?B/s]

data/train-00002-of-00030.parquet:   0%|          | 0.00/184M [00:00<?, ?B/s]

data/train-00003-of-00030.parquet:   0%|          | 0.00/119M [00:00<?, ?B/s]

data/train-00004-of-00030.parquet:   0%|          | 0.00/89.4M [00:00<?, ?B/s]

data/train-00005-of-00030.parquet:   0%|          | 0.00/151M [00:00<?, ?B/s]

data/train-00006-of-00030.parquet:   0%|          | 0.00/201M [00:00<?, ?B/s]

data/train-00007-of-00030.parquet:   0%|          | 0.00/202M [00:00<?, ?B/s]

data/train-00008-of-00030.parquet:   0%|          | 0.00/220M [00:00<?, ?B/s]

data/train-00009-of-00030.parquet:   0%|          | 0.00/218M [00:00<?, ?B/s]

data/train-00010-of-00030.parquet:   0%|          | 0.00/218M [00:00<?, ?B/s]

data/train-00011-of-00030.parquet:   0%|          | 0.00/180M [00:00<?, ?B/s]

data/train-00012-of-00030.parquet:   0%|          | 0.00/237M [00:00<?, ?B/s]

data/train-00013-of-00030.parquet:   0%|          | 0.00/266M [00:00<?, ?B/s]

data/train-00014-of-00030.parquet:   0%|          | 0.00/261M [00:00<?, ?B/s]

data/train-00015-of-00030.parquet:   0%|          | 0.00/307M [00:00<?, ?B/s]

data/train-00016-of-00030.parquet:   0%|          | 0.00/295M [00:00<?, ?B/s]

data/train-00017-of-00030.parquet:   0%|          | 0.00/391M [00:00<?, ?B/s]

data/train-00018-of-00030.parquet:   0%|          | 0.00/378M [00:00<?, ?B/s]

data/train-00019-of-00030.parquet:   0%|          | 0.00/247M [00:00<?, ?B/s]

data/train-00020-of-00030.parquet:   0%|          | 0.00/38.7M [00:00<?, ?B/s]

data/train-00021-of-00030.parquet:   0%|          | 0.00/41.2M [00:00<?, ?B/s]

data/train-00022-of-00030.parquet:   0%|          | 0.00/53.2M [00:00<?, ?B/s]

data/train-00023-of-00030.parquet:   0%|          | 0.00/53.9M [00:00<?, ?B/s]

data/train-00024-of-00030.parquet:   0%|          | 0.00/54.2M [00:00<?, ?B/s]

data/train-00025-of-00030.parquet:   0%|          | 0.00/58.1M [00:00<?, ?B/s]

data/train-00026-of-00030.parquet:   0%|          | 0.00/66.0M [00:00<?, ?B/s]

data/train-00027-of-00030.parquet:   0%|          | 0.00/56.6M [00:00<?, ?B/s]

data/train-00028-of-00030.parquet:   0%|          | 0.00/48.0M [00:00<?, ?B/s]

data/train-00029-of-00030.parquet:   0%|          | 0.00/69.3M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/7667416 [00:00<?, ? examples/s]

Loading dataset shards:   0%|          | 0/27 [00:00<?, ?it/s]

Dataset({
    features: ['input_ids', 'attention_mask'],
    num_rows: 7667416
})


In [7]:
import os
os.environ["WANDB_PROJECT"] = "Mixtral_fine_tune"
os.environ["WANDB_RUN_NAME"] = "experiment-2"

In [None]:
from transformers import Trainer, TrainingArguments, DataCollatorForLanguageModeling

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,  # causal LM
)

import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

lora_model.enable_input_require_grads()
lora_model.gradient_checkpointing_enable()

training_args = TrainingArguments(
    output_dir="./outputs",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=16, 
    max_steps=500,               # try a small number first

    dataloader_num_workers = 4,
    dataloader_pin_memory=True,
    dataloader_prefetch_factor=4,
    dataloader_persistent_workers=True,
    
    learning_rate=1e-5,
    weight_decay=0.03,
    lr_scheduler_type="cosine",
    warmup_ratio=0.05,
    bf16=True,

    logging_steps=5,
    logging_strategy="steps",
    logging_first_step=True,
    disable_tqdm=False,    # ensure progress bar is visible

    save_strategy="steps",
    save_steps=100,
    save_total_limit=10,
    save_safetensors=True,
    push_to_hub=True,
    hub_model_id="kaaiiii/Mixtral_LoRA_v2",

    report_to = "wandb",    
)

trainer = Trainer(
    model=lora_model,          # your LoRA-wrapped model
    args=training_args,
    train_dataset=tokenized_dataset,
    data_collator=data_collator,
)

trainer.train()


In [8]:
from trl import SFTConfig
import wandb
import logging
logging.getLogger("trl.trainer.sft_trainer").setLevel(logging.ERROR)

trainer = SFTTrainer(
    model = lora_model,
    train_dataset = dataset,
    processing_class = tokenizer,
    args = SFTConfig(
        per_device_train_batch_size = 1,
        gradient_accumulation_steps = 16,
        packing = True,
        group_by_length = True,
        warmup_ratio = 0.05,
        bf16 = True,
        max_steps=500,
        learning_rate = 5e-5,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "cosine",
        seed = 42,
        eval_strategy="no",
        do_eval=False,
        output_dir = "./outputs",
        remove_unused_columns=False,

        save_strategy="steps",
        save_steps=100,
        save_total_limit=10,
        save_safetensors=True,
        push_to_hub=True,
        hub_model_id="kaaiiii/Mixtral_LoRA_v2",

        report_to = "wandb",              
    )
)

torch.cuda.empty_cache()

trainer.train()

Packing train dataset:   0%|          | 0/7667416 [00:00<?, ? examples/s]

KeyboardInterrupt: 