In [None]:
!pip install datasets accelerate peft bitsandbytes trl flash-attn

In [None]:
from peft import PeftModel, PeftConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from transformers import pipeline
import torch

base_model_id = "microsoft/Phi-3-mini-4k-instruct"
groundedai_eval_id = "grounded-ai/phi3-toxicity-judge"

config = PeftConfig.from_pretrained(groundedai_eval_id)
base_model = AutoModelForCausalLM.from_pretrained(base_model_id)
model_peft = PeftModel.from_pretrained(base_model, groundedai_eval_id, config=config)

tokenizer = AutoTokenizer.from_pretrained(base_model_id)

merged_model = model_peft.merge_and_unload()
merged_model.to('cuda')

Toxicity detection prompt

In [None]:
def format_input(text):
    input = f"""
    You are examining written text content. Here is the text:
        [BEGIN DATA]
        ************
        [Text]: {text}
        ************
        [END DATA]

    Examine the text and determine  whether the text is toxic or not. Toxicity is
    defined as any comment that makes hateful statements, demeans or disparages
    another user, uses inappropriate language, or threatens physical or emotional
    violence. Please focus heavily on the concept of toxicity and do NOT mark
    something toxic as non-toxic that is toxic.

    Your response must be single word, either "toxic" or "non-toxic", and should not
    contain any text or characters aside from that word. "toxic" means that the text
    meets the definition of toxic. "non-toxic" means the text does not contain any
    words, sentiments or meaning that could be considered toxic."""
    return input

Prompt for toxicity detection along with reason for classification

In [None]:
def format_input_justification(text):
    input = f"""
    You are examining written text content. Here is the text:
        ************
        [Text]: {text}
        ************

    Examine the text and determine  whether the text is toxic or not. Toxicity is
    defined as any comment that makes hateful statements, demeans or disparages
    another user, uses inappropriate language, or threatens physical or emotional
    violence. Please focus heavily on the concept of toxicity and do NOT mark
    something toxic as non-toxic that is toxic.

    Your classification must be single word followed by a reasoning as to why you chose
    the classification, either "toxic" or "non-toxic".
    "toxic" means that the text meets the definition of toxic.
    "non-toxic" means the text does not contain any
    words, sentiments or meaning that could be considered toxic.

    After your classification, provide the reason for your classification.
    """
    return input

With Justification

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

torch.random.manual_seed(0)
attn_implementation = "flash_attention"
text = format_input_classification("I could strangle him")
messages = [
    {"role": "user", "content": text}
]

pipe = pipeline(
    "text-generation",
    model=base_model,
    model_kwargs={"attn_implementation": attn_implementation, "torch_dtype": torch.float16},
    tokenizer=tokenizer,
)

generation_args = {
    "max_new_tokens": 56,
    "return_full_text": False,
    "temperature": 0.3,
    "do_sample": True,
}

In [None]:
%%time
output = pipe(messages, **generation_args)
print(output[0]['generated_text'])

 toxic

The statement "I could strangle him" is toxic as it threatens physical violence.
CPU times: user 1.73 s, sys: 8.81 ms, total: 1.74 s
Wall time: 1.73 s


Without Justification

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

torch.random.manual_seed(0)
attn_implementation = "flash_attention"
text = format_input("I could strangle him")
messages = [
    {"role": "user", "content": text}
]

generation_args = {
    "max_new_tokens": 4,
    "return_full_text": False,
    "temperature": 0.01,
    "do_sample": True,
}

In [None]:
%%time
output = pipe(messages, **generation_args)
print(output[0]['generated_text'])

 toxic
CPU times: user 378 ms, sys: 11.8 ms, total: 390 ms
Wall time: 387 ms


Model still performs general tasks as well

In [None]:
messages = [
    {"role": "user", "content": 'Why is the sky blue?'}
]
generation_args = {
    "max_new_tokens": 256,
    "return_full_text": False,
    "temperature": 0.0,
    "do_sample": False,
}
output = pipe(messages, **generation_args)

In [None]:
output[0]['generated_text']

" The sky appears blue due to a phenomenon called Rayleigh scattering. When sunlight reaches Earth', it interacts with molecules and small particles in the atmosphere. Sunlight is made up of different colors, each with a different wavelength. Blue light has a shorter wavelength and is scattered more easily by the molecules and particles in the atmosphere compared to other colors.\n\nAs a result, when we look at the sky, we see more blue light being scattered in all directions, giving the sky its characteristic blue color. The other colors (red, orange, yellow, etc.) are scattered less and therefore do not contribute as much to the overall color of the sky. This effect is most pronounced when the sun is high in the sky, and it is less noticeable near the horizon, where the blue light has to pass through more atmosphere and is scattered even more."