# Fine-Tuning Mistral

In [1]:
!pip install -q -U transformers bitsandbytes peft datasets accelerate trl

# Loading Tokenizer

In [2]:
from transformers import AutoTokenizer

base_model = "mistralai/Mistral-7B-v0.1"

tokenizer = AutoTokenizer.from_pretrained(
    base_model, 
    padding_side = "right",
    add_eos_token = True,
)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.add_bos_token, tokenizer.add_eos_token

tokenizer_config.json:   0%|          | 0.00/967 [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.80M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/72.0 [00:00<?, ?B/s]

(True, True)

In [3]:
tokenizer

LlamaTokenizerFast(name_or_path='mistralai/Mistral-7B-v0.1', vocab_size=32000, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'pad_token': '</s>'}, clean_up_tokenization_spaces=False),  added_tokens_decoder={
	0: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	1: AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	2: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}

# Loading the Model

In [4]:
import torch
from transformers import BitsAndBytesConfig

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=False,
    bnb_4bit_compute_dtype=torch.bfloat16
)

In [5]:
from transformers import AutoModelForCausalLM

model = AutoModelForCausalLM.from_pretrained(
    base_model,
    load_in_4bit=True,
    quantization_config=bnb_config,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/25.1k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.94G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/4.54G [00:00<?, ?B/s]



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

In [6]:
import bitsandbytes 

def find_all_linear_names(model):
    cls = bitsandbytes.nn.Linear4bit #if args.bits == 4 else (bnb.nn.Linear8bitLt if args.bits == 8 else torch.nn.Linear)
    lora_module_names = set()
    for name, module in model.named_modules():
        if isinstance(module, cls):
            names = name.split('.')
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])


    # lm_head is often excluded.
    if 'lm_head' in lora_module_names:  # needed for 16-bit
        lora_module_names.remove('lm_head')
    return list(lora_module_names)


modules = find_all_linear_names(model)

In [7]:
modules

['v_proj', 'o_proj', 'up_proj', 'gate_proj', 'down_proj', 'q_proj', 'k_proj']

# Loading the Dataset

In [8]:
from datasets import load_dataset

dataset_name = "databricks/databricks-dolly-15k"

train_dataset = load_dataset(dataset_name, split="train[0:800]")
eval_dataset = load_dataset(dataset_name, split="train[800:1000]")

