<a href="https://colab.research.google.com/github/david-meltzer/LLMs/blob/main/SFT_QA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Dependencies

In [1]:
%cd drive/MyDrive/LLMs/Fine-tuning

/content/drive/MyDrive/LLMs/Fine-tuning


In [2]:
# installations

!pip install transformers -qqq
!pip install datasets -qqq
!pip install apache-beam -qqq
!pip install wandb -qqq
!pip install accelerate -qqq
!pip install trl -qqq
!pip install bitsandbytes -qqq
!pip install peft -qqq

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.4/7.4 MB[0m [31m56.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m268.8/268.8 kB[0m [31m29.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m106.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m76.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m519.3/519.3 kB[0m [31m9.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 kB[0m [31m11.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.1/194.1 kB[0m [31m17.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m8.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━

In [3]:
import gc

import os
import torch
from google.colab import runtime
import pandas as pd

import datasets
import accelerate
import transformers
from transformers import (AutoTokenizer,
                          AutoModelForCausalLM,
                          Trainer,
                          TrainingArguments,
                          DataCollatorForLanguageModeling,
                          BitsAndBytesConfig,
                          TrainerCallback)
import bitsandbytes as bnb
import wandb
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
from datetime import datetime


wandb.login()

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

# Datasets

In [4]:
#with wandb.init(project='ELI5_analysis',
#                 entity='ft-llmmm',
#                 job_type='training',
#                 name='SFT_training') as run:
#
#    artifact_wiki_QA = run.use_artifact('ft-llmmm/ELI5_analysis/simple_wiki_QA:v1', type='dataset')
#    artifact_dir_wiki_QA = artifact_wiki_QA.download()
#
#    artifact_ELI5 = run.use_artifact('ft-llmmm/ELI5_analysis/ELI5_cleaned:v2', type='dataset')
#    artifact_dir_ELI5 = artifact_ELI5.download()

In [5]:
artifact_dir_wiki_QA='./artifacts/simple_wiki_QA:v1'
artifact_dir_ELI5='./artifacts/ELI5_cleaned:v2'

In [6]:
simplewiki_QA_ds = datasets.load_dataset("csv",
                                         data_files={"train": artifact_dir_wiki_QA + '/simple_wiki_QA_combined_train.csv',
                                                    "test": artifact_dir_wiki_QA +  '/simple_wiki_QA_combined_test.csv',
                                                    "val": artifact_dir_wiki_QA + '/simple_wiki_QA_combined_validation.csv'
                                        }
)
simplewiki_QA_ds = simplewiki_QA_ds.remove_columns(['id','system_message','prompt_template'])
simplewiki_QA_ds = simplewiki_QA_ds.rename_columns({'trunc_text':'answer'})

simplewiki_QA_ds['validation'] = simplewiki_QA_ds['val']
del simplewiki_QA_ds['val']

Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/3 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

Generating val split: 0 examples [00:00, ? examples/s]

In [7]:
ELI5_ds = datasets.load_from_disk(f'{artifact_dir_ELI5}/ds_SFT')
ELI5_ds = ELI5_ds.flatten()
ELI5_ds = ELI5_ds.remove_columns(['document','q_id','title','selftext','subreddit','url','title_urls','selftext_urls','answers_urls','pref_idxs','dupl_scores_idxs','qu_emb',
                                  'answers.a_id','answers.fkg','answers.fre', 'answers.score'])
ELI5_ds = ELI5_ds.map(lambda x: {'answers.text':list(x['answers.text'])})

ELI5_ds = ELI5_ds.with_format("pandas").map(lambda df:
                                                df.explode("answers.text"),
                                                batched=True)

ELI5_ds = ELI5_ds.with_format(None)

ELI5_ds = ELI5_ds.remove_columns(['__index_level_0__'])
ELI5_ds = ELI5_ds.rename_columns({'answers.text':'answer',
                                  'title_body':'question'})

In [8]:
SFT_QA_dataset = datasets.DatasetDict()

for split in ['train','validation','test']:

    SFT_QA_dataset[split] = datasets.concatenate_datasets([simplewiki_QA_ds[split],
                                                ELI5_ds[split]])

In [9]:
SFT_QA_dataset = SFT_QA_dataset.shuffle(seed=12321)

In [10]:
SFT_QA_dataset.save_to_disk('./data/SFT_QA_ds')

Saving the dataset (0/1 shards):   0%|          | 0/75742 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/2133 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/3536 [00:00<?, ? examples/s]

In [11]:
ds_dict = {}
ds_dict['ELI5'] = ELI5_ds
ds_dict['wiki'] = simplewiki_QA_ds
ds_dict['combined'] = SFT_QA_dataset

# Training

## Definitions

In [12]:
class PeftSavingCallback(TrainerCallback):
    def on_save(self, args, state, control, **kwargs):
        checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{state.global_step}")
        kwargs["model"].save_pretrained(checkpoint_path)

        if "pytorch_model.bin" in os.listdir(checkpoint_path):
            os.remove(os.path.join(checkpoint_path, "pytorch_model.bin"))

In [26]:
# setup collator
def formatting_prompts_func(example):
    output_texts = []
    for i in range(len(example['question'])):
        text = f"### Question: {example['question'][i]}\n ### Answer: {example['answer'][i]}"
        output_texts.append(text)
    return output_texts
response_template = " ### Answer:"
#sft_collator = DataCollatorForCompletionOnlyLM(response_template=response_template, tokenizer=tokenizer)

def sft_collator(tokenizer, response_template = " ### Answer:"):
    return DataCollatorForCompletionOnlyLM(response_template=response_template, tokenizer=tokenizer)

In [27]:
def prepare_model(checkpoint,
                 target_modules,
                 lora_rank=32,
                 lora_alpha=32,
                 lora_dropout=0.05,
                 bias="none",
                 task_type="CAUSAL_LM",
                 train_type = 'qlora',
                 extra_quant = True):

    if train_type not in {'lora','qlora','full'}:
        raise ValueError('Train type should be "lora", "qlora", or "full".')

    if train_type in {'lora','qlora'}:

        if train_type == 'qlora':

            nf4_config = BitsAndBytesConfig(load_in_4bit=True,
                                    bnb_4bit_quant_type="nf4",
                                    bnb_4bit_use_double_quant = extra_quant,
                                    bnb_4bit_compute_dtype=torch.bfloat16
                                    )

            model = AutoModelForCausalLM.from_pretrained(checkpoint,
                                                quantization_config=nf4_config,
                                                device_map='auto',
                                                torch_dtype = torch.bfloat16
                                                )
        else:
            model = AutoModelForCausalLM.from_pretrained(checkpoint,
                                                 load_in_8bit = extra_quant)

        model = prepare_model_for_kbit_training(model)

        lora_config = LoraConfig(
          r = lora_rank,
          lora_alpha = lora_alpha,
          target_modules = target_modules,
          lora_dropout = lora_dropout,
          bias = bias,
          task_type = task_type
          )

        model = get_peft_model(model, lora_config)
    else:
        model = AutoModelForCausalLM.from_pretrained(checkpoint)



    tokenizer = AutoTokenizer.from_pretrained(checkpoint)

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token


    return model, tokenizer


In [28]:
def SFT_train(checkpoint,
              dataset_dict = ds_dict,
              ds_name = 'combined',
              max_seq_length = 512,
              target_modules = None,
              lora_rank=32,
              lora_alpha=32,
              lora_dropout=0.05,
              bias="none",
              task_type="CAUSAL_LM",
              train_type = 'qlora',
              extra_quant = True,
              evaluation_strategy = "steps",
              log_level = 'error',
              report_to = 'wandb',
              logging_steps = 100,
              save_strategy = 'steps',
              save_steps = .1,
              num_train_epochs = 3,
              lr = 5e-5,
              warmup_steps = 50,
              weight_decay = .01,
              optim = 'adamw_torch_fused',
              prec = 'fp16',
              grad_accum = 4,
              grad_checkpoint = True,
              group_by_length = True,
              dataloader_num_workers = 4,
              save_total_limit = 3
              ):

    dataset = ds_dict[ds_name]
    model_name = checkpoint.split('/')[-1]

    if report_to == 'wandb':

        %env WANDB_PROJECT='SFT_training_dm'


        run_name = f'{model_name}, r={lora_rank}, alpha = {lora_alpha}'

        now = datetime.now()
        time_stamp = now.strftime("%m.%d.%y-%H.%M.%S")

        run_name += '_'+time_stamp

    training_args = TrainingArguments(
        evaluation_strategy = evaluation_strategy,
        logging_steps = logging_steps,
        save_strategy = save_strategy,
        log_level = log_level,
        report_to = report_to,
        num_train_epochs = num_train_epochs,
        learning_rate = lr,
        warmup_steps = warmup_steps,
        weight_decay = weight_decay,
        optim = optim,
        fp16 = True if prec=='fp16' else False,
        bf16 = True if prec=='bf16' else False,
        gradient_accumulation_steps = grad_accum,
        gradient_checkpointing = grad_checkpoint,
        group_by_length = group_by_length,
        dataloader_num_workers = dataloader_num_workers,
        run_name=run_name,
        load_best_model_at_end=True,
        save_total_limit = save_total_limit,
        logging_dir = f'./{model_name}_{ds_name}/logs',
        output_dir= f'./{model_name}_{ds_name}/models',
        )

    model, tokenizer = prepare_model(checkpoint,
                                     target_modules,
                                     lora_rank=lora_rank,
                                     lora_alpha=lora_alpha,
                                     lora_dropout=lora_dropout,
                                     bias=bias,
                                     task_type=task_type,
                                     train_type = train_type,
                                     extra_quant = extra_quant)

    if grad_checkpoint:
        model.gradient_checkpointing_enable()

    collator = sft_collator(tokenizer)

    sft_trainer = SFTTrainer(
        model,
        training_args,
        max_seq_length=max_seq_length,
        train_dataset=dataset['train'],
        eval_dataset=dataset['validation'],
        formatting_func=formatting_prompts_func,
        data_collator=collator,
        callbacks=[PeftSavingCallback()]
        )

    sft_trainer.train()

    wandb.finish()

In [None]:
SFT_train('distilgpt2',prec=None,ds_name = 'wiki')

env: WANDB_PROJECT='SFT_training_dm'


Using pad_token, but it is not set yet.
Using pad_token, but it is not set yet.


Map:   0%|          | 0/30000 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]



{'loss': 3.312, 'learning_rate': 4.909453096704093e-05, 'epoch': 0.11}
{'eval_loss': 3.1178946495056152, 'eval_runtime': 13.6785, 'eval_samples_per_second': 73.107, 'eval_steps_per_second': 9.138, 'epoch': 0.11}
{'loss': 3.207, 'learning_rate': 4.7283592901122785e-05, 'epoch': 0.21}
{'eval_loss': 3.069530725479126, 'eval_runtime': 13.3569, 'eval_samples_per_second': 74.868, 'eval_steps_per_second': 9.358, 'epoch': 0.21}
{'loss': 3.1533, 'learning_rate': 4.547265483520464e-05, 'epoch': 0.32}
{'eval_loss': 3.0456104278564453, 'eval_runtime': 13.887, 'eval_samples_per_second': 72.01, 'eval_steps_per_second': 9.001, 'epoch': 0.32}
{'loss': 3.138, 'learning_rate': 4.366171676928649e-05, 'epoch': 0.43}
{'eval_loss': 3.0291614532470703, 'eval_runtime': 13.2463, 'eval_samples_per_second': 75.493, 'eval_steps_per_second': 9.437, 'epoch': 0.43}
{'loss': 3.1143, 'learning_rate': 4.185077870336835e-05, 'epoch': 0.53}
{'eval_loss': 3.0189785957336426, 'eval_runtime': 13.9017, 'eval_samples_per_seco



{'eval_loss': 3.0121889114379883, 'eval_runtime': 13.5573, 'eval_samples_per_second': 73.761, 'eval_steps_per_second': 9.22, 'epoch': 0.64}


In [30]:
tokenizer = AutoTokenizer.from_pretrained('distilgpt2')

In [33]:
tokenizer.encode(' ### Answer:')

[44386, 23998, 25]

In [35]:
tokenizer.decode(torch.tensor([21017, 18233,    25,   317,  2043,    32,   329, 12180,   257, 34875,
         8224,    11,   788,  4585,   616,  6478,   338,  6478,   262,  3738,
          488,  1585,   290,  9482,   465,  3850,   618,   339,  8556,   502,
           30,   198, 18690,    11,  1309,   502,  1577,   345,   617,  4469,
           13,   314,  3111,   329,   257,  3236,    11, 28061,  4009,    13,
          314,   836,   470,   765,   284,  1438,  3891,    11,   475,   345,
         1053,  4753,  2982,   286,   606,    13,  1119,   821, 48583,   287,
        10598,    11,   484,   821,  1107,   656,  8509,   290,  8237,   357,
         4360,   691,   262,  4409,   651,   284,  4144,   262,  8237,     8,
          290,   484,  2192,   423,   257,  8478,   287,   534,  3240,    13,
         7683,   812,  2084,   314,  5717,   257, 34875,  8224,   351,   616,
         9749,    13,   314,  1908,   616,  6478,   357,   732,  1183,   869,
          683,   366,  2348,  4943,   257,  1351,   286,  6957,  1243,   661,
          287,   616,  4009,   547,  1804,   326,   314,  1807,   547,  2642,
           13, 33495,   484,   389, 21608,   661,   503,   286,   511,  1204,
          338, 10653,   416,  6301,   644,  6867,   284,   651,   503,   286,
         7356,  1479,  4116,    11,   475,   314,   423,   340,   319,   922,
         4934,   326,   484,   836,   470,  1682,   670,    13,   314,   635,
         1234,   257,  4866,   286,   262,  1351,   510,   319,   262,  1957,
         4928,  3420,   357, 47057,  7224,   340,   618,   484,   467,   287,
          737,   314,  1612,    11,   314,  1422,   470,   765,  2687,  2073,
          284,  7030,   511,  1637,    11,   826,    30,   887,   314,  2630,
          606,   287,  9133,    11,   523,   691,   661,   508,   547,  1107,
         4451,   561,   307,  1498,   284,  1100,   606,    13,  3894,    11,
          706,   326,  1243,  1611,   286,  1392,   503,   286,  1021,    13,
        43141,   925,   257,  7684,   286,  9088,   286,   262,  1351,    11,
         2130,  2073, 14251,   340,   656,  2679,   290,   925,  1576,  9088,
          326,  7288,   714,  1100,   340,    11,   290,   878,   314,  2993,
          644,   373,  5836,   340,   550,  3750, 14416,    13, 20090,  2582,
          340,   373,   477,  2687,   373,  3375,   546,    13,   978,   550,
         7891,   287,  3638,   351,   465,  6478,   546,   703,   284,  5412,
          477,   428,    11,   290,   340,  2492,   470, 22655,   922,   329,
          502,    13,  1406,   314,  1908,   257,  3850,   284,   978,   338,
         6478,   357,  1616,   338,   869,   683, 50191,   737,   314,  4893,
          644,   262,  1351,   373,   477,   546,    13,   314,  8072,   683,
          314,  1422,   470,   423,   597,  2761,   351,   683,  7620,   393,
          674,  4009,   287,  2276,    11,   314,   655,  1807,   356,   815,
          651,  5755,   286,   262,   661,   508,   547,  6301,   925,   510,
        41746,   684,    13,   314,  1422,   470,   765,   284,   651,   287,
          597,  5876,    11,   345,   760,    30,  2735,    11,   314,   815,
         3068,   326,   612,   389,   257,  1256,   286, 14923,  1016,  1088,
          546,   428,  3516,    13,  4525,    11, 17166, 14923,    13,  4525,
           11,  1714,  4671,   290,  3404,  1611,   286, 14923,    13,   314,
          836,   470,   760,   611,   597,   286,   340,   338,  2081,    11,
          475,   339,  1107,  1595,   470,  1283,   588,   257,  6635,   510,
         5646,  1611,   286,  3516,    11,   345,   760,    30, 21836,    11,
         5729, 50191,  1422,   470,   588,   616,  3850,   780,   339, 12284,
          326,   314,  1282,  1561,   351,   683,   379,   262,  6355, 10043,
           13,  3894,    11,   326,   338,   257,   890,   835,   422,   810,
          314,  2107,    11,   290,   314,  2492,   470,  1654,   314,  2936,
         3338,  1016,   612,  6949,    11,   345,   760,    30,   314,  2982,
          484,  1444]))

'### Question: AITA for filing a whistleblower complaint, then calling my boss\'s boss the Antichrist and burning his letter when he threatened me?\nOk, let me give you some background. I worked for a huge, multinational organization. I don\'t want to name names, but you\'ve definitely heard of them. They\'re headquartered in Rome, they\'re really into bread and wine (but only the employees get to drink the wine) and they probably have a branch in your town. Three years ago I filed a whistleblower complaint with my employer. I sent my boss (we\'ll call him "Al") a list of 95 things people in my organization were doing that I thought were wrong. Mostly they are cheating people out of their life\'s savings by selling what amounts to get out of jail free cards, but I have it on good authority that they don\'t actually work. I also put a copy of the list up on the local church door (everyone sees it when they go in). I mean, I didn\'t want anyone else to waste their money, right? But I wro

In [None]:
### Question: AITA for filing a whistleblower complaint, then calling my boss's boss the Antichrist and burning his letter when he threatened me?
Ok, let me give you some background. I worked for a huge, multinational organization. I don't want to name names, but you've definitely heard of them. They're headquartered in Rome, they're really into bread and wine (but only the employees get to drink the wine) and they probably have a branch in your town. Three years ago I filed a whistleblower complaint with my employer. I sent my boss (we'll call him "Al") a list of 95 things people in my organization were doing that I thought were wrong. Mostly they are cheating people out of their life's savings by selling what amounts to get out of jail free cards, but I have it on good authority that they don't actually work. I also put a copy of the list up on the local church door (everyone sees it when they go in). I mean, I didn't want anyone else to waste their money, right? But I wrote them in Latin, so only people who were really smart would be able to read them. Well, after that things kind of got out of hand. Somebody made a bunch of copies of the list, someone else translated it into German and made enough copies that everybody could read it, and before I knew what was happening it had gone viral. Pretty soon it was all anyone was talking about. Al had gotten in touch with his boss about how to handle all this, and it wasn't sounding good for me. So I sent a letter to Al's boss (let's call him Giovanni). I explained what the list was all about. I promised him I didn't have any problems with him personally or our organization in general, I just thought we should get rid of the people who were selling made up pardons. I didn't want to get in any trouble, you know? Now, I should mention that there are a lot of rumors going around about this guy. Like, nasty rumors. Like, sex parties and stuff kind of rumors. I don't know if any of it's true, but he really doesn't seem like a totally upstanding kind of guy, you know? Anyway, apparently Giovanni didn't like my letter because he demanded that I come talk with him at the corporate headquarters. Well, that's a long way from where I live, and I wasn't sure I felt safe going there anyway, you know? I heard they called