# Template for Zero-shot/Few-shot Classification with FLAN through Low-Ranking Adapters

In [1]:
#IMPORTANT - to ensure package loading, first add the path of the utils folder to your system path
import os
import sys

module_dir = os.getcwd()
sys.path.append(os.path.abspath(os.path.join(module_dir, os.pardir, "utils")))

### Import all Modules

In [2]:
import gc
import os
import time
import random
import argparse

import wandb
import torch
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader
from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training
from transformers import (AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq)

from dataload_utils import load_full_dataset, load_dataset_task_prompt_mappings


  from .autonotebook import tqdm as notebook_tqdm


### Setup Arguments and Data

 In the following code block, you are asked to set up several key parameters that will define the behavior and environment of your fine-tuning process:

1. **WandB Project Name (`WANDB_PROJECT_NAME`)**: This is the name of the project in Weights & Biases (WandB) where your training run will be logged. WandB is a tool that helps track experiments, visualize data, and share insights. By setting the project name here, you ensure that all the metrics, outputs, and logs from your training process are organized under a single project for easy access and comparison. Specify a meaningful name that reflects the nature of your training session or experiment.  If you leave the argument empty, the project will not be tracked on WandB.

2. **Model Name (`MODEL_NAME`)**: Here, you select the size of FLAN model that you wish to fine-tune. This notebook was ran and tested on (`google/flan-t5-xl`), which we found to be the best trade-off between computational power required to run the model and the accuracy of predictions. Full list of models is available at : https://huggingface.co/docs/transformers/model_doc/flan-t5


In [3]:
# Specs WandB and Which Model you want to fine-tune
WANDB_PROJECT_NAME = "FLAN_template_1"  #leave empty if needed
MODEL_NAME ='google/flan-t5-xl'

In the next code block, you are required to set up various configuration variables that will dictate how the inference processes are executed. These variables are crucial as they define the nature of the task, the data, and the specific behaviors during the model's training and evaluation.

1. **Task (`task`)**: Specify the type of task you want to run inference on. The task is represented by an integer, with each number corresponding to a different type of task (e.g., 1, 2, 3, etc.). You must select from the predefined choices, which are typically mapped to specific NLP tasks or scenarios.

2. **Dataset (`dataset`)**: Choose the dataset on which you want to run inference. Like tasks, datasets are identified by integers, and each number corresponds to a different dataset. Ensure that the dataset selected is relevant to the task at hand.

3. **Output Directory (`output_dir`)**: Define the path to the directory where you want to store the generated samples. This is where the output of your training and inference processes will be saved.

4. **Random Seed (`seed`)**: Setting a random seed ensures that the results are reproducible. By using the same seed, you can achieve the same outcomes on repeated runs under identical conditions.

5. **Data Directory (`data_dir`)**: Specify the path to the directory containing the datasets you plan to use for training and evaluation.

6. **Label Usage (`not_use_full_labels`)**: This boolean variable determines whether to use the full label descriptions or abbreviated labels during training and inference. Setting it to `False` means full labels will be used.

7. **Dataset-Task Mappings File Path (`dataset_task_mappings_fp`)**: Define the path to the file containing mappings between datasets and tasks. This file is crucial for ensuring the correct dataset is used for the specified task.

9. **Number of Epochs (`n_epochs`)**: Specify the number of epochs for training the model. An epoch refers to one complete pass through the entire training dataset.

10. **Maximum prompt length (`max_prompt_len`)**: The maximum length of prompt in tokens to be taken as input before truncating the input. Longer input sequences require more computational power to run, so the shortest sequence required to capture the text is recommended.

11. **Batch size (`batch_size`)**: Number of observations used in each training and validation batch. Larger batch size requires more computational memory as one batch needs to fit on one machine, but makes learning more stable. We found that for FLAN-XL, batch size of 8 was possible by taking batch size of 4 and accumulating results of 2 batches (see  (`gradient_accumulation_steps`) below)

12. **Gradient accumulation steps (`gradient_accumulation_steps`)**: In a case where gradient accumulation steps is larger than 1, instead of updating the gradient after each batch, the gradient is updated after the sum of _n_ batches. This allows to train a model to learn on a larger global batch (_batch size_ * _gradient accumulation steps_) than the one that is able to fit on one machine.

