In [4]:
import numpy as np
import pandas as pd

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

import utilities

In [5]:
from accelerate import FullyShardedDataParallelPlugin, Accelerator
from torch.distributed.fsdp.fully_sharded_data_parallel import FullOptimStateDictConfig, FullStateDictConfig

fsdp_plugin = FullyShardedDataParallelPlugin(
    state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=False),
    optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=False),
)

accelerator = Accelerator(fsdp_plugin=fsdp_plugin)

Detected kernel version 4.14.287, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [6]:
base_model_id = "mistralai/Mistral-7B-Instruct-v0.2"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

model = AutoModelForCausalLM.from_pretrained(base_model_id, quantization_config=bnb_config, device_map="auto")

tokenizer = AutoTokenizer.from_pretrained(base_model_id)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Fine-tune

In [7]:
# load the tables
conversation_df = pd.read_csv("annotated_trees_101.csv")
conversation_df = conversation_df.rename({'Unnamed: 0': "index"}, axis=1)
df_only_bad_tone = pd.read_csv("bad_tone_nodes_with_generated_messages_chat_gpt_3_5_turbo.csv")
df_only_bad_tone["neg branch path"] = df_only_bad_tone["neg branch path"].apply(eval) # converts the list string into list

In [8]:
df_only_bad_tone

