# Exp018: Conditional instruction fine-tuning
This experiment aims at instruction fine-tuning from existing skills in the dataset to train the model on single constraints.

In [1]:
from datasets import load_dataset
from dotenv import load_dotenv
load_dotenv()
import os
os.environ['CACHE_DIR'] = f"/scratch/tmp.{os.getenv('SLURM_JOB_ID')}.dglandorf" # speed up model loading
os.environ['WANDB_DIR'] = os.getenv('CACHE_DIR')

from tqdm.notebook import tqdm
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments
import math
import re

import pickle
from torch.utils.data import RandomSampler, Subset
import numpy as np
import json
from peft import LoraConfig, PeftModel
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
import random
import sys
sys.path.append(f'../source')
import helpers
import models
import evaluation
import importlib
#importlib.reload(evaluation)

[nltk_data] Downloading package punkt to
[nltk_data]     /cluster/home/dglandorf/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [2]:
# params
out_file = '../data/corpus_classification_all.pkl'
preprossed_dataset_file = '../data/SFT_data.jsonl'
model_name = "mistralai/Mistral-7B-Instruct-v0.2"
nrs = [1180] #[58, 616]#
classifiers = {nr: models.load_classifier(nr, "corpus_training") for nr in nrs}
EOP = "[/INST]"
egp = helpers.get_egp()

## Prepare dataset

In [3]:
with open(out_file, 'rb') as f:
    all_hit_indices = pickle.load(f)
    all_hit_sentences = pickle.load(f)
    extracts = pickle.load(f)

data = [{"context": extracts[idx][0],
         "response": extracts[idx][1],
         "nr": [nr],
         "source": extracts[idx][2],} for nr in nrs for idx in all_hit_indices[nr]]

In [4]:
def formatting_func(example, unconstrained=False):
    rules = egp[egp['#'].isin(example['nr'])]
    constraints = os.linesep.join("- " + rules['SubCategory'] + ": " + rules['Can-do statement']) # " - " + rules['guideword']
    context = os.linesep.join([("A" if (i%2==0) else "B") + ": " + utt for i, utt in enumerate(example["context"])])

    instruction = f"""Write the response of A and include these grammatical items in the response:
{constraints}"""
   # instruction = 'Write an answer of A that includes the affirmative form of "would like".'
    if unconstrained: instruction = "Write the response of A."
    prompt = f"""[INST] 
{instruction}
Dialog:
{context} {EOP} 
A: """
    completion = f"{example['response']}"
    return {'prompt': prompt, 'completion': completion, 'text': prompt+completion+"</s>"}

def unconstrain_item(item):
    item.update(formatting_func(item, unconstrained=True))
    return item
    
with open(preprossed_dataset_file, 'w') as f:
    for item in tqdm(data):
        item.update(formatting_func(item))
        f.write(json.dumps(item) + '\n')

  0%|          | 0/2042 [00:00<?, ?it/s]

### Load dataset

In [5]:
dataset = load_dataset('json', data_files=preprossed_dataset_file, split='train', cache_dir=os.getenv('CACHE_DIR'))
train_test_split = dataset.train_test_split(test_size=100)
train_dataset, test_dataset = train_test_split['train'], train_test_split['test']

unconstrained = test_dataset.map(unconstrain_item)

Generating train split: 0 examples [00:00, ? examples/s]

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

## Load and prepare base model

In [6]:
bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_quant_type="nf4")
model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=bnb_config, cache_dir=os.getenv('CACHE_DIR'), device_map="auto")
model.config.use_cache = False
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, cache_dir=os.getenv('CACHE_DIR'), padding_side="right")
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
tokenizer.pad_token = '[PAD]'
model.resize_token_embeddings(len(tokenizer))

Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/4.54G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.46k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.80M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/72.0 [00:00<?, ?B/s]

Embedding(32001, 4096)

### Inference with current model

In [7]:
def extract_response(output):
    match = re.search(r'(.*)(\nB:)?', output)
    return match.group(1)
    
def generate(prompts, max_new_tokens=128, batch_size=32, verbose=False):
    tokenizer.padding_side = "left"
    model.eval()
    outputs = []
    for i in tqdm(range(0, len(prompts), batch_size), total=math.ceil(len(prompts)/batch_size)):
        batch = prompts[i:i + batch_size]
        model_input = tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
        with torch.no_grad():
            token_ids = model.generate(**model_input, max_new_tokens=max_new_tokens, pad_token_id=2, eos_token_id=[2,32000])
        outputs += tokenizer.batch_decode(token_ids[:,model_input['input_ids'].shape[1]:], skip_special_tokens=True, device="cpu")
        if verbose: print(outputs[-batch_size:])
    tokenizer.padding_side = "right"
    responses = [extract_response(output) for output in outputs]
    return responses[0] if len(responses)==1 else responses