**Customizing for Your Own Tasks:**
If you plan to run a custom task or use a dataset that is not predefined, you will need to make modifications to the `utils_src` file. This file contains all mappings for different datasets and tasks. Adding your custom task or dataset involves defining the new task or dataset number and specifying its characteristics and mappings in the `utils_src` file. This ensures that your custom task or dataset integrates seamlessly with the existing framework for training and inference.


In [4]:
# Configuration Variables

# Type of task to run inference on
task = 2  # Choices: [1,2,3,4,5,6]

# Dataset to run inference on
dataset = 1  # Choices: [1, 2, 3, 4]

# Size of the sample to generate
sample_size = '50'  # Choices: ['50','100','250','500','1000','1500']

# Path to the directory to store the generated samples
output_dir = '../../data'

# Random seed to use
seed = 2019

# Path to the directory containing the datasets
data_dir = '../../data'

#Path to where the models are stored
model_dir = "../../models"

# Whether to use the full label
not_use_full_labels = False

# Path to the dataset-task mappings file
dataset_task_mappings_fp = os.path.normpath(os.path.join(module_dir, '..', '..','dataset_task_mappings.csv'))

#Maximum length of prompt to be taken by the model as input (check documentation for current maximum length)
max_prompt_len = 4096

# Batch size (we finetuned the models using batch sizes of 4 multiplied by gradient accumulation size to 2, which considers a mega-batch of 8)
batch_size = 4

#run name - Optional Argument if you want it to be called something else than the default way (defined below)
run_name = ""

#set True if wanting to run a few-shot prompt
few_shot = False

In [5]:
dataset_name = f'ds_{dataset}__task_{task}_full'

### Define Utility Functions

In [6]:
def compute_metrics(eval_preds, tokenizer, metric):
    logits, labels = eval_preds
    if isinstance(logits, tuple):
        logits = logits[0]
    preds = np.argmax(logits, axis=-1)
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    result = metric.compute(predictions=decoded_preds, references=decoded_labels)
    return result

In [7]:
def preprocess_function(tokenizer, prompt, df, label_column, max_length: int = 4096, padding: str | bool = False):
    # first check that all inputs are part of a labelset
    inputs = [prompt.format(text=text_i) for text_i in df["text"]]
    model_inputs = tokenizer(inputs, max_length=max_length, padding=padding, truncation=True)

    labels = tokenizer(
        text_target=df[label_column],
        padding=padding,
        max_length=max_length,
        truncation=True,
    )

    labels["input_ids"] = [
        [(l if l != tokenizer.pad_token_id else -100) for l in label]
        for label in labels["input_ids"]
    ]

    model_inputs["labels"] = labels["input_ids"]

    return model_inputs


In [8]:
def set_all_seeds(seed: int = 123):
    # tf.random.set_seed(123)
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Set a fixed value for the hash seed
    os.environ["PYTHONHASHSEED"] = str(seed)

    # Set seed with the `transformers` library
    # set_seed(seed)

In [9]:
def print_trainable_parameters(model):
    """
        Prints the number of trainable parameters in the model.
        """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )

## Main Implementation

In [10]:
set_all_seeds(seed)

In [11]:
exp_name = run_name if run_name != '' else f'{MODEL_NAME}_ds_{dataset}_task_{int(task)}_sample_0_prompt_max_len_{max_prompt_len}_batch_size_{batch_size}'


if few_shot:
    exp_name += '_few_shot'
# Initialize the Weights and Biases run
if WANDB_PROJECT_NAME != "":
    wandb.init(
        # set the wandb project where this run will be logged
        project=WANDB_PROJECT_NAME,
        name=exp_name,
        # track hyperparameters and run metadata
        config={
            "model": MODEL_NAME,
            "dataset": dataset,
            "task": task,
            "max_prompt_len": max_prompt_len,
            "batch_size": batch_size
        }
    )