Unnamed: 0,index,node_id,tree_id,timestamp,author,text,parent,Aggressive,AgreeBut,AgreeToDisagree,...,RephraseAttack,RequestClarification,Ridicule,Sarcasm,Softening,Sources,ViableTransformation,WQualifiers,neg branch path,generated_moderation
0,34,d5307dl,4rl42j,1467909986,Hq3473,So a marginal increase in safety (say 1% less ...,33,0,0,0,...,0,0,1,0,0,0,0,0,"[34, 33, 32, 31, 29, 28]","Hq3473, please remember to keep the discussion..."
1,38,d533sdz,4rl42j,1467914340,Hq3473,I though bodily autonomy was absolute?,37,0,0,0,...,0,0,0,1,0,0,0,0,"[38, 37, 36, 35, 34, 33, 32, 31, 29, 28]",It seems like the discussion is getting into a...
2,43,d539o5o,4rl42j,1467921339,ThatBelligerentSloth,"<quote>Again, why would even a marginal increa...",42,0,0,0,...,0,0,0,0,0,0,0,0,"[43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 33, 3...","ThatBelligerentSloth, it's important to rememb..."
3,44,d539xnf,4rl42j,1467921636,Hq3473,<quote>Because of bodily autonomy</quote> Righ...,43,0,0,0,...,0,0,0,0,1,0,0,0,"[44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 3...","Hq3473, it seems like you are engaging in a th..."
4,56,d52epx2,4rl42j,1467864800,Slagernicus,I have a feeling that most people who have gon...,55,0,0,0,...,0,0,0,1,0,0,0,0,"[56, 55, 49, 48, 28]","Slagernicus, it's important to note that the t..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1727,10540,dugl4q0,7yf2le,1518989877,icyMG,Education is useless they teach us outdated us...,10508,1,0,0,...,0,0,0,0,0,0,0,0,"[10540, 10508]","icyMG, while it's understandable that you may ..."
1728,10542,duh97vv,7yf2le,1519019073,icyMG,Yeah well the government doesn't really give a...,10541,1,0,0,...,0,0,0,0,0,0,0,0,"[10542, 10541, 10540, 10508]","icyMG, please remember to keep the conversatio..."
1729,10551,duhtqls,7yf2le,1519056308,_busch,"wait. so every teacher, in every school, in ev...",10550,0,0,0,...,0,0,0,1,0,0,0,0,"[10551, 10550, 10508]","_busch, it's important to remember that the co..."
1730,10555,duibhx7,7yf2le,1519073875,_busch,you should teach!,10554,0,0,0,...,0,0,0,1,0,0,0,0,"[10555, 10554, 10553, 10552, 10551, 10550, 10508]","_busch, please remember to keep the conversati..."


In [None]:
# load the tables
pseudo_positive_df = utilities.obtain_pseudo_positive_conversation(conversation_df)
pseudo_positive_df["generated_moderation"] = "No moderation note is required."
pseudo_positive_df

In [10]:
pseudo_positive_df["pseudo positive"] = 1
df_only_bad_tone["pseudo positive"] = 0
df_only_bad_tone = df_only_bad_tone.rename({"neg branch path": "branch_path"}, axis=1)
pseudo_positive_df = pseudo_positive_df.rename({"pseudo positive branch path": "branch_path"}, axis=1)
# sample 1500 pseudo positive messages and concat to
pseudo_positive_samples_df = pseudo_positive_df.sample(1500, random_state=1)

train_and_test_df = pd.concat([df_only_bad_tone, pseudo_positive_samples_df]) # we will suffule them next

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  pseudo_positive_df["pseudo positive"] = 1


In [11]:
train_and_test_df

Unnamed: 0,index,node_id,tree_id,timestamp,author,text,parent,Aggressive,AgreeBut,AgreeToDisagree,...,RequestClarification,Ridicule,Sarcasm,Softening,Sources,ViableTransformation,WQualifiers,branch_path,generated_moderation,pseudo positive
0,34,d5307dl,4rl42j,1467909986,Hq3473,So a marginal increase in safety (say 1% less ...,33,0,0,0,...,0,1,0,0,0,0,0,"[34, 33, 32, 31, 29, 28]","Hq3473, please remember to keep the discussion...",0
1,38,d533sdz,4rl42j,1467914340,Hq3473,I though bodily autonomy was absolute?,37,0,0,0,...,0,0,1,0,0,0,0,"[38, 37, 36, 35, 34, 33, 32, 31, 29, 28]",It seems like the discussion is getting into a...,0
2,43,d539o5o,4rl42j,1467921339,ThatBelligerentSloth,"<quote>Again, why would even a marginal increa...",42,0,0,0,...,0,0,0,0,0,0,0,"[43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 33, 3...","ThatBelligerentSloth, it's important to rememb...",0
3,44,d539xnf,4rl42j,1467921636,Hq3473,<quote>Because of bodily autonomy</quote> Righ...,43,0,0,0,...,0,0,0,1,0,0,0,"[44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 3...","Hq3473, it seems like you are engaging in a th...",0
4,56,d52epx2,4rl42j,1467864800,Slagernicus,I have a feeling that most people who have gon...,55,0,0,0,...,0,0,1,0,0,0,0,"[56, 55, 49, 48, 28]","Slagernicus, it's important to note that the t...",0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2500,2500,dg3l7j7,64ki01,1491872193,potato1,Then all media are government endorsed.,2499,0,0,0,...,0,0,0,0,0,0,0,"[2500, 2499, 2498, 2492, 2490, 2489, 2488, 248...",No moderation note is required.,1
3940,3940,djfkxhx,6jmube,1498507572,Ansuz07,LGBT is not a federally protected class.,3939,0,0,0,...,0,0,0,0,0,0,0,"[3940, 3939, 3938, 3765]",No moderation note is required.,1
4986,4986,dkbe1xi,6nmn0d,1500257560,prettyinpinkpanther1,Im an advocate for American world supremacy ba...,4985,0,0,0,...,0,0,0,0,0,0,0,"[4986, 4985, 4915]",No moderation note is required.,1
10092,10092,dt6hj49,7snmsg,1516820989,OFGhost,What someone labels them has no impact on what...,10091,0,0,0,...,0,0,0,0,0,0,0,"[10092, 10091, 10084, 10083, 10082, 10081, 100...",No moderation note is required.,1


In [12]:
# Generate train dataset
train_data = list()
for i in range(train_and_test_df.shape[0]):
    index = train_and_test_df["index"].iloc[i]
    pseudo_positive_label = train_and_test_df["pseudo positive"].iloc[i]
    prompt = utilities.generate_branch_for_negative_tone_prompt_for_mistral(train_and_test_df, conversation_df, node_index=i, branch_col_name="branch_path")
    label = train_and_test_df["generated_moderation"].iloc[i]
    train_data.append((index, prompt, label, pseudo_positive_label))

dataset_df = pd.DataFrame(train_data, columns=["index", 'prompt', 'label', "pseudo positive"])
dataset_df.to_pickle("train_dataset_include_pseudo_positive.pkl")

In [13]:
from datasets import load_dataset
mod_dataset = load_dataset("pandas", data_files="train_dataset_include_pseudo_positive.pkl", split='train')
mod_dataset

Generating train split: 0 examples [00:00, ? examples/s]

Dataset({
    features: ['index', 'prompt', 'label', 'pseudo positive'],
    num_rows: 3232
})

In [14]:
def generate_mod_prompt(data_point):
    return f"""{data_point['prompt']}\n{data_point['label']}</s>"""


text_column = [generate_mod_prompt(ex) for ex in mod_dataset]
mod_dataset = mod_dataset.add_column("full_prompt", text_column)

mod_dataset = mod_dataset.shuffle(seed=1234)  # Shuffle dataset here
mod_dataset = mod_dataset.map(lambda samples: tokenizer(samples["full_prompt"]), batched=True)
mod_dataset = mod_dataset.train_test_split(test_size=0.2, seed=1234)
train_data = mod_dataset["train"]
test_data = mod_dataset["test"]

Map:   0%|          | 0/3232 [00:00<?, ? examples/s]

In [15]:
mod_dataset

DatasetDict({
    train: Dataset({
        features: ['index', 'prompt', 'label', 'pseudo positive', 'full_prompt', 'input_ids', 'attention_mask'],
        num_rows: 2585
    })
    test: Dataset({
        features: ['index', 'prompt', 'label', 'pseudo positive', 'full_prompt', 'input_ids', 'attention_mask'],
        num_rows: 647
    })
})

In [29]:
from peft import prepare_model_for_kbit_training
model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)

In [30]:
import bitsandbytes as bnb
def find_all_linear_names(model):
  cls = bnb.nn.Linear4bit #if args.bits == 4 else (bnb.nn.Linear8bitLt if args.bits == 8 else torch.nn.Linear)
  lora_module_names = set()
  for name, module in model.named_modules():
    if isinstance(module, cls):
      names = name.split('.')
      lora_module_names.add(names[0] if len(names) == 1 else names[-1])
    if 'lm_head' in lora_module_names: # needed for 16-bit
      lora_module_names.remove('lm_head')
  return list(lora_module_names)

modules = find_all_linear_names(model)
print(modules)

['v_proj', 'gate_proj', 'q_proj', 'o_proj', 'up_proj', 'down_proj', 'k_proj']


In [31]:
from peft import LoraConfig, get_peft_model

lora_config = LoraConfig(
    r=8,
    lora_alpha=32,
    target_modules=modules,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, lora_config)

In [32]:
trainable, total = model.get_nb_trainable_parameters()
print(f"Trainable: {trainable} | total: {total} | Percentage: {trainable/total*100:.4f}%")

Trainable: 20971520 | total: 7262703616 | Percentage: 0.2888%


In [33]:
# from huggingface_hub import notebook_login
# notebook_login()

In [34]:
#new code using SFTTrainer
import transformers

from trl import SFTTrainer

tokenizer.pad_token = tokenizer.eos_token
torch.cuda.empty_cache()

trainer = SFTTrainer(
    model=model,
    train_dataset=train_data,
    eval_dataset=test_data,
    dataset_text_field="full_prompt",
    peft_config=lora_config,
    args=transformers.TrainingArguments(
        per_device_train_batch_size=2,
        gradient_accumulation_steps=4,
        warmup_steps=1,
        max_steps=100,
        learning_rate=2.5e-5,
        logging_steps=1,
        save_steps=10,
        save_strategy="steps",
        eval_steps=10,
        evaluation_strategy="steps",
        do_eval=True,
        output_dir="outputs_mine_include_pseudo_positive",
        optim="paged_adamw_8bit",
        bf16=True,
        
        logging_dir="./logs_mine"
    ),
    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)



Map:   0%|          | 0/2585 [00:00<?, ? examples/s]

Map:   0%|          | 0/647 [00:00<?, ? examples/s]

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)
Detected kernel version 4.14.287, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [35]:
model.config.use_cache = False  # silence the warnings. Please re-enable for inference!
trainer.train()