In [8]:
example = random.choice(test_dataset)
#example = train_dataset[10]
#example['nr'] = [58, 616]
#example['text'] = formatting_func(example)
#print(example['text'])
#example.update(formatting_func(example, unconstrained=True))
print(example['text'])

generate([example['prompt']])

[INST] 
Write the response of A and include these grammatical items in the response:
- negation: Can form negative imperatives of main verbs with 'don't' + main verb. ► Clauses: imperatives
Dialog:
A: Well, your word overrules the file, sir. One moment, please.
B: I knew you'd see it my way.
A: Sir, I deleted the $ 10, but I had to add a $ 2 service charge to your bill.
B: Am I in the Twilight Zone? You're charging me for a movie I never saw? [/INST] 
A: Please don't blame me, sir. Blame the computer programmer.</s>


  0%|          | 0/1 [00:00<?, ?it/s]

"I'm sorry, I didn't mean to cause any confusion, but I can't allow that change in the file, sir. And I must inform you that I accidentally deleted the $10, but there is a $2 service charge added to your bill instead."

## Evaluate outputs

In [9]:
def calc_metrics(contexts, outputs, constraints, eval_quality=False):
    scores = [np.mean(evaluation.detector.constraint_satisfaction(output, constraint)) for output, constraint in zip(outputs, constraints)]
    constraint_outputs = lambda comb: [outputs[idx] for idx, constraint in enumerate(constraints) if constraint==comb]
    distinct = [evaluation.calculate_distinct_n(constraint_outputs(comb)) for comb in np.unique(constraints)]
    if eval_quality:
        iter_metrics = tqdm(evaluation.gpt_metrics.keys(), desc="Metrics", total=len(evaluation.gpt_metrics))
        iter_responses = lambda: tqdm(zip(contexts, outputs), desc="Responses", total=len(outputs))
        quality = {metric: [evaluation.get_single_response_metric(metric, context, output) for context, output in iter_responses()] for metric in iter_metrics}
    return scores, distinct, (quality if eval_quality else {})

In [10]:
def compute_metrics(eval_preds, verbose=False, n=25, datasets={"train": train_dataset, "test": test_dataset}, eval_quality=False, ground_truth=False):
    results = {}
    for name, ds in datasets.items():
        subset = dataset[RandomSampler(ds, num_samples=n)]
        if verbose: print(subset['prompt'][0])
        outputs = subset['completion'] if ground_truth else generate(subset['prompt'])
        scores, distinct, quality = calc_metrics(subset['context'], outputs, subset['nr'], eval_quality)
        if verbose:
            for truth, output in zip(subset['completion'], outputs):
                print(f"Truth: {truth}")
                print(f"Gener: {output}")
            print(f"Grammar detected: {scores}")
            print(f"Distinctiveness per constraint {distinct}")
            print(f"Quality: {quality}")
        print(list(zip(outputs,scores))[:10])
        
        results.update({f"{name}_constraint": np.mean(scores)})
        results.update({f"{name}_{metric}": np.mean(quality[metric]) for metric in quality.keys()})
        results.update({f"{name}_distinct": np.mean(distinct)})        
    return results

compute_metrics([], verbose=False, n=25, datasets={"test": test_dataset}, eval_quality=True, ground_truth=True) # test

Metrics:   0%|          | 0/4 [00:00<?, ?it/s]

Responses:   0%|          | 0/25 [00:00<?, ?it/s]

Responses:   0%|          | 0/25 [00:00<?, ?it/s]

Responses:   0%|          | 0/25 [00:00<?, ?it/s]

Responses:   0%|          | 0/25 [00:00<?, ?it/s]

[("Just a moment. Don't worry. Where are you now?", 1.0), ("Don't mention it. See you.", 1.0), ("Oh, uh, well, uh... well I couldn't find the second page of the recipe, but don't worry. I have plenty of experience around the house. Plenty of experience in cooking.", 1.0), ("No problem and please don't hesitate to call again if you have any other questions.", 1.0), ("Don't get upset over it. At least you have another week to go.", 1.0), ("I will bring the diary. OK, this afternoon you have a meeting with your accountant at 5:00 PM. On Wednesday you are going to London. Don't forget your train leaves at 9:30 AM.", 1.0), ("Don't be afraid. She isn't going to bite you.", 1.0), ("Don't get your hopes up.", 1.0), ("Hum, don't say that. Maybe you can do something useful for me.", 1.0), ("Dude, don't equivocate. A theory only becomes a theory after withstanding rigorous testing. You slept through class, didn't you? ", 1.0)]