print('Running exp:', exp_name)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mmaria-korobeynikova[0m. Use [1m`wandb login --relogin`[0m to force relogin


Running exp: google/flan-t5-xl_ds_1_task_2_sample_0_prompt_max_len_4096_batch_size_4


### Load Data and the prompt

In [12]:
prompt_col = 'few_shot_prompt' if few_shot else 'zero_shot_prompt'

dataset_idx, dataset_task_mappings = load_dataset_task_prompt_mappings(
    dataset_num=dataset, task_num=task, dataset_task_mappings_fp=dataset_task_mappings_fp)

# Get information specific to the dataset and the prompt
label_column = dataset_task_mappings.loc[dataset_idx, "label_column"]
labelset = dataset_task_mappings.loc[dataset_idx, "labelset"].split(",")
labelset = [label.strip() for label in labelset]
prompt = dataset_task_mappings.loc[dataset_idx, prompt_col]

datasets = load_full_dataset(
        data_dir=data_dir, dataset_name=dataset_name, task_num=task,
        label_column=label_column, labelset=labelset, full_label=False)

if WANDB_PROJECT_NAME != "":
    # Log the system prompt and user_prompt_format as files in wandb
    prompts_artifact = wandb.Artifact('prompts', type='prompts')
    with prompts_artifact.new_file('prompt.txt', mode='w') as f:
        f.write(prompt)
    wandb.run.log_artifact(prompts_artifact)

### Define the model, tokenizers, data collator

In [13]:
# load the tokenizer and the model
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, truncation_side="left")
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, load_in_8bit=True, device_map="auto")

# Preprocess training and validation sets
unnecessary_cols = datasets['eval'].column_names

tokenized_dataset = datasets.map(
    lambda x:
    preprocess_function(tokenizer, prompt=prompt, df=x, label_column=label_column,
                        max_length=max_prompt_len, padding=False),
    batched=True, remove_columns=unnecessary_cols)

# We want to ignore tokenizer pad token in the loss
label_pad_token_id = -100

# Data collator
data_collator = DataCollatorForSeq2Seq(
    tokenizer,
    model=model,
    label_pad_token_id=label_pad_token_id,
    pad_to_multiple_of=batch_size,
    padding='longest'
)

Loading checkpoint shards: 100%|██████████| 2/2 [00:03<00:00,  1.83s/it]
100%|██████████| 1/1 [00:00<00:00,  4.09ba/s]


### Prepare the model for training

In [14]:
# PeFT
model = prepare_model_for_int8_training(model)
lora_config = LoraConfig(
    r=16, lora_alpha=32, target_modules=["q", "v"], lora_dropout=0.05, bias="none", task_type="SEQ_2_SEQ_LM"
)
model = get_peft_model(model, lora_config)
print_trainable_parameters(model)



trainable params: 9437184 || all params: 2859194368 || trainable%: 0.33006444422319176


### Run Predictions

In [16]:
predictions_dir = os.path.join(output_dir, 'predictions', MODEL_NAME.replace("/", "_")))
os.makedirs(predictions_dir, exist_ok=True)

dataloader = DataLoader(tokenized_dataset['eval'], batch_size=batch_size, collate_fn=data_collator)

with torch.no_grad():
    predictions_out = []
    for i, batch in enumerate(dataloader):
        outputs = model.generate(input_ids =
            batch['input_ids'].cuda()
        )

        generated_text_minibatch = tokenizer.batch_decode(
            outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True
        )

        predictions_out += generated_text_minibatch

        if i == 0:
            print("Sample prediction: ")
            print(predictions_out[0])

eval_df = pd.read_csv(os.path.join(data_dir, f'{dataset_name}.csv'))

eval_df['prediction'] = predictions_out
print(eval_df.head())
eval_df.to_csv(os.path.join(predictions_dir, f'{exp_name.replace("/", "_")}.csv'))

Sample prediction: 
A
             status_id                  Date  \
0  1274037131636285443  2020-06-19T00:00:00Z   
1  1236472965983657986  2020-03-08T00:00:00Z   
2  1248195673771528194  2020-04-09T00:00:00Z   
3  1318029215149731840  2020-10-19T00:00:00Z   
4  1300648503299973122  2020-09-01T00:00:00Z   

                                                text problem_solution_ra  \
0  The First Amendment binds the government, not ...             Neither   
1  hi, my main acc @DEMINATIONIST is currently su...             Neither   
2  After a report showed a surge in misinformatio...            Solution   
3  Mass report on these accounts.  They stay in o...            Solution   
4  Newest National Notables video with George Mag...             Neither   

  prediction  
0          A  
1    NEUTRAL  
2          B  
3          B  
4    NEUTRAL  


### Terminate WandB

In [17]:
if WANDB_PROJECT_NAME != "":
    wandb.finish()