Step,Training Loss,Validation Loss
10,2.0745,2.026882
20,2.0239,1.89122
30,1.8561,1.810342
40,1.6688,1.77167
50,1.8315,1.74497
60,1.9474,1.724498
70,1.8078,1.706986
80,1.6463,1.694251
90,1.8337,1.686669
100,1.6581,1.683793




TrainOutput(global_step=100, training_loss=1.8404169750213624, metrics={'train_runtime': 4305.2209, 'train_samples_per_second': 0.186, 'train_steps_per_second': 0.023, 'total_flos': 3.48009359597568e+16, 'train_loss': 1.8404169750213624, 'epoch': 0.31})

Load fined tuned model

In [36]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

base_model_id = "mistralai/Mistral-7B-Instruct-v0.2"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

base_model = AutoModelForCausalLM.from_pretrained(
    base_model_id,  # Mistral, same as before
    quantization_config=bnb_config,  # Same quantization config as before
    device_map="auto",
    trust_remote_code=True,
)

eval_tokenizer = AutoTokenizer.from_pretrained(base_model_id, add_bos_token=True, trust_remote_code=True)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [37]:
from peft import PeftModel

ft_model = PeftModel.from_pretrained(base_model, "outputs_mine_include_pseudo_positive/checkpoint-100")

