In [None]:
# Install and import dependencies, and set environment variables.

%pip install accelerate==0.21.0 peft==0.4.0 bitsandbytes==0.40.2 transformers==4.31.0 trl==0.4.7

import os
import torch
from datasets import load_dataset, DatasetDict
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments, pipeline, logging
from peft import LoraConfig, PeftModel
from trl import SFTTrainer
import torch
import gc
import transformers

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:1024"

In [25]:
# Load and split the dataset.

def load_custom_dataset():
    dataset = load_dataset('json', data_files='all_hand_labeled_split.json')

    def filter_train_split(example): return example['split'] == 'train'
    def filter_val_split  (example): return example['split'] == 'val'
    def filter_test_split (example): return example['split'] == 'test'

    dataset = {
        'train': dataset.filter(filter_train_split),
        'val':   dataset.filter(filter_val_split),
        'test':  dataset.filter(filter_test_split)
    }

    return dataset

dataset = load_custom_dataset()

In [None]:
# Debugging tool: reduce the dataset size.

# n = 1

# reduced_dataset = {}
# for split_name, dataset_dict in dataset.items():
#     if split_name == 'train':
#       reduced_data = dataset_dict['train']
#     else:
#       reduced_data = dataset_dict['train'].select(range(n))

#     reduced_dataset[split_name] = DatasetDict({'train': reduced_data})

# dataset = reduced_dataset

# print(dataset)

In [27]:
# Define string-matching methods to extract prompts and classifications from texts.

def convert_label(text):
    if   text[-11:].strip().lower().replace('.', '') == "compliance": return 'complied'
    elif text[-10:].strip().lower().replace('.', '') == "rejection":  return 'rejected'

def extract_prompt(text):
    if   convert_label(text) == "complied": return text[:-11]
    elif convert_label(text) == "rejected": return text[:-10]
    else:
      # print(text)
      return 'FAILED TO EXTRACT PROMPT'

def extract_classification(text):
    if   convert_label(text) == "complied": return 'complied'
    elif convert_label(text) == "rejected": return 'rejected'
    else:
      print('FAILED TO EXTRACT CLASSIFICATION BELOW')
      print(text)
      print(
          'FAILED TO EXTRACT CLASSIFICATION ABOVE, COMPARATOR:',
          text[-11:].strip().lower().replace('.', '')
      )
      return 'FAILED TO EXTRACT PROMPT'

In [28]:
# Define how to prompt the model.

def prompt_model(model, tokenizer, prompts):
    use_cache = model.config.use_cache
    model.config.use_cache = True
    model.eval()

    pipe = pipeline(task="text-generation", model=model, tokenizer=tokenizer, max_length=2500)
    with torch.no_grad():
        # results = []
        for i, prompt in enumerate(prompts):
            print(f'Processing prompt {i+1} of {len(prompts)}...')
            try:
                prompts[i] = pipe(prompt)[0]['generated_text']
                # results.append(pipe(prompt)[0]['generated_text'])
            except:
                prompts[i] = 'ERROR GETTING MODEL RESPONSE'
                # results.append('ERROR GETTING MODEL RESPONSE')

    model.config.use_cache = use_cache

    return prompts

In [23]:
# Define a class for tracking results for a given subset of the data.

class Result:
    def __init__(self, X):
        self.X = X

        self.responses = None
        self.predictions = None
        self.num_invalid_predictions = None
        self.num_mismatches = None
        self.accuracy = None

    def calculate_performance(self, y_true, model, tokenizer):
        self.responses = prompt_model(model, tokenizer, self.X)
        self.predictions = [extract_classification(prompt) for prompt in self.responses]

        # Filter out valid predictions
        valid_predictions = [(pred, true, idx) for idx, (pred, true) in enumerate(zip(self.predictions, y_true)) if pred in ['complied', 'rejected']]
        self.num_invalid_predictions = len(self.predictions) - len(valid_predictions)

        # Calculate mismatches among valid predictions
        self.num_mismatches = sum(1 for pred, true, _ in valid_predictions if pred != true)
        self.accuracy = 1 - ((self.num_invalid_predictions + self.num_mismatches) / len(self.predictions))

        if len(valid_predictions) > 0:
            self.accuracy_excluding_invalid_responses = 1 - (self.num_mismatches / len(valid_predictions))
        else:
            self.accuracy_excluding_invalid_responses = 1

        # Calculate error rates for 'complied' and 'rejected', excluding invalid predictions
        num_complied_true = sum(1 for _, true, _ in valid_predictions if true == 'complied')
        num_rejected_true = sum(1 for _, true, _ in valid_predictions if true == 'rejected')
        errors_complied = sum(1 for pred, true, _ in valid_predictions if true == 'complied' and pred != true)
        errors_rejected = sum(1 for pred, true, _ in valid_predictions if true == 'rejected' and pred != true)

        self.accuracy_complied = 1 - (errors_complied / num_complied_true if num_complied_true > 0 else 0)
        self.accuracy_rejected = 1 - (errors_rejected / num_rejected_true if num_rejected_true > 0 else 0)

        wrong_rejections = [self.X[idx] for pred, true, idx in valid_predictions if true == 'rejected' and pred != true]
        print("\n>>>>>>>>>>>>>>>>>>>>>>>\nExamples of responses that are labeled 'rejected' but the model predicted 'complied':")
        for wr in wrong_rejections[:2]:
            print(wr)
            print('--------------------------------------')
        print()

        wrong_compliances = [self.X[idx] for pred, true, idx in valid_predictions if true == 'complied' and pred != true]
        print("\n>>>>>>>>>>>>>>>>>>>>>>>\nExamples of responses that are labeled 'complied' but the model predicted 'rejected':")
        for wr in wrong_compliances[:2]:
            print(wr)
            print('--------------------------------------')
        print()

