# Template for Zero-shot and Few-shot Classification with LLaMA through Low-Ranking Adapters

### Import all Modules

In [1]:
#IMPORTANT - to ensure package loading, first add the path of the utils folder to your system path
import os
import sys

module_dir = os.getcwd()
sys.path.append(os.path.abspath(os.path.join(module_dir, os.pardir, "utils")))

In [2]:
import tqdm
from transformers import (AutoTokenizer,
                          LlamaForCausalLM, BitsAndBytesConfig, GenerationConfig)
import pandas as pd
import torch
import numpy as np
import random
import wandb

from dataload_utils import load_full_dataset, load_dataset_task_prompt_mappings

  from .autonotebook import tqdm as notebook_tqdm


### Setup Arguments and Data

 In the following code block, you are asked to set up several key parameters that will define the behavior and environment of your fine-tuning process:

1. **WandB Project Name (`WANDB_PROJECT_NAME`)**: This is the name of the project in Weights & Biases (WandB) where your training run will be logged. WandB is a tool that helps track experiments, visualize data, and share insights. By setting the project name here, you ensure that all the metrics, outputs, and logs from your training process are organized under a single project for easy access and comparison. Specify a meaningful name that reflects the nature of your training session or experiment. If you leave the argument empty, the project will not be tracked on WandB.

2. **Model Name (`MODEL_NAME`)**: Here, you select the size of LLAMA model that you wish to fine-tune. This notebook was ran and tested on (`meta-llama/Llama-2-70b-chat-hf`, `meta-llama/Llama-2-13b-chat-hf` and `OASST-LLAMA 30b` (not available on HuggingFace anymore)).


In [3]:
WANDB_PROJECT_NAME = "llama2_annotations_llm_comparison"
# Name of the model to finetune (this script was tested on LLAMA-2 70b, LLAMA-2 13b, and OASST-LLAMA 30b)
MODEL_NAME = "meta-llama/Llama-2-70b-chat-hf"

