<a href="https://colab.research.google.com/github/david-meltzer/LLMs/blob/main/training/colab/SFT_QA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Dependencies

In [1]:
%cd drive/MyDrive/LLMs/Fine-tuning/SFT

/content/drive/MyDrive/LLMs/Fine-tuning/SFT


In [None]:
# installations
!pip install detoxify
!pip install optimum
!pip install peft==0.4.0
!pip install accelerate==0.21.0
#!pip install bitsandbytes==0.40.2
!pip install bitsandbytes==0.41.1
!pip install safetensors>=0.3.1
!pip install trl
!pip install wandb
!pip install tokenizers>=0.13.3
!pip install -U transformers

#!pip install transformers -qqq
#!pip install datasets --upgrade -qqq
#!pip install apache-beam -qqq
#!pip install wandb -qqq
#!pip install accelerate -qqq
#!pip install trl -qqq
#!pip install bitsandbytes -qqq
#!pip install peft -qqq

In [9]:
import gc

import os
import torch
from google.colab import runtime
import pandas as pd

import datasets
import accelerate
import transformers
from transformers import (AutoTokenizer,
                          AutoModelForCausalLM,
                          Trainer,
                          TrainingArguments,
                          DataCollatorForLanguageModeling,
                          BitsAndBytesConfig,
                          TrainerCallback)
import bitsandbytes as bnb
import wandb
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
from datetime import datetime
from huggingface_hub import login

