<a href="https://colab.research.google.com/github/david-meltzer/LLMs/blob/main/training/david/SFT/colab_SFT_QA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Dependencies

In [1]:
from google.colab import drive
drive.mount('/content/drive')
%cd drive/MyDrive/LLMs/Fine-tuning/SFT

# installations
#!pip install detoxify

!pip install peft==0.4.0
!pip install bitsandbytes==0.41.1
!pip install safetensors>=0.3.1
!pip install trl
!pip install wandb
!pip install tokenizers>=0.13.3
!pip install -U transformers
!pip install accelerate==0.21.0
!pip install datasets
!pip install -U torch
!pip install evaluate
!pip install rouge_score
!pip install nltk
!pip install bert_score

!python -c "import torch; assert torch.cuda.get_device_capability()[0] >= 8, 'Hardware not supported for Flash Attention'"
!pip install ninja packaging
!pip install flash-attn --no-build-isolation

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/MyDrive/LLMs/Fine-tuning/SFT


In [2]:
import gc

import os
import torch
from google.colab import runtime
import pandas as pd

import datasets
import accelerate
import transformers
from transformers import (AutoTokenizer,
                          AutoModelForCausalLM,
                          Trainer,
                          TrainingArguments,
                          DataCollatorForLanguageModeling,
                          BitsAndBytesConfig,
                          TrainerCallback)
import bitsandbytes as bnb
import wandb
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
from datetime import datetime
from huggingface_hub import login

from peft.tuners.lora import LoraLayer
import evaluate

#from getpass import getpass
#hf_token = getpass()
#wandb_token = getpass()

In [3]:
from getpass import getpass
hf_token = getpass()
wandb_token = getpass()

login(hf_token)
wandb.login(key=wandb_token)

··········
··········
Token will not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.


