<a href="https://colab.research.google.com/github/david-meltzer/LLMs/blob/main/training/SFT_QA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Dependencies

In [1]:
%cd drive/MyDrive/LLMs/Fine-tuning

/content/drive/MyDrive/LLMs/Fine-tuning


In [2]:
# installations

!pip install transformers -qqq
!pip install datasets --upgrade -qqq
!pip install apache-beam -qqq
!pip install wandb -qqq
!pip install accelerate -qqq
!pip install trl -qqq
!pip install bitsandbytes -qqq
!pip install peft -qqq

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.4/7.4 MB[0m [31m31.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m268.8/268.8 kB[0m [31m25.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m113.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m84.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m519.3/519.3 kB[0m [31m5.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 kB[0m [31m12.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.1/194.1 kB[0m [31m15.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m13.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━

In [3]:
import gc

import os
import torch
from google.colab import runtime
import pandas as pd

import datasets
import accelerate
import transformers
from transformers import (AutoTokenizer,
                          AutoModelForCausalLM,
                          Trainer,
                          TrainingArguments,
                          DataCollatorForLanguageModeling,
                          BitsAndBytesConfig,
                          TrainerCallback)
import bitsandbytes as bnb
import wandb
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
from datetime import datetime

In [4]:
wandb.login()

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

# Definitions

## Datasets

In [5]:
# setup collator
def formatting_prompts_func(example):
    output_texts = []
    for i in range(len(example['question'])):
        text = f"### Question: {example['question'][i]}\n\n### Answer: {example['answer'][i]}"
        output_texts.append(text)
    return output_texts
response_template = " ### Answer:"
#sft_collator = DataCollatorForCompletionOnlyLM(response_template=response_template, tokenizer=tokenizer)

def sft_collator(tokenizer, response_template = "\n\n### Answer:"):
    return DataCollatorForCompletionOnlyLM(response_template=response_template, tokenizer=tokenizer)


def prepare_dataset(ds,
                    tokenizer,
                    formatting_func,
                    max_seq_length='auto'):

    if max_seq_length == 'auto':
        max_seq_length = tokenizer.model_max_length

    ds = ds.map(lambda x: {'QA':formatting_func(x)},
                batched=True)

    ds = ds.map(lambda x: {'tokens':tokenizer(x['QA'],
                                              return_length=False)})

    ds = ds.filter(lambda x: len(x['tokens']['input_ids'])<=max_seq_length)

    return ds

## Training

In [25]:
class PeftSavingCallback(TrainerCallback):
    def on_save(self, args, state, control, **kwargs):
        checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{state.global_step}")
        kwargs["model"].save_pretrained(checkpoint_path)

        if "pytorch_model.bin" in os.listdir(checkpoint_path):
            os.remove(os.path.join(checkpoint_path, "pytorch_model.bin"))

def prepare_model(checkpoint,
                 target_modules,
                 lora_rank=32,
                 lora_alpha=32,
                 lora_dropout=0.05,
                 bias="none",
                 task_type="CAUSAL_LM",
                 model_type = 'qlora',
                 extra_quant = True):

    if model_type not in {'lora','qlora','full'}:
        raise ValueError('Train type should be "lora", "qlora", or "full".')

    if model_type in {'lora','qlora'}:

        if model_type == 'qlora':

            nf4_config = BitsAndBytesConfig(load_in_4bit=True,
                                    bnb_4bit_quant_type="nf4",
                                    bnb_4bit_use_double_quant = extra_quant,
                                    bnb_4bit_compute_dtype=torch.bfloat16
                                    )

            model = AutoModelForCausalLM.from_pretrained(checkpoint,
                                                quantization_config=nf4_config,
                                                device_map='auto',
                                                torch_dtype = torch.bfloat16
                                                )
        else:
            model = AutoModelForCausalLM.from_pretrained(checkpoint,
                                                 load_in_8bit = extra_quant)

        model = prepare_model_for_kbit_training(model)

        lora_config = LoraConfig(
          r = lora_rank,
          lora_alpha = lora_alpha,
          target_modules = target_modules,
          lora_dropout = lora_dropout,
          bias = bias,
          task_type = task_type
          )

        model = get_peft_model(model, lora_config)
    else:
        model = AutoModelForCausalLM.from_pretrained(checkpoint)

    tokenizer = AutoTokenizer.from_pretrained(checkpoint)

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    model_name = checkpoint.split('/')[-1]

    if model_type in {'lora','qlora'}:
        model_name += f'_{model_type}'
        model_name += f'_r_{lora_rank}_a_{lora_alpha}'

    return model, tokenizer, model_name


def prepare_hyperparameters(model_name,
                            ds_name = 'combined',
                            evaluation_strategy = 'steps',
                            save_steps = .1,
                            eval_steps = .1,
                            logging_steps = 100,
                            log_level = 'error',
                            report_to = 'wandb',
                            num_train_epochs = 3,
                            lr = 5e-5,
                            warmup_steps = 50,
                            weight_decay = .01,
                            optim = 'adamw_torch_fused',
                            prec = 'fp16',
                            train_batch_size = 8,
                            eval_batch_size = 16,
                            grad_accum = 4,
                            grad_checkpoint = True,
                            group_by_length = True,
                            dataloader_num_workers = 2,
                            save_total_limit = 3,
                            wandb_report = 'SFT_training_dm'):

    training_args = TrainingArguments(
        logging_dir = f'./{model_name}_{ds_name}/logs',
        output_dir= f'./{model_name}_{ds_name}/models',
        evaluation_strategy = evaluation_strategy,
        save_strategy = evaluation_strategy,
        save_steps = save_steps,
        eval_steps = eval_steps,
        logging_steps = logging_steps,
        log_level = log_level,
        report_to = report_to,
        num_train_epochs = num_train_epochs,
        learning_rate = lr,
        warmup_steps = warmup_steps,
        weight_decay = weight_decay,
        optim = optim,
        fp16 = True if prec=='fp16' else False,
        bf16 = True if prec=='bf16' else False,
        per_device_train_batch_size = train_batch_size,
        per_device_eval_batch_size = eval_batch_size,
        gradient_accumulation_steps = grad_accum,
        gradient_checkpointing = grad_checkpoint,
        group_by_length = group_by_length,
        dataloader_num_workers = dataloader_num_workers,
        load_best_model_at_end=True,
        save_total_limit = save_total_limit,
        )

    if report_to == 'wandb':

        %env WANDB_PROJECT = wandb_report

        now = datetime.now()
        time_stamp = now.strftime("%m.%d.%y-%H.%M.%S")

        run_name = f'{model_name}__time_stamp'

        training_args.run_name = run_name



    return training_args

def SFT_train(model,
              tokenizer,
              training_args,
              dataset,
              dataset_text_field='QA',
              formatting_func = formatting_prompts_func,
              ds_name = 'combined',
              max_seq_length = 'auto',
              packing = False,
              collator = sft_collator,
              preprocess_ds = False
              ):

    if max_seq_length == 'auto':
        max_seq_length = tokenizer.model_max_length

    if training_args.gradient_checkpointing:
        model.gradient_checkpointing_enable()

    collator = collator(tokenizer)

    if ds_name == 'ELI5':
        dataset = dataset.filter(lambda x:x['source']=='ELI5')
    elif ds_name == 'simple_wiki':
        dataset = dataset.filter(lambda x:x['source']=='simple_wiki')

    if preprocess_ds:
        dataset = prepare_dataset(dataset,tokenizer,formatting_func)

        sft_trainer = SFTTrainer(
            model,
            training_args,
            max_seq_length=max_seq_length,
            train_dataset=dataset['train'],
            eval_dataset=dataset['validation'],
            dataset_text_field=dataset_text_field,
            data_collator=collator,
            callbacks=[PeftSavingCallback()],
            packing=packing
            )

    else:
        sft_trainer = SFTTrainer(
            model,
            training_args,
            max_seq_length=max_seq_length,
            train_dataset=dataset['train'],
            eval_dataset=dataset['validation'],
            formatting_func=formatting_prompts_func,
            data_collator=collator,
            callbacks=[PeftSavingCallback()],
            packing=packing
            )

    sft_trainer.train()

    wandb.finish()

def full_training(
    checkpoint,
    dataset,
    target_modules=None,
    dataset_text_field="QA",
    max_seq_length = 'auto',
    ds_name = 'combined',
    lora_rank=32,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    model_type = 'qlora',
    extra_quant = True,
    evaluation_strategy = 'steps',
    save_steps = .1,
    eval_steps = .1,
    logging_steps = 100,
    log_level = 'error',
    report_to = 'wandb',
    num_train_epochs = 3,
    lr = 5e-5,
    warmup_steps = 50,
    weight_decay = .01,
    optim = 'adamw_torch_fused',
    prec = 'fp16',
    train_batch_size = 8,
    eval_batch_size = 16,
    grad_accum = 4,
    grad_checkpoint = True,
    group_by_length = True,
    dataloader_num_workers = 2,
    save_total_limit = 3,
    wandb_report = 'SFT_training_dm',
    packing = False,
    collator = sft_collator,
    preprocess_ds = False
    ):

    model, tokenizer, model_name = prepare_model(checkpoint,
                                                 target_modules,
                                                lora_rank=lora_rank,
                                                lora_alpha=lora_alpha,
                                                lora_dropout=lora_dropout,
                                                bias=bias,
                                                task_type=task_type,
                                                model_type = model_type,
                                                extra_quant = extra_quant)

    training_args = prepare_hyperparameters(model_name,
                            ds_name,
                            evaluation_strategy =evaluation_strategy,
                            save_steps = save_steps,
                            eval_steps = eval_steps,
                            logging_steps = logging_steps,
                            log_level = log_level,
                            report_to = report_to,
                            num_train_epochs = num_train_epochs,
                            lr = lr,
                            warmup_steps = warmup_steps,
                            weight_decay = weight_decay,
                            optim = optim,
                            prec = prec,
                            train_batch_size = train_batch_size,
                            eval_batch_size = eval_batch_size,
                            grad_accum = grad_accum,
                            grad_checkpoint = grad_checkpoint,
                            group_by_length = group_by_length,
                            dataloader_num_workers = dataloader_num_workers,
                            save_total_limit = save_total_limit,
                            wandb_report = wandb_report
                            )

    SFT_train(model,
              tokenizer,
              training_args,
              dataset = dataset,
              dataset_text_field=dataset_text_field,
              ds_name = ds_name,
              max_seq_length = max_seq_length,
              packing = packing,
              collator = collator,
              preprocess_ds = preprocess_ds
              )



# Datasets

## Download and Combine Datasets

In [7]:
#with wandb.init(project='ELI5_analysis',
#                 entity='ft-llmmm',
#                 job_type='training',
#                 name='SFT_training') as run:
#
#    artifact_wiki_QA = run.use_artifact('ft-llmmm/ELI5_analysis/simple_wiki_QA:v1', type='dataset')
#    artifact_dir_wiki_QA = artifact_wiki_QA.download()
#
#    artifact_ELI5 = run.use_artifact('ft-llmmm/ELI5_analysis/ELI5_cleaned:v2', type='dataset')
#    artifact_dir_ELI5 = artifact_ELI5.download()

In [8]:
artifact_dir_wiki_QA='./artifacts/simple_wiki_QA:v1'
artifact_dir_ELI5='./artifacts/ELI5_cleaned:v2'

In [9]:
simplewiki_QA_ds = datasets.load_dataset("csv",
                                         data_files={"train": artifact_dir_wiki_QA + '/simple_wiki_QA_combined_train.csv',
                                                    "test": artifact_dir_wiki_QA +  '/simple_wiki_QA_combined_test.csv',
                                                    "val": artifact_dir_wiki_QA + '/simple_wiki_QA_combined_validation.csv'
                                        }
)
simplewiki_QA_ds = simplewiki_QA_ds.remove_columns(['id','system_message','prompt_template'])
simplewiki_QA_ds = simplewiki_QA_ds.rename_columns({'trunc_text':'answer'})

simplewiki_QA_ds['validation'] = simplewiki_QA_ds['val']
del simplewiki_QA_ds['val']

Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/3 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

Generating val split: 0 examples [00:00, ? examples/s]

In [10]:
for split in simplewiki_QA_ds:
    dset_source = datasets.Dataset.from_dict({'source':['simple_wiki']*len(simplewiki_QA_ds[split])})
    simplewiki_QA_ds[split] = datasets.concatenate_datasets([simplewiki_QA_ds[split],dset_source],axis=1)

In [11]:
ELI5_ds = datasets.load_from_disk(f'{artifact_dir_ELI5}/ds_SFT')
ELI5_ds = ELI5_ds.flatten()
ELI5_ds = ELI5_ds.remove_columns(['document','q_id','title','selftext','subreddit','url','title_urls','selftext_urls','answers_urls','pref_idxs','dupl_scores_idxs','qu_emb',
                                  'answers.a_id','answers.fkg','answers.fre', 'answers.score'])
ELI5_ds = ELI5_ds.map(lambda x: {'answers.text':list(x['answers.text'])})

ELI5_ds = ELI5_ds.with_format("pandas").map(lambda df:
                                                df.explode("answers.text"),
                                                batched=True)

ELI5_ds = ELI5_ds.with_format(None)

ELI5_ds = ELI5_ds.remove_columns(['__index_level_0__'])
ELI5_ds = ELI5_ds.rename_columns({'answers.text':'answer',
                                  'title_body':'question'})

In [12]:
for split in ELI5_ds:
    dset_source = datasets.Dataset.from_dict({'source':['ELI5']*len(ELI5_ds[split])})
    ELI5_ds[split] = datasets.concatenate_datasets([ELI5_ds[split],dset_source],axis=1)

In [13]:
SFT_QA_dataset = datasets.DatasetDict()

for split in ['train','validation','test']:

    SFT_QA_dataset[split] = datasets.concatenate_datasets([simplewiki_QA_ds[split],
                                                ELI5_ds[split]])

In [14]:
SFT_QA_dataset = SFT_QA_dataset.shuffle(seed=12321)

In [24]:
SFT_QA_dataset.save_to_disk('./data/SFT_QA_ds')

Saving the dataset (0/1 shards):   0%|          | 0/75742 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/2133 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/3536 [00:00<?, ? examples/s]

In [None]:
now = datetime.now()
time_stamp = now.strftime("%m.%d.%y-%H.%M.%S")
with wandb.init(project='ELI5_analysis',
                entity='ft-llmmm',
                job_type='upload_data',
                name=f'combined_dataset_{time_stamp}') as run:

    clean_data_art = wandb.Artifact('combined_dataset', 'dataset')
    clean_data_art.add_dir('./data/SFT_QA_ds')
    run.log_artifact(clean_data_art)

## Instruction Formatting

In [16]:
tok = AutoTokenizer.from_pretrained('distilgpt2')

Downloading (…)lve/main/config.json:   0%|          | 0.00/762 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

In [18]:
GPT2_QA_tokenized = prepare_dataset(SFT_QA_dataset,tok,formatting_prompts_func)

Map:   0%|          | 0/75742 [00:00<?, ? examples/s]

Map:   0%|          | 0/2133 [00:00<?, ? examples/s]

Map:   0%|          | 0/3536 [00:00<?, ? examples/s]

Map:   0%|          | 0/75742 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (1068 > 1024). Running this sequence through the model will result in indexing errors


Map:   0%|          | 0/2133 [00:00<?, ? examples/s]

Map:   0%|          | 0/3536 [00:00<?, ? examples/s]

Filter:   0%|          | 0/75742 [00:00<?, ? examples/s]

Filter:   0%|          | 0/2133 [00:00<?, ? examples/s]

Filter:   0%|          | 0/3536 [00:00<?, ? examples/s]

In [19]:
GPT2_QA_tokenized.save_to_disk('./data/GPT2_QA_tokenized')

Saving the dataset (0/1 shards):   0%|          | 0/75032 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/2119 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/3485 [00:00<?, ? examples/s]

In [20]:
now = datetime.now()
time_stamp = now.strftime("%m.%d.%y-%H.%M.%S")
with wandb.init(project='ELI5_analysis',
                entity='ft-llmmm',
                job_type='upload_data',
                name=f'GPT2_QA_tokenized_dataset_{time_stamp}') as run:

    clean_data_art = wandb.Artifact('GPT2_QA_tokenized', 'dataset')
    clean_data_art.add_dir('./data/GPT2_QA_tokenized')
    run.log_artifact(clean_data_art)

[34m[1mwandb[0m: Currently logged in as: [33mdmeltzer[0m ([33mft-llmmm[0m). Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: Adding directory to artifact (./data/GPT2_QA_tokenized)... Done. 2.0s


In [23]:
GPT2_QA_tokenized['train'][0]['tokens'].keys()

dict_keys(['attention_mask', 'input_ids'])

# Training

# Experiments

In [38]:
full_training('distilgpt2',dataset_dict=ds_dict,prec=None, ds_name = 'ELI5')

Using pad_token, but it is not set yet.


env: WANDB_PROJECT=wandb_report


Using pad_token, but it is not set yet.


Map:   0%|          | 0/45742 [00:00<?, ? examples/s]

Map:   0%|          | 0/1133 [00:00<?, ? examples/s]

RuntimeError: ignored