Downloading readme:   0%|          | 0.00/8.20k [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/13.1M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

# Understanding the Model

In [9]:
train_dataset

Dataset({
    features: ['instruction', 'context', 'response', 'category'],
    num_rows: 800
})

In [10]:
train_dataset.to_pandas()

Unnamed: 0,instruction,context,response,category
0,When did Virgin Australia start operating?,"Virgin Australia, the trading name of Virgin A...",Virgin Australia commenced services on 31 Augu...,closed_qa
1,Which is a species of fish? Tope or Rope,,Tope,classification
2,Why can camels survive for long without water?,,Camels use the fat in their humps to keep them...,open_qa
3,"Alice's parents have three daughters: Amy, Jes...",,The name of the third daughter is Alice,open_qa
4,When was Tomoaki Komorida born?,Komorida was born in Kumamoto Prefecture on Ju...,"Tomoaki Komorida was born on July 10,1981.",closed_qa
...,...,...,...,...
795,Who is the founder of the Communist Party?,,Lenin,open_qa
796,What is gardening?,Gardening is the practice of growing and culti...,Gardening is laying out and caring for a plot ...,information_extraction
797,What are your thoughts of Michael Jackson as a...,,Michael Jackson is acclaimed as the greatest p...,creative_writing
798,What is the largest pollutant?,,Carbon dioxide (CO2) - a greenhouse gas emitte...,general_qa


In [11]:
train_dataset.to_pandas().dtypes

instruction    object
context        object
response       object
category       object
dtype: object

In [12]:
train_dataset.to_pandas().value_counts("category")

category
open_qa                   202
general_qa                132
classification            111
brainstorming              95
closed_qa                  90
information_extraction     68
summarization              63
creative_writing           39
Name: count, dtype: int64

# Generating the Prompt Format

In [13]:
def generate_prompt(sample):
    full_prompt =f"""<s>[INST]{sample['instruction']}
    {f"Here is some context: {sample['context']}" if len(sample["context"]) > 0 else None}
    [/INST] {sample['response']}</s>"""
    return {"text": full_prompt}

In [14]:
train_dataset[0]

{'instruction': 'When did Virgin Australia start operating?',
 'context': "Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline. It is the largest airline by fleet size to use the Virgin brand. It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route. It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001. The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.",
 'response': 'Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.',
 'category': 'closed_qa'}

In [15]:
print(generate_prompt(train_dataset[0]))

{'text': "<s>[INST]When did Virgin Australia start operating?\n    Here is some context: Virgin Australia, the trading name of Virgin Australia Airlines Pty Ltd, is an Australian-based airline. It is the largest airline by fleet size to use the Virgin brand. It commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route. It suddenly found itself as a major airline in Australia's domestic market after the collapse of Ansett Australia in September 2001. The airline has since grown to directly serve 32 cities in Australia, from hubs in Brisbane, Melbourne and Sydney.\n    [/INST] Virgin Australia commenced services on 31 August 2000 as Virgin Blue, with two aircraft on a single route.</s>"}


In [16]:
generated_train_dataset = train_dataset.map(
    generate_prompt, remove_columns=list(train_dataset.features))


generated_val_dataset = eval_dataset.map(
    generate_prompt, remove_columns=list(train_dataset.features))

Map:   0%|          | 0/800 [00:00<?, ? examples/s]

Map:   0%|          | 0/200 [00:00<?, ? examples/s]

In [17]:
generated_train_dataset

Dataset({
    features: ['text'],
    num_rows: 800
})

In [18]:
generated_train_dataset[5]["text"]

"<s>[INST]If I have more pieces at the time of stalemate, have I won?\n    Here is some context: Stalemate is a situation in chess where the player whose turn it is to move is not in check and has no legal move. Stalemate results in a draw. During the endgame, stalemate is a resource that can enable the player with the inferior position to draw the game rather than lose. In more complex positions, stalemate is much rarer, usually taking the form of a swindle that succeeds only if the superior side is inattentive.[citation needed] Stalemate is also a common theme in endgame studies and other chess problems.\n\nThe outcome of a stalemate was standardized as a draw in the 19th century. Before this standardization, its treatment varied widely, including being deemed a win for the stalemating player, a half-win for that player, or a loss for that player; not being permitted; and resulting in the stalemated player missing a turn. Stalemate rules vary in other games of the chess family.\n    

In [19]:
tokenizer(generated_train_dataset[5]["text"])

{'input_ids': [1, 1, 733, 16289, 28793, 3381, 315, 506, 680, 7769, 438, 272, 727, 302, 341, 282, 366, 380, 28725, 506, 315, 1747, 28804, 13, 2287, 4003, 349, 741, 2758, 28747, 662, 282, 366, 380, 349, 264, 4620, 297, 997, 819, 970, 272, 4385, 4636, 1527, 378, 349, 298, 2318, 349, 459, 297, 1877, 304, 659, 708, 5648, 2318, 28723, 662, 282, 366, 380, 2903, 297, 264, 3924, 28723, 6213, 272, 948, 8835, 28725, 341, 282, 366, 380, 349, 264, 3715, 369, 541, 8234, 272, 4385, 395, 272, 24797, 2840, 298, 3924, 272, 2039, 3210, 821, 6788, 28723, 560, 680, 4630, 9161, 28725, 341, 282, 366, 380, 349, 1188, 408, 283, 263, 28725, 4312, 3344, 272, 1221, 302, 264, 1719, 507, 291, 369, 9481, 28713, 865, 513, 272, 11352, 2081, 349, 297, 1061, 308, 495, 20011, 28717, 5174, 3236, 28793, 662, 282, 366, 380, 349, 835, 264, 3298, 7335, 297, 948, 8835, 7193, 304, 799, 997, 819, 4418, 28723, 13, 13, 1014, 14120, 302, 264, 341, 282, 366, 380, 403, 4787, 1332, 390, 264, 3924, 297, 272, 28705, 28740, 28774, 362, 5

# LoRA Configuration

In [20]:
from peft import prepare_model_for_kbit_training

model.gradient_checkpointing_enable()

model = prepare_model_for_kbit_training(model)

In [21]:
def print_trainable_parameters(model):
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )

In [22]:
from peft import LoraConfig, get_peft_model
    
lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
        "lm_head",
    ],
    bias="none",
    lora_dropout=0.05, 
    task_type="CAUSAL_LM",
)

In [23]:
from peft import get_peft_model

model = get_peft_model(model, lora_config)

print_trainable_parameters(model)

trainable params: 21260288 || all params: 3773331456 || trainable%: 0.5634354746703705


