<a href="https://colab.research.google.com/github/jlb-jlb/ML_Notebooks/blob/main/MPT_7B_Instruct_Fine_Tuning_Tutorial.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Install

In [None]:
!pip install --upgrade transformers einops datasets accelerate

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.29.2-py3-none-any.whl (7.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.1/7.1 MB[0m [31m102.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting einops
  Downloading einops-0.6.1-py3-none-any.whl (42 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m5.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting datasets
  Downloading datasets-2.12.0-py3-none-any.whl (474 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m474.6/474.6 kB[0m [31m47.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting accelerate
  Downloading accelerate-0.19.0-py3-none-any.whl (219 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m219.1/219.1 kB[0m [31m27.3 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.14.1 (from transformers)
  Downloading huggingface_hub-0

# Import

In [None]:
import torch
import transformers
from transformers import TrainingArguments, Trainer
from transformers import AutoModelForSequenceClassification, AdamW
from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset
from typing import Any, Dict, Tuple
import warnings
from threading import Event, Thread
import textwrap
from datasets import concatenate_datasets


import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# Mosaic Instruct Model MPT-7B-Instruct

In [None]:
INSTRUCTION_KEY = "### Instruction:"
RESPONSE_KEY = "### Response:"
END_KEY = "### End"
INTRO_BLURB = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
PROMPT_FOR_GENERATION_FORMAT = """{intro}
{instruction_key}
{instruction}
{response_key}
""".format(
    intro=INTRO_BLURB,
    instruction_key=INSTRUCTION_KEY,
    instruction="{instruction}",
    response_key=RESPONSE_KEY,
)


class InstructionTextGenerationPipeline:
    def __init__(
        self,
        model_name,
        torch_dtype=torch.bfloat16,
        trust_remote_code=True,
        use_auth_token=None,
    ) -> None:
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch_dtype,
            trust_remote_code=trust_remote_code,
            use_auth_token=use_auth_token,
        )

        tokenizer = AutoTokenizer.from_pretrained(
            model_name,
            trust_remote_code=trust_remote_code,
            use_auth_token=use_auth_token,
        )
        if tokenizer.pad_token_id is None:
            warnings.warn(
                "pad_token_id is not set for the tokenizer. Using eos_token_id as pad_token_id."
            )
            tokenizer.pad_token = tokenizer.eos_token
        tokenizer.padding_side = "right" # "left"
        self.tokenizer = tokenizer

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.eval()
        self.model.to(device=device, dtype=torch_dtype)

        self.generate_kwargs = {
            "temperature": 0.5,
            "top_p": 0.92,
            "top_k": 0,
            "max_new_tokens": 512,
            "use_cache": True,
            "do_sample": True,
            "eos_token_id": self.tokenizer.eos_token_id,
            "pad_token_id": self.tokenizer.pad_token_id,
            "repetition_penalty": 1.1,  # 1.0 means no penalty, > 1.0 means penalty, 1.2 from CTRL paper
        }

    def format_instruction(self, instruction):
        return PROMPT_FOR_GENERATION_FORMAT.format(instruction=instruction)

    def __call__(
        self, instruction: str, **generate_kwargs: Dict[str, Any]
    ) -> Tuple[str, str, float]:
        s = PROMPT_FOR_GENERATION_FORMAT.format(instruction=instruction)
        input_ids = self.tokenizer(s, return_tensors="pt").input_ids
        input_ids = input_ids.to(self.model.device)
        gkw = {**self.generate_kwargs, **generate_kwargs}
        with torch.no_grad():
            output_ids = self.model.generate(input_ids, **gkw)
        # Slice the output_ids tensor to get only new tokens
        new_tokens = output_ids[0, len(input_ids[0]) :]
        output_text = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
        return output_text


In [None]:
generate = InstructionTextGenerationPipeline(
    "mosaicml/mpt-7b-instruct",
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
)

stop_token_ids = generate.tokenizer.convert_tokens_to_ids(["<|endoftext|>"])

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.23k [00:00<?, ?B/s]

Downloading (…)configuration_mpt.py:   0%|          | 0.00/9.20k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/mosaicml/mpt-7b-instruct:
- configuration_mpt.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


Downloading (…)main/modeling_mpt.py:   0%|          | 0.00/18.4k [00:00<?, ?B/s]

Downloading (…)solve/main/blocks.py:   0%|          | 0.00/2.55k [00:00<?, ?B/s]

Downloading (…)resolve/main/norm.py:   0%|          | 0.00/2.56k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/mosaicml/mpt-7b-instruct:
- norm.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


Downloading (…)ve/main/attention.py:   0%|          | 0.00/17.7k [00:00<?, ?B/s]

Downloading (…)flash_attn_triton.py:   0%|          | 0.00/28.2k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/mosaicml/mpt-7b-instruct:
- flash_attn_triton.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
A new version of the following files was downloaded from https://huggingface.co/mosaicml/mpt-7b-instruct:
- attention.py
- flash_attn_triton.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
A new version of the following files was downloaded from https://huggingface.co/mosaicml/mpt-7b-instruct:
- blocks.py
- norm.py
- attention.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


Downloading (…)meta_init_context.py:   0%|          | 0.00/3.64k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/mosaicml/mpt-7b-instruct:
- meta_init_context.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


Downloading (…)in/param_init_fns.py:   0%|          | 0.00/12.6k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/mosaicml/mpt-7b-instruct:
- param_init_fns.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


Downloading (…)refixlm_converter.py:   0%|          | 0.00/27.2k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/mosaicml/mpt-7b-instruct:
- hf_prefixlm_converter.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


Downloading (…)n/adapt_tokenizer.py:   0%|          | 0.00/1.75k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/mosaicml/mpt-7b-instruct:
- adapt_tokenizer.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
A new version of the following files was downloaded from https://huggingface.co/mosaicml/mpt-7b-instruct:
- modeling_mpt.py
- blocks.py
- meta_init_context.py
- param_init_fns.py
- hf_prefixlm_converter.py
- adapt_tokenizer.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


Downloading (…)model.bin.index.json:   0%|          | 0.00/16.0k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading (…)l-00001-of-00002.bin:   0%|          | 0.00/9.94G [00:00<?, ?B/s]

Downloading (…)l-00002-of-00002.bin:   0%|          | 0.00/3.36G [00:00<?, ?B/s]

You are using config.init_device='cpu', but you can also use config.init_device="meta" with Composer + FSDP for fast initialization.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading (…)neration_config.json:   0%|          | 0.00/91.0 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/237 [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/99.0 [00:00<?, ?B/s]



In [None]:
  # Define a custom stopping criteria
class StopOnTokens(StoppingCriteria):
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        for stop_id in stop_token_ids:
            if input_ids[0][-1] == stop_id:
                return True
        return False

def process_stream(instruction, temperature, top_p, top_k, max_new_tokens):
    # Tokenize the input
    input_ids = generate.tokenizer(
        generate.format_instruction(instruction), return_tensors="pt"
    ).input_ids
    input_ids = input_ids.to(generate.model.device)

    # Initialize the streamer and stopping criteria
    streamer = TextIteratorStreamer(
        generate.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
    )
    stop = StopOnTokens()

    if temperature < 0.1:
        temperature = 0.0
        do_sample = False
    else:
        do_sample = True

    gkw = {
        **generate.generate_kwargs,
        **{
            "input_ids": input_ids,
            "max_new_tokens": max_new_tokens,
            "temperature": temperature,
            "do_sample": do_sample,
            "top_p": top_p,
            "top_k": top_k,
            "streamer": streamer,
            "stopping_criteria": StoppingCriteriaList([stop]),
        },
    }

    response = ''

    def generate_and_signal_complete():
        generate.model.generate(**gkw)

    t1 = Thread(target=generate_and_signal_complete)
    t1.start()

    for new_text in streamer:
        response += new_text

    return response

In [None]:
def generate_text(instruction, temperature = 0.3, top_p = 0.95, top_k = 0, max_new_tokens = 2000):
  response = process_stream(instruction, temperature, top_p, top_k, max_new_tokens)

  wrapped_text = textwrap.fill(response, width=100)
  print(wrapped_text)

In [None]:
generate_text('Hello')

Hi there


# Data Processing

In [None]:
dataset = load_dataset('ag_news')
# Extract label names
label_names = dataset['train'].features['label'].names
label_names

Downloading builder script:   0%|          | 0.00/4.06k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/2.65k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/7.95k [00:00<?, ?B/s]

Downloading and preparing dataset ag_news/default to /root/.cache/huggingface/datasets/ag_news/default/0.0.0/bc2bcb40336ace1a0374767fc29bb0296cdaf8a6da7298436239c54d79180548...


Downloading data:   0%|          | 0.00/11.0M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/751k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/120000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/7600 [00:00<?, ? examples/s]

Dataset ag_news downloaded and prepared to /root/.cache/huggingface/datasets/ag_news/default/0.0.0/bc2bcb40336ace1a0374767fc29bb0296cdaf8a6da7298436239c54d79180548. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]

['World', 'Sports', 'Business', 'Sci/Tech']

In [None]:
label_names_str = str(label_names).replace("'","")
prefix = f"Classify the following text into one of the following categories: {label_names_str}\nText: "

num_samples_per_class = 75  # Set this to determine the number of samples per class

# Extract subsets of the training dataset with equal number of samples per class
train_datasets = []
for label in range(len(label_names)):
    label_dataset = dataset['train'].filter(lambda x: x['label'] == label).select(range(num_samples_per_class))
    label_dataset = label_dataset.map(lambda x: {'text': generate.format_instruction(prefix + x['text']) + label_names[x['label']]})
    train_datasets.append(label_dataset)

train_dataset = concatenate_datasets(train_datasets)

# Do the same for the validation dataset, if needed
val_datasets = []
for label in range(len(label_names)):
    label_dataset = dataset['test'].filter(lambda x: x['label'] == label).select(range(num_samples_per_class))
    label_dataset = label_dataset.map(lambda x: {'text': generate.format_instruction(prefix + x['text']) + label_names[x['label']]})
    val_datasets.append(label_dataset)

val_dataset = concatenate_datasets(val_datasets)

Filter:   0%|          | 0/120000 [00:00<?, ? examples/s]



Map:   0%|          | 0/75 [00:00<?, ? examples/s]

Filter:   0%|          | 0/120000 [00:00<?, ? examples/s]

Map:   0%|          | 0/75 [00:00<?, ? examples/s]

Filter:   0%|          | 0/120000 [00:00<?, ? examples/s]

Map:   0%|          | 0/75 [00:00<?, ? examples/s]

Filter:   0%|          | 0/120000 [00:00<?, ? examples/s]

Map:   0%|          | 0/75 [00:00<?, ? examples/s]

Filter:   0%|          | 0/7600 [00:00<?, ? examples/s]

Map:   0%|          | 0/75 [00:00<?, ? examples/s]

Filter:   0%|          | 0/7600 [00:00<?, ? examples/s]

Map:   0%|          | 0/75 [00:00<?, ? examples/s]

Filter:   0%|          | 0/7600 [00:00<?, ? examples/s]

Map:   0%|          | 0/75 [00:00<?, ? examples/s]

Filter:   0%|          | 0/7600 [00:00<?, ? examples/s]

Map:   0%|          | 0/75 [00:00<?, ? examples/s]

In [None]:
# Tokenize the datasets
train_encodings = generate.tokenizer(train_dataset['text'], truncation=True, padding=True, max_length=256)
val_encodings = generate.tokenizer(val_dataset['text'], truncation=True, padding=True, max_length=256)

In [None]:
# Convert the encodings to PyTorch datasets
class TextDataset(Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item["labels"] = item["input_ids"].clone()
        return item

    def __len__(self):
        return len(self.encodings["input_ids"])

In [None]:
# Convert the encodings to PyTorch datasets
train_dataset = TextDataset(train_encodings)
val_dataset = TextDataset(val_encodings)

Example Before Fine Tuning

In [None]:
index = 0
example_text = prefix + dataset['test'][index]['text']
print('Example text:')
print(textwrap.fill(example_text, width=100))
print("Response:")
generate_text(example_text, temperature=0)
print("Correct Label:")
print(label_names[dataset['test'][index]['label']])

Example text:
Classify the following text into one of the following categories: [World, Sports, Business,
Sci/Tech] Text: Fears for T N pension after talks Unions representing workers at Turner   Newall say
they are 'disappointed' after talks with stricken parent firm Federal Mogul.
Response:
The above paragraph can be classified as follows - World (Fears), Sport(Unions),Business and Science
& Tech
Correct Label:
Business


# Freeze Layers

In [None]:
for name, param in generate.model.named_parameters():
    print(f"{name}   Modelsize: {param.numel()/1000**2:.1f}M parameters")
    if "31" not in name:
        param.requires_grad = False
    print(name, param.requires_grad)

In [None]:
# Define the training arguments
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=50,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=64,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=10,
)


In [None]:
# Create the Trainer and train
trainer = Trainer(
    model=generate.model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset
)

In [None]:
trainer.train()



Step,Training Loss
10,3.0172
20,3.0562
30,2.9797
40,2.8469
50,2.6078
60,2.3609
70,2.1016
80,1.9187
90,1.7055
100,1.4984


TrainOutput(global_step=950, training_loss=0.9318174342105263, metrics={'train_runtime': 239.8278, 'train_samples_per_second': 62.545, 'train_steps_per_second': 3.961, 'total_flos': 9.509450563584e+16, 'train_loss': 0.9318174342105263, 'epoch': 50.0})

In [None]:
generate.model.eval()

MPTForCausalLM(
  (transformer): MPTModel(
    (wte): Embedding(50432, 4096)
    (emb_drop): Dropout(p=0, inplace=False)
    (blocks): ModuleList(
      (0-31): 32 x MPTBlock(
        (norm_1): LPLayerNorm((4096,), eps=1e-05, elementwise_affine=True)
        (attn): MultiheadAttention(
          (Wqkv): Linear(in_features=4096, out_features=12288, bias=False)
          (out_proj): Linear(in_features=4096, out_features=4096, bias=False)
        )
        (norm_2): LPLayerNorm((4096,), eps=1e-05, elementwise_affine=True)
        (ffn): MPTMLP(
          (up_proj): Linear(in_features=4096, out_features=16384, bias=False)
          (act): GELU(approximate='none')
          (down_proj): Linear(in_features=16384, out_features=4096, bias=False)
        )
        (resid_attn_dropout): Dropout(p=0, inplace=False)
        (resid_ffn_dropout): Dropout(p=0, inplace=False)
      )
    )
    (norm_f): LPLayerNorm((4096,), eps=1e-05, elementwise_affine=True)
  )
)

Example After Fine Tuning

In [None]:
index = 0
example_text = prefix + dataset['test'][index]['text']
print('Example text:')
print(textwrap.fill(example_text, width=100))
print("Response:")
generate_text(example_text, temperature=0)
print("Correct Label:")
print(label_names[dataset['test'][index]['label']])

Example text:
Classify the following text into one of the following categories: [World, Sports, Business,
Sci/Tech] Text: Fears for T N pension after talks Unions representing workers at Turner   Newall say
they are 'disappointed' after talks with stricken parent firm Federal Mogul.
Response:
Business
Correct Label:
Business
