# NOT FOR PUBLIC

## HF/WNB

In [1]:
import os
import huggingface_hub

os.environ["HF_HOME"] = "/home/AD/gmatlin3/.cache/huggingface/"
os.environ["HUGGINGFACE_HUB_CACHE"] = "/home/AD/gmatlin3/.cache/huggingface/hub/"
HF_AUTH = "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"
huggingface_hub.login(HF_AUTH)

import wandb

# Set the wandb project where this run will be logged
WANDB_PROJECT = f"llama2_sft_fomc"
os.environ[
    "WANDB_PROJECT"
] = WANDB_PROJECT
# Turn off save your trained model checkpoint to wandb (our models are too large)
os.environ[
    "WANDB_LOG_MODEL"
] = "false"

# Turn off watch to log faster
os.environ["WANDB_WATCH"] = "false"
os.environ["WANDB_API_KEY"] = "fa69ffc6a97578da0410b553042cbb8b3bf5fcaf"
os.environ["WANDB_NOTEBOOK_NAME"] = f"llama2_sft"
wandb.login()

Token will not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /home/AD/gmatlin3/.cache/huggingface/token
Login successful


[34m[1mwandb[0m: Currently logged in as: [33mglennmatlin[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

# TODOS

- See if `offload_folder="tmp"` can help when loading models?

- `input_ids = torch.tensor(input_ids).long()` # TODO: extra code to ensure that input_ids is a PyTorch tensor ... is unneeded

- How to unset a parameter in the pre-loaded config?
> /home/AD/gmatlin3/.conda/envs/conference/lib/python3.8/site-packages/transformers/generation/configuration_utils.py:362: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `None` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.

# [PUBLIC]<br>Supervised Fine-Tuning of Llama2 on FOMC

## IMPORTS

### Standard Libraries

In [2]:
import os
import gc
import logging
import time
import fire
from pathlib import Path
from functools import partial
from typing import NamedTuple, List, Type
from IPython.display import display
from dataclasses import dataclass, field
from datetime import datetime

### Third-Party Libraries

In [3]:
import uuid
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score, f1_score
from tqdm.auto import tqdm

### PyTorch and HuggingFace Libraries

In [4]:
import torch
import bitsandbytes as bnb
import evaluate
from datasets import Dataset, DatasetDict, load_dataset
from trl import SFTTrainer
from transformers import logging as hf_logging
from transformers.trainer_callback import TrainerCallback
from transformers import set_seed as transformers_set_seed
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    GenerationConfig,
    TrainingArguments
)
from peft import (
    PeftModel,
    AutoPeftModelForCausalLM,
    LoraConfig,
    TaskType,
    get_peft_model,
    prepare_model_for_kbit_training,
)
from dataclasses import dataclass, fields, make_dataclass

## FUNCTIONS

### Logging Functions

In [5]:
def generate_uid(id_length=8, dt_format="%y%m%d"):
    date_str = datetime.now().strftime(dt_format)
    uid = str(uuid.uuid4())[:id_length]
    uid = f"{uid}_{date_str}"
    return uid

def create_logger(name="llama2_finetune", level=logging.DEBUG):
    logger = logging.getLogger(name)
    if not logger.hasHandlers():
        logger.setLevel(level)
        hf_logging.set_verbosity(level)

        # Create handlers
        c_handler = logging.StreamHandler()
        f_handler = logging.FileHandler("llama2_finetune.log")
        c_handler.setLevel(level)
        f_handler.setLevel(level)

        # Create formatters and add it to handlers
        format = "%(name)s - %(levelname)s - %(message)s"
        c_handler.setFormatter(logging.Formatter(format))
        f_handler.setFormatter(logging.Formatter(format))

        # Add handlers to the logger
        logger.addHandler(c_handler)
        logger.addHandler(f_handler)
    return logger

### Label Functions

In [6]:
FOMC_COMMUNICATION_MAPPING = {0: "DOVISH", 1: "HAWKISH", 2: "NEUTRAL"}

# Function to decode the labels
def decode_label(label_number):
    return FOMC_COMMUNICATION_MAPPING.get(label_number, "undefined").upper()


# Function to encode the labels
def encode_label(label_name):
    reversed_mapping = {v: k for k, v in FOMC_COMMUNICATION_MAPPING.items()}
    return reversed_mapping.get(label_name.lower(), -1)


# TODO: have extract_lavel use our encoding/mapping
def extract_label(text_output, label_list=["DOVISH", "HAWKISH", "NEUTRAL"], E_INST="[/INST]"):
    """
    Extracts the label from the text output from a large language model
    """
    # Find the 'end of instruction' token and remove text before it
    response_pos = text_output.find(E_INST)
    # Convert the string to uppercase for case-insensitive search
    generated_text = text_output[response_pos + len(E_INST) :].strip().upper()
    # Define the substring options
    label_list = ["DOVISH", "HAWKISH", "NEUTRAL"]
    # Iterate over the substrings and find the matching label
    for i, label in enumerate(label_list):
        if label in generated_text:
            return i
    # If none of the substrings are found, return -1
    return -1

### Metrics Functions

In [7]:
def compute_metrics(eval_pred, tokenizer, label_list=["DOVISH", "HAWKISH", "NEUTRAL"]):
    predictions, true_labels = eval_pred
    decoded_preds = tokenizer.batch_decode(sequences=pred, skip_special_tokens=True)
    # Decode the predictions to text
    # decoded_preds = [
    #     tokenizer.decode(pred, skip_special_tokens=True) for pred in predictions
    # ]    
    predicted_labels = [
        extract_label(decoded_preds[i]) for i in range(len(decoded_preds))
    ]
    accuracy_perc, f1_score_perc, missing_perc = evaluate_predictions(
        true_labels, predicted_labels
    )
    return {
        "accuracy": accuracy_perc,
        "f1_score": f1_score_perc,
        "missing": missing_perc
    }

def evaluate_predictions(true_labels, predicted_labels):
    accuracy_perc = accuracy_score(true_labels, predicted_labels)
    f1_score_perc = f1_score(true_labels, predicted_labels, average="weighted")
    missing_perc = (predicted_labels.count(-1) / len(predicted_labels)) * 100.0
    return accuracy_perc, f1_score_perc, missing_perc

def log_trainable_parameters(model, logger: logging.Logger):
    """
    Logs the number of trainable parameters in the model.

    Parameters:
    - model : torch.nn.Module - The model to log.
    - logger : logging.Logger - Logger to use for logging the info.
    """

    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total_params = sum(p.numel() for p in model.parameters())

    logger.info(
        f"Trainable params: {trainable_params} || "
        f"All params: {total_params} || "
        f"Trainable%: {100 * trainable_params / total_params}"
    )


def log_dtypes(model, logger: logging.Logger):
    """
    Logs the data types of the model parameters.

    Parameters:
    - model : torch.nn.Module - The model to log.
    - logger : logging.Logger - Logger to use for logging the info.
    """
    dtypes = {}

    for p in model.parameters():
        dtype = p.dtype
        dtypes[dtype] = dtypes.get(dtype, 0) + p.numel()

    total = sum(dtypes.values())

    for dtype, count in dtypes.items():
        logger.info(f"{dtype}: {count} ({100 * count / total:.2f}%)")


def merge_evaluation_results(
    baseline_results: dict, final_results: dict
) -> pd.DataFrame:
    """
    Merge evaluation results for comparison.
    """
    all_metrics = set(baseline_results.keys()).union(final_results.keys())
    data = {"Metric": [], "Baseline": [], "After Fine-tuning": []}

    for metric in all_metrics:
        data["Metric"].append(metric)
        data["Baseline"].append(baseline_results.get(metric, "N/A"))
        data["After Fine-tuning"].append(final_results.get(metric, "N/A"))

    return pd.DataFrame(data)

### Dataset Processing Functions

In [8]:
def load_dataset_split(args, logger, split: str):
    """
    Load a dataset split
    """
    logger.debug(f"Loading {split} dataset...")
    dataset_split = load_dataset(f"{args.organization}/{args.task_name}")[split]

    return dataset_split


def split_dataset(train_dataset, train_ratio=0.7, seed=42):
    """
    Split a Hugging Face dataset into training and validation sets with a given ratio.

    Parameters:
    - train_dataset: Hugging Face dataset to split
    - train_ratio: Ratio of data to keep in the training set
    - seed: Seed for reproducibility

    Returns:
    - train_set: Training dataset
    - val_set: Validation dataset
    """
    # Ensuring the ratios are valid
    if train_ratio <= 0 or train_ratio >= 1:
        raise ValueError("Train ratio must be between 0 and 1")

    val_ratio = 1 - train_ratio

    # Splitting the dataset
    datasets = train_dataset.train_test_split(test_size=val_ratio, seed=seed)
    train_set = datasets["train"]
    val_set = datasets["test"]  # TODO: can I name this eval instead?

    return train_set, val_set

def _preprocess_dataset_batch_(
    batch,
    args,
    logger: logging.Logger,
    tokenizer: AutoTokenizer,
):
    """
    Creates formatted prompts and tokenizes in batch mode.

    Parameters:
    - batch: dict - Batch containing columns as lists.
    - args: Args - Arguments needed for formatting.
    - tokenizer: AutoTokenizer - Tokenizer for the model.
    """

    # Rename a column
    batch[args.encoded_label_field] = batch[args.label_field]
    # Decode the label
    batch[args.response_field] = [
        decode_label(label) for label in batch[args.encoded_label_field]
    ]
    # Validate the prompts
    if not args.instruction_prompt.strip() or not args.system_prompt.strip():
        raise ValueError("All prompts (instruction, system) must be non-empty strings.")
    # Validate the fields
    if not all(item.strip() for item in batch[args.context_field]) or not all(
        item.strip() for item in batch[args.response_field]
    ):
        raise ValueError("All fields (context, response) must be non-empty strings.")
    # Formatt the input text for the batch
    batch[args.text_field] = [
        args.B_INST
        + args.B_SYS
        + args.system_prompt
        + args.E_SYS
        + args.instruction_prompt
        + context
        + args.E_INST
        for context in batch[args.context_field]
    ]

    tokenized_inputs = tokenizer(
        batch[args.text_field],
        max_length=args.max_seq_length,
        truncation=args.truncation,
        padding=args.padding,
    )

    batch["input_ids"] = tokenized_inputs["input_ids"]
    batch["attention_mask"] = tokenized_inputs["attention_mask"]

    return batch


def preprocess_dataset(
    args, logger: logging.Logger, tokenizer: AutoTokenizer, dataset: Dataset
):
    """
    Prepare the dataset for supervised fine-tuning.

    Parameters:
    - args: Args - Arguments needed for formatting.
    - tokenizer: AutoTokenizer - Tokenizer for the model.
    - dataset: Dataset - Dataset to preprocess.
    """

    logger.debug(f"Preprocessing dataset...")

    # We have to preprocess in batch because datasets dont allow for easy assignment of new fields
    dataset = dataset.map(
        partial(
            _preprocess_dataset_batch_,
            args=args,
            logger=logger,
            tokenizer=tokenizer,
        ),
        batched=True,
    )

    logger.debug("Filtering dataset to ensure we are below the maximum sequence length")
    dataset = dataset.filter(
        lambda sample: len(sample["input_ids"]) <= args.max_seq_length
    )
    logger.debug("Shuffling the data using our seed value")
    dataset = dataset.shuffle(seed=args.seed)
    return dataset

### Model & Tokenizer Functions

In [9]:
def create_tokenizer(args, logger):
    """
    Configures the tokenizer based on the provided arguments.
    """
    tokenizer = AutoTokenizer.from_pretrained(args.model_id, use_fast=False)
    tokenizer.pad_token = args.EOS

    return tokenizer


def create_model(args, logger, bnb_config=None, peft_config=None):
    """
    Applies further configurations to the model based on the arguments provided.
    """
    info_data = []

    if bnb_config is not None:
        logger.debug("Creating ModelforCausalLM using BitsAndBytes ...")
        model = AutoModelForCausalLM.from_pretrained(
            args.model_id,
            load_in_4bit=args.load_in_4bit,
            load_in_8bit=args.load_in_8bit,
            device_map=args.device_map,
            max_memory=args.cuda_max_memory,
            torch_dtype=args.bnb_compute_dtype,
            quantization_config=bnb_config,
        )
    else:
        logger.debug("Creating ModelforCausalLM ...")
        model = AutoModelForCausalLM.from_pretrained(
            args.model_id,
            device_map=args.device_map,
            max_memory=args.cuda_max_memory,
            torch_dtype=args.bnb_compute_dtype,
        )

    logger.debug("Logging the model's memory footprint ...")
    memory_footprint = model.get_memory_footprint()
    info_data.append(["Memory Footprint", memory_footprint])
    logger.debug(f"Logging the model's Dtypes ...")
    dtypes_loaded = log_dtypes(model, logger)
    info_data.append(["Dtypes init", dtypes_loaded])
    
    if peft_config is not None:
        logger.debug(f"Model Dtypes before applying PEFT config ...")
        dtypes_before = log_dtypes(model, logger)
        info_data.append(["Dtypes Before PEFT Config", dtypes_before])

        model = get_peft_model(model, peft_config)

        logger.debug(f"Model Dtypes after applying PEFT config ...")
        dtypes_after_peft = log_dtypes(model, logger)
        info_data.append(["Dtypes After PEFT Config", dtypes_after_peft])

        logger.debug("Information about the percentage of trainable parameters...")
        trainable_parameters = log_trainable_parameters(model, logger)
        info_data.append(["Trainable Parameters", trainable_parameters])

    if bnb_config or peft_config:
        logger.debug(
            "Converting the info_data list into a pandas DataFrame and saving it..."
        )
        df = pd.DataFrame(info_data, columns=["Info", "Value"])
        logger.debug("\n%s", df.to_string(index=False))

    model.config.use_cache = False
    model.config.pretraining_tp = 1
    return model

### Configuration Functions

In [10]:
def create_bnb_config(args):
    """
    Configures BitsAndBytes based on the arguments provided.
    """
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=args.load_in_4bit,
        load_in_8bit=args.load_in_8bit,
        bnb_4bit_use_double_quant=args.bnb_use_double_quant,
        bnb_8bit_use_double_quant=args.bnb_use_double_quant,
        bnb_4bit_quant_type=args.bnb_quant_type,
        bnb_8bit_quant_type=args.bnb_quant_type,
        bnb_4bit_compute_dtype=args.bnb_compute_dtype,
        bnb_8bit_compute_dtype=args.bnb_compute_dtype,
    )
    return bnb_config

def create_peft_config(args, modules: List[str]) -> LoraConfig:
    """
    Create PEFT configuration for LoRA.

    Parameters:
    - args : Args - The arguments containing LoRA parameters
    - modules : List[str] - List of module names

    Returns:
    - LoraConfig - Configuration object for PEFT
    """
    return LoraConfig(
        target_modules=modules,
        r=args.lora_r,
        lora_alpha=args.lora_alpha,
        lora_dropout=args.lora_dropout,
        bias="none",
        task_type=TaskType.CAUSAL_LM,
    )

def create_generation_config(args) -> GenerationConfig:
    generation_config = GenerationConfig(
                    # force_words_ids=force_words_ids,
                    max_new_tokens=args.max_new_tokens,
                    do_sample=args.do_sample,
                    num_beams=args.num_beams,
                    temperature=args.temperature,
                    top_p=args.top_p,
                    top_k=args.top_k,
                    penalty_alpha=args.penalty_alpha,
                    min_length=args.min_length,
                    use_cache=args.use_cache,
                    repetition_penalty=args.repetition_penalty,
                    length_penalty=args.length_penalty,
                    num_return_sequences=args.num_return_sequences,
                    early_stopping=args.early_stopping,
                    remove_invalid_values=args.remove_invalid_values,
                    no_repeat_ngram_size=args.no_repeat_ngram_size,
                    # push_to_hub=args.push_to_hub
                )
    return generation_config
# generation_config.save_pretrained(args.output_dir, "generation_config.json")
# ## You could then use the named generation config file to parameterize generation
# generation_config = GenerationConfig.from_pretrained(args.output_dir, "generation_config.json")
# outputs = model.generate(**inputs, generation_config=generation_config)
# tokenizer.batch_decode(outputs, skip_special_tokens=True)

def create_training_arguments(args) -> TrainingArguments:
    """
    Configures and returns the TrainingArguments based on the provided arguments.
    """
    training_arguments = TrainingArguments(
        output_dir=args.output_dir,
        fp16=args.fp16,
        bf16=args.bf16,
        per_device_train_batch_size=args.per_device_train_batch_size,
        per_device_eval_batch_size=args.per_device_eval_batch_size,
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        max_grad_norm=args.max_grad_norm,
        weight_decay=args.weight_decay,
        optim=args.optim,
        learning_rate=args.learning_rate,
        lr_scheduler_type=args.lr_scheduler_type,
        num_train_epochs=args.num_train_epochs,
        max_steps=args.max_steps,
        warmup_ratio=args.warmup_ratio,
        save_safetensors=args.save_safetensors,
        load_best_model_at_end=args.load_best_model_at_end,
        evaluation_strategy=args.evaluation_strategy,
        logging_dir=args.logging_dir,
        report_to=args.report_to,
        save_strategy=args.save_strategy,
        save_steps=args.save_steps,
        logging_strategy=args.logging_strategy,
        logging_steps=args.logging_steps,
        push_to_hub=False,
        # group_by_length=args.group_by_length,
        # torch_compile=args.torch_compile,
    )
    return training_arguments

### Utility Functions

In [11]:
def memory_cleanup():
    """
    Empty VRAM
    """
    if "trainer" in locals() or "trainer" in globals():
        del trainer
    if "model" in locals() or "model" in globals():
        del model
    if "pipe" in locals() or "pipe" in globals():
        del pipe
    torch.cuda.empty_cache()
    gc.collect()
    gc.collect()

def set_seeds(args, logger):
    logger.debug(f"Setting reproducibility seed: '{args.seed}'")
    transformers_set_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    torch.manual_seed(args.seed)

def configure_cuda_args(args, logger):
    """
    Configure the parameter arguments using the system's CUDA information
    """
    if args.cuda_n_gpus is None:
        args.cuda_n_gpus = torch.cuda.device_count()
        logger.debug(f"args.cuda_n_gpus now defined: {args.cuda_n_gpus}")
    else:
        logger.debug("args.cuda_n_gpus already defined.")

    if args.cuda_max_memory is None:
        CUDA_MAX_MEMORY = f"{int(torch.cuda.mem_get_info()[0] / 1024 ** 3) - 2}GB"
        args.cuda_max_memory = {i: CUDA_MAX_MEMORY for i in range(args.cuda_n_gpus)}
        logger.debug(f"args.cuda_max_memory now defined: {args.cuda_max_memory}")
    else:
        logger.debug("args.cuda_max_memory already defined.")

    return args

def get_max_seq_length(model: Type[torch.nn.Module]) -> int:
    """
    Get the maximum length of position embeddings in the model.

    Parameters:
    - model : torch.nn.Module - The model to inspect

    Returns:
    - int - Maximum length of position embeddings
    """
    conf = model.config
    max_seq_length = None

    # Checking various attributes to determine max length
    for length_setting in ["n_positions", "max_position_embeddings", "seq_length"]:
        max_seq_length = getattr(conf, length_setting, None)
        if max_seq_length:
            print(f"Found max sequence length: {max_seq_length}")
            break

    # Defaulting to 1024 if no length attribute is found
    if not max_seq_length:
        max_seq_length = 1024
        print(f"Using default max sequence length: {max_seq_length}")

    return max_seq_length

def setup_output_directory(args):
    """
    Sets up the output directory for saving model checkpoints and other outputs.
    """
    checkpoint_path = args.checkpoint_dir
    checkpoint_path.mkdir(mode=0o777, parents=True, exist_ok=True)
    mergepoint_path = args.mergepoint_dir
    mergepoint_path.mkdir(mode=0o777, parents=True, exist_ok=True)
    

def find_all_linear_names(model: Type[torch.nn.Module], bits: int) -> List[str]:
    """
    Find names of all linear layers in the model based on the number of bits specified.

    Parameters:
    - model : torch.nn.Module - The model to inspect
    - bits : int - The number of bits to select the appropriate linear layer class

    Returns:
    - List[str] - List of linear layer names
    """

    # Selecting the appropriate class based on the number of bits
    if bits == 4:
        cls = bnb.nn.Linear4bit
    elif bits == 8:
        cls = bnb.nn.Linear8bitLt
    else:
        cls = torch.nn.Linear

    lora_module_names = set()
    for name, module in model.named_modules():
        if isinstance(module, cls):
            names = name.split(".")
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])

    # Removing 'lm_head' if exists (specific to 16-bit scenarios)
    lora_module_names.discard("lm_head")

    return list(lora_module_names)

### Trainer Functions

In [12]:
def supervised_fine_tuning(args, logger):
    logger.info("Starting Supervised Fine Tuning ...")
    set_seeds(args, logger)
    
    # BitsAndBytes Setup
    logger.debug("Creating BitsAndBytesConfig ...")
    bnb_config = create_bnb_config(args)
    
    # Load Base Model for Configs Creation
    logger.debug("Creating Base Model ...")
    base_model = create_model(args, logger, bnb_config=bnb_config, peft_config=None)

    # Parmeter Efficient Fine-Tuning Setup
    logger.debug(
        "Get module names for the linear layers where we add LORA adapters..."
    )
    layers_for_adapters = find_all_linear_names(base_model, 4)
    logger.debug(f"Layers for Adapters: {layers_for_adapters}")
    logger.debug(
        "Create PEFT config using the adapted layers for PEFT..."
    )
    peft_config = create_peft_config(args, layers_for_adapters)

    # Model & Tokenizer Setup    
    del(base_model)
    logger.debug("Creating Tokenizer ...")
    tokenizer = create_tokenizer(args, logger)
    logger.debug("Creating Base Model ...")
    model = create_model(args, logger, bnb_config=bnb_config, peft_config=peft_config)
    assert args.max_seq_length == get_max_seq_length(model)
    logger.debug("Creating GenerationConfig ...")
    generation_config = create_generation_config(args)
    model.generation_config = generation_config

    # Prepare for K-Bit Training
    logger.debug(f"Using the prepare_model_for_kbit_training method from PEFT...")
    logger.debug(f"Gradient Checkpointing == {args.gradient_checkpointing}")
    model = prepare_model_for_kbit_training(
        model, use_gradient_checkpointing=args.gradient_checkpointing
    )
    logger.debug(f"Model Dtypes after preparing for kbit training ...")
    dtypes_after = log_dtypes(model, logger)

    # Dataset Setup
    logger.debug("Loading train dataset ...")
    train_dataset = load_dataset_split(args=args, logger=logger, split="train")
    logger.debug("Preprocessing train dataset ...")
    train_dataset = preprocess_dataset(
        args=args, logger=logger, tokenizer=tokenizer, dataset=train_dataset
    )
    train_set, eval_set = split_dataset(train_dataset, train_ratio=0.7, seed=args.seed)
    
    # Supervised Fine-Tuning
    logger.debug("Running Supervised Fine Tuning ...")    
    logger.debug("Creating TrainingArguments ...")
    training_arguments = create_training_arguments(args)
    logger.debug("Creating Trainer Callbacks ...")
    callbacks = [PeftSavingCallback()]
    logger.debug("Creating `SFTTrainer` ...")
    trainer = SFTTrainer(
        model=model,
        tokenizer=tokenizer,
        args=training_arguments,
        train_dataset=train_set.with_format("torch"),
        eval_dataset=eval_set.with_format("torch"),
        peft_config=peft_config,
        callbacks=callbacks,
        max_seq_length=args.max_seq_length,
        dataset_text_field=args.text_field,
        predict_with_generate = args.predict_with_generate
        # compute_metrics=compute_metrics,
        # packing=args.packing,
        # neftune_noise_alpha=args.neftune_noise_alpha,
    )
    # trainer.predict_with_generate = args.predict_with_generate
    
    logger.debug("Trying trainer.train() ...")
    try:
        trainer.train()
    except Exception as e:
        logger.error("The trainer.train() failed !!!")
        raise Exception(e)

    # Saving-exiting
    trainer.save_state()
    logger.debug("Saving final results ...")
    model = trainer.model
    logger.debug(f"Saving PEFT Model adapters to {args.checkpoint_dir} ...")
    model.save_pretrained(args.checkpoint_dir, safe_serialization=True, save_adapter=True, save_config=True)
    logger.debug(f"saving tokenizer to {args.checkpoint_dir} ...")
    tokenizer.save_pretrained(args.checkpoint_dir)
    
class PeftSavingCallback(TrainerCallback):
    """
    A callback to save the PEFT adapters during the model training.
    """

    def on_save(self, args, state, control, **kwargs):
        checkpoint_path = os.path.join(
            args.output_dir, f"checkpoint-{state.global_step}"
        )
        kwargs["model"].save_pretrained(checkpoint_path, save_adapter=True, save_config=True)

        if "pytorch_model.bin" in os.listdir(checkpoint_path):
            os.remove(os.path.join(checkpoint_path, "pytorch_model.bin"))

### Generation Functions

In [13]:
def text_generation(args, logger, model, tokenizer, generation_config,
                    forced_words=None, override_seed: int = None):
    """
    TODO: 2 gpu text generation https://huggingface.co/docs/accelerate/usage_guides/distributed_inference
    """

    logger.debug("Re-configuring CUDA to single-device ...")
    args.device, args.cuda_n_gpus, args.cuda_max_memory = "cuda:0", 1, {0: "39GB"}
    logger.debug(
        f"Using k={args.cuda_n_gpus} CUDA GPUs with max memory {args.cuda_max_memory}"
    )

    set_seeds(args, logger)
    
    if model is not None:
        logger.debug(f"Using model provided to function ...")
        pass
    else:
        logger.debug("Creating the Model ...")
        model = AutoModelForCausalLM.from_pretrained(
            args.mergepoint_dir,
            device_map=args.device,
            max_memory=args.cuda_max_memory,
            torch_dtype=args.bnb_compute_dtype,
        )

    if tokenizer is not None:
        logger.debug(f"Using tokenizer provided to function ...")
        pass
    else:
        logger.debug("Creating the Tokenizer ...")
        tokenizer = create_tokenizer(args=args, logger=logger)
    
    # Load Test Dataset
    test_dataset = load_dataset_split(args, logger, "test")
    test_dataset = preprocess_dataset(args, logger, tokenizer, test_dataset)
    logger.debug(
        f"Creating the Test DataLoader with batch size == {args.per_device_eval_batch_size} ..."
    )
    test_dataloader = torch.utils.data.DataLoader(
        test_dataset, batch_size=args.per_device_eval_batch_size,
        num_workers=4, pin_memory=True,
    )
    logger.debug(f"Sending the model to device '{args.device}'")
    model.eval()
    model.to(args.device)

    # Generating Text from Model
    logger.debug("Generating text ...")
    if generation_config is not None:
        logger.debug("Using the GenerationConfig provided to function ...")
        pass
    else:
        logger.debug("Creating GenerationConfig ...")
        generation_config = create_generation_config(args)
    test_responses = []
    start = time.perf_counter()
    if forced_words:
        force_words_ids = tokenizer(forced_words, add_special_tokens=False).input_ids
    for batch in tqdm(test_dataloader):
        inputs = tokenizer(
            batch[args.text_field],
            padding=args.padding,
            truncation=args.truncation,
            max_length=args.max_seq_length,
            return_tensors="pt",
        )
        inputs.to(args.device)        
        with torch.no_grad():
            try:
                generated_ids = model.generate(
                    **inputs,
                    force_words_ids=force_words_ids,
                    generation_config=generation_config,
                    # max_new_tokens=args.max_new_tokens,
                    # do_sample=args.do_sample,
                    # num_beams=args.num_beams,
                    # temperature=args.temperature,
                    # top_p=args.top_p,
                    # top_k=args.top_k,
                    # penalty_alpha=args.penalty_alpha,
                    # min_length=args.min_length,
                    # use_cache=args.use_cache,
                    # repetition_penalty=args.repetition_penalty,
                    # length_penalty=args.length_penalty,
                    # num_return_sequences=args.num_return_sequences,
                    # early_stopping=args.early_stopping,
                    # remove_invalid_values=args.remove_invalid_values,
                    # no_repeat_ngram_size=args.no_repeat_ngram_size,
                )
            except TypeError as e:
                logger.error(f"An error occurred during generation: {e}")
                raise TypeError(e)
        generated_texts = tokenizer.batch_decode(sequences=gen_id, skip_special_tokens=True)
        # generated_texts = [
        #     tokenizer.decode(gen_id, skip_special_tokens=True)
        #     for gen_id in generated_ids
        # ]

        test_responses.extend(generated_texts)

    e2e_inference_time = (time.perf_counter() - start) * 1000
    logger.debug(f"the inference time is {e2e_inference_time} ms")

    predicted_labels = [
        extract_label(test_responses[i]) for i in range(len(test_responses))
    ]
    logger.debug(
        f"Predicted label_encoded counts:\n {pd.Series(predicted_labels).value_counts().to_string()}"
    )
    true_labels = test_dataset["label_encoded"]
    logger.debug(
        f"Ground truth label_encoded counts:\n {pd.Series(true_labels).value_counts().to_string()}"
    )
    logger.debug("Evaluating prediction metrics ...")
    accuracy_perc, f1_score_perc, missing_perc = evaluate_predictions(
        true_labels, predicted_labels
    )

    logger.info(f"Accuracy: {accuracy_perc}")
    logger.info(f"F1 Score: {f1_score_perc}")
    logger.info(f"Missing Percent: {missing_perc}")

    return test_dataset, generated_texts, true_labels, predicted_labels

## Parameters

In [14]:
@dataclass
class HuggingfaceParams:
    organization: str = "gtfintechlab"
    tokenizers_parallelism: str = "false"
    push_to_hub: bool = False

@dataclass
class TaskParams:
    task_name: str = "fomc_communication"
    seed: int = 5768
    context_field: str = "sentence"
    label_field: str = "label"
    encoded_label_field: str = "label_encoded"
    response_field: str = "label_decoded"
    text_field: str = "input_texts"

@dataclass
class ModelParams:
    model_parameters: str = "7b"
    model_id: str = f"meta-llama/Llama-2-{model_parameters}-chat-hf"
    model_name: str = model_id.split("/")[-1]

@dataclass
class LoggingParams:
    report_to: str = "tensorboard"
    logging_dir: str = str(Path.home() / "tensorboard" / "logs")

@dataclass
class DirectoryParams:
    # TODO: maybe move uid out of params?
    uid: str = generate_uid()
    output_dir: str = Path("/fintech_3") / "glenn" / "results" / f"{TaskParams.task_name}" / f"{ModelParams.model_name}" / uid
    checkpoint_dir: str = output_dir / "final_checkpoint"
    mergepoint_dir: str = output_dir / "final_merged_checkpoint"

@dataclass
class PromptParams:
    system_prompt: str = f"Below is an instruction that describes a task. Write a response that appropriately completes the request."
    instruction_prompt: str = f"Discard all the previous instructions. Behave like you are an expert sentence classifier. Classify the following sentence from FOMC into 'HAWKISH', 'DOVISH', or 'NEUTRAL' class. Label 'HAWKISH' if it is corresponding to tightening of the monetary policy, 'DOVISH' if it is corresponding to easing of the monetary policy, or 'NEUTRAL' if the stance is neutral. Provide the label in the first line and provide a short explanation in the second line. The sentence: "
    ## system_prompt = f"Discard all previous instructions. Below is an instruction that describes a task. Write a response that appropriately completes the request."
    ## instruction_prompt = f"Discard all the previous instructions. Behave like you are an expert sentence classifier. Classify the following sentence from FOMC into 'HAWKISH', 'DOVISH', or 'NEUTRAL' class. Label 'HAWKISH' if it is corresponding to tightening of the monetary policy, 'DOVISH' if it is corresponding to easing of the monetary policy, or 'NEUTRAL' if the stance is neutral. Provide the label 'HAWKISH', 'DOVISH', or 'NEUTRAL'. The sentence: ",
    ## instruction_prompt = f"Behave like you are an expert sentence classifier. Classify the following sentence from the Federal Open Market Committee into 'HAWKISH', 'DOVISH', or 'NEUTRAL' class. Label 'HAWKISH' if it is corresponding to tightening of the monetary policy. Label 'DOVISH' if it is corresponding to easing of the monetary policy. Label 'NEUTRAL' if the stance is neutral. Provide a single label from the choices 'HAWKISH', 'DOVISH', or 'NEUTRAL' then stop generating text. The sentence: "
    B_INST: str = "[INST]"
    E_INST: str = "[/INST]"
    B_SYS: str = "<<SYS>>\n"
    E_SYS: str = "\n<</SYS>>\n\n"
    BOS: str = "<s>"
    EOS: str = "</s>"
    repo_name: str = f"{HuggingfaceParams().organization}/{ModelParams().model_name}_{TaskParams().task_name}"

@dataclass
class QloraParams:
    lora_r: int = 8
    lora_alpha: int = 16
    lora_dropout: float = 0.1

@dataclass
class SftParams:
    # Default maximum sequence length to use
    max_seq_length: int = 4096
    # Pack multiple short examples in the same input sequence to increase efficiency
    packing: bool = True

@dataclass
class CudaParams:
    compute_dtype = torch.bfloat16
    fp16: bool = False
    bf16: bool = True
    cuda_n_gpus: int = field(default=None) # Determined dynamically at runtime
    cuda_max_memory: str = field(default=None) # Determined dynamically at runtime
    device_map: str = "auto"
    device: str = "cuda:0"
    save_safetensors: bool = True
    
@dataclass
class BitsAndBytesParams:
    bnb_mode = True
    # Activate 4-bit precision base model loading
    load_in_4bit: bool = True
    # Activate 8-bit precision base model loading
    load_in_8bit: bool = False
    # Compute dtype for 4-bit base models
    bnb_compute_dtype = CudaParams().compute_dtype
    # Quantization type (fp4 or nf4)
    bnb_quant_type: str = "nf4"
    # Activate nested quantization for 4-bit base models (double quantization)
    bnb_use_double_quant: bool = True

@dataclass
class GenerationParams:
    # Generation Decoding Strategy
    ## Number of possible generations
    num_beams: int = 1
    # num_beam_groups: int = 1
    ## How many of possible generations are returned
    num_return_sequences: int = 1
    do_sample: bool = False
    ## The value used to modulate the next token probabilities.
    temperature: float = None # 0.0
    ## Contrastive Search Parameters
    penalty_alpha: float = 0
    ## If top_p < 1, only the smallest set of most probable tokens with probabilities
    ## that add up to top_p or higher are kept for generation.
    top_p: float = None # 0.90
    ## top 40 words are chosen
    top_k: int = None # 40
    # Other Parameters
    ## The maximum numbers of tokens to generate
    max_new_tokens: int = 100
    ## The minimum length of the sequence to be generated, input prompt + min_new_tokens
    min_length: int = None
    ## Whether or not the model should use the past last key/values attentions
    ## (if applicable to the model) to speed up decoding.
    use_cache: bool = True
    ## Parameter for repetition penalty. 1.0 means no penalty.
    repetition_penalty: float = 1.1
    ## Exponential penalty to the length that is used with beam-based generation.
    length_penalty: int = 1.1
    ## Max padding length to be used with tokenizer padding the prompts.
    early_stopping: bool = False
    return_dict_in_generate: bool = True
    output_scores: bool = False
    truncation: bool = True
    padding: bool = True
    remove_invalid_values: bool = True
    no_repeat_ngram_size: int = 0

@dataclass
class TrainingArgumentsParams:
    # Number of training epochs
    num_train_epochs: int = 5
    # Batch size per GPU for training
    per_device_train_batch_size: int = 8
    # Batch size per GPU for evaluation
    per_device_eval_batch_size: int = 8
    # Number of update steps before accumulating gradients
    gradient_accumulation_steps: int = 4
    # Enable gradient checkpointing
    gradient_checkpointing: bool = True
    # Maximum gradient normal (gradient clipping)
    max_grad_norm: float = 0.3
    # Initial learning rate (AdamW optimizer)
    learning_rate: float = 3e-5
    # Weight decay to apply to all layers except bias/LayerNorm
    weight_decay: float = 0.001
    # Optimizer to use
    optim: str = "adamw_bnb_8bit"
    # Learning rate schedule
    lr_scheduler_type: str = "constant"
    # Number of training steps (overrides num_train_epochs)
    max_steps: int = -1
    # Ratio of steps for a linear warmup (from 0 to learning rate)
    warmup_ratio: float = 0.03
    # Group sequences into batches with same length
    # Saves memory and speeds up training considerably
    group_by_length: bool = False
    # Save checkpoint every X updates steps
    save_steps: float = 0.1
    # Log every X updates steps
    logging_steps: float = 0.1
    load_best_model_at_end: bool = True
    save_strategy: str = "epoch"
    logging_strategy: str = "epoch"
    evaluation_strategy: str = "epoch"
    disable_tqdm: bool = False
    predict_with_generate: bool = True
    torch_compile: bool = False

@dataclass
class Args(HuggingfaceParams, TaskParams, ModelParams, LoggingParams,
           DirectoryParams, PromptParams, QloraParams, SftParams, CudaParams,
           BitsAndBytesParams, TrainingArgumentsParams, GenerationParams):
    pass

## `main()`

In [15]:
def main(seed: int = None):
    args = Args()
    if 'logger' in locals():
        del(logger)
        logger = create_logger()
    else:
        logger = create_logger()
    args = configure_cuda_args(args, logger)
    logger.debug(
        f"Using k={args.cuda_n_gpus} CUDA GPUs with max memory {args.cuda_max_memory}"
    )
    
    if seed:
        args.seed = seed
        logger.debug(f"Seed value overriden to '{args.seed}'")
    else:
        logger.debug(f"Seed value is '{args.seed}'")
    set_seeds(args, logger)

    try:
        setup_output_directory(args)
        supervised_fine_tuning(args, logger)
    except Exception as e:
        logger.error(e)
        raise Exception(e)
    finally:
        memory_cleanup()

    logger.debug(f"Reloading original weights of Model ...")
    base_model = create_model(args=args, logger=logger, bnb_config=None, peft_config=None)
    log_dtypes(logger=logger, model=base_model)

    logger.debug(f"Merging adapters into weights to create the final Model ...")
    peft_model = PeftModel.from_pretrained(base_model, args.checkpoint_dir)
    merged_model = peft_model.merge_and_unload()
    log_dtypes(merged_model, logger)

    logger.debug(f"Merged Model saving to {args.mergepoint_dir} ...")
    merged_model.save_pretrained(args.mergepoint_dir, safe_serialization=True)

    logger.debug(f"Tokenizer saving to {args.mergepoint_dir} ...")
    tokenizer = create_tokenizer(args, logger)
    tokenizer.save_pretrained(args.mergepoint_dir)

    if args.push_to_hub:
        logger.debug(f"Pushing the Merged PEFT model and tokenizer to hub repo {args.repo_name}")
        merged_model.push_to_hub(args.repo_name, private=True, use_temp_dir=True)
        tokenizer.push_to_hub(args.repo_name, private=True, use_temp_dir=True)
       
    memory_cleanup()

    test_dataset, generated_texts, true_labels, predicted_labels = text_generation(
        args=args, logger=logger, merged_model=None, tokenizer=None, generation_config=generation_config,
        forced_words=["DOVISH", "HAWKISH", "NEUTRAL"]
    )

    return merged_model, tokenizer, test_dataset, generated_texts, true_labels, predicted_labels

---
---
---
# RUNNING `main()`

In [None]:
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb=100'
os.environ["TOKENIZERS_PARALLELISM"] = HuggingfaceParams.tokenizers_parallelism

merged_model, tokenizer, test_dataset, generated_texts, true_labels, predicted_labels = main()

# if name = "__main__" :
   # fire.Fire(main)

---
---
---
# SANDBOX

---
---
---
# RECYCLING BIN

---
---
---
# GATHERING DATA

FOMC classification performance; Single A20 NVIDIA GPU, Batch Size of 8, BFloat16:
| Model Size | Seed | Tuned | Decoding | Accuracy % | F1 Score % | Missing % | Wall Clock |
| --- | --- | --- | --- | --- | --- | --- | --- |
| 7B | 42 | False | Greedy | 42.54 | 42.81 | 0 | 254s |
| 7B | 5768 | False | Greedy | 42.14 | 42.39 | 0 | 254s |
| 7B | 78516 | False | Greedy | 41.53 | 41.66 | 0 | 255s |
| 7B | 944601 | False | Greedy | 42.54 | 42.76 | 0 | 253s |
| 13B | 42 | False | Greedy | 50.20 | 43.70 | 0 | 452s |


Parameter Comparison: 
| Model Name | Parameter Count (in Milllions) | Weights Size (in Gigabytes) |
| --- | --- | --- |
| RoBERTa-base | 125M | 0.5GB |
| RoBERTa-large | 355M | 1.5GB |
| Llama2-chat | 7,000M | 13GB |
| Llama2-chat | 13,000M | 25GB |
| Llama2-chat | 70,000M | 120GB |

Checkpoint Size Comparison:
    
|LLama 2 Checkpoint Size|LoRA (applied to all layers, rank=8)|Traditional fine-tuning|
|---|---|---|
|7B|40MB|13GB|
|13B|65MB|26GB|
|70B|200MB|128GB|

FOMC Communication Dataset:

| metric | total | test | train | fine-tuning | validation |
|---|---|---|---|---|---|
| pct_total | 100% | 20% | 80% | 56% | 24% |
| n_total | 2,480 | 496 | 1,984 | 1,388 | 596 |

Parameter Count During SFT:
| Model ID | Model Count of Parameters | Adapter Count of Parameters | Adapter % of Parameters |
|---|---|---|---|
| Llama2 7B | 6,738,415,616| 159,907,840 | 2.32% |

In the final served version the adapters values could be merged in so there's no added params during final inference

Alternatively we can keep the model weights stationary and swap out fine-tuning adapters for each client or scenario

Text Generation Decoding Strategies [https://huggingface.co/docs/transformers/generation_strategies]

Greedy Sampling: `beams==1; sample==False`

Multinomial Sampling: `beams==1; sample==True`

Beam Search: `beams>1; sample==False`

Beam Search + Multinomial Sampling:`beams>1; sample==True`