## Fine-tuning of WizardLM-13B on Legal Diversity

Install and Load Required Libraries

In [1]:
! pip3 install -q -U transformers
! pip3 install -q -U datasets
! pip3 install -q -U peft
! pip3 install -q -U trl
! pip3 install -q -U auto-gptq
! pip3 install -q -U optimum
! pip3 install -q -U bitsandbytes

In [2]:
import os
os.environ['TRANSFORMERS_CACHE'] = '/home/kmb85/rds/hpc-work/huggingface'

In [3]:
import transformers
import torch
from datasets import load_dataset
from peft import (
    LoraConfig,
    prepare_model_for_kbit_training,
    get_peft_model
)
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
)
from trl import SFTTrainer

  from .autonotebook import tqdm as notebook_tqdm


### Load WizardLM-13B and Tokenizer

In [3]:
model_name_or_path = "WizardLM/WizardLM-13B-V1.2"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4b_quant_type='nf4',
    torch_dtype=torch.float16,
)

model = AutoModelForCausalLM.from_pretrained(
    model_name_or_path,
    use_safetensors=True,
    device_map="auto",
    trust_remote_code=True,
    quantization_config=bnb_config,
    token=""
)



In [4]:
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
tokenizer.pad_token=tokenizer.eos_token

In [5]:
model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)

### Load LoRA Adapter

In [6]:
config = LoraConfig(
    r=32,
    lora_alpha=16,
    bias="none",
    task_type="CASUAL_LM",
)

In [7]:
model=get_peft_model(model, config)

### Dataset preparation

In [4]:
diversity_names = ['diversity_1', 'diversity_2', 'diversity_3', 'diversity_4', 'diversity_5', 'diversity_6']

In [5]:
diversity_datasets_train = {}
diversity_datasets_test = {}

diversity_datasets_train['diversity_1_1']  = load_dataset('nguha/legalbench', 'diversity_1')['train']
diversity_datasets_train['diversity_1_2']  = load_dataset('nguha/legalbench', 'diversity_1')['test']

for diversity_name in diversity_names[1:]:
    diversity_datasets_train[diversity_name] = load_dataset('nguha/legalbench', diversity_name)['train']
    diversity_datasets_test[diversity_name] = load_dataset('nguha/legalbench', diversity_name)['test']

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


In [6]:
from datasets import concatenate_datasets

combined_diversity_dataset_train = concatenate_datasets(diversity_datasets_train.values())
combined_diversity_dataset_test = concatenate_datasets(diversity_datasets_test.values())

In [7]:
combined_diversity_dataset_train[0]

{'aic_is_met': 'False',
 'answer': 'No',
 'index': '0',
 'parties_are_diverse': 'False',
 'text': 'Evelyn is from Hawaii. Charlotte is from Hawaii. Evelyn sues Charlotte for negligence for $20,000.'}

In [8]:
DEFAULT_PROMPT = "Diversity jurisdiction gives federal courts the ability to hear cases between parties that are “citizens” of different states. It specifies that state claims may be brought in federal court provided those claims involve citizens of different states for certain minimum amounts. Below is a court case, answer if diversity jurisdiction exists only with 'Yes' or 'No'. Answer only based on the provided infromation:"

def generate_train_prompt(data_point):
    text = data_point[ 'text']
    answer = data_point[ 'answer']
    text = f'{DEFAULT_PROMPT}\n###Court case:\n{text}\n###Output:\n{answer}'
    return {'text': text, 'labels': answer}

In [9]:
train_dataset = combined_diversity_dataset_train.shuffle().map(generate_train_prompt)

Map: 100%|██████████| 336/336 [00:00<00:00, 8444.70 examples/s]


In [10]:
def generate_test_prompt(data_point):
    text = data_point[ 'text']
    answer = data_point[ 'answer']
    text = f'{DEFAULT_PROMPT}\n###Court case:\n{text}\n###Output:\n'
    return {'text': text}

In [11]:
test_dataset = combined_diversity_dataset_test.shuffle().map(generate_test_prompt)

Map: 100%|██████████| 1500/1500 [00:00<00:00, 6196.82 examples/s]


### Training

In [16]:
training_args = transformers.TrainingArguments(
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    learning_rate=0.001,
    fp16=True,
    num_train_epochs=8,
    save_strategy="epoch",
    save_safetensors=True,
    lr_scheduler_type="cosine",
    optim="paged_adamw_8bit",
    output_dir='./experiments',
    remove_unused_columns=False,
    warmup_ratio=0.05,
    logging_strategy='epoch',
    label_names=['labels'],
    group_by_length=True
)

trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,
    args=training_args,
    tokenizer=tokenizer,
    dataset_text_field='text',
    peft_config=config,
    max_seq_length=4096
)