[34m[1mwandb[0m: Currently logged in as: [33mdmeltzer[0m ([33mft-llmmm[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


Token is valid (permission: write).
Your token has been saved to /root/.cache/huggingface/token
Login successful


True

# Definitions

## Datasets

In [4]:
# setup collator

def formatting_prompts_func(example):
    output_texts = []
    for i in range(len(example['question'])):
        text = f"### Human: {example['question'][i]}\n ### Assistant: {example['answer'][i]}"
        output_texts.append(text)
    return output_texts

def sft_collator(tokenizer, response_template = " ### Assistant:"):

    return DataCollatorForCompletionOnlyLM(response_template=response_template, tokenizer=tokenizer)

def combine_question_answer(ds,formatting_func):

    if 'QA' not in ds['train']:
        ds = ds.map(lambda x: {'QA':formatting_func(x)},
                    batched=True)
    return ds

def prepare_dataset(ds,
                    tokenizer,
                    formatting_func,
                    max_seq_length='auto'):

    if max_seq_length == 'auto':
        max_seq_length = tokenizer.model_max_length

    ds = combine_question_answer(ds,formatting_func)

    ds = ds.map(lambda x: {'tokens':tokenizer(x['QA'],
                                              return_length=False)})

    ds = ds.filter(lambda x: len(x['tokens']['input_ids'])<=max_seq_length)

    return ds

# Form Datasets

## Download datasets

In [None]:
with wandb.init(project='ELI5_analysis',
                 entity='ft-llmmm',
                 job_type='training',
                 name='SFT_training') as run:

    artifact_wiki_QA = run.use_artifact('ft-llmmm/ELI5_analysis/simple_wiki_QA:latest', type='dataset')
    artifact_dir_wiki_QA = artifact_wiki_QA.download()

    artifact_ELI5 = run.use_artifact('ft-llmmm/ELI5_analysis/ELI5_cleaned:latest', type='dataset')
    artifact_dir_ELI5 = artifact_ELI5.download()

[34m[1mwandb[0m: Currently logged in as: [33mdmeltzer[0m ([33mft-llmmm[0m). Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: \ 1 of 3 files downloaded...[34m[1mwandb[0m:   3 of 3 files downloaded.  
[34m[1mwandb[0m: Downloading large artifact ELI5_cleaned:latest, 1379.04MB. 24 files... 
[34m[1mwandb[0m:   24 of 24 files downloaded.  
Done. 0:0:32.0


VBox(children=(Label(value='0.002 MB of 0.010 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.164206…

In [None]:
artifact_dir_wiki_QA='./artifacts/simple_wiki_QA:v4'
artifact_dir_ELI5='./artifacts/ELI5_cleaned:v5'

In [None]:
simplewiki_QA_ds = datasets.load_dataset("csv",
                                         data_files={"train": artifact_dir_wiki_QA + '/simple_wiki_QA_combined_train.csv',
                                                    "test": artifact_dir_wiki_QA +  '/simple_wiki_QA_combined_test.csv',
                                                    "val": artifact_dir_wiki_QA + '/simple_wiki_QA_combined_validation.csv'
                                        }
)
simplewiki_QA_ds = simplewiki_QA_ds.remove_columns(['id','system_message','prompt_template'])
simplewiki_QA_ds = simplewiki_QA_ds.rename_columns({'trunc_text':'answer'})

simplewiki_QA_ds['validation'] = simplewiki_QA_ds['val']
del simplewiki_QA_ds['val']

In [None]:
simplewiki_QA_ds['train'][0]

{'Unnamed: 0': 0,
 'question': 'What was the impact of the 2005 Kashmir Earthquake?',
 'answer': 'The 2005 Kashmir Earthquake (also known as the Great Pakistan earthquake) was a major earthquake centered in Pakistan-administered Kashmir and in Khyber Pakhtunkhwa near the city of Muzaffarabad. It occurred at 08:52:37 Pakistan Standard Time (03:52:37 UTC) on 8 October 2002 87,351 peoples died, 75,266 peoples injuried, and 2.4 million people were left homeless. Kashmir, Pakistan, and Southern part of India were all affected.',
 'source': 'simple_wiki'}

In [None]:
for split in simplewiki_QA_ds:
    dset_source = datasets.Dataset.from_dict({'source':['simple_wiki']*len(simplewiki_QA_ds[split])})
    simplewiki_QA_ds[split] = datasets.concatenate_datasets([simplewiki_QA_ds[split],dset_source],axis=1)

In [None]:
ELI5_ds = datasets.load_from_disk(f'{artifact_dir_ELI5}/ds_SFT')
ELI5_ds = ELI5_ds.flatten()
ELI5_ds = ELI5_ds.remove_columns(['document','q_id','title','selftext','subreddit','url','title_urls','selftext_urls','answers_urls','pref_idxs','dupl_scores_idxs','qu_emb',
                                  'answers.a_id','answers.fkg','answers.fre','answers.score'])
ELI5_ds = ELI5_ds.map(lambda x: {'answers.text':list(x['answers.text'])})

ELI5_ds = ELI5_ds.with_format("pandas").map(lambda df:
                                                df.explode("answers.text"),
                                                batched=True)

ELI5_ds = ELI5_ds.with_format(None)

ELI5_ds = ELI5_ds.remove_columns(['__index_level_0__'])
ELI5_ds = ELI5_ds.rename_columns({'answers.text':'answer',
                                  'title_body':'question'})

In [None]:
for split in ELI5_ds:
    dset_source = datasets.Dataset.from_dict({'source':['ELI5']*len(ELI5_ds[split])})
    ELI5_ds[split] = datasets.concatenate_datasets([ELI5_ds[split],dset_source],axis=1)

## Detoxify ELI5

In [None]:
!pip install detoxify
!pip install -U torch
!pip install -U transformers
from detoxify import Detoxify

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
detoxify_model = Detoxify('unbiased')
detoxify_model.model.to(device)

In [None]:
ELI5_ds = ELI5_ds.map(lambda x: detoxify_model.predict(x['answer']),
                                                  batched=True,batch_size=64
                      )

In [None]:
ELI5_ds.save_to_disk('../data/ELI5_toxic_scores')

In [None]:
ELI5_ds = datasets.load_from_disk('../data/ELI5_toxic_scores')

In [None]:
metrics=['toxicity', 'severe_toxicity',
         'obscene', 'identity_attack',
         'insult', 'threat', 'sexual_explicit']

ELI5_non_toxic = ELI5_ds.filter(lambda x: all(x[metric]<=.1
                                              for metric in metrics))

ELI5_non_toxic = ELI5_non_toxic.remove_columns([col for col in ELI5_non_toxic['train'].features if
                                                col not in ['answer','question']])

ELI5_non_toxic.save_to_disk('../data/ELI5_non_toxic')

## Combine Datasets

In [None]:
SFT_QA_dataset = datasets.DatasetDict()
ELI5_non_toxic = datasets.load_from_disk('../data/ELI5_non_toxic')

for split in ['train','validation','test']:

    SFT_QA_dataset[split] = datasets.concatenate_datasets([simplewiki_QA_ds[split],
                                                ELI5_non_toxic[split]])

In [None]:
SFT_QA_dataset = SFT_QA_dataset.shuffle(seed=12321)

In [None]:
SFT_QA_dataset = combine_question_answer(SFT_QA_dataset,
                                         formatting_prompts_func)

Map:   0%|          | 0/107468 [00:00<?, ? examples/s]

Map:   0%|          | 0/5955 [00:00<?, ? examples/s]

Map:   0%|          | 0/7298 [00:00<?, ? examples/s]

In [None]:
SFT_QA_dataset = SFT_QA_dataset.remove_columns('Unnamed: 0')

In [None]:
SFT_QA_dataset.save_to_disk('../data/SFT_QA_ds')

Saving the dataset (0/1 shards):   0%|          | 0/107468 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/5955 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/7298 [00:00<?, ? examples/s]

In [None]:
now = datetime.now()
time_stamp = now.strftime("%m.%d.%y-%H.%M.%S")
with wandb.init(project='ELI5_analysis',
                entity='ft-llmmm',
                job_type='upload_data',
                name=f'SFT_QA_dataset_{time_stamp}') as run:

    clean_data_art = wandb.Artifact('combined_dataset', 'dataset')
    clean_data_art.add_dir('../data/SFT_QA_ds')
    run.log_artifact(clean_data_art)

[34m[1mwandb[0m: Adding directory to artifact (./../data/SFT_QA_ds)... Done. 2.4s


In [None]:
ds = datasets.load_from_disk('../data/SFT_QA_ds')

In [None]:
107468+5955+7298

120721

## Tokenizing

### GPT-2

In [None]:
SFT_QA_dataset = datasets.load_from_disk('../data/SFT_QA_ds')

In [None]:
tok = AutoTokenizer.from_pretrained('distilgpt2')
GPT2_QA_tokenized = prepare_dataset(SFT_QA_dataset,tok,formatting_prompts_func)
GPT2_QA_tokenized.save_to_disk('./data/GPT2_QA_tokenized')

now = datetime.now()
time_stamp = now.strftime("%m.%d.%y-%H.%M.%S")
with wandb.init(project='ELI5_analysis',
                entity='ft-llmmm',
                job_type='upload_data',
                name=f'GPT2_QA_tokenized_dataset_{time_stamp}') as run:

    clean_data_art = wandb.Artifact('GPT2_QA_tokenized', 'dataset')
    clean_data_art.add_dir('./data/GPT2_QA_tokenized')
    run.log_artifact(clean_data_art)

Downloading (…)lve/main/config.json:   0%|          | 0.00/762 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Map:   0%|          | 0/107468 [00:00<?, ? examples/s]

Map:   0%|          | 0/5955 [00:00<?, ? examples/s]

Map:   0%|          | 0/7298 [00:00<?, ? examples/s]

Map:   0%|          | 0/107468 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (1791 > 1024). Running this sequence through the model will result in indexing errors


Map:   0%|          | 0/5955 [00:00<?, ? examples/s]

Map:   0%|          | 0/7298 [00:00<?, ? examples/s]

Filter:   0%|          | 0/107468 [00:00<?, ? examples/s]

Filter:   0%|          | 0/5955 [00:00<?, ? examples/s]

Filter:   0%|          | 0/7298 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/106806 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/5942 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/7259 [00:00<?, ? examples/s]

[34m[1mwandb[0m: Adding directory to artifact (./data/GPT2_QA_tokenized)... Done. 2.3s


### Llama

In [None]:
from transformers import AutoTokenizer
import datasets

In [None]:
SFT_QA_dataset = datasets.load_from_disk('../data/SFT_QA_ds')

In [None]:
!pip install -U tokenizers



In [None]:
model_id = "meta-llama/Llama-2-7b-hf"
model_name = model_id.split('/')[-1]
llama_tokenizer = AutoTokenizer.from_pretrained(model_id)
llama_tokenizer.pad_token = llama_tokenizer.eos_token

In [None]:
SFT_QA_dataset_llama = SFT_QA_dataset.map(lambda x :
                                    llama_tokenizer(x['QA']))

SFT_QA_dataset_llama = SFT_QA_dataset_llama.map(lambda x: {'length':len(x['input_ids'])})

SFT_QA_dataset_llama.save_to_disk('../data/SFT_QA_dataset_llama')

Map:   0%|          | 0/107468 [00:00<?, ? examples/s]

Map:   0%|          | 0/5955 [00:00<?, ? examples/s]

Map:   0%|          | 0/7298 [00:00<?, ? examples/s]

Map:   0%|          | 0/107468 [00:00<?, ? examples/s]

Map:   0%|          | 0/5955 [00:00<?, ? examples/s]

Map:   0%|          | 0/7298 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/107468 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/5955 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/7298 [00:00<?, ? examples/s]

In [None]:
with wandb.init(project='ELI5_analysis',
                entity='ft-llmmm',
                job_type='upload_data',
                name=f'llama_QA_tokenized_dataset_clean') as run:

    clean_data_art = wandb.Artifact('llama_QA_tokenized', 'dataset')
    clean_data_art.add_dir('../data/SFT_QA_dataset_llama')
    run.log_artifact(clean_data_art)

[34m[1mwandb[0m: Adding directory to artifact (./../data/SFT_QA_dataset_llama)... Done. 1.8s


In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
ds_llama = datasets.load_from_disk('../data/SFT_QA_dataset_llama')

In [None]:
ds_llama_wiki = ds_llama.filter(lambda x: x['source']=='simple_wiki')
ds_llama_eli5 = ds_llama.filter(lambda x: x['source']=='ELI5')

Filter:   0%|          | 0/72214 [00:00<?, ? examples/s]

Filter:   0%|          | 0/1964 [00:00<?, ? examples/s]

Filter:   0%|          | 0/3301 [00:00<?, ? examples/s]

Filter:   0%|          | 0/72214 [00:00<?, ? examples/s]

Filter:   0%|          | 0/1964 [00:00<?, ? examples/s]

Filter:   0%|          | 0/3301 [00:00<?, ? examples/s]

In [None]:
for key in ['train','validation','test']:
    print(f"max length in split {key} for ELI5 is: {max(ds_llama_eli5[key]['length'])}")
    print(f"max length in split {key} for wiki is: {max(ds_llama_wiki[key]['length'])}")

max length in split train for ELI5 is: 3250
max length in split train for wiki is: 937
max length in split validation for ELI5 is: 2434
max length in split validation for wiki is: 645
max length in split test for ELI5 is: 3975
max length in split test for wiki is: 550


In [None]:
SFT_QA_dataset_llama = datasets.load_from_disk('../data/SFT_QA_dataset_llama')

In [None]:
SFT_QA_dataset_llama_1024 = SFT_QA_dataset_llama.filter(lambda x:x['length']<=1024)
SFT_QA_dataset_llama_2048 = SFT_QA_dataset_llama.filter(lambda x:x['length']<=2048)

Filter:   0%|          | 0/107468 [00:00<?, ? examples/s]

Filter:   0%|          | 0/5955 [00:00<?, ? examples/s]

Filter:   0%|          | 0/7298 [00:00<?, ? examples/s]

Filter:   0%|          | 0/107468 [00:00<?, ? examples/s]

Filter:   0%|          | 0/5955 [00:00<?, ? examples/s]

Filter:   0%|          | 0/7298 [00:00<?, ? examples/s]

In [None]:
SFT_QA_dataset_llama_1024.save_to_disk('../data/llama_tokenized_1024')
SFT_QA_dataset_llama_2048.save_to_disk('../data/llama_tokenized_2048')

Saving the dataset (0/1 shards):   0%|          | 0/106557 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/5939 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/7247 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/107388 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/5954 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/7287 [00:00<?, ? examples/s]

In [None]:
with wandb.init(project='ELI5_analysis',
                entity='ft-llmmm',
                job_type='upload_data',
                name=f'llama_QA_tokenized_dataset_clean_short') as run:

    clean_data_art_1024 = wandb.Artifact('llama_QA_tokenized_1024', 'dataset')
    clean_data_art_1024.add_dir('../data/llama_tokenized_1024')
    run.log_artifact(clean_data_art_1024)

    clean_data_art_2048 = wandb.Artifact('llama_QA_tokenized_2048', 'dataset')
    clean_data_art_2048.add_dir('../data/llama_tokenized_2048')
    run.log_artifact(clean_data_art_2048)

In [None]:
ds=datasets.load_from_disk('../data/llama_tokenized_1024')

In [None]:
ds_wiki = ds.filter(lambda x:x['source']=='simple_wiki')
ds_eli5 = ds.filter(lambda x:x['source']!='simple_wiki')

Filter:   0%|          | 0/106557 [00:00<?, ? examples/s]

Filter:   0%|          | 0/5939 [00:00<?, ? examples/s]

Filter:   0%|          | 0/7247 [00:00<?, ? examples/s]

Filter:   0%|          | 0/106557 [00:00<?, ? examples/s]

Filter:   0%|          | 0/5939 [00:00<?, ? examples/s]

Filter:   0%|          | 0/7247 [00:00<?, ? examples/s]

In [None]:
import numpy as np

In [None]:
ds_wiki['train']

Dataset({
    features: ['question', 'answer', 'source', 'QA', 'input_ids', 'attention_mask', 'length'],
    num_rows: 65252
})

In [None]:
ds_eli5['train']

Dataset({
    features: ['question', 'answer', 'source', 'QA', 'input_ids', 'attention_mask', 'length'],
    num_rows: 41305
})

In [None]:
np.sum(ds_wiki['train']['length'])

9748885

In [None]:
np.sum(ds_eli5['train']['length'])

9909595

# Training Experiments

In [None]:
import wandb
run = wandb.init(project='SFT_training_dm',
                 entity='ft-llmmm')

artifact = run.use_artifact(
    'ft-llmmm/ELI5_analysis/llama_QA_tokenized_1024:latest',
    type='dataset')
artifact_dir = artifact.download()

In [None]:
ds_combined_1024 = datasets.load_from_disk(
    './artifacts/llama_QA_tokenized_1024:v0')
ds_wiki_1024 = ds_combined_1024.filter(lambda x:
                                       x['source']=='simple_wiki')

ds_wiki_1024.save_to_disk('./data/ds_wiki_1024')

In [None]:
ds_full = datasets.load_from_disk(
    './artifacts/llama_QA_tokenized_1024:v1')

ds_wiki_1024_full = ds_full.filter(
    lambda x: x['source']=='simple_wiki')

ds_eli5_1024 = ds_full.filter(
    lambda x: x['source']!='simple_wiki')

Filter:   0%|          | 0/106557 [00:00<?, ? examples/s]

Filter:   0%|          | 0/5939 [00:00<?, ? examples/s]

Filter:   0%|          | 0/7247 [00:00<?, ? examples/s]

In [None]:
ds_wiki_1024_full.save_to_disk('./data/ds_wiki_1024_full')
ds_eli5_1024.save_to_disk('./data/ds_eli5_1024')

In [None]:
model_id = "meta-llama/Llama-2-13b-hf" # sharded weights
dataset_path = './artifacts/llama_QA_tokenized_1024:v1'
ds_name = 'eli5-wiki-1024'

now = datetime.now()
time_stamp = now.strftime("%m.%d.%y-%H.%M.%S")

#model_name = model_id.replace('/','-')
model_name = model_id.split('/')[-1]
#ds_name = dataset_path.split('/')[-1].replace('llama','combined_large').replace(':','-')

#ds_name = dataset_path.split('/')[-1]
output_dir = f'./{model_name}_{ds_name}/models'
logging_dir = f'{output_dir}/logs'

run_name = f'resumed_{ds_name}_{time_stamp}'
optim = 'paged_adamw_8bit'

from pathlib import Path
Path(output_dir).mkdir(parents=True, exist_ok=True)
Path(logging_dir).mkdir(parents=True, exist_ok=True)

repo_id = f'{model_name}-{ds_name}'

In [None]:
!python ./run_clm.py \
--output_dir {output_dir} \
--logging_dir {logging_dir} \
--model_id {model_id} \
--dataset_path {dataset_path} \
--run_name {run_name} \
--repo_id {repo_id} \
--report_to_wandb 0 \
--epochs 1 \
--max_steps -1 \
--per_device_train_batch_size 16 \
--per_device_eval_batch_size 16 \
--gradient_accumulation_steps 8 \
--lr 2e-4 \
--entity 'ft-llmmm' \
--project_name 'SFT_training_dm' \
--hub_strategy 'every_save' \
--torch_compile 0 \
--gradient_checkpointing 1 \
--optim 'paged_adamw_8bit' \
--group_by_length 1 \
--hf_token {hf_token} \
--wandb_token {wandb_token} \
--use_flash_attention 1 \
--logging_steps 10 \
--resume_from_checkpoint 1 \
--auto_find_batch_size 0

In [None]:
model_id = "meta-llama/Llama-2-13b-hf" # sharded weights
model_name = model_id.split('/')[-1]

dataset_path = './data/ds_wiki_1024_full'
ds_name = dataset_path.split('/')[-1]

now = datetime.now()
time_stamp = now.strftime("%m.%d.%y-%H.%M.%S")

#model_name = model_id.replace('/','-')
#ds_name = dataset_path.split('/')[-1].replace('llama','combined_large').replace(':','-')

output_dir = f'./{model_name}_{ds_name}/models'
logging_dir = f'{output_dir}/logs'

run_name = f'{ds_name}_{time_stamp}'
optim = 'paged_adamw_8bit'

from pathlib import Path
Path(output_dir).mkdir(parents=True, exist_ok=True)
Path(logging_dir).mkdir(parents=True, exist_ok=True)

repo_id = f'{model_name}-{ds_name}'

!python ./run_clm.py \
--output_dir {output_dir} \
--logging_dir {logging_dir} \
--model_id {model_id} \
--dataset_path {dataset_path} \
--run_name {run_name} \
--repo_id {repo_id} \
--report_to_wandb 1 \
--epochs 1 \
--max_steps -1 \
--per_device_train_batch_size 16 \
--per_device_eval_batch_size 16 \
--gradient_accumulation_steps 8 \
--lr 2e-4 \
--entity 'ft-llmmm' \
--project_name 'SFT_training_dm' \
--hub_strategy 'every_save' \
--torch_compile 0 \
--gradient_checkpointing 1 \
--optim 'paged_adamw_8bit' \
--group_by_length 1 \
--hf_token {hf_token} \
--wandb_token {wandb_token} \
--use_flash_attention 1 \
--logging_steps 10 \
--resume_from_checkpoint 0 \
--auto_find_batch_size 0

args is Namespace(model_id='meta-llama/Llama-2-13b-hf', repo_id='Llama-2-13b-hf-ds_wiki_1024_full', hub_strategy='every_save', output_dir='./Llama-2-13b-hf_ds_wiki_1024_full/models', output_data_dir=None, dataset_path='./data/ds_wiki_1024_full', hf_token='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', report_to_wandb=1, wandb_token='93b4fb1b729b939f257d7db15130b3710cad2ebb', epochs=1, max_steps=-1, per_device_train_batch_size=16, per_device_eval_batch_size=16, gradient_accumulation_steps=8, max_seq_length=4096, logging_steps=10, optim='paged_adamw_8bit', lr=0.0002, lora_r=64, lora_alpha=16, weight_decay=0.1, lora_dropout=0.1, load_in_4bit=1, load_in_8bit=0, use_peft=1, gradient_checkpointing=1, bf16=1, group_by_length=1, merge_weights=0, seed=42, warmup_ratio=0.03, project_name='SFT_training_dm', entity='ft-llmmm', run_name='ds_wiki_1024_full_09.05.23-18.50.22', load_best_model_at_end=1, use_sagemaker=1, torch_compile=0, use_flash_attention=1, resume_from_checkpoint=0, auto_find_batch_size=0)

In [None]:
model_id = "meta-llama/Llama-2-13b-hf" # sharded weights
model_name = model_id.split('/')[-1]

dataset_path = './data/ds_eli5_1024'
ds_name = dataset_path.split('/')[-1]

now = datetime.now()
time_stamp = now.strftime("%m.%d.%y-%H.%M.%S")

#model_name = model_id.replace('/','-')
#ds_name = dataset_path.split('/')[-1].replace('llama','combined_large').replace(':','-')

output_dir = f'./{model_name}_{ds_name}/models'
logging_dir = f'{output_dir}/logs'

run_name = f'{model_name}_{ds_name}_{time_stamp}'
optim = 'paged_adamw_8bit'

from pathlib import Path
Path(output_dir).mkdir(parents=True, exist_ok=True)
Path(logging_dir).mkdir(parents=True, exist_ok=True)

repo_id = f'{model_name}-{ds_name}'

!python ./run_clm.py \
--output_dir {output_dir} \
--logging_dir {logging_dir} \
--model_id {model_id} \
--dataset_path {dataset_path} \
--run_name {run_name} \
--repo_id {repo_id} \
--report_to_wandb 1 \
--epochs 1 \
--max_steps -1 \
--per_device_train_batch_size 16 \
--per_device_eval_batch_size 16 \
--gradient_accumulation_steps 8 \
--lr 2e-4 \
--entity 'ft-llmmm' \
--project_name 'SFT_training_dm' \
--hub_strategy 'every_save' \
--torch_compile 0 \
--gradient_checkpointing 1 \
--optim 'paged_adamw_8bit' \
--group_by_length 1 \
--hf_token {hf_token} \
--wandb_token {wandb_token} \
--use_flash_attention 1 \
--logging_steps 10 \
--auto_find_batch_size 0

args is Namespace(model_id='meta-llama/Llama-2-13b-hf', repo_id='Llama-2-13b-hf-ds_eli5_1024', hub_strategy='every_save', output_dir='./Llama-2-13b-hf_ds_eli5_1024/models', output_data_dir=None, dataset_path='./data/ds_eli5_1024', hf_token='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', report_to_wandb=1, wandb_token='93b4fb1b729b939f257d7db15130b3710cad2ebb', epochs=1, max_steps=-1, per_device_train_batch_size=16, per_device_eval_batch_size=16, gradient_accumulation_steps=8, max_seq_length=4096, logging_steps=10, optim='paged_adamw_8bit', lr=0.0002, lora_r=64, lora_alpha=16, weight_decay=0.1, lora_dropout=0.1, load_in_4bit=1, load_in_8bit=0, use_peft=1, gradient_checkpointing=1, bf16=1, group_by_length=1, merge_weights=0, seed=42, warmup_ratio=0.03, project_name='SFT_training_dm', entity='ft-llmmm', run_name='ds_eli5_1024_09.06.23-00.19.39', load_best_model_at_end=1, use_sagemaker=1, torch_compile=0, use_flash_attention=1, resume_from_checkpoint=0, auto_find_batch_size=0)
extra is ['--loggin

In [None]:
from google.colab import runtime
runtime.unassign()

# Inference

### Computing Predictions

In [5]:
from huggingface_hub import login
from collections import defaultdict
from transformers import AutoTokenizer
from tqdm import tqdm
from peft import PeftModel
import pickle
import os
import pandas as pd
from transformers import pipeline

In [6]:
def inference_formatting(example):
    return f"### Human: {example}\n ### Assistant:"

def generate_examples(model,
                      tokenizer,
                      data,
                      padding=True):
    generation_config = transformers.GenerationConfig(num_beams = 1,
                                         max_new_tokens = 256,
                                         do_sample = True,
                                         temperature = .6,
                                         top_p = 0.9,
                                         repetition_penalty = 1.2,
                                         #pad_token_id = model.config.eos_token_id
                                        )

    prompts = data['prompt']

    #pipe = pipeline('text-generation',model,tokenizer=tokenizer)

    #predictions = pipe(prompts,generation_config = generation_config)

    input = tokenizer(prompts, return_tensors = 'pt', padding = padding).to('cuda')

    output_ids = model.generate(input_ids = input['input_ids'],
                                attention_mask = input['attention_mask'],
                                generation_config = generation_config,
                                pad_token_id = model.config.eos_token_id,
                                )

    predictions =  [tokenizer.decode(ids, skip_special_tokens = True) for ids in output_ids]

    return predictions

def generate_df_predictions(model_ids,
                            ds,
                            output_dir,
                            batch_size=16,
                            seed = 50,
                            size = 100,
                            padding=True):

    bnb_config = BitsAndBytesConfig(
        #load_in_8bit=True,
        #load_in_4bit=True,
        #bnb_4bit_use_double_quant=True,
        #bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
    )

    ds_small = {}

    for base_model, model_id in model_ids:
        print(f'working on model {model_id.split("/")[-1]}')
        tokenizer = AutoTokenizer.from_pretrained(model_id)

        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.padding_side = "left"
        predictions = defaultdict(list)

        if base_model:

            model = AutoModelForCausalLM.from_pretrained(
                base_model,
                device_map="auto",
                load_in_4bit = True,
                quantization_config=bnb_config
                )

            model = PeftModel.from_pretrained(model = model,
                            model_id = model_id,
                            torch_dtype = torch.bfloat16,
                            is_trainable = False)
        else:
            model = AutoModelForCausalLM.from_pretrained(
                model_id,
                device_map="auto",
                load_in_4bit = True,
                quantization_config=bnb_config
                )

        model.eval()
        model_name = model_id.split('/')[-1]

        for ds_name in ds:

            ds_small = ds[ds_name]['validation'].map(lambda x: {'prompt':inference_formatting(x['question'])})
            ds_small = ds_small.shuffle(seed=seed)
            ds_small = ds_small.select(range(size))

            print(f'working on dataset {ds_name}')

            for k in tqdm(range(0,len(ds_small),batch_size)):
                prediction = generate_examples(model,tokenizer, ds_small[k:k+batch_size],padding=padding)
                predictions[model_name,ds_name].append(prediction)

            with open(f'./val_results_new_merge/{model_name}_{ds_name}.pkl', 'wb') as f:
                pickle.dump(predictions[model_name,ds_name], f)

            rouge_scores = {}
            bert_scores = {}

        for model_name, _ in predictions:

            rouge_scores[(model_name,ds_name)] = rouge.compute(
                predictions = results[model_name,ds_name],
                references = ds_small['QA']
            )

            bert_scores[(model_name,ds_name)] = bertscore.compute(
                predictions = results[model_name,ds_name],
                references = ds_small['QA'],
                lang='en')
        del model

    df_preds = pd.DataFrame(predictions)
    df_rouge = pd.DataFrame(rouge_scores)
    df_bert = pd.DataFrame(bert_scores)

    df_preds.to_csv(output_dir+'/predictions.csv',index=False)
    df_rouge.to_csv(output_dir+'/rouge.csv',index=False)
    df_bert.to_csv(output_dir+'/bertscore.csv',index=False)



    return df_preds, df_rouge, df_bert


In [7]:
ds = {}
ds['full'] = datasets.load_from_disk('../data/SFT_QA_dataset_llama')
ds['wiki'] = ds['full'].filter(lambda x: x['source']=='simple_wiki')
ds['eli5'] = ds['full'].filter(lambda x: x['source']!='simple_wiki')

In [8]:
model_ids = []
model_ids.append((None,'meta-llama/Llama-2-7b-hf'))
model_ids.append((None,'dhmeltzer/llama-7b-SFT_ds_eli5_1024_r_64_alpha_16_merged'))
model_ids.append((None,'dhmeltzer/llama-7b-SFT_ds_wiki65k_1024_r_64_alpha_16_merged'))
model_ids.append((None,'dhmeltzer/llama-7b-SFT_eli5_wiki65k_1024_r_64_alpha_16_merged'))
#model_ids.append((None,'meta-llama/Llama-2-13b-hf'))
#model_ids.append(('meta-llama/Llama-2-13b-hf', 'dhmeltzer/Llama-2-13b-hf-ds_wiki_1024_full_r_64_alpha_16'))
#model_ids.append(('meta-llama/Llama-2-13b-hf', 'dhmeltzer/Llama-2-13b-hf-ds_eli5_1024_r_64_alpha_16'))
#model_ids.append(('meta-llama/Llama-2-13b-hf', 'dhmeltzer/Llama-2-13b-hf-eli5-wiki-1024_r_64_alpha_16'))

In [11]:
%ls

'=0.13.3'
'=0.3.1'
 [0m[01;34martifacts[0m/
 [01;34mdata[0m/
 [01;34mdistilgpt2_ds_wiki_1024_full[0m/
 [01;34mEleutherAI-pythia-70m-deduped_ds_wiki_1024_full[0m/
 [01;34mLlama-2-13b-hf_ds_eli5_1024[0m/
 [01;34mLlama-2-13b-hf_ds_wiki_1024_full[0m/
 [01;34mLlama-2-13b-hf_eli5-wiki-1024[0m/
 [01;34mLlama-2-7b-hf_combined_large_QA_tokenized_1024-v1[0m/
 [01;34mLlama-2-7b-hf_combined_tokenized_1024[0m/
 [01;34mLlama-2-7b-hf_ds_wiki_1024[0m/
 [01;34mLlama-2-7b-hf_ds_wiki_1024_full[0m/
 [01;34mLlama-2-7b-hf_wiki_r_64_alpha_16_wiki[0m/
 [01;34mNone[0m/
 [01;34mNousResearch-Llama-2-7b-hf_ds_wiki_1024_full[0m/
 [01;34mresults[0m/
 run_clm.py
 SFT_QA.ipynb
 [01;34mutils[0m/
 [01;34mval_results[0m/
 [01;34mval_results_new_merge[0m/
 [01;34mwandb[0m/


In [None]:
generate_df_predictions(model_ids,
                        ds,
                        './llama-2-7b-inference',
                        batch_size=2,
                        padding=True)

working on model Llama-2-7b-hf


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

working on dataset full


100%|██████████| 50/50 [09:59<00:00, 12.00s/it]


Map:   0%|          | 0/4991 [00:00<?, ? examples/s]

working on dataset wiki


 82%|████████▏ | 41/50 [08:07<01:48, 12.06s/it]

In [17]:
import transformers
import gc
gc.collect()
torch.cuda.empty_cache()

In [68]:

output_ids = model.generate(input_ids = input['input_ids'],
                            attention_mask = input['attention_mask'],
                            #pad_token_id = model.config.eos_token_id,
                            )

predictions =  [tokenizer.decode(ids, skip_special_tokens = True) for ids in output_ids]

predictions

['complete this sentence: I am a ']

In [74]:
locals()

{'__name__': '__main__',
 '__doc__': 'Automatically created module for IPython interactive environment',
 '__package__': None,
 '__loader__': None,
 '__spec__': None,
 '__builtin__': <module 'builtins' (built-in)>,
 '__builtins__': <module 'builtins' (built-in)>,
 '_ih': ['',
  'from google.colab import drive\ndrive.mount(\'/content/drive\')\nget_ipython().run_line_magic(\'cd\', \'drive/MyDrive/LLMs/Fine-tuning/SFT\')\n\n# installations\n#!pip install detoxify\n\nget_ipython().system(\'pip install peft==0.4.0\')\nget_ipython().system(\'pip install bitsandbytes==0.41.1\')\nget_ipython().system(\'pip install safetensors>=0.3.1\')\nget_ipython().system(\'pip install trl\')\nget_ipython().system(\'pip install wandb\')\nget_ipython().system(\'pip install tokenizers>=0.13.3\')\nget_ipython().system(\'pip install -U transformers\')\nget_ipython().system(\'pip install accelerate==0.21.0\')\nget_ipython().system(\'pip install datasets\')\nget_ipython().system(\'pip install -U torch\')\nget_ipyt

### Fixing results formatting

In [None]:
df_results = pd.read_csv('./data/df_predictions')

In [None]:
results = defaultdict(list)

for col in df_results.columns[1:]:
    if '.' in col:
        model_name = col.split('.')[0]
    else:
        model_name = col
    ds_name = df_results[col][0]

    for i in range(1,8):
        results[(model_name,ds_name)].extend(eval(df_results[col][i]))

df_results_fixed = pd.DataFrame(results).T

In [None]:
df_results_fixed.to_csv('./results/df_results_fixed.csv')

In [None]:
run = wandb.init(entity='ft-llmmm',project='inference')
run.log({'Val_Predictions':wandb.Table(dataframe=df_results_fixed.iloc[:,:99])})

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01666997038333117, max=1.0)…

In [None]:
wandb.finish()

In [None]:
df_results_fixed=pd.read_csv('./results/df_results_fixed.csv')

In [None]:
df_results_fixed.iloc[i:i+1,:]

Unnamed: 0,model,dataset,0,1,2,3,4,5,6,7,...,90,91,92,93,94,95,96,97,98,99
0,llama-7b-SFT_eli5_wiki65k_1024_r_64_alpha_16,full,### Human: What was the purpose of Apollo 10?\...,"### Human: Who is Hervé Barulea, also known as...",### Human: Who was Danny Murphy and what were ...,### Human: Who was David Azulai and when did h...,### Human: What is a song and what are some di...,### Human: What is the origin of the Chenab Ri...,### Human: Who was Richard Pryor and what awar...,### Human: What is the defense mechanism of th...,...,### Human: Who was Carlos Amadeu?\n ### Assist...,### Human: What is Pau and where is it located...,### Human: What is the black sapote also known...,### Human: What is the title of the movie Litt...,### Human: Who are the Twelve Olympians in Gre...,### Human: Who is Michael Blunden?\n ### Assis...,### Human: What is a damson and how is it diff...,### Human: Who was Nat King Cole?\n ### Assist...,### Human: What is the official currency of Si...,### Human: What is the Klondike?\n ### Assista...


In [None]:
series_pred['model'].values[0]

'llama-7b-SFT_ds_wiki65k_1024_r_64_alpha_16'

In [None]:
for i in range(12):
    series_pred = df_results_fixed.iloc[i:i+1,:]
    model_name = series_pred['model'].values[0]
    ds_name = series_pred['dataset'].values[0]

    run.log({f'{model_name}_{ds_name}_eval':wandb.Table(dataframe=series_pred.iloc[:1,2:])})

In [None]:
df_results_fixed= df_results_fixed.rename(columns={'Unnamed: 0':'model'})

In [None]:
df_results_fixed = df_results_fixed.set_index(['Unnamed: 0','Unnamed: 1'])

In [None]:
!pip install huggingface_hub
from huggingface_hub import login
login()

In [None]:
llama_tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-2-7b-hf')

In [None]:
df_results_split = df_results_fixed.applymap(lambda x:x.split('Assistant:')[-1])
df_results_split = df_results_split.applymap(lambda x: len(llama_tokenizer(x)['input_ids']))

In [None]:
import numpy as np
import matplotlib.pyplot as plt

In [None]:
df_results_split.agg(func=['mean','min','max'],axis=1)

Unnamed: 0_level_0,Unnamed: 1_level_0,mean,min,max
Unnamed: 0,Unnamed: 1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
llama-7b-SFT_eli5_wiki65k_1024_r_64_alpha_16,full,512.01,366.0,514.0
llama-7b-SFT_ds_wiki65k_1024_r_64_alpha_16,full,505.15,312.0,515.0
llama-7b-SFT_ds_eli5_1024_r_64_alpha_16,full,513.95,513.0,515.0
llama2-7b,full,513.98,513.0,515.0
llama-7b-SFT_eli5_wiki65k_1024_r_64_alpha_16,wiki,511.28,378.0,516.0
llama-7b-SFT_ds_wiki65k_1024_r_64_alpha_16,wiki,510.96,350.0,516.0
llama-7b-SFT_ds_eli5_1024_r_64_alpha_16,wiki,511.45,261.0,514.0
llama2-7b,wiki,513.95,513.0,514.0
llama-7b-SFT_eli5_wiki65k_1024_r_64_alpha_16,eli5,513.39,462.0,514.0
llama-7b-SFT_ds_wiki65k_1024_r_64_alpha_16,eli5,502.45,334.0,515.0


In [None]:
adapter_model_ids = ['dhmeltzer/llama-7b-SFT_eli5_wiki65k_1024_r_64_alpha_16',
          'dhmeltzer/llama-7b-SFT_ds_wiki65k_1024_r_64_alpha_16',
          'dhmeltzer/llama-7b-SFT_ds_eli5_1024_r_64_alpha_16']

model_names = [model_id.split('/')[-1] for model_id in adapter_model_ids]

model_names.append('llama2-7b')

### Rouge

In [None]:
rouge = evaluate.load('rouge')

In [None]:
rouge_scores = {}

for model_name in model_names:
    print(f'working on model {model_name}')
    for ds_name in ds:
        print(f'working on dataset {ds_name}')

        data = ds[ds_name]['validation']
        data = data.shuffle(seed=50)
        data_small = data.select(range(100))

        if (model_name,ds_name) in rouge_scores:
            continue

        rouge_scores[(model_name,ds_name)] = rouge.compute(
            predictions = results[model_name,ds_name],
            references = data_small['QA']
        )

In [None]:
indices_relabel = {
    'llama-7b-SFT_ds_eli5_1024_r_64_alpha_16':'llama2-7b-eli5',
    'llama-7b-SFT_ds_wiki65k_1024_r_64_alpha_16':'llama2-7b-wiki',
    'llama-7b-SFT_eli5_wiki65k_1024_r_64_alpha_16':'llama2-7b-eli5-wiki'
}

In [None]:
df_dict = {}

for ds_name in ['eli5','wiki','full']:

    df_dict[ds_name] = pd.concat([pd.Series(rouge_scores[(model_name,ds_name)])
            for model_name in model_names],axis=1).T
    df_dict[ds_name].index = model_names
    df_dict[ds_name] = df_dict[ds_name].rename(index = indices_relabel)
    df_dict[ds_name] = df_dict[ds_name].loc[['llama2-7b',
                        'llama2-7b-eli5',
                        'llama2-7b-wiki',
                        'llama2-7b-eli5-wiki']]

In [None]:
run = wandb.init(entity='ft-llmmm',
                 project='inference')

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [None]:
for ds_name in ['eli5','wiki','full']:
    table = wandb.Table(dataframe=df_dict[ds_name].reset_index())
    run.log({f'llama2-7b_{ds_name}':table})

In [None]:
rouge1 = pd.DataFrame(rouge_scores).loc['rouge1'].unstack()

rouge1_v2 = rouge1.reindex(['llama2-7b',
                         'llama-7b-SFT_ds_eli5_1024_r_64_alpha_16',
                         'llama-7b-SFT_ds_wiki65k_1024_r_64_alpha_16',
                         'llama-7b-SFT_eli5_wiki65k_1024_r_64_alpha_16'])

rouge1_v2 = rouge1_v2.rename(index = indices_relabel)
rouge1_v2 = rouge1_v2[['eli5','wiki','full']]

In [None]:
df_combined = pd.DataFrame(rouge_scores).stack().rename(columns=indices_relabel)
df_combined = df_combined.swaplevel().unstack().loc[['eli5','wiki','full']]
df_combined = df_combined.T.unstack()
df_combined = df_combined.loc[['llama2-7b',
                               'llama2-7b-eli5',
                               'llama2-7b-wiki',
                               'llama2-7b-eli5-wiki']]
#df_combined.stack().stack().swaplevel(i=0,j=2).unstack().unstack()

In [None]:
!pip install huggingface_hub
from huggingface_hub import login
login()

In [None]:
df_combined.to_csv('./results/df_combined.csv')

### Bert-Score

In [None]:
import numpy as np

In [None]:
bertscore = evaluate.load("bertscore")

Downloading builder script:   0%|          | 0.00/7.95k [00:00<?, ?B/s]

In [None]:
bert_scores = {}

for model_name in model_names:
    print(f'working on model {model_name}')
    for ds_name in ds:
        print(f'working on dataset {ds_name}')

        data = ds[ds_name]['validation']
        data = data.shuffle(seed=50)
        data_small = data.select(range(100))

        bert_scores[(model_name,ds_name)] = bertscore.compute(
            predictions = results[model_name,ds_name],
            references = data_small['QA'],
            lang='en'
        )

In [None]:
df_bert_scores_v0 = pd.DataFrame(bert_scores)
df_bert_scores_v0.to_csv('./results/df_bert_scores_v0.csv')

In [None]:
models_relabel = {
    'llama-7b-SFT_ds_eli5_1024_r_64_alpha_16':'llama2-7b-eli5',
    'llama-7b-SFT_ds_wiki65k_1024_r_64_alpha_16':'llama2-7b-wiki',
    'llama-7b-SFT_eli5_wiki65k_1024_r_64_alpha_16':'llama2-7b-eli5-wiki'
}

def fix_names(model_name):
    if '.' in model_name:
        model_name = model_name.split('.')[0]
    if model_name in models_relabel:
        return models_relabel[model_name]
    else:
        return model_name

In [None]:
df_bert_scores_v0 = pd.read_csv('./results/df_bert_scores_v0.csv').T
df_bert_scores_v0.columns=['dataset','precision','recall','f1','hashcode']
df_bert_scores_v0 = df_bert_scores_v0.iloc[1:,:-1].reset_index()
df_bert_scores_v0['index'] = df_bert_scores_v0['index'].apply(fix_names)
#df_bert_scores_v0 = df_bert_scores_v0.set_index(['index','dataset'])

In [None]:
bert_score_summary = {}
for ds_name in ['eli5','wiki','full']:
    bert_score_summary[ds_name] = df_bert_scores_v0[df_bert_scores_v0['dataset']==
                                                    ds_name][['index','precision','recall','f1']]
    bert_score_summary[ds_name] = bert_score_summary[ds_name].set_index('index')
    bert_score_summary[ds_name] = bert_score_summary[ds_name].loc[['llama2-7b',
                                                                   'llama2-7b-eli5',
                                                                   'llama2-7b-wiki',
                                                                   'llama2-7b-eli5-wiki']]
    bert_score_summary[ds_name] = bert_score_summary[ds_name].applymap(lambda x:
                                                                       np.mean(eval(x)))

In [None]:
run = wandb.init(entity='ft-llmmm',
                 project='inference')

for ds_name in ['eli5','wiki','full']:
    table = wandb.Table(dataframe=bert_score_summary[ds_name].reset_index())
    run.log({f'llama2-7b_bertscore_{ds_name}':table})

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [None]:
run.finish()

VBox(children=(Label(value='0.005 MB of 0.016 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.299669…

#Merging Weights

In [None]:
import torch
import peft
import json
import shutil
from peft.utils import _get_submodules
import os
import bitsandbytes as bnb
from bitsandbytes.functional import dequantize_4bit
from peft import PeftModel
from transformers import (AutoModelForCausalLM,
                          LlamaForCausalLM,
                          LlamaTokenizer,
                          BitsAndBytesConfig,
                          AutoTokenizer
)
import gc
import copy
from getpass import getpass

In [None]:
def dequantize_model(model, tokenizer, dtype=torch.bfloat16, device="cuda"):
    """
    'model': the peftmodel you loaded with qlora.
    'tokenizer': the model's corresponding hf's tokenizer.
    'to': directory to save the dequantized model
    'dtype': dtype that the model was trained using
    'device': device to load the model to
    """

    cls = bnb.nn.Linear4bit

    with torch.no_grad():
        for name, module in model.named_modules():
            if isinstance(module, cls):
                print(f"Dequantizing `{name}`...")
                quant_state = copy.deepcopy(module.weight.quant_state)

                quant_state[2] = dtype

                weights = dequantize_4bit(module.weight.data, quant_state=quant_state, quant_type="nf4").to(dtype)

                new_module = torch.nn.Linear(module.in_features, module.out_features, bias=None, dtype=dtype)
                new_module.weight = torch.nn.Parameter(weights)
                new_module.to(device=device, dtype=dtype)

                parent, target, target_name = _get_submodules(model, name)
                setattr(parent, target_name, new_module)

        # a hack, setting this to avoid hf's saving error because hf
        # itself does not support saving a model that is registered to be loaded in 4bit.
        model.is_loaded_in_4bit = False
        return model

def merge_weights(base_model_id,
                  adapter_model_id,
                  hf_token,
                  dtype=torch.bfloat16,
                  device="cuda"):

    repo_id = adapter_model_id+'_merged'

    quantization_config=BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.bfloat16,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
        )

    model = AutoModelForCausalLM.from_pretrained(
        base_model_id,
        load_in_4bit=True,
        torch_dtype=torch.bfloat16,
        quantization_config=quantization_config,
        device_map={"": 0},
        use_auth_token=hf_token
        )

    tok = AutoTokenizer.from_pretrained(base_model_id,
                                         use_auth_token=hf_token
                                        )
    model = dequantize_model(model, tok)
    model = PeftModel.from_pretrained(model = model, model_id = adapter_model_id)
    model = model.merge_and_unload()

    model.push_to_hub(repo_id,safe_serialization=True)
    tok.push_to_hub(repo_id)

In [None]:
adapter_models = [
    'dhmeltzer/llama-7b-SFT_eli5_wiki65k_1024_r_64_alpha_16',
    'dhmeltzer/llama-7b-SFT_ds_wiki65k_1024_r_64_alpha_16',
    'dhmeltzer/llama-7b-SFT_ds_eli5_1024_r_64_alpha_16'
]
base_model_id = 'meta-llama/Llama-2-7b-hf'

for adapter_model in adapter_models:
    merge_weights(base_model_id,
                  adapter_model,
                  hf_token,
                  dtype=torch.bfloat16,
                  device="cuda")

In [None]:
base_model_id = 'meta-llama/Llama-2-13b-hf'
adapter_model ='dhmeltzer/Llama-2-13b-hf-eli5-wiki-1024_r_64_alpha_16'

merge_weights(base_model_id,
                adapter_model,
                hf_token,
                dtype=torch.bfloat16,
                device="cuda")