In order to run LLAMA-2 models, you need to register yourself at the HuggingFace model page (https://huggingface.co/meta-llama/Llama-2-70b-chat-hf). Then, you can either insert the token here (not recommended if sharing a repository on GitHub), or input it in the hf_token.txt as done here and ensure it is included in the .gitignore.

In [4]:
with open(os.path.join(module_dir, "hf_token.txt"), "r") as file:
    hf_token = file.read().strip()

This is an optional parameter to run if your default transformers cache location does not contain enough storage to load the LLAMA models. Otherwise, you can keep it as is.

In [5]:
#cache_location = os.environ['HF_HOME']
cache_location = "your/path/to/large/storage"

os.environ['TRANSFORMERS_CACHE'] = cache_location
os.environ['HF_HOME'] = cache_location

In the next code block, you are required to set up various configuration variables that will dictate how the inference processes are executed. These variables are crucial as they define the nature of the task, the data, and the specific behaviors during the model's training and evaluation.

1. **Task (`task`)**: Specify the type of task you want to run inference on. The task is represented by an integer, with each number corresponding to a different type of task (e.g., 1, 2, 3, etc.). You must select from the predefined choices, which are typically mapped to specific NLP tasks or scenarios.

2. **Dataset (`dataset`)**: Choose the dataset on which you want to run inference. Like tasks, datasets are identified by integers, and each number corresponds to a different dataset. Ensure that the dataset selected is relevant to the task at hand.

3. **Output Directory (`output_dir`)**: Define the path to the directory where you want to store the generated samples. This is where the output of your training and inference processes will be saved.

4. **Model Directory (`model_dir`)**: Define the path to the directory where you want to store the generated models. This should have sufficient memory.

5. **Llama-2 Prompt** (`use_llama2_prompt`)**: Keep True if running LLAMA-2 models (provides correct prompts with system and user message separated). Set False if running OASST models.

6. **Division Line for User Message** (`system_user_prompt_division_line`)**: Length in number of lines of the user message. Relevant only for LLAMA-2 models.

7. **Random Seed (`seed`)**: Setting a random seed ensures that the results are reproducible. By using the same seed, you can achieve the same outcomes on repeated runs under identical conditions.

8. **Data Directory (`data_dir`)**: Specify the path to the directory containing the datasets you plan to use for training and evaluation.

9. **Label Usage (`not_use_full_labels`)**: This boolean variable determines whether to use the full label descriptions or abbreviated labels during training and inference. Setting it to `False` means full labels will be used.

10. **Dataset-Task Mappings File Path (`dataset_task_mappings_fp`)**: Define the path to the file containing mappings between datasets and tasks. This file is crucial for ensuring the correct dataset is used for the specified task.

11. **Maximum prompt length (`max_prompt_len`)**: The maximum length of prompt in tokens to be taken as input before truncating the input. Longer input sequences require more computational power to run, so the shortest sequence required to capture the text is recommended.

12. **Few shot indicator (`few_shot`)**: Set to True to run few-shot learning with the prompt defined as few_shot_prompt in the dataset mapping dataframe.

13. **Batch size (`batch-size`)**: Number of observations used in each training and validation batch. Larger batch size requires more computational memory as one batch needs to fit on one machine, but makes learning more stable. We found that for FLAN-XL, batch size of 8 was possible by taking batch size of 4 and accumulating results of 2 batches (see  (`gradient_accumulation_steps`) below)

14. **Gradient accumulation steps (`gradient_accumulation_steps`)**: In a case where gradient accumulation steps is larger than 1, instead of updating the gradient after each batch, the gradient is updated after the sum of _n_ batches. This allows to train a model to learn on a larger global batch (_batch size_ * _gradient accumulation steps_) than the one that is able to fit on one machine.

15. **Run_name (`run_name`)**: Optional run_name to store in WandB. We recommend to keep it null, which generates an automatic run name based on all the relevant parameters of finetuning for easier tracking. 

16. **Maximum output length (`max_new_tokens`)**: Maximum length of prediction produced by LLAMA. For data labelling, it does not make sense to make it longer than that. 

17. **LoRA hyperparameters:** The values provided below were taken from Stanford Alpaca LoRA repository: https://github.com/tloen/alpaca-lora/blob/main/finetune.py.


In [6]:
# Configuration Variables

# Type of task to run inference on
task = 1  # 

# Dataset to run inference on
dataset = 1  #

# Path to the directory to store the generated predictions
output_dir = '../../data'

# Path to the directory to store the models (make sure this location is included in the .gitignore if using GitHub)
model_dir = '../../data'

#If using LLAMA2, keep True. Set False only for OASST-LLAMA.
use_llama2_prompt = True

#This is relevant for LLAMA2, where the system and user message are separated by context tokens. You should count how many lines your user message takes (in this case, 3)
system_user_prompt_division_line = 3

# Random seed to use
seed = 2019

# Path to the directory containing the datasets
data_dir = '../../data'

# Whether to use the full label
not_use_full_labels = False

# Path to the dataset-task mappings file
dataset_task_mappings_fp = os.path.normpath(os.path.join(module_dir, '..', '..', 'dataset_task_mappings.csv'))

#Maximum length of prompt to be taken by the model as input (check documentation for current maximum length)
max_prompt_len = 4096

#Zero or few-shot binary variable
few_shot = False

#run name - Optional Argument if you want it to be called something else than the default way (defined below)
run_name = ""

#maximum length of sequence to produce
max_new_tokens = 100

#Text Generation parameters (Values below taken from Stanford Alpaca LoRA repository : https://github.com/tloen/alpaca-lora/blob/main/generate.py)
temp = 0.05                                     
top_p = 0.75      
top_k = 40

In [7]:
dataset_name = f'ds_{dataset}__task_{task}_eval_set'

**Customizing for Your Own Tasks:**
If you plan to run a custom task or use a dataset that is not predefined, you will need to make modifications to the `label_utils` file. This file contains all mappings for different datasets and tasks. Adding your custom task or dataset involves defining the new task or dataset number and specifying its characteristics and mappings in the `label_utils` file. This ensures that your custom task or dataset integrates seamlessly with the existing framework for training and inference.

### Define Utility Functions

In [8]:
def set_all_seeds(seed: int = 123):
    # tf.random.set_seed(123)
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Set a fixed value for the hash seed
    os.environ["PYTHONHASHSEED"] = str(seed)

    # Set seed with the `transformers` library
    # set_seed(seed)

## Main Implementation

In [9]:
set_all_seeds(seed)

In [10]:
exp_name = run_name if run_name != '' else f'{MODEL_NAME}_ds_{dataset}_task_{int(task)}_sample_{0}_prompt_max_len_{max_prompt_len}'
if few_shot:
    exp_name += "few_shot"
exp_name = exp_name.replace('.', '_')

# Initialize the Weights and Biases run
if WANDB_PROJECT_NAME != "":
    wandb.init(
        # set the wandb project where this run will be logged
        project=WANDB_PROJECT_NAME,
        name=exp_name,
        # track hyperparameters and run metadata
        config={
            "model": MODEL_NAME,
            "dataset": dataset,
            "task": task,
            "max_prompt_len": max_prompt_len
        }
    )

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mmaria-korobeynikova[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [11]:
if not_use_full_labels:
    exp_name += '_label_abbreviation'
    labelset_col = 'labelset'
else:
    labelset_col = 'labelset_fullword'

In [12]:
print('Running exp:', exp_name)

Running exp: meta-llama/Llama-2-70b-chat-hf_ds_1_task_1_sample_0_prompt_max_len_4096


### Load Data and the prompt

In [13]:
prompt_col = 'few_shot_prompt' if few_shot else 'zero_shot_prompt'
# Load the prompt
dataset_idx, dataset_task_mappings = load_dataset_task_prompt_mappings(
    dataset_num=dataset, task_num=task, dataset_task_mappings_fp=dataset_task_mappings_fp)

# Get information specific to the dataset
label_column = dataset_task_mappings.loc[dataset_idx, "label_column"]
labelset = dataset_task_mappings.loc[dataset_idx, labelset_col].split("; ")
labelset = [label.strip() for label in labelset]
prompt = dataset_task_mappings.loc[dataset_idx, prompt_col]

# Get the system or instruction prompt and the user prompt format
system_prompt = ('\n'.join(prompt.split('\n')[:-system_user_prompt_division_line])).strip()
user_prompt_format = ('\n'.join(prompt.split('\n')[-system_user_prompt_division_line:])).strip()

# Log the system prompt and user_prompt_format as files in wandb
if WANDB_PROJECT_NAME != "":
    prompts_artifact = wandb.Artifact('prompts', type='prompts')
    with prompts_artifact.new_file('system_prompt.txt', mode='w') as f:
        f.write(system_prompt)
    with prompts_artifact.new_file('user_prompt_format.txt', mode='w') as f:
        f.write(user_prompt_format)
    wandb.run.log_artifact(prompts_artifact)

# Load the train and eval datasets with the full prompt format
print(f'label_columns: {label_column}')
print(f'labelset: {labelset}')

datasets = load_full_dataset(
    data_dir=data_dir, dataset_name=dataset_name, task_num = task,
    label_column=label_column, labelset=labelset, full_label=not not_use_full_labels, system_prompt=system_prompt, user_prompt_format=user_prompt_format,
    llama_2=use_llama2_prompt)

label_columns: relevant_ra
labelset: ['RELEVANT', 'IRRELEVANT']
loading ../../data/ds_1__task_1_eval_set.csv
../../data/ds_1__task_1_eval_set.csv
dataset has the following cols Index(['status_id', 'Date', 'text', 'relevant_ra'], dtype='object')
The label_column is: relevant_ra
loading ../../data/ds_1__task_1_eval_set.csv
../../data/ds_1__task_1_eval_set.csv
dataset has the following cols Index(['status_id', 'Date', 'text', 'relevant_ra'], dtype='object')
The label_column is: relevant_ra


In [15]:
print(f"Eval set example with completion ({len(datasets['eval'])} rows): ")
print("-" * 50 + '\n')
print(datasets["eval"]["text"][0])
print('\n\n')

print(f"Eval set without completion ({len(datasets['eval_wo_completion'])} rows): ")
print("-" * 50 + '\n')
print(datasets["eval_wo_completion"]["text"][0])
print('\n\n')

Eval set example with completion (387 rows): 
--------------------------------------------------

<s>[INST] <<SYS>>
âContent moderationâ refers to the practice of screening and monitoring content posted by users on social media sites to determine if the content should be published or not, based on specific rules and guidelines.

I will ask you to classify a tweet as RELEVANT or IRRELEVANT to the content moderation:

A: Tweet is RELEVANT if it includes: social media platformsâ content moderation rules and practices, censorship, governmentsâ regulation of online content moderation, and/or mild forms of content moderation like flagging, shadowbanning, or account suspension.

B: Tweet is IRRELEVANT if they do not refer to content moderation, as defined above. This would include, for example, a tweet by Trump that Twitter has labeled his tweet as âdisputedâ, or a tweet claiming that something is false.
<</SYS>>

Now, is the following tweet RELEVANT or IRRELEVANT to content moder

### Define the model, tokenizers, data collator

In [16]:
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, truncation_side="left", use_fast=False, token=hf_token, cache_dir = "../cache")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

# Load the model
use_4bit = True                         # Activate 4-bit precision base model loading
bnb_4bit_compute_dtype = "float16"      # Compute dtype for 4-bit base models
bnb_4bit_quant_type = "nf4"             # Quantization type (fp4 or nf4)
use_nested_quant = False                # Activate nested quantization for 4-bit base models (double quantization)

compute_dtype = getattr(torch, bnb_4bit_compute_dtype)
bnb_config = BitsAndBytesConfig(
    load_in_4bit=use_4bit,
    bnb_4bit_quant_type=bnb_4bit_quant_type,
    bnb_4bit_compute_dtype=compute_dtype,
    bnb_4bit_use_double_quant=use_nested_quant,
)

model = LlamaForCausalLM.from_pretrained(MODEL_NAME, quantization_config=bnb_config,
                                            device_map="auto", token=hf_token, cache_dir = "../cache")

tokenizer_config.json: 100%|██████████| 1.62k/1.62k [00:00<00:00, 6.71MB/s]
tokenizer.model: 100%|██████████| 500k/500k [00:00<00:00, 17.1MB/s]
special_tokens_map.json: 100%|██████████| 414/414 [00:00<00:00, 528kB/s]
config.json: 100%|██████████| 614/614 [00:00<00:00, 1.95MB/s]
model.safetensors.index.json: 100%|██████████| 66.7k/66.7k [00:00<00:00, 13.4MB/s]
Downloading shards:   0%|          | 0/15 [00:00<?, ?it/s]
model-00001-of-00015.safetensors:   0%|          | 0.00/9.85G [00:00<?, ?B/s][A
model-00001-of-00015.safetensors:   0%|          | 31.5M/9.85G [00:00<00:37, 260MB/s][A
model-00001-of-00015.safetensors:   1%|          | 62.9M/9.85G [00:00<00:34, 283MB/s][A
model-00001-of-00015.safetensors:   1%|          | 94.4M/9.85G [00:00<00:33, 287MB/s][A
model-00001-of-00015.safetensors:   1%|▏         | 126M/9.85G [00:00<00:33, 292MB/s] [A
model-00001-of-00015.safetensors:   2%|▏         | 157M/9.85G [00:00<00:33, 290MB/s][A
model-00001-of-00015.safetensors:   2%|▏         | 189

### Run Predictions

In [17]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, truncation_side="left", use_fast=False, token=hf_token)

# Default params from alpaca-lora generate script (commonly used)
generation_config = GenerationConfig(
    temperature=temp,
    top_p=top_p,
    top_k=top_k,
    do_sample=True,
    max_new_tokens=max_new_tokens
)

with torch.no_grad():
    predictions_out = []
    for i, input_text_i in tqdm.tqdm(enumerate(datasets["eval_wo_completion"]["text"])):
        # Tokenize the text
        tokenized_text_i = tokenizer(
            text_target=input_text_i,
            padding=False,
            max_length=max_prompt_len,
            truncation=True,
            return_tensors="pt"
        )

        # Generate the completions
        outputs = model.generate(
            input_ids=tokenized_text_i["input_ids"].cuda(),
            generation_config=generation_config
        )

        generated_text_minibatch = tokenizer.batch_decode(
            outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True
        )

        predictions_out += generated_text_minibatch

        if i == 0:
            print("Sample prediction: ")
            print(predictions_out[0])

1it [00:12, 12.82s/it]

Sample prediction: 
[INST] <<SYS>>
âContent moderationâ refers to the practice of screening and monitoring content posted by users on social media sites to determine if the content should be published or not, based on specific rules and guidelines.

I will ask you to classify a tweet as RELEVANT or IRRELEVANT to the content moderation:

A: Tweet is RELEVANT if it includes: social media platformsâ content moderation rules and practices, censorship, governmentsâ regulation of online content moderation, and/or mild forms of content moderation like flagging, shadowbanning, or account suspension.

B: Tweet is IRRELEVANT if they do not refer to content moderation, as defined above. This would include, for example, a tweet by Trump that Twitter has labeled his tweet as âdisputedâ, or a tweet claiming that something is false.
<</SYS>>

Now, is the following tweet RELEVANT or IRRELEVANT to content moderation? @jennahasredhair Aww ok didn't know sexy, yes I will report and block that

154it [26:15, 10.23s/it]


KeyboardInterrupt: 

In [21]:
os.path.join(output_dir, 'predictions', MODEL_NAME.replace("/", "_"))

'../../data/predictions/meta-llama_Llama-2-70b-chat-hf'

In [None]:
#load csv to add the predictions on 
predictions_dir = os.path.join(output_dir, 'predictions', MODEL_NAME.replace("/", "_"))
os.makedirs(predictions_dir, exist_ok=True)
eval_df = pd.read_csv(os.path.join(data_dir, f"{dataset_name}.csv"))
eval_df['prediction'] = predictions_out

In [20]:
os.path.join(predictions_dir, f'{exp_name.replace("/", "_")}.csv')

'../../data/predictions/Llama-2-70b-chat-hf/meta-llama_Llama-2-70b-chat-hf_ds_1_task_1_sample_0_prompt_max_len_4096/meta-llama_Llama-2-70b-chat-hf_ds_1_task_1_sample_0_prompt_max_len_4096.csv'

In [None]:
#export to csv
eval_df.to_csv(os.path.join(predictions_dir, f'{exp_name.replace("/", "_")}.csv'))

### Terminate WandB

In [None]:
if WANDB_PROJECT_NAME != "":
    wandb.finish()