Map: 100%|██████████| 336/336 [00:00<00:00, 8367.99 examples/s]
Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [17]:
model.config.use_cache = False
trainer.state.log_history = True
trainer.train()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mkmb85[0m ([33mcam_kiril[0m). Use [1m`wandb login --relogin`[0m to force relogin


You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.




Step,Training Loss
84,0.8705
168,0.2104
252,0.1885
336,0.1825
420,0.1776
504,0.1699
588,0.1594
672,0.1468


Checkpoint destination directory ./experiments/checkpoint-84 already exists and is non-empty.Saving will proceed but saved results may be invalid.
Checkpoint destination directory ./experiments/checkpoint-168 already exists and is non-empty.Saving will proceed but saved results may be invalid.
Checkpoint destination directory ./experiments/checkpoint-252 already exists and is non-empty.Saving will proceed but saved results may be invalid.
Checkpoint destination directory ./experiments/checkpoint-336 already exists and is non-empty.Saving will proceed but saved results may be invalid.
Checkpoint destination directory ./experiments/checkpoint-420 already exists and is non-empty.Saving will proceed but saved results may be invalid.
Checkpoint destination directory ./experiments/checkpoint-504 already exists and is non-empty.Saving will proceed but saved results may be invalid.
Checkpoint destination directory ./experiments/checkpoint-588 already exists and is non-empty.Saving will proceed

TrainOutput(global_step=672, training_loss=0.26319746885980877, metrics={'train_runtime': 1445.7804, 'train_samples_per_second': 1.859, 'train_steps_per_second': 0.465, 'total_flos': 2.789797400027136e+16, 'train_loss': 0.26319746885980877, 'epoch': 8.0})

### Save the fine-tuned model

In [18]:
model.save_pretrained(f'WizardLM-13B_legal_Diversity_8_epochs')

### Evaluate fine-tuned model

First, we need to manually set up text-generation-webui and run its' server.

In [12]:
import requests

url = "http://127.0.0.1:5000/api/v1/generate"

In [13]:
test_dataset = test_dataset.shuffle(seed=42)

In [14]:
request = {
    'max_new_tokens': 100,
    'temperature': 0.1,
    'repetition_penalty': 1,
    'top_p': 0.7,
    'stopping_strings': ['\n', '###']
}
headers = {'Content-Type': 'application/json'}

In [15]:
import ast

def extract_after_output(text):
    index = text.find('output')
    if index != -1:
        return text[index + len('output'):]
    else:
        return text

In [16]:
total_correct = 0
num_samples = 1000

In [16]:
for i in range(num_samples):
    request['prompt'] = test_dataset[i]['text']
    response = requests.post(url, json=request)
    prediction = extract_after_output(ast.literal_eval(response.text)["results"][0]['text'].lower())
    if test_dataset[i]['answer'].lower() in prediction:
        total_correct+=1
    elif test_dataset[i]['answer'].lower() == 'yes' and 'diversity jurisdiction exists' in prediction:
        total_correct+=1
    elif test_dataset[i]['answer'].lower() == 'no' and 'diversity jurisdiction does not exists' in prediction:
        total_correct+=1

In [17]:
correct_percentage = (total_correct / num_samples) * 100
print(f'Correctness percentage {correct_percentage}%')

Correctness percentage 48.699999999999996%


### Evaluate RAG Model

In [22]:
import requests

url_rag = "https://b1b6-131-111-184-110.ngrok-free.app/search"
url = "http://127.0.0.1:5000/api/v1/generate"

payload = {
    "text": combined_diversity_dataset_test[0]['text'],
    "number_documents": 10,
    "collection": "legal_diversity"
}

In [23]:
combined_diversity_dataset_test = combined_diversity_dataset_test.shuffle(seed=42)

In [24]:
total_correct = 0
num_samples = 1000

In [25]:
request = {
    'max_new_tokens': 100,
    'temperature': 0.1,
    'repetition_penalty': 1,
    'top_p': 0.7,
    'stopping_strings': ['\n', '###']
}
headers = {'Content-Type': 'application/json'}

In [26]:
def generate_rag_prompt(data_point):
    text = data_point[ 'text']
    answer = data_point[ 'answer']
    text = f'###Court case:\n{text}\n###Output:\n{answer}\n'
    return text

In [27]:
for i in range(num_samples):
    request['prompt'] = DEFAULT_PROMPT+'\n'

    payload['text'] = combined_diversity_dataset_test[i]['text']

    response_rag = requests.get(url_rag, json=payload)
    data_rag = response_rag.json()

    for record in data_rag:
        request['prompt'] += generate_rag_prompt(record)

    request['prompt'] += f'###Court case:\n{combined_diversity_dataset_test[i]["text"]}\n###Output:\n'
    response = requests.post(url, json=request)
    prediction = extract_after_output(ast.literal_eval(response.text)["results"][0]['text'].lower())
    if combined_diversity_dataset_test[i]['answer'].lower() in prediction:
        total_correct+=1
    elif combined_diversity_dataset_test[i]['answer'].lower() == 'yes' and 'diversity jurisdiction exists' in prediction:
        total_correct+=1
    elif combined_diversity_dataset_test[i]['answer'].lower() == 'no' and 'diversity jurisdiction does not exists' in prediction:
        total_correct+=1

In [28]:
correct_percentage = (total_correct / num_samples) * 100
print(f'Correctness percentage {correct_percentage}%')

Correctness percentage 51.4%