In [22]:
print(model)

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): MistralForCausalLM(
      (model): MistralModel(
        (embed_tokens): Embedding(32000, 4096)
        (layers): ModuleList(
          (0-31): 32 x MistralDecoderLayer(
            (self_attn): MistralAttention(
              (q_proj): lora.Linear4bit(
                (base_layer): Linear4bit(in_features=4096, out_features=4096, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (k_proj): lora.Linear4bit(
                (base_layer): Linea

# Model Training

In [24]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
secret_value_0 = user_secrets.get_secret("HUGGINGFACE_TOKEN")

!huggingface-cli login --token $secret_value_0

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Token will not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid (permission: read).
Your token has been saved to /root/.cache/huggingface/token
Login successful


In [50]:
from transformers import TrainingArguments

training_arguments = TrainingArguments(
    output_dir="kajol/zephyr_01",
    num_train_epochs=1,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=1,
    optim="paged_adamw_32bit",
    save_strategy="steps",
    save_steps=25,
    logging_steps=25,
    learning_rate=2e-4,
    weight_decay=0.001,
    max_steps=50,
    evaluation_strategy="steps", 
    eval_steps=25,       
    do_eval=True,               
    report_to="none",
    #push_to_hub =True,
    #hub_model_id ="kajol/zephyr_01",
    #hub_strategy ="end",
    #hub_token ="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"  #always give the write token not the read token both are different
)

In [26]:
from trl import SFTTrainer

# Setting sft parameters
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_arguments,
    train_dataset=generated_train_dataset,
    eval_dataset=generated_val_dataset,
    peft_config=lora_config,
    dataset_text_field="text",   
)



Map:   0%|          | 0/800 [00:00<?, ? examples/s]

Map:   0%|          | 0/200 [00:00<?, ? examples/s]

In [27]:
model.config.use_cache = False
trainer.train()
#trainer.push_to_hub()

You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step,Training Loss,Validation Loss
25,1.5111,1.499313
50,1.5216,1.466869




HfHubHTTPError: 403 Client Error: Forbidden for url: https://huggingface.co/api/models/kajol/zephyr_01/preupload/main (Request ID: Root=1-6581930b-71d5e4e92926296a08210583;048a6ea0-90d8-41fd-a0e4-75b2f574ae2f)

Forbidden: you must use a write token to upload to a repository.

In [28]:
model.save_pretrained("mistral")
tokenizer.save_pretrained("mistral")

('mistral/tokenizer_config.json',
 'mistral/special_tokens_map.json',
 'mistral/tokenizer.model',
 'mistral/added_tokens.json',
 'mistral/tokenizer.json')

In [40]:
#push the model to hub
#!huggingface-cli login
model.push_to_hub(repo_id ="kajol/model_01", token="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")
tokenizer.push_to_hub(repo_id="kajol/model_01", token="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")

adapter_model.safetensors:   0%|          | 0.00/609M [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/kajol/model_01/commit/f3355f8b7df73c09b64a86b8b830442e93bab958', commit_message='Upload tokenizer', commit_description='', oid='f3355f8b7df73c09b64a86b8b830442e93bab958', pr_url=None, pr_revision=None, pr_num=None)

In [45]:
valid_df=eval_dataset.to_pandas()

In [46]:
valid_df

Unnamed: 0,instruction,context,response,category
0,Give this paragraph about the Alley Cats a cap...,"The group originated in 1987, when a concert c...",Jay Leno and Arsenio Hall,closed_qa
1,Provide a bulleted list of ways to spend less ...,,The following are ways to spend less money\n1....,brainstorming
2,How should I deal with my kid's allergy?,,Find out what your kid is allergic to by condu...,general_qa
3,What does Wittgenstein view as a problem with ...,Wittgenstein clarifies the problem of communic...,"As example of ""Ostensive defining"" is pointing...",information_extraction
4,"Identify the bird from the list: Queensbury, K...",,Kingfisher,classification
...,...,...,...,...
195,Write a haiku about sitting on the shore and w...,,I sit on the shore \nobserving the waves crash...,creative_writing
196,"From the passage provided, extract the names a...","At the dawn as a social science, economics was...",Jean-Baptiste Say - Treatise on Political Econ...,information_extraction
197,Classify each of the following as either a cou...,,"Countries: Sweden, France, India, Portugal\nCi...",classification
198,What is the fastest train in the world?,,"Shanghai Maglev in Shanghai, China",open_qa


In [43]:
from transformers import pipeline
pipe = pipeline(
    "text-generation", 
    model=model, 
    tokenizer = tokenizer, 
    torch_dtype=torch.bfloat16, 
    device_map="auto"
)

The model 'PeftModelForCausalLM' is not supported for text-generation. Supported models are ['BartForCausalLM', 'BertLMHeadModel', 'BertGenerationDecoder', 'BigBirdForCausalLM', 'BigBirdPegasusForCausalLM', 'BioGptForCausalLM', 'BlenderbotForCausalLM', 'BlenderbotSmallForCausalLM', 'BloomForCausalLM', 'CamembertForCausalLM', 'LlamaForCausalLM', 'CodeGenForCausalLM', 'CpmAntForCausalLM', 'CTRLLMHeadModel', 'Data2VecTextForCausalLM', 'ElectraForCausalLM', 'ErnieForCausalLM', 'FalconForCausalLM', 'FuyuForCausalLM', 'GitForCausalLM', 'GPT2LMHeadModel', 'GPT2LMHeadModel', 'GPTBigCodeForCausalLM', 'GPTNeoForCausalLM', 'GPTNeoXForCausalLM', 'GPTNeoXJapaneseForCausalLM', 'GPTJForCausalLM', 'LlamaForCausalLM', 'MarianForCausalLM', 'MBartForCausalLM', 'MegaForCausalLM', 'MegatronBertForCausalLM', 'MistralForCausalLM', 'MixtralForCausalLM', 'MptForCausalLM', 'MusicgenForCausalLM', 'MvpForCausalLM', 'OpenLlamaForCausalLM', 'OpenAIGPTLMHeadModel', 'OPTForCausalLM', 'PegasusForCausalLM', 'PersimmonF

In [47]:
instruction=valid_df["instruction"][3]
cntext=valid_df["context"][3]
response=valid_df["response"][3]

In [58]:
model.eval()
sequences = pipe(
    f"""<s>[INST]{instruction}{f"Here is some context: {cntext}"}[/INST] ['response']:</s>""" ,
    do_sample=True,
    max_new_tokens=100, 
    temperature=0.7, 
    top_k=50, 
    top_p=0.95,
    num_return_sequences=1,
)
print(sequences[0]['generated_text'])

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.


<s>[INST]What does Wittgenstein view as a problem with learning a language using "ostensive defining"?Here is some context: Wittgenstein clarifies the problem of communicating using a human language when he discusses learning a language by "ostensive defining." For example, if one wanted to teach someone that a pencil was called a "pencil" and pointed to a pencil and said, "pencil," how does the listener know that what one is trying to convey is that the thing in front of me (e.g., the entire pencil) is called a "pencil"? Isn't it possible that the listener would associate "pencil" with "wood"? Maybe the listener would associate the word "pencil" with "round" instead (as pencils are, usually, in fact, round!). Wittgenstein writes regarding several possible "interpretations" which may arise after such a lesson. The student may interpret your pointing at a pencil and saying "pencil" to mean the following: (1) This is a pencil; (2) This is round; (3) This is wood; (4) This is one; (5) Thi

In [59]:
response

'As example of "Ostensive defining" is pointing to an object and saying it\'s name. Wittgenstein sees "ostensive defining" as a problem of communication because simply linking a word to an object carries with it ambiguity. For instance, if someone pointed to a pencil and said, "pencil", the listener may not know exactly what the word refers to. It may refer to the pencil itself, the shape of the pencil, the material of the pencil, the quantity of the pencil, or any other quality of the pencil.'

In [41]:
instruct_model = AutoModelForCausalLM.from_pretrained("kajol/model_01",
                                                       torch_dtype=torch.bfloat16)

adapter_config.json:   0%|          | 0.00/669 [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

adapter_model.safetensors:   0%|          | 0.00/609M [00:00<?, ?B/s]

In [42]:
instruct_tokenizer = AutoTokenizer.from_pretrained(
    "kajol/model_01", 
    padding_side = "right",
    add_eos_token = True,
)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.add_bos_token, tokenizer.add_eos_token

tokenizer_config.json:   0%|          | 0.00/995 [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.80M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/437 [00:00<?, ?B/s]

(True, True)

In [48]:
new_pipe = pipeline(
    "text-generation", 
    model=instruct_model, 
    tokenizer = instruct_tokenizer, 
    torch_dtype=torch.bfloat16, 
    device_map="cuda"
)

In [49]:
#instruct_model.eval()
next_sequences = new_pipe(
    f"""<s>[INST]{instruction}
    {f"Here is some context: {cntext}"}[/INST]</s>""" ,
    
    do_sample=True,
    max_new_tokens=256, 
    temperature=0.7, 
    top_k=50, 
    top_p=0.95,
    num_return_sequences=1,
)
print(next_sequences[0]['generated_text'])

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.


<s>[INST]What does Wittgenstein view as a problem with learning a language using "ostensive defining"?
    Here is some context: Wittgenstein clarifies the problem of communicating using a human language when he discusses learning a language by "ostensive defining." For example, if one wanted to teach someone that a pencil was called a "pencil" and pointed to a pencil and said, "pencil," how does the listener know that what one is trying to convey is that the thing in front of me (e.g., the entire pencil) is called a "pencil"? Isn't it possible that the listener would associate "pencil" with "wood"? Maybe the listener would associate the word "pencil" with "round" instead (as pencils are, usually, in fact, round!). Wittgenstein writes regarding several possible "interpretations" which may arise after such a lesson. The student may interpret your pointing at a pencil and saying "pencil" to mean the following: (1) This is a pencil; (2) This is round; (3) This is wood; (4) This is one; (5