In [18]:
def get_completion(prompt, model, tokenizer):
    model_inputs = tokenizer(prompt, return_tensors="pt", add_special_tokens=False).to("cuda")

    generated_ids = model.generate(**model_inputs, max_new_tokens=300, do_sample=True, pad_token_id=tokenizer.eos_token_id)
    decoded = tokenizer.batch_decode(generated_ids)
    
    return decoded[0]

def get_mod_message_generation_function(model, tokenizer):
    def generate_moderation_message(row):
        print('sample:', row.name)
        answer = get_completion(row['prompt'], model=model, tokenizer=tokenizer)
        splits = answer.split('[/INST]')
        if len(splits) != 2:
            return 'model produced illegal output..'
        return splits[-1].rstrip('</s>')
    
    return generate_moderation_message

In [19]:
train_df, test_df = train_data.to_pandas(), test_data.to_pandas()

In [41]:
fine_tuned_model_moderation = get_mod_message_generation_function(ft_model, eval_tokenizer)
test_df['fine_tuned_model_moderation'] = test_df.apply(fine_tuned_model_moderation, axis=1)

sample: 0


sample: 1
sample: 2
sample: 3
sample: 4
sample: 5
sample: 6
sample: 7
sample: 8
sample: 9
sample: 10
sample: 11
sample: 12
sample: 13
sample: 14
sample: 15
sample: 16
sample: 17
sample: 18
sample: 19
sample: 20
sample: 21
sample: 22
sample: 23
sample: 24
sample: 25
sample: 26
sample: 27
sample: 28
sample: 29
sample: 30
sample: 31
sample: 32
sample: 33
sample: 34
sample: 35
sample: 36
sample: 37
sample: 38
sample: 39
sample: 40
sample: 41
sample: 42
sample: 43
sample: 44
sample: 45
sample: 46
sample: 47
sample: 48
sample: 49
sample: 50
sample: 51
sample: 52
sample: 53
sample: 54
sample: 55
sample: 56
sample: 57
sample: 58
sample: 59
sample: 60
sample: 61
sample: 62
sample: 63
sample: 64
sample: 65
sample: 66
sample: 67
sample: 68
sample: 69
sample: 70
sample: 71
sample: 72
sample: 73
sample: 74
sample: 75
sample: 76
sample: 77
sample: 78
sample: 79
sample: 80
sample: 81
sample: 82
sample: 83
sample: 84
sample: 85
sample: 86
sample: 87
sample: 88
sample: 89
sample: 90
sample: 91
sample: 

Base model

In [20]:
base_model_id = "mistralai/Mistral-7B-Instruct-v0.2"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

model = AutoModelForCausalLM.from_pretrained(base_model_id, quantization_config=bnb_config, device_map="auto")

tokenizer = AutoTokenizer.from_pretrained(base_model_id)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [21]:
base_model_generation = get_mod_message_generation_function(model, tokenizer)
test_df['base_model_moderation'] = test_df.apply(base_model_generation, axis=1)

sample: 0
sample: 1
sample: 2
sample: 3
sample: 4
sample: 5
sample: 6
sample: 7
sample: 8
sample: 9
sample: 10
sample: 11
sample: 12
sample: 13
sample: 14
sample: 15
sample: 16
sample: 17
sample: 18
sample: 19
sample: 20
sample: 21
sample: 22
sample: 23
sample: 24
sample: 25
sample: 26
sample: 27
sample: 28
sample: 29
sample: 30
sample: 31
sample: 32
sample: 33
sample: 34
sample: 35
sample: 36
sample: 37
sample: 38
sample: 39
sample: 40
sample: 41
sample: 42
sample: 43
sample: 44
sample: 45
sample: 46
sample: 47
sample: 48
sample: 49
sample: 50
sample: 51
sample: 52
sample: 53
sample: 54
sample: 55
sample: 56
sample: 57
sample: 58
sample: 59
sample: 60
sample: 61
sample: 62
sample: 63
sample: 64
sample: 65
sample: 66
sample: 67
sample: 68
sample: 69
sample: 70
sample: 71
sample: 72
sample: 73
sample: 74
sample: 75
sample: 76
sample: 77
sample: 78
sample: 79
sample: 80
sample: 81
sample: 82
sample: 83
sample: 84
sample: 85
sample: 86
sample: 87
sample: 88
sample: 89
sample: 90
sample: 9

In [None]:
test_df = test_df[["index", "pseudo positive", "prompt", "label", "base_model_moderation", "fine_tuned_model_moderation"]]
test_df.to_csv("test_with_fine_tuned_moderations_mine_include_pseudo_positive.csv", index=False)