{'test_constraint': 1.0,
 'test_Appropriateness': 2.72,
 'test_Relevance': 2.6,
 'test_Content Richness': 2.36,
 'test_Grammatical Correctness': 3.92,
 'test_distinct': 0.8263358778625954}

## Fine-tuning

In [None]:
peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.1,
    r=64,
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
        "lm_head",
    ],
    bias="none",
    task_type="CAUSAL_LM"
)

In [None]:
training_arguments = TrainingArguments(
    output_dir="../models/mistral_FT",
    num_train_epochs=1,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=1,
    optim="paged_adamw_32bit",
    #save_steps=25,
    logging_steps=5,
    learning_rate=1e-4,
    weight_decay=0.001,
    fp16=False,
    bf16=False,
    max_grad_norm=0.3,
    max_steps=-1,
    warmup_ratio=0.03,
    group_by_length=True,
    lr_scheduler_type="linear",
    report_to="wandb",
    run_name="gctg",
    #load_best_model_at_end=True,
    evaluation_strategy="steps",
    eval_steps=50,
    per_device_eval_batch_size=4,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
)

In [None]:
collator = DataCollatorForCompletionOnlyLM("[/INST]", tokenizer=tokenizer)

trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    peft_config=peft_config,
    dataset_text_field="text",
    max_seq_length=512,
    tokenizer=tokenizer,
    args=training_arguments,
    packing=False,
    data_collator=collator,
    compute_metrics=compute_metrics
    #neftune_noise_alpha=5,
)