from peft.tuners.lora import LoraLayer

  warn("The installed version of bitsandbytes was compiled without GPU support. "


/usr/local/lib/python3.10/dist-packages/bitsandbytes/libbitsandbytes_cpu.so: undefined symbol: cadam32bit_grad_fp32


In [2]:
from getpass import getpass
hf_token = getpass()
wandb_token = getpass()

··········
··········


In [None]:
from huggingface_hub import login
import wandb

login(hf_token)
wandb.login(key=wandb_token)

# Definitions

## Datasets

In [8]:
# setup collator


def formatting_prompts_func(example):
    output_texts = []
    for i in range(len(example['question'])):
        text = f"### Human: {example['question'][i]}\n ### Assistant: {example['answer'][i]}"
        output_texts.append(text)
    return output_texts

def sft_collator(tokenizer, response_template = " ### Assistant:"):

    return DataCollatorForCompletionOnlyLM(response_template=response_template, tokenizer=tokenizer)



#def formatting_prompts_func(example):
#    output_texts = []
#    for i in range(len(example['question'])):
#        text = f"### Question: {example['question'][i]}\n\n### Answer: {example['answer'][i]}"
#        output_texts.append(text)
#    return output_texts
#response_template = " ### Answer:"
#sft_collator = DataCollatorForCompletionOnlyLM(response_template=response_template, tokenizer=tokenizer)

#def sft_collator(tokenizer, response_template = "\n\n### Answer:"):
#    return DataCollatorForCompletionOnlyLM(response_template=response_template, tokenizer=tokenizer)

def combine_question_answer(ds,formatting_func):

    if 'QA' not in ds['train']:
        ds = ds.map(lambda x: {'QA':formatting_func(x)},
                    batched=True)
    return ds

def prepare_dataset(ds,
                    tokenizer,
                    formatting_func,
                    max_seq_length='auto'):

    if max_seq_length == 'auto':
        max_seq_length = tokenizer.model_max_length

    ds = combine_question_answer(ds,formatting_func)

    ds = ds.map(lambda x: {'tokens':tokenizer(x['QA'],
                                              return_length=False)})

    ds = ds.filter(lambda x: len(x['tokens']['input_ids'])<=max_seq_length)

    return ds

## Training

In [9]:
class PeftSavingCallback(TrainerCallback):
    def on_save(self, args, state, control, **kwargs):
        checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{state.global_step}")
        kwargs["model"].save_pretrained(checkpoint_path)

        if "pytorch_model.bin" in os.listdir(checkpoint_path):
            os.remove(os.path.join(checkpoint_path, "pytorch_model.bin"))

def prepare_model(checkpoint,
                 target_modules,
                 lora_rank=32,
                 lora_alpha=32,
                 lora_dropout=0.05,
                 bias="none",
                 task_type="CAUSAL_LM",
                 model_type = 'qlora',
                 extra_quant = True):

    if model_type not in {'lora','qlora','full'}:
        raise ValueError('Train type should be "lora", "qlora", or "full".')

    if model_type in {'lora','qlora'}:

        if model_type == 'qlora':

            nf4_config = BitsAndBytesConfig(load_in_4bit=True,
                                    bnb_4bit_quant_type="nf4",
                                    bnb_4bit_use_double_quant = extra_quant,
                                    bnb_4bit_compute_dtype=torch.bfloat16
                                    )

            model = AutoModelForCausalLM.from_pretrained(checkpoint,
                                                quantization_config=nf4_config,
                                                device_map='auto',
                                                torch_dtype = torch.bfloat16
                                                )
        else:
            model = AutoModelForCausalLM.from_pretrained(checkpoint,
                                                 load_in_8bit = extra_quant)

        model = prepare_model_for_kbit_training(model)

        lora_config = LoraConfig(
          r = lora_rank,
          lora_alpha = lora_alpha,
          target_modules = target_modules,
          lora_dropout = lora_dropout,
          bias = bias,
          task_type = task_type
          )

        model = get_peft_model(model, lora_config)
    else:
        model = AutoModelForCausalLM.from_pretrained(checkpoint)

    tokenizer = AutoTokenizer.from_pretrained(checkpoint)

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    model_name = checkpoint.split('/')[-1]

    if model_type in {'lora','qlora'}:
        model_name += f'_{model_type}'
        model_name += f'_r_{lora_rank}_a_{lora_alpha}'

    return model, tokenizer, model_name


def prepare_hyperparameters(model_name,
                            ds_name,
                            evaluation_strategy = 'steps',
                            save_steps = .1,
                            eval_steps = .1,
                            logging_steps = 100,
                            log_level = 'error',
                            report_to = 'wandb',
                            num_train_epochs = 3,
                            lr = 5e-5,
                            warmup_steps = 50,
                            weight_decay = .01,
                            optim = 'adamw_torch_fused',
                            prec = 'fp16',
                            train_batch_size = 8,
                            eval_batch_size = 16,
                            grad_accum = 4,
                            grad_checkpoint = True,
                            group_by_length = True,
                            dataloader_num_workers = 2,
                            save_total_limit = 3):

    training_args = TrainingArguments(
        logging_dir = f'./{model_name}_{ds_name}/logs',
        output_dir= f'./{model_name}_{ds_name}/models',
        evaluation_strategy = evaluation_strategy,
        save_strategy = evaluation_strategy,
        save_steps = save_steps,
        eval_steps = eval_steps,
        logging_steps = logging_steps,
        log_level = log_level,
        report_to = report_to,
        num_train_epochs = num_train_epochs,
        learning_rate = lr,
        warmup_steps = warmup_steps,
        weight_decay = weight_decay,
        optim = optim,
        fp16 = True if prec=='fp16' else False,
        bf16 = True if prec=='bf16' else False,
        per_device_train_batch_size = train_batch_size,
        per_device_eval_batch_size = eval_batch_size,
        gradient_accumulation_steps = grad_accum,
        gradient_checkpointing = grad_checkpoint,
        group_by_length = group_by_length,
        dataloader_num_workers = dataloader_num_workers,
        load_best_model_at_end=True,
        save_total_limit = save_total_limit,
        )

    if report_to == 'wandb':

        %env WANDB_PROJECT = 'SFT_training_dm'

        now = datetime.now()
        time_stamp = now.strftime("%m.%d.%y-%H.%M.%S")

        run_name = f'{model_name}__time_stamp'

        training_args.run_name = run_name



    return training_args

def SFT_train(model,
              tokenizer,
              training_args,
              dataset,
              ds_name,
              dataset_text_field='QA',
              formatting_func = formatting_prompts_func,
              max_seq_length = 'auto',
              packing = False,
              collator = sft_collator,
              preprocess_ds = False
              ):

    if max_seq_length == 'auto':
        max_seq_length = tokenizer.model_max_length

    if training_args.gradient_checkpointing:
        model.gradient_checkpointing_enable()

    if not packing:
        collator = collator(tokenizer)

    if ds_name == 'ELI5':
        dataset = dataset.filter(lambda x:x['source']=='ELI5')
    elif ds_name == 'simple_wiki':
        dataset = dataset.filter(lambda x:x['source']=='simple_wiki')

    if preprocess_ds:

        dataset = prepare_dataset(dataset,tokenizer,formatting_func)

    sft_trainer = SFTTrainer(
            model,
            training_args,
            max_seq_length=max_seq_length,
            train_dataset=dataset['train'],
            eval_dataset=dataset['validation'],
            dataset_text_field=dataset_text_field,
            data_collator=collator if not packing else None,
            callbacks=[PeftSavingCallback()],
            packing=packing
            )

    sft_trainer.train()

    wandb.finish()

def full_training(
    checkpoint,
    dataset,
    ds_name,
    target_modules=None,
    dataset_text_field="QA",
    max_seq_length = 'auto',
    lora_rank=32,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    model_type = 'qlora',
    extra_quant = True,
    evaluation_strategy = 'steps',
    save_steps = 0.1,
    eval_steps = 0.1,
    logging_steps = 100,
    log_level = 'error',
    report_to = 'wandb',
    num_train_epochs = 3,
    lr = 5e-5,
    warmup_steps = 50,
    weight_decay = .01,
    optim = 'adamw_torch_fused',
    prec = 'fp16',
    train_batch_size = 8,
    eval_batch_size = 16,
    grad_accum = 4,
    grad_checkpoint = True,
    group_by_length = True,
    dataloader_num_workers = 2,
    save_total_limit = 3,
    wandb_report = 'SFT_training_dm',
    packing = False,
    collator = sft_collator,
    preprocess_ds = False
    ):

    model, tokenizer, model_name = prepare_model(checkpoint,
                                                 target_modules,
                                                lora_rank=lora_rank,
                                                lora_alpha=lora_alpha,
                                                lora_dropout=lora_dropout,
                                                bias=bias,
                                                task_type=task_type,
                                                model_type = model_type,
                                                extra_quant = extra_quant)

    training_args = prepare_hyperparameters(model_name,
                            ds_name,
                            evaluation_strategy =evaluation_strategy,
                            save_steps = save_steps,
                            eval_steps = eval_steps,
                            logging_steps = logging_steps,
                            log_level = log_level,
                            report_to = report_to,
                            num_train_epochs = num_train_epochs,
                            lr = lr,
                            warmup_steps = warmup_steps,
                            weight_decay = weight_decay,
                            optim = optim,
                            prec = prec,
                            train_batch_size = train_batch_size,
                            eval_batch_size = eval_batch_size,
                            grad_accum = grad_accum,
                            grad_checkpoint = grad_checkpoint,
                            group_by_length = group_by_length,
                            dataloader_num_workers = dataloader_num_workers,
                            save_total_limit = save_total_limit,
                            wandb_report = wandb_report
                            )

    SFT_train(model,
              tokenizer,
              training_args,
              dataset = dataset,
              dataset_text_field=dataset_text_field,
              ds_name = ds_name,
              max_seq_length = max_seq_length,
              packing = packing,
              collator = collator,
              preprocess_ds = preprocess_ds
              )



# Making Datasets

## Download datasets

In [None]:
with wandb.init(project='ELI5_analysis',
                 entity='ft-llmmm',
                 job_type='training',
                 name='SFT_training') as run:

    artifact_wiki_QA = run.use_artifact('ft-llmmm/ELI5_analysis/simple_wiki_QA:v2', type='dataset')
    artifact_dir_wiki_QA = artifact_wiki_QA.download()

    artifact_ELI5 = run.use_artifact('ft-llmmm/ELI5_analysis/ELI5_cleaned:v4', type='dataset')
    artifact_dir_ELI5 = artifact_ELI5.download()

In [102]:
artifact_dir_wiki_QA='./artifacts/simple_wiki_QA:v2'
artifact_dir_ELI5='./artifacts/ELI5_cleaned:v4'

In [103]:
simplewiki_QA_ds = datasets.load_dataset("csv",
                                         data_files={"train": artifact_dir_wiki_QA + '/simple_wiki_QA_combined_train.csv',
                                                    "test": artifact_dir_wiki_QA +  '/simple_wiki_QA_combined_test.csv',
                                                    "val": artifact_dir_wiki_QA + '/simple_wiki_QA_combined_validation.csv'
                                        }
)
simplewiki_QA_ds = simplewiki_QA_ds.remove_columns(['id','system_message','prompt_template'])
simplewiki_QA_ds = simplewiki_QA_ds.rename_columns({'trunc_text':'answer'})

simplewiki_QA_ds['validation'] = simplewiki_QA_ds['val']
del simplewiki_QA_ds['val']

Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/3 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

Generating val split: 0 examples [00:00, ? examples/s]

In [104]:
for split in simplewiki_QA_ds:
    dset_source = datasets.Dataset.from_dict({'source':['simple_wiki']*len(simplewiki_QA_ds[split])})
    simplewiki_QA_ds[split] = datasets.concatenate_datasets([simplewiki_QA_ds[split],dset_source],axis=1)

In [71]:
ELI5_ds = datasets.load_from_disk(f'{artifact_dir_ELI5}/ds_SFT')
ELI5_ds = ELI5_ds.flatten()
ELI5_ds = ELI5_ds.remove_columns(['document','q_id','title','selftext','subreddit','url','title_urls','selftext_urls','answers_urls','pref_idxs','dupl_scores_idxs','qu_emb',
                                  'answers.a_id','answers.fkg','answers.fre','answers.score'])
ELI5_ds = ELI5_ds.map(lambda x: {'answers.text':list(x['answers.text'])})

ELI5_ds = ELI5_ds.with_format("pandas").map(lambda df:
                                                df.explode("answers.text"),
                                                batched=True)

ELI5_ds = ELI5_ds.with_format(None)

ELI5_ds = ELI5_ds.remove_columns(['__index_level_0__'])
ELI5_ds = ELI5_ds.rename_columns({'answers.text':'answer',
                                  'title_body':'question'})

Map:   0%|          | 0/41820 [00:00<?, ? examples/s]

Map:   0%|          | 0/787 [00:00<?, ? examples/s]

Map:   0%|          | 0/2219 [00:00<?, ? examples/s]

Map:   0%|          | 0/41820 [00:00<?, ? examples/s]

Map:   0%|          | 0/787 [00:00<?, ? examples/s]

Map:   0%|          | 0/2219 [00:00<?, ? examples/s]

In [72]:
for split in ELI5_ds:
    dset_source = datasets.Dataset.from_dict({'source':['ELI5']*len(ELI5_ds[split])})
    ELI5_ds[split] = datasets.concatenate_datasets([ELI5_ds[split],dset_source],axis=1)

## Detoxify ELI5

In [None]:
from detoxify import Detoxify

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
detoxify_model = Detoxify('unbiased')
detoxify_model.model.to(device)

ELI5_ds = ELI5_ds.map(lambda x: detoxify_model.predict(x['answer']),
                                                  batched=True,batch_size=64
                      )

In [84]:
ELI5_ds.save_to_disk('../data/ELI5_toxic_scores')

Saving the dataset (0/1 shards):   0%|          | 0/45538 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1124 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/2509 [00:00<?, ? examples/s]

In [6]:
ELI5_ds = datasets.load_from_disk('../data/ELI5_toxic_scores')

In [99]:
metrics=['toxicity', 'severe_toxicity',
         'obscene', 'identity_attack',
         'insult', 'threat', 'sexual_explicit']

ELI5_non_toxic = ELI5_ds.filter(lambda x: all(x[metric]<=.1
                                              for metric in metrics))

ELI5_non_toxic = ELI5_non_toxic.remove_columns([col for col in ELI5_non_toxic['train'].features if
                                                col not in ['answer','question']])

ELI5_non_toxic.save_to_disk('../data/ELI5_non_toxic')

Saving the dataset (0/1 shards):   0%|          | 0/42214 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/964 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/2301 [00:00<?, ? examples/s]

## Combine Datasets

In [117]:
SFT_QA_dataset = datasets.DatasetDict()

for split in ['train','validation','test']:

    SFT_QA_dataset[split] = datasets.concatenate_datasets([simplewiki_QA_ds[split],
                                                ELI5_non_toxic[split]])

In [118]:
SFT_QA_dataset = SFT_QA_dataset.shuffle(seed=12321)

In [119]:
SFT_QA_dataset = combine_question_answer(SFT_QA_dataset,
                                         formatting_prompts_func)

Map:   0%|          | 0/72214 [00:00<?, ? examples/s]

Map:   0%|          | 0/1964 [00:00<?, ? examples/s]

Map:   0%|          | 0/3301 [00:00<?, ? examples/s]

In [120]:
SFT_QA_dataset.save_to_disk('../data/SFT_QA_ds')

Saving the dataset (0/1 shards):   0%|          | 0/72214 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1964 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/3301 [00:00<?, ? examples/s]

In [121]:
now = datetime.now()
time_stamp = now.strftime("%m.%d.%y-%H.%M.%S")
with wandb.init(project='ELI5_analysis',
                entity='ft-llmmm',
                job_type='upload_data',
                name=f'SFT_QA_dataset_{time_stamp}') as run:

    clean_data_art = wandb.Artifact('combined_dataset', 'dataset')
    clean_data_art.add_dir('../data/SFT_QA_ds')
    run.log_artifact(clean_data_art)

[34m[1mwandb[0m: Adding directory to artifact (./../data/SFT_QA_ds)... Done. 9.5s


In [122]:
SFT_QA_dataset['train'][0]

{'question': '- why clothes that are hang to dry and not dry within a certain number of hours stink? does it mean that the clothes were not properly washed by the machine to start?\n',
 'answer': "It usually means that some bacteria or mold started growing. Even if you washed the clothes thoroughly, there might be some bacteria or mold spores in the water, or the machine, or even in the air. By airing the laundry, and drying it quickly enough, you stop the mold/bacteria from growing (and making smelly chemicals in the process). Clothes that haven't been washed properly still smell of body odor (not that moldy smell). I've heard a good soak in diluted vinegar can help with either smell, if the fabric can't handle a hot wash.",
 'source': None,
 'QA': "### Human: - why clothes that are hang to dry and not dry within a certain number of hours stink? does it mean that the clothes were not properly washed by the machine to start?\n\n ### Assistant: It usually means that some bacteria or mol

## Tokenizing

### GPT-2

In [123]:
tok = AutoTokenizer.from_pretrained('distilgpt2')
GPT2_QA_tokenized = prepare_dataset(SFT_QA_dataset,tok,formatting_prompts_func)
GPT2_QA_tokenized.save_to_disk('./data/GPT2_QA_tokenized')

now = datetime.now()
time_stamp = now.strftime("%m.%d.%y-%H.%M.%S")
with wandb.init(project='ELI5_analysis',
                entity='ft-llmmm',
                job_type='upload_data',
                name=f'GPT2_QA_tokenized_dataset_{time_stamp}') as run:

    clean_data_art = wandb.Artifact('GPT2_QA_tokenized', 'dataset')
    clean_data_art.add_dir('./data/GPT2_QA_tokenized')
    run.log_artifact(clean_data_art)

Map:   0%|          | 0/72214 [00:00<?, ? examples/s]

Map:   0%|          | 0/1964 [00:00<?, ? examples/s]

Map:   0%|          | 0/3301 [00:00<?, ? examples/s]

Map:   0%|          | 0/72214 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (1446 > 1024). Running this sequence through the model will result in indexing errors


Map:   0%|          | 0/1964 [00:00<?, ? examples/s]

Map:   0%|          | 0/3301 [00:00<?, ? examples/s]

Filter:   0%|          | 0/72214 [00:00<?, ? examples/s]

Filter:   0%|          | 0/1964 [00:00<?, ? examples/s]

Filter:   0%|          | 0/3301 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/71554 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1951 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/3262 [00:00<?, ? examples/s]

[34m[1mwandb[0m: Adding directory to artifact (./data/GPT2_QA_tokenized)... Done. 2.7s


### Llama

In [4]:
from transformers import AutoTokenizer
import datasets

In [5]:
SFT_QA_dataset = datasets.load_from_disk('../data/SFT_QA_ds')

In [6]:
model_id = "meta-llama/Llama-2-7b-hf"
model_name = model_id.split('/')[-1]
llama_tokenizer = AutoTokenizer.from_pretrained(model_id)
llama_tokenizer.pad_token = llama_tokenizer.eos_token

Downloading tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

In [7]:
SFT_QA_dataset_llama = SFT_QA_dataset.map(lambda x :
                                    llama_tokenizer(x['QA']))

SFT_QA_dataset_llama = SFT_QA_dataset_llama.map(lambda x: {'length':len(x['input_ids'])})

SFT_QA_dataset_llama.save_to_disk('../data/SFT_QA_dataset_llama')

Map:   0%|          | 0/72214 [00:00<?, ? examples/s]

Map:   0%|          | 0/1964 [00:00<?, ? examples/s]

Map:   0%|          | 0/3301 [00:00<?, ? examples/s]

Map:   0%|          | 0/72214 [00:00<?, ? examples/s]

Map:   0%|          | 0/1964 [00:00<?, ? examples/s]

Map:   0%|          | 0/3301 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/72214 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1964 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/3301 [00:00<?, ? examples/s]

In [None]:
with wandb.init(project='ELI5_analysis',
                entity='ft-llmmm',
                job_type='upload_data',
                name=f'llama_QA_tokenized_dataset_clean') as run:

    clean_data_art = wandb.Artifact('llama_QA_tokenized', 'dataset')
    clean_data_art.add_dir('../data/SFT_QA_dataset_llama')
    run.log_artifact(clean_data_art)

[34m[1mwandb[0m: Adding directory to artifact (./../data/SFT_QA_dataset_llama)... Done. 2.2s


In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

# Experiments

## distil-GPT2

In [None]:
full_training('distilgpt2',dataset=SFT_QA_dataset,prec=None,ds_name = 'combined')

# Llama

In [None]:
SFT_QA_dataset = datasets.load_from_disk('../data/SFT_QA_ds')

ds_wiki = SFT_QA_dataset.filter(lambda x:
                                x['source']=='simple_wiki')

In [None]:
ds_wiki=ds_wiki.remove_columns(['source','QA'])

In [None]:
model_id = "meta-llama/Llama-2-7b-hf" # sharded weights
model_name = model_id.split('/')[-1]
llama_tokenizer = AutoTokenizer.from_pretrained(model_id)
llama_tokenizer.pad_token = llama_tokenizer.eos_token

Downloading (…)okenizer_config.json:   0%|          | 0.00/776 [00:00<?, ?B/s]

Downloading tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

In [None]:
%env WANDB_LOG_MODEL='end'

env: WANDB_LOG_MODEL='end'


In [None]:
def find_all_linear_names(model):
    lora_module_names = set()
    for name, module in model.named_modules():
        if isinstance(module, bnb.nn.Linear4bit):
            names = name.split(".")
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])

    if "lm_head" in lora_module_names:  # needed for 16-bit
        lora_module_names.remove("lm_head")
    return list(lora_module_names)

def create_peft_model(model,
                      r=64,
                      lora_alpha=16,
                      lora_dropout=0.1,
                      bias='none',
                      task_type='CAUSAL_LM',
                      gradient_checkpointing=True,
                      bf16=True):

    # prepare int-4 model for training
    model = prepare_model_for_kbit_training(
        model, use_gradient_checkpointing=gradient_checkpointing
    )
    if gradient_checkpointing:
        model.gradient_checkpointing_enable()

    # get lora target modules
    modules = find_all_linear_names(model)
    print(f"Found {len(modules)} modules to quantize: {modules}")

    peft_config = LoraConfig(
        r=r,
        lora_alpha=lora_alpha,
        target_modules=modules,
        lora_dropout=lora_dropout,
        bias=bias,
        task_type=task_type,
    )

    model = get_peft_model(model, peft_config)

    # pre-process the model by upcasting the layer norms in float 32 for
    for name, module in model.named_modules():
        if isinstance(module, LoraLayer):
            if bf16:
                module = module.to(torch.bfloat16)
        if "norm" in name:
            module = module.to(torch.float32)
        if "lm_head" in name or "embed_tokens" in name:
            if hasattr(module, "weight"):
                if bf16 and module.weight.dtype == torch.float32:
                    module = module.to(torch.bfloat16)

    model.print_trainable_parameters()
    return model

def formatting_prompts_func(example):
    output_texts = []
    for i in range(len(example['question'])):
        text = f"### Human: {example['question'][i]}\n ### Assistant: {example['answer'][i]}"
        output_texts.append(text)
    return output_texts

def sft_collator(tokenizer, response_template = "### Assistant:"):
    return DataCollatorForCompletionOnlyLM(response_template=response_template, tokenizer=tokenizer)

def training_function(model_id,
                      dataset,
                      hf_token,
                      wandb_token,
                      ds_name,
                      r=64,
                      lora_alpha=16,
                      lora_dropout=0.1,
                      bias='none',
                      task_type='CAUSAL_LM',
                      max_seq_length=512,
                      epochs = 1,
                      max_steps = -1,
                      gradient_checkpointing = True,
                      lr=2e-4,
                      weight_decay=.01,
                      per_device_train_batch_size=16,
                      per_device_eval_batch_size=16,
                      gradient_accumulation_steps=4,
                      optim='paged_adamw_32bit',
                      warmup_ratio=0.03,
                      group_by_length=True,
                      dataloader_num_workers=2,
                      logging_steps=10,
                      save_total_limit=3,
                      save_strategy='steps',
                      save_steps =.2,
                      eval_steps=.2,
                      load_best_model_at_end=True,
                      project_name='SFT_training_dm',
                      entity='ft-llmmm',
                      torch_compile=False
                      ):
    # set seed
    #torch._dynamo.config.verbose=True
    #torch._dynamo.config.suppress_errors = True
    now = datetime.now()
    time_stamp = now.strftime("%m.%d.%y-%H.%M.%S")

    model_name = model_id.split('/')[-1]
    model_name = f'{model_name}_{ds_name}_r_{r}_alpha_{lora_alpha}'

    run_name = f'{model_name}_{time_stamp}'

    if torch.cuda.get_device_capability()[0] == 8:
        bf16=True,
        fp16=False
    else:
        bf16=False
        fp16=True

    #dataset = datasets.load_from_disk(args.dataset_path)
    # load model from the hub with a bnb config
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
    )

    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        use_cache=False
        if gradient_checkpointing
        else True,  # this is needed for gradient checkpointing
        device_map="auto",
        quantization_config=bnb_config,
        #use_auth_token=hf_token
    )

    tokenizer = AutoTokenizer.from_pretrained(
        model_id,
        use_auth_token=hf_token
    )

    tokenizer.pad_token = tokenizer.eos_token

    model = create_peft_model(model,
                      r=r,
                      lora_alpha=lora_alpha,
                      lora_dropout=lora_dropout,
                      bias=bias,
                      task_type=task_type,
                      gradient_checkpointing=gradient_checkpointing,
                      bf16=bf16)

    #model = model.to_bettertransformer()


    with wandb.init(project='SFT_Training_dm',
                 entity='ft-llmmm',
                 job_type='SFT_training',
                 name=run_name) as run:

        output_dir = f'./{model_name}_{ds_name}/models'
        training_args = TrainingArguments(
            logging_dir = f'./{model_name}_{ds_name}/logs',
            output_dir= output_dir,
            per_device_train_batch_size=per_device_train_batch_size,
            per_device_eval_batch_size=per_device_eval_batch_size,
            bf16=bf16,  # Use BF16 if available
            fp16=fp16,
            learning_rate=lr,
            num_train_epochs=epochs,
            max_steps = max_steps,
            gradient_checkpointing=gradient_checkpointing,
            optim=optim,
            warmup_ratio=warmup_ratio,
            weight_decay = weight_decay,
            gradient_accumulation_steps=gradient_accumulation_steps,
            group_by_length=group_by_length,
            # logging strategies
            logging_strategy="steps",
            logging_steps=logging_steps,
            #evaluation_strategy = evaluation_strategy,
            save_strategy=save_strategy,
            evaluation_strategy = save_strategy,
            save_steps = save_steps,
            eval_steps = eval_steps,
#           log_level = 'error',
            hub_token=hf_token,
            report_to='wandb' if wandb_token else None,
            #report_to=None,
            #dataloader_num_workers = dataloader_num_workers,
            load_best_model_at_end=load_best_model_at_end,
            save_total_limit = save_total_limit,
            remove_unused_columns=False,
            disable_tqdm=False,
            torch_compile=torch_compile
            #max_grad_norm=0.3
        )

        collator=sft_collator(tokenizer)


        trainer = SFTTrainer(
            model,
            training_args,
            max_seq_length = max_seq_length,
            train_dataset = dataset['train'],
            eval_dataset = dataset['validation'],
            tokenizer=tokenizer,
            formatting_func=formatting_prompts_func,
            packing=False,
            data_collator=collator
            )

        # Start training
        trainer.train()

        outputs=trainer.evaluate()
        trainer.save_model(output_dir)

        run.log({"Performance-data": wandb.Table(dataframe=
                                                pd.DataFrame(outputs, index=["Performance"]))})

        model.push_to_hub('dhmeltzer/'+model_name)
        tokenizer.push_to_hub('dhmeltzer/'+model_name)

        trained_model_art=wandb.Artifact(model_name,type='model')
        trained_model_art.metadata={"hub_id":'dhmeltzer/'+model_name}

        #return trainer

In [None]:
gc.collect()
torch.cuda.empty_cache()

In [None]:
model = training_function(model_id='meta-llama/Llama-2-7b-hf',
                  dataset=ds_wiki,
                  hf_token=hf_token,
                  wandb_token=wandb_token,
                  ds_name='wiki',
                  per_device_train_batch_size=1,
                  per_device_eval_batch_size=1,
                  gradient_checkpointing = False,
                  torch_compile=False)

Downloading (…)lve/main/config.json:   0%|          | 0.00/609 [00:00<?, ?B/s]

Downloading (…)fetensors.index.json:   0%|          | 0.00/26.8k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading (…)of-00002.safetensors:   0%|          | 0.00/9.98G [00:00<?, ?B/s]

Downloading (…)of-00002.safetensors:   0%|          | 0.00/3.50G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading (…)neration_config.json:   0%|          | 0.00/188 [00:00<?, ?B/s]



Found 7 modules to quantize: ['k_proj', 'down_proj', 'v_proj', 'up_proj', 'gate_proj', 'o_proj', 'q_proj']
trainable params: 159,907,840 || all params: 3,660,320,768 || trainable%: 4.368683788535114


The BetterTransformer implementation does not support padding during training, as the fused kernels do not support attention masks. Beware that passing padded batched data during training may result in unexpected outputs. Please refer to https://huggingface.co/docs/optimum/bettertransformer/overview for more details.


In [None]:
gc.collect()
torch.cuda.empty_cache()

In [None]:
training_function(model_id='meta-llama/Llama-2-7b-hf',
                  dataset=ds_wiki,
                  hf_token=hf_token,
                  wandb_token=wandb_token,
                  ds_name='wiki',
                  per_device_train_batch_size=50,
                  per_device_eval_batch_size=100)

In [None]:
from google.colab import runtime
runtime.unassign()

# Inference

In [None]:
from transformers import GenerationConfig
from pprint import pprint
from peft import PeftModel
from transformers import pipeline

In [None]:
model_id = "meta-llama/Llama-2-7b-hf"

bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
    )

model = AutoModelForCausalLM.from_pretrained(
        model_id,
        device_map="auto",
        quantization_config=bnb_config,
        #use_auth_token=hf_token
    )

model = PeftModel.from_pretrained(model,
    'dhmeltzer/Llama-2-7b-hf_wiki_r_64_alpha_16',
    is_trainable=False
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
generation_config = GenerationConfig(
    num_beams=1,
    max_length=1000,
    do_sample=True,
    temperature=1,
    top_k=100,
    top_p=.8
)
llama_tokenizer = AutoTokenizer.from_pretrained(model_id)
model.generation_config = generation_config
llama_tokenizer.pad_token = llama_tokenizer.eos_token

In [None]:
def inference_formatting(example):
    return f"### Human: {example}\n ### Assistant: "

SFT_QA_dataset = datasets.load_from_disk('../data/SFT_QA_ds')

ds_wiki = SFT_QA_dataset.filter(lambda x:
                                x['source']=='simple_wiki')

df_validation = pd.DataFrame(ds_wiki['validation'])

df_val_qus = df_validation['question'].map(lambda x: inference_formatting(x))
df_val_qus=df_val_qus.to_list()

text_generator = pipeline('text-generation',model=model,
                          tokenizer=llama_tokenizer,
                          batch_size=32)

In [None]:
df_val_qus[1]

'### Question: What is Ornitholestes and when did it live?\n ### Answer: '

In [None]:
inputs = llama_tokenizer(df_val_qus[1], return_tensors="pt").to('cuda');
outputs = model.generate(**inputs,max_length=2000);
pprint(llama_tokenizer.batch_decode(outputs, skip_special_tokens=True))

['### Question: What is Ornitholestes and when did it live?\n'
 ' ### Answer: \n'
 'Ornitholestes (the bird robber) is a genus of ornithischian dinosaur from '
 'the Late Jurassic of North America. It lived 150 million years ago. It is '
 'the oldest known dinosaur with a known nesting site. It had long forelimbs '
 'and very short hind limbs. Its skeleton suggests that it was a fast runner. '
 'Its tail is long and heavy. The tail is probably adapted for balance rather '
 'than speed. The bones of the arms and legs of Ornitholestes are the same '
 'shape as those of birds. This may mean that the Ornitholestes is closely '
 'related to birds. \n'
 'It lived during the Toarcian age, about 150 million years ago. It was one of '
 'the last dinosaurs to exist. The remains were found in the Morrison '
 'Formation of Wyoming. In 2010, fossils of this dinosaur were found in the '
 'Black Peak Formation of Montana. The fossils suggest that Ornitholestes may '
 'have had feathers, and that some

In [None]:
question = 'What is Bodmin?';

inputs = llama_tokenizer(f'### Question: {question}\n ### Answer: ', return_tensors="pt").to('cuda');
outputs = model.generate(**inputs,max_length=500);
pprint(llama_tokenizer.batch_decode(outputs, skip_special_tokens=True))

['### Question: What is Bodmin?\n'
 ' ### Answer: \n'
 'Bodmin is a town and civil parish in Cornwall, England. It is about north of '
 "Plymouth. The parish's population in 2011 was 21,347. Bodmin Moor is in the "
 'parish. The town has many old buildings, including the 17th century '
 'Jailhouse and the parish church, which was built in the 14th century. The '
 "town is the setting for Daphne du Maurier's novel Rebecca. The town has an "
 'industrial past. The first modern railway in Cornwall, the West Cornwall '
 'Railway, opened in 1834. It ran between Plymouth and Truro, with a station '
 'at Bodmin. The railway was later extended to Penzance. A major slate quarry '
 'opened in the area in the 1860s. It closed in 1980. Bodmin has a prison. The '
 'Bodmin and Wenford Railway operates steam trains on a section of the former '
 'West Cornwall Railway. Bodmin has a rugby union club, Bodmin RFC. Bodmin is '
 'also home to the Cornish Gorsedh Kernow. It is one of several bodies that '
 

In [None]:
predictions=text_generator(ds_inf['qus'],
                           max_length=100,
                           temperature=1,
                           top_k=50,
                           top_p=.9)