In [29]:
# Define a class for managing the train/val/test results.

class SplitManager:
    def __init__(self, split, dataset):
        self.split = split

        self.y = [item["tone"] for item in dataset[split]['train']]

        self.results = {
            'zero-shot':  Result([item["zero_shot_instruction"] for item in dataset[split]['train']]),
            'few-shot':   Result([item["few_shot_instruction"]  for item in dataset[split]['train']]),
            'CoT':        Result([item["CoT_instruction"]       for item in dataset[split]['train']]),
            'finetuning': Result([item["zero_shot_instruction"] for item in dataset[split]['train']])
        }

    def print_result(self, result, shot, is_finetuned):
        model_description = "BASE" if not is_finetuned else "FINE-TUNED"

        print(f'{self.split} SET, {model_description} MODEL, {shot}')
        print(f"% invalid items:                      {result.num_invalid_predictions/len(self.y)}")
        print(f"% non-matching items:                 {result.num_mismatches/len(self.y)}")
        print(f"Accuracy including invalid responses: {result.accuracy}")
        print(f"Accuracy on compliances:              {result.accuracy_complied}")
        print(f"Accuracy on rejections:               {result.accuracy_rejected}")
        print(f"Overall accuracy:                     {result.accuracy_excluding_invalid_responses}")

    def calculate_performance_for_prompting_method(self, method, model, tokenizer, is_finetuned):
        self.results[method].calculate_performance(self.y, model, tokenizer)
        self.print_result(self.results[method], method, is_finetuned)

In [30]:
# Define how to "replenish" the dataset (as prompts are overwritten in place by their responses).

def replenish_dataset():
    dataset = load_custom_dataset()
    test_split_manager = SplitManager('test', dataset)

test_split_manager = SplitManager('test', dataset)

In [None]:
# Load the base model and its tokenizer.

base_model = "NousResearch/Llama-2-7b-chat-hf"

# 4-bit quantization configuration.
compute_dtype = getattr(torch, "float16")
quant_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=compute_dtype, bnb_4bit_use_double_quant=False)

# Load the model with 4-bit precision.
model = AutoModelForCausalLM.from_pretrained(base_model, quantization_config=quant_config, device_map={"": 0})
model.config.use_cache = False
model.config.pretraining_tp = 1

# Load the tokenizer.
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

In [10]:
# Define a method for executing a given prompting method.

def execute_prompting_method(method, is_finetuned):
    replenish_dataset()
    gc.collect()

    for split_manager in [test_split_manager]:
      print('\n========================================================\n')
      split_manager.calculate_performance_for_prompting_method(method, model, tokenizer, is_finetuned)

      for prompt in split_manager.results[method].responses:
          if extract_classification(prompt) not in ['complied', 'rejected']:
              # print(prompt, '\n\n\n')
              pass

In [None]:
# Get zero-shot performance for the non-finetuned model.

execute_prompting_method('zero-shot', is_finetuned=False)

In [None]:
# Get few-shot performance for the non-finetuned model.

execute_prompting_method('few-shot', is_finetuned=False)

In [None]:
# Get chain-of-thought performance for the non-finetuned model.

execute_prompting_method('CoT', is_finetuned=False)

In [16]:
# Force clean the PyTorch cache and load training data.

gc.collect()
torch.cuda.empty_cache()

In [None]:
# Configure training details.

# PEFT: Parameter Effecient Fine-Tuning
peft_params = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.1,
    r=64,
    bias="none",
    task_type="CAUSAL_LM"
)

training_params = TrainingArguments(
    output_dir="./results",
    num_train_epochs=1,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=1,
    optim="paged_adamw_32bit",
    save_steps=50,
    logging_steps=25,
    learning_rate=2e-4,
    weight_decay=0.001,
    fp16=False,
    bf16=False,
    max_grad_norm=0.3,
    max_steps=-1,
    warmup_ratio=0.03,
    group_by_length=True,
    lr_scheduler_type="constant",
    report_to="tensorboard",
    remove_unused_columns=False,
    gradient_checkpointing=True
)

trainer = SFTTrainer(
    model=model,
    train_dataset=dataset['train']['train'],
    peft_config=peft_params,
    dataset_text_field="zero_shot_instruction",
    max_seq_length=2048,
    tokenizer=tokenizer,
    args=training_params
)

In [18]:
# Force clean the PyTorch cache.

gc.collect()
torch.cuda.empty_cache()

In [None]:
# Train the model.

trainer.train()

In [None]:
# Get zero-shot performance for the fine-tuned model.

execute_prompting_method('zero-shot', is_finetuned=True)

In [None]:
from tensorboard import notebook
log_dir = "results/runs"
notebook.start("--logdir {} --port 4000".format(log_dir))