In [14]:
trainer.train()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mdo-gl[0m ([33mdomgla[0m). Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss,Validation Loss,Train Constraint,Train Distinct,Test Constraint,Test Distinct
50,1.2773,1.572276,1.0,0.531381,0.96,0.56338
100,1.2724,1.492942,1.0,0.668016,0.96,0.673307
150,0.923,1.422863,1.0,0.623656,1.0,0.565371
200,0.9454,1.354905,1.0,0.487671,0.96,0.633205


  0%|          | 0/1 [00:00<?, ?it/s]

[("Don't worry, sir. We'll have you home in no time.", 1.0), ("Don't forget to give my regards to Gene.", 1.0), ("Don't do this to me, don't do this to yourself.", 1.0), ("Don't worry about it.", 1.0), ("Don't worry. I'll get you something nice.", 1.0), ("don't worry. I'll give you some aspirin for that.", 1.0), ("Don't forget to sing the songs in the right key.", 1.0), ("don't worry. I'll explain everything. ", 1.0), ("Don't worry. I don't smoke pot.", 1.0), ("Don't worry.", 1.0)]


  0%|          | 0/1 [00:00<?, ?it/s]

[("Don't worry, I'll be back to work tomorrow. ", 1.0), ("Don't worry.  I'm sure he won't mind.", 1.0), ("Don't worry. I'll be there on time.", 1.0), ("Don't worry about it. I'll take it.", 1.0), ("Don't worry about it.", 1.0), ("Don't forget the plutocracy", 1.0), ("Don't forget to brush your tongue.", 1.0), ("Don't worry.", 1.0), ('100% fresh', 0.0), ("Don't worry, we'll find something. ", 1.0)]


  0%|          | 0/1 [00:00<?, ?it/s]

[("Don't forget about the President of Zimbabwe. He's the wealthiest.", 1.0), ("Don't worry. We'll be there on time.", 1.0), ("Don't mention it.", 1.0), ("Don't worry.  I won't tell anyone.", 1.0), ("Don't worry, I'll remember.", 1.0), ("Don't worry. I'll get him something.", 1.0), ("Don't worry. I'll tell them what to do.", 1.0), ("Don't push me around. I'll do it myself.", 1.0), ("Don't get me wrong, I'm not saying that they don't deserve it. I'm just saying that it's amazing to think about.", 1.0), ("Don't worry about it. Google is a great company.", 1.0)]


  0%|          | 0/1 [00:00<?, ?it/s]

[("Don't replace it. I'll take it. ", 1.0), ("Don't worry about it. We'll have plenty of food.", 1.0), ("Don't give up hope. There are plenty of other candidates.", 1.0), ("Don't mention it.", 1.0), ("Don't worry about it.", 1.0), ("Don't do this to me, baby. ", 1.0), ("don't worry. He's a nice guy.", 1.0), ("Don't get your hopes up.", 1.0), ("Don't forget the ice cream.", 1.0), ("Don't say that. ", 1.0)]


  0%|          | 0/1 [00:00<?, ?it/s]

[("Don't worry. We will have a good time.", 1.0), ("Don't worry, you can see it on youtube.", 1.0), ("Don't worry about that.  I'm sure you can swim.  I'm sure you can do anything you set your mind to.", 1.0), ("Don't worry. This community is very safe.", 1.0), ("Don't be so excited. We should prepare for the picnic first.", 1.0), ("Don't worry about it.", 1.0), ("Don't be angry with me. I have a bad memory.", 1.0), ("Don't forget about the police horses. They are used to chase criminals.", 1.0), ("Don't worry about it. I'll be fine.", 1.0), ("Don't worry about it.", 1.0)]


  0%|          | 0/1 [00:00<?, ?it/s]

[("Don't be ridiculous! I'm not going to cry. I'm not going to let you go. ", 1.0), ("Don't blame me. I'm just a victim of my own taste buds. ", 1.0), ("Don't worry about it. I am sure you will find it.", 1.0), (".... Don't worry about it. We'll get married in the backyard.", 1.0), ("Don't worry about it, I'm sure it's just a beer brand", 1.0), ("Don't worry about it. I'll talk to her.", 1.0), ("Don't worry, you will find someone else.", 1.0), ("Don't worry. She will like you.", 1.0), ("Don't worry. I'll give you a prescription.", 1.0), ("Don't worry. We'll find something.", 1.0)]


  0%|          | 0/1 [00:00<?, ?it/s]

[("Don't be so mean. I'm just not into cartoons.", 1.0), ("Don't worry. You'll be fine.", 1.0), ("Don't be so literal.", 1.0), ("Don't mention it.", 1.0), ("Don't worry. It's nothing serious. You have a common cold.", 1.0), ("Don't worry about it, I'm sure you'll be fine!", 1.0), ("Don't forget that it's also a great reminder that you should never judge a book by its cover.", 1.0), ("Don't mention it.", 1.0), ("Don't worry. I'll take care of it.", 1.0), ("Don't be so down. I'm sure you can pass the exams.", 1.0)]


  0%|          | 0/1 [00:00<?, ?it/s]

[("Don't forget to check out the other 1000+ radio stations on the app.", 1.0), ("don't worry. I'll help you with that. ", 1.0), ("Don't be so cynical.", 1.0), ("Don't worry. I'll see what I can do. Let me call my friend, John. He's a plumber. He might be able to help you.", 1.0), ("Don't worry. He's not here yet.", 1.0), ("....prudish? Don't worry. I'll help you pick something out.", 1.0), ("Don't mention it.", 1.0), ("Don't worry. I'm sure we'll find someone eventually.", 1.0), ("Don't ever change. ", 1.0), ("Don't worry. We'll be late.", 1.0)]


KeyboardInterrupt: 

In [16]:
compute_metrics([], verbose=False, datasets={"test": test_dataset}, n=100, eval_quality=False)

  0%|          | 0/4 [00:00<?, ?it/s]

[('17 year old girl, that is impressive, I wonder if she is still playing, I know that Babe Ruth is one of the best players in the history of the game, he was a great player', tensor(True)), ('I know that he was the only losing coach in the history of the University of Kansas.  I wonder if he was the only losing coach in the history of the world.', tensor(True)), ('I did not know that.  I wonder if he was the richest president in the world.  I know that the top three wealthiest presidents in American history were JFK, Washington, and Jefferson.', tensor(True)), ("10 times the size of the Earth! That's huge! I wonder if the sun is the biggest star in the galaxy.", tensor(True)), ("1 million people live there, it's the 4th largest city in the country", tensor(True)), ("I'm going to go. I'm the best player in the team.", tensor(True)), ('100% they are the largest software company in the world and the largest personal computer company in the world.', tensor(True)), ('I agree.  Did you know

{'test_success_59': 0.7900000214576721, 'test_distinct': 0.39714714714714716}

In [18]:
#trainer.save_model("../models/mistral_FT_2")

In [21]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=False,
    bnb_4bit_quant_type="nf4"
)
#model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=bnb_config, cache_dir=os.getenv('CACHE_DIR'), device_map="auto")
#model.config.use_cache = False
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, cache_dir=os.getenv('CACHE_DIR'), padding_side="right")
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
tokenizer.pad_token = '[PAD]'
model.resize_token_embeddings(len(tokenizer))
model = PeftModel.from_pretrained(model, "../models/